Joint Prompt Optimization of Stacked LLMs using Variational Inference
Alessandro Sordoni, Xingdi Yuan, Marc-Alexandre C\^ot\'e, Matheus, Pereira, Adam Trischler, Ziang Xiao, Arian Hosseini, Friederike Niedtner,, Nicolas Le Roux

TL;DR
This paper introduces a novel method for optimizing prompts in stacked large language models using variational inference, enabling multi-layer prompt learning that improves performance on reasoning and understanding tasks.
Contribution
It presents a new approach for joint prompt optimization in multi-layer LLMs using variational inference, extending from single to multi-layer networks.
Findings
DLN-1 performs well on reasoning tasks
DLN-2 outperforms single-layer models
Potential to match GPT-4 performance with smaller models
Abstract
Large language models (LLMs) can be seen as atomic units of computation mapping sequences to a distribution over sequences. Thus, they can be seen as stochastic language layers in a language network, where the learnable parameters are the natural language prompts at each layer. By stacking two such layers and feeding the output of one layer to the next, we obtain a Deep Language Network (DLN). We first show how to effectively perform prompt optimization for a 1-Layer language network (DLN-1). Then, we present an extension that applies to 2-layer DLNs (DLN-2), where two prompts must be learned. The key idea is to consider the output of the first layer as a latent variable, which requires inference, and prompts to be learned as the parameters of the generative distribution. We first test the effectiveness of DLN-1 in multiple reasoning and natural language understanding tasks. Then, we…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
Taxonomy
TopicsTopic Modeling · Natural Language Processing Techniques · Speech Recognition and Synthesis
MethodsAttention Is All You Need · Linear Layer · Position-Wise Feed-Forward Layer · Absolute Position Encodings · Multi-Head Attention · Layer Normalization · Label Smoothing · Adam · Byte Pair Encoding · Residual Connection
