JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning
Anique Tahir, Lu Cheng, and Huan Liu

TL;DR
This paper introduces JORA, a JAX-based framework for efficient, scalable fine-tuning of large language models like Llama-2 for retrieval-augmented tasks, significantly reducing memory use and training time.
Contribution
It presents a novel JAX-based tensor-parallel LoRA library that enables efficient, distributed fine-tuning of LLMs for retrieval tasks, addressing memory constraints.
Findings
Over 12x faster runtime compared to existing methods
Uses less than half the VRAM per GPU
Enables fine-tuning on systems with limited GPU resources
Abstract
The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsTopic Modeling · Computational Physics and Python Applications · Tensor decomposition and applications
MethodsRefunds@Expedia|||How do I get a full refund from Expedia? · Attention Is All You Need · Linear Warmup With Linear Decay · Dropout · Byte Pair Encoding · Residual Connection · Linear Layer · Dense Connections · Adam · Attention Dropout
