Equinox: neural networks in JAX via callable PyTrees and filtered transformations
Patrick Kidger, Cristian Garcia

TL;DR
Equinox is a lightweight neural network library that integrates object-oriented programming with JAX's functional approach by representing parameterized functions as PyTrees and filtering components for transformations.
Contribution
It introduces a method to combine PyTorch-like class-based neural networks with JAX's functional programming without new abstractions.
Findings
Enables class-based neural networks in JAX using PyTrees.
Maintains compatibility with JAX transformations like jit and grad.
Provides an open-source implementation at GitHub.
Abstract
JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-to-functional transformations, multiple new abstractions, and been limited in the extent to which they integrate with JAX (Flax, Haiku, Objax). Either way this OO/functional difference has been a source of tension. Here, we introduce `Equinox', a small neural network library showing how a PyTorch-like class-based approach may be admitted without sacrificing JAX-like functional programming. We provide two main ideas. One:…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsParallel Computing and Optimization Techniques · Computational Physics and Python Applications · Topic Modeling
