Compiler-First State Space Duality and Portable $O(1)$ Autoregressive Caching for Inference
Cosmo Santoni

TL;DR
This paper introduces a hardware-agnostic, compiler-based approach for efficient autoregressive inference in state-space models, achieving $O(1)$ memory management and high performance across CPU, GPU, and TPU platforms.
Contribution
It presents a novel compiler-first method that maps state-space duality onto XLA optimizations, enabling portable, kernel-free inference with on-device caching.
Findings
Achieves near-peak FLOPS on TPU for large models.
Runs unmodified across CPU, GPU, and TPU from a single JAX source.
Matches reference decoding accuracy across platforms.
Abstract
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsParallel Computing and Optimization Techniques · Embedded Systems Design Techniques · Distributed systems and fault tolerance
