• Corpus ID: 240353913

Equinox: neural networks in JAX via callable PyTrees and filtered transformations

  title={Equinox: neural networks in JAX via callable PyTrees and filtered transformations},
  author={Patrick Kidger and Cristian Garcia},
JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-tofunctional transformations… 

