FedJAX: Federated learning simulation with JAX
Jae Hun Ro, Ananda Theertha Suresh, Ke Wu

TL;DR
FedJAX is an easy-to-use, high-performance JAX-based library designed to simplify and accelerate federated learning research and simulation, supporting rapid experimentation with datasets and models.
Contribution
It introduces FedJAX, a new open source library that streamlines federated learning research with simple APIs, prepackaged datasets, and fast simulation capabilities.
Findings
FedJAX can train models on EMNIST in a few minutes.
It trains on Stack Overflow dataset in about an hour.
Supports efficient federated averaging with TPU acceleration.
Abstract
Federated learning is a machine learning technique that enables training across decentralized data. Recently, federated learning has become an active area of research due to an increased focus on privacy and security. In light of this, a variety of open source federated learning libraries have been developed and released. We introduce FedJAX, a JAX-based open source library for federated learning simulations that emphasizes ease-of-use in research. With its simple primitives for implementing federated learning algorithms, prepackaged datasets, models and algorithms, and fast simulation speed, FedJAX aims to make developing and evaluating federated algorithms faster and easier for researchers. Our benchmark results show that FedJAX can be used to train models with federated averaging on the EMNIST dataset in a few minutes and the Stack Overflow dataset in roughly an hour with standard…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsPrivacy-Preserving Technologies in Data · Artificial Intelligence in Healthcare and Education
