• Corpus ID: 240353913

Equinox: neural networks in JAX via callable PyTrees and filtered transformations

  title={Equinox: neural networks in JAX via callable PyTrees and filtered transformations},
  author={Patrick Kidger and Cristian Garcia},
JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-tofunctional transformations… 

Synthesized Differentiable Programs

  • Computer Science
  • 2022
This paper presents a combined algorithm for synthesizing syntactic programs, compiling them into the weights of a neural network, and then tuning the resulting model to form an efficient gorithm for inducing abstract algorithmic structure and a corresponding local set of desirable complex programs.

Rieoptax: Riemannian Optimization in JAX

It is shown that many differential geometric primitives, such as Riemannian exponential and logarithm maps, are usually faster in Rieoptax than existing frameworks in Python, both on CPU and GPU.

Optical design, analysis, and calibration using ∂Lux

This manuscript explores some of the many ways to harness the potential of these codes, particularly focusing on the application example provided by the Toliman space telescope mission.

Robust Neural Posterior Estimation and Statistical Model Criticism

Robust neural posterior estimation ( RNPE) is proposed, an extension of NPE to simultaneously achieve both these aims, through explicitly modelling the discrepancies between simulations and the observed data, and performs well across the tasks, whereas naïvely using NPE leads to misleading and erratic posteriors.



PyTorch: An Imperative Style, High-Performance Deep Learning Library

This paper details the principles that drove the implementation of PyTorch and how they are reflected in its architecture, and explains how the careful and pragmatic implementation of the key components of its runtime enables them to work together to achieve compelling performance.

Flux: Elegant machine learning with Julia

  • Mike Innes
  • Computer Science
    J. Open Source Softw.
  • 2018
JuliaFlux is library for machine learning (ML), written using the numerical computing language Julia, and applies automatic differentiation (AD) to seamlessly calculate derivatives and train the model.

Fashionable Modelling with Flux

A framework named Flux is presented that shows how further refinement of the core ideas of machine learning, built upon the foundation of the Julia programming language, can yield an environment that is simple, easily modifiable, and performant.

Julia: A Fresh Approach to Numerical Computing

The Julia programming language and its design is introduced---a dance between specialization and abstraction, which recognizes what remains the same after computation, and which is best left untouched as they have been built by the experts.

Swift for TensorFlow: A portable, flexible platform for deep learning

Deep learning platform Swift for TensorFlow combines a language-integrated automatic differentiation system and multiple Tensor implementations within a modern ahead-of-time compiled language oriented around mutable value semantics.

Decomposing reverse-mode automatic differentiation

We decompose reverse-mode automatic differentiation into (forward-mode) linearization followed by transposition. Doing so isolates the essential difference between forwardand reverse-mode AD, and

functorch: JAX-like composable function transforms for PyTorch

  • 2021

Haiku: Sonnet for JAX

  • Version 0.0.3
  • 2020

JAX: composable transformations of Python+NumPy programs. Version 0.2.5

  • 2018

torchtyping. Accessed 2021

  • URL: https://github.com/patrick-kidger/torchtyping
  • 2021