Rieoptax: Riemannian Optimization in JAX
Saiteja Utpala, Andi Han, Pratik Jawanpuria, Bamdev Mishra

TL;DR
Rieoptax is an open source Python library that enables efficient Riemannian optimization in JAX, supporting various algorithms and differentially private methods, with faster geometric primitives than existing frameworks.
Contribution
The paper introduces Rieoptax, a new library that offers faster Riemannian geometric primitives and supports advanced stochastic and private optimization methods in JAX.
Findings
Faster computation of Riemannian exponential and logarithm maps in Rieoptax.
Supports a wide range of stochastic optimization algorithms.
Includes differentially private Riemannian optimization.
Abstract
We present Rieoptax, an open source Python library for Riemannian optimization in JAX. We show that many differential geometric primitives, such as Riemannian exponential and logarithm maps, are usually faster in Rieoptax than existing frameworks in Python, both on CPU and GPU. We support various range of basic and advanced stochastic optimization solvers like Riemannian stochastic gradient, stochastic variance reduction, and adaptive gradient methods. A distinguishing feature of the proposed toolbox is that we also support differentially private optimization on Riemannian manifolds.
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsStochastic Gradient Optimization Techniques
MethodsLib
