TL;DR
The paper introduces Eventax, a JAX-based framework that enables exact gradient training of diverse event-based neural networks using differentiable ODE solvers, overcoming previous limitations in neuron model complexity and spike-time resolution.
Contribution
Eventax combines differentiable numerical ODE solvers with event-based spike handling to support flexible, exact gradient training for a wide range of neuron models in neural networks.
Findings
Eventax successfully trains various neuron models including LIF, QIF, EIF, Izhikevich, and EGRU.
It achieves accurate gradient computation for complex neuron dynamics.
Demonstrates effectiveness on benchmarks like Yin-Yang and MNIST.
Abstract
Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
