MUX-PLMs: Data Multiplexing for High-throughput Language Models
Vishvak Murahari, Ameet Deshpande, Carlos E. Jimenez, Izhak Shafran,, Mingqiu Wang, Yuan Cao, Karthik Narasimhan

TL;DR
This paper introduces MUX-PLMs, a new class of high-throughput language models trained with data multiplexing, achieving 2-5x inference speedup with minimal performance loss across various tasks.
Contribution
The paper develops novel multiplexing and demultiplexing modules enabling high-performance, high-throughput PLMs trained with data multiplexing, which are competitive with standard models.
Findings
Achieves 2x inference speedup on multiple tasks.
Achieves 5x inference speedup with minimal 1-4% performance drop.
Demonstrates broad applicability across different NLP tasks.
Abstract
The widespread adoption of large language models such as ChatGPT and Bard has led to unprecedented demand for these technologies. The burgeoning cost of inference for ever-increasing model sizes coupled with hardware shortages has limited affordable access and poses a pressing need for efficiency approaches geared towards high throughput and performance. Multi-input multi-output (MIMO) algorithms such as data multiplexing, offer a promising solution with a many-fold increase in throughput by performing inference for multiple inputs at the cost of a single input. Yet these approaches are not currently performant enough to be deployed in modern systems. We change that by developing MUX-PLMs, a class of high throughput pre-trained language models (PLMs) trained with data multiplexing, that can be fine-tuned for any downstream task to yield high-throughput high-performance. Our novel…
| Model | GLUE | Token | ||||
|---|---|---|---|---|---|---|
| Mean (std) | Max | Mean (std) | Max | |||
| BERT | 1 | 85.4 (0.0) | 85.4 | 95.8 (0.0) | 95.8 | 1.0 |
| ELECTRA | 82.1 (0.0) | 82.1 | 95.3 (0.0) | 95.3 | 1.0 | |
| T-MUX | 2 | 60.4 (0.6) | 61.8 | 81.4 (0.1) | 81.5 | 1.9 |
| MUX-BERT‡ | 82.5 (0.6) | 83.4 | 95.2 (0.1) | 95.4 | 2.0 | |
| MUX-ELEC‡ | 82.5 (0.4) | 83.1 | 95.0 (0.0) | 95.1 | 2.0 | |
| T-MUX | 5 | 59.7 (0.6) | 60.6 | 81.3 (0.2) | 81.5 | 4.4 |
| MUX-BERT‡ | 80.3 (0.4) | 80.9 | 93.6 (0.1) | 93.6 | 4.9 | |
| MUX-ELEC‡ | 79.8 (0.6) | 80.5 | 93.4 (0.0) | 93.5 | 4.9 | |
| T-MUX | 10 | 58.1 (0.5) | 59.1 | 79.7 (0.2) | 80.0 | 8.4 |
| MUX-BERT‡ | 77.8 (0.6) | 78.8 | 91.6 (0.1) | 91.8 | 9.8 | |
| MUX-ELEC‡ | 78.2 (0.6) | 79.0 | 91.7 (0.1) | 91.8 | 9.7 | |
| Model | QNLI | QQP | SST2 | |
|---|---|---|---|---|
| BERT | 90.5 | 91.2 | 91.7 | |
| MUX-BERT (N=2) | 88.2 | 90.4 | 90.6 | |
| MUX-BERT (N=5) | 85.6 | 88.8 | 86.9 | |
| Use additional unlabelled or task-specific data | ||||
| DistilBERT6 | 89.2 | 88.5 | 91.3 | |
| Block Pruning | 89.7 | - | 91.2 | |
| Prune OFA | 90.3 | 91.2 | 91.5 | |
| \hdashline Hybrid Approaches | ||||
| TinyBERT6 | 91.1 | 91.1 | 93.0 | |
| CoFi | 91.3 | - | 93.0 | |
| AutoTinyBERT | 89.7 | 89.9 | 91.4 | |
| MobileBERT | 91.0 | - | 92.1 | |
| Config | Model | GLUE | Token | |
|---|---|---|---|---|
| Small | BERT | 80.6 | 94.0 | 5.9 |
| T-MUX | 59.5 | 81.8 | 8.7 | |
| MUX-BERT‡ | 79.0 | 93.3 | 11.5 | |
| Base | BERT | 85.4 | 95.8 | 1.0 |
| T-MUX | 60.4 | 81.4 | 1.9 | |
| MUX-BERT‡ | 82.5 | 95.2 | 2.0 | |
| Large | BERT | 85.8 | 95.6 | 0.3 |
| T-MUX | 61.7 | 80.9 | 0.6 | |
| MUX-BERT‡ | 84.1 | 95.2 | 0.6 |
| Model | Mux () | MNLI | QQP | ||||
|---|---|---|---|---|---|---|---|
| No Ens | Ens | No Ens | Ens | ||||
| MUX-BERT | 2 | 80.6 | 81.2 | + 0.6 | 90.4 | 90.8 | + 0.4 |
| 5 | 77.2 | 78.8 | + 1.6 | 88.8 | 89.7 | + 0.9 | |
| 10 | 73.6 | 74.8 | + 1.2 | 86.9 | 87.7 | + 0.8 | |
| MUX-ELEC | 2 | 80.3 | 80.8 | + 0.5 | 90.6 | 90.9 | + 0.3 |
| 5 | 77.0 | 78.4 | + 1.4 | 89.1 | 89.9 | + 0.8 | |
| 10 | 74.6 | 76.0 | + 1.4 | 87.6 | 88.3 | + 0.7 | |
| Mux (N) | Model | Mux | Demux | GLUE | Token |
|---|---|---|---|---|---|
| 2 | MUX-BERT | Non-contextual | RSA-DeMUX | 82.5 | 95.2 |
| Ablation 1 | Non-contextual | Prefix | 83.2 | 95.3 | |
| Ablation 2 | Contextual | RSA-DeMUX | 82.3 | 95.3 | |
| 5 | MUX-BERT | Non-contextual | RSA-DeMUX | 80.3 | 93.6 |
| Ablation 1 | Non-contextual | Prefix | 78.6 | 38.9 | |
| Ablation 2 | Contextual | RSA-DeMUX | 76.8 | 94.2 | |
| 10 | MUX-BERT | Non-contextual | RSA-DeMUX | 77.8 | 91.6 |
| Ablation 1 | Non-contextual | Prefix | 76.6 | 25.6 | |
| Ablation 2 | Contextual | RSA-DeMUX | 76.0 | 93.3 |
| MUX-ELECTRA | MUX-BERT | |||||
|---|---|---|---|---|---|---|
| Best ticket | Worst ticket | Best ticket | Worst ticket | |||
| 2 | 83.1 | 82.0 | 1.1 | 83.4 | 81.8 | 1.6 |
| 5 | 80.5 | 78.9 | 1.6 | 80.9 | 79.7 | 1.2 |
| 10 | 79.0 | 77.3 | 1.7 | 78.8 | 77.0 | 1.8 |
| Hyperparameter | MUX-BERT | MUX-ELECTRA | ||
|---|---|---|---|---|
| Small | Base | Large | Base | |
| Number of layers | 4 | 12 | 24 | 12 |
| Hidden Size | 512 | 768 | 1024 | 768 |
| FFN intermediate hidden size | 2048 | 3072 | 4096 | 3072 |
| Attention heads | 8 | 12 | 16 | 12 |
| Attention head size | 64 | 64 | 64 | 64 |
| Mask percent | 15 | 15 | 15 | N/A |
| Learning Rate Decay | Linear | Linear | Linear | Linear |
| Warmup steps | 10000 | 10000 | 10000 | 10000 |
| Learning Rate | [1e-4, 5e-5] | [1e-4, 5e-5] | [1e-4, 5e-5] | [1e-4, 5e-5] |
| Adam | 1e-6 | 1e-6 | 1e-6 | 1e-6 |
| Adam | 0.9 | 0.9 | 0.9 | 0.9 |
| Adam | 0.999 | 0.999 | 0.999 | 0.999 |
| Attention Dropout | 0.1 | 0.1 | 0.1 | 0.1 |
| Dropout | 0.1 | 0.1 | 0.1 | 0.1 |
| Batch Size | 256 | 256 | 256 | 256 |
| Sequence Length | 512 | 512 | 512 | 512 |
| Train Steps | 1M | 1M | 1M | 1M |
| Hyperparameter | Value |
|---|---|
| Learning Rate | [2e-5, 5e-5] |
| Adam | 1e-8 |
| Adam | 0.9 |
| Adam | 0.999 |
| Learning rate decay | Linear |
| Warmup fraction | 0.1 |
| Attention Dropout | 0.1 |
| Dropout | 0.1 |
| Weight Decay | 0 |
| Batch Size | [32, 128] for Small/ Base, [16, 64] for Large |
| Train Steps | 2000 for RTE and WNLI |
| 10000 for MRPC, COLA and STSB | |
| 20000 for NER, SST2, QNLI and POS | |
| [20000, 100000] for MNLI and QQP | |
| Sequence Length | 128 |
| Model Size | N | MNLI | QQP | QNLI | MRPC | WNLI | STSB | RTE | SST2 | COLA | GLUE | GLUE |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Small | 1 | 77.86±0.0 | 88.99±0.0 | 84.00±0.0 | 77.70±0.0 | 56.34±0.0 | 84.25±0.0 | 62.45±0.0 | 88.88±0.0 | 43.48±0.0 | 73.77 | 80.59 |
| 2 | 75.09±0.1 | 88.88±0.1 | 84.31±0.2 | 79.75±0.7 | 50.99±8.1 | 82.65±0.3 | 55.52±1.5 | 87.04±0.7 | 30.64±1.7 | 70.54 | 79.03 | |
| 5 | 70.50±0.1 | 86.39±0.1 | 81.23±0.2 | 74.26±1.0 | 54.65±3.3 | 79.90±0.2 | 58.56±1.9 | 82.57±0.3 | 12.78±1.6 | 66.76 | 76.20 | |
| 10 | 61.98±0.1 | 80.85±0.1 | 63.47±0.3 | 70.69±0.9 | 56.62±4.3 | 36.93±1.0 | 53.57±1.8 | 80.39±0.4 | 1.10±2.2 | 56.18 | 63.98 | |
| Base | 1 | 84.24±0.0 | 91.19±0.0 | 90.54±0.0 | 87.75±0.0 | 56.34±0.0 | 89.18±0.0 | 63.18±0.0 | 91.74±0.0 | 58.79±0.0 | 79.22 | 85.40 |
| 2 | 80.59±0.1 | 90.36±0.1 | 88.17±0.1 | 83.77±1.4 | 50.70±7.0 | 85.84±0.1 | 58.19±1.6 | 90.62±0.6 | 55.61±1.6 | 75.98 | 82.51 | |
| 5 | 77.18±0.2 | 88.79±0.1 | 85.58±0.1 | 80.10±0.6 | 53.52±2.5 | 84.28±0.2 | 59.13±1.2 | 86.88±0.4 | 12.33±2.4 | 69.75 | 80.28 | |
| 10 | 73.62±0.3 | 86.94±0.1 | 82.08±0.3 | 78.63±0.6 | 52.68±6.0 | 81.62±0.2 | 58.27±2.4 | 83.44±0.6 | 0.00±0.0 | 66.36 | 77.80 | |
| Large | 1 | 85.79±0.0 | 91.46±0.0 | 92.29±0.0 | 83.82±0.0 | 56.34±0.0 | 89.53±0.0 | 66.06±0.0 | 91.40±0.0 | 57.79±0.0 | 79.39 | 85.76 |
| 2 | 83.23±0.2 | 90.85±0.1 | 90.66±0.2 | 84.90±0.8 | 56.34±0.0 | 88.22±0.2 | 59.21±0.9 | 91.38±0.4 | 57.89±1.5 | 78.08 | 84.06 | |
| 5 | 79.55±0.2 | 89.37±0.1 | 87.41±0.2 | 83.77±1.1 | 54.93±0.0 | 85.86±0.3 | 57.26±2.0 | 88.65±0.7 | 46.66±0.9 | 74.83 | 81.70 | |
| 10 | 35.45±0.0 | 63.18±0.0 | 50.54±0.0 | 68.38±0.0 | 56.90±5.2 | 82.81±0.2 | 52.13±1.9 | 50.92±0.0 | 1.87±4.6 | 51.35 | 57.63 |
| Model Size | N | MNLI | QQP | QNLI | MRPC | WNLI | STSB | RTE | SST2 | COLA | GLUE | GLUE |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Small | 1 | 77.86 | 88.99 | 84.00 | 77.70 | 56.34 | 84.25 | 62.45 | 88.88 | 43.48 | 73.77 | 80.59 |
| 2 | 75.21 | 89.01 | 84.61 | 80.64 | 61.97 | 82.97 | 58.12 | 87.84 | 33.08 | 72.61 | 79.77 | |
| 5 | 70.66 | 86.46 | 81.60 | 75.74 | 61.97 | 80.24 | 60.65 | 83.49 | 15.57 | 68.49 | 76.98 | |
| 10 | 62.17 | 80.93 | 63.85 | 71.81 | 63.38 | 38.20 | 55.96 | 80.96 | 2.63 | 57.77 | 64.84 | |
| Base | 1 | 84.24 | 91.19 | 90.54 | 87.75 | 56.34 | 89.18 | 63.18 | 91.74 | 58.79 | 79.22 | 85.40 |
| 2 | 80.82 | 90.47 | 88.28 | 86.03 | 66.20 | 86.06 | 60.65 | 91.51 | 56.93 | 78.55 | 83.40 | |
| 5 | 77.66 | 88.89 | 85.70 | 81.13 | 59.15 | 84.47 | 60.65 | 87.50 | 15.79 | 71.22 | 80.86 | |
| 10 | 74.04 | 87.03 | 82.45 | 79.41 | 63.38 | 81.89 | 62.45 | 84.29 | 0.00 | 68.33 | 78.79 | |
| Large | 1 | 85.79 | 91.46 | 92.29 | 83.82 | 56.34 | 89.53 | 66.06 | 91.40 | 57.79 | 79.39 | 85.76 |
| 2 | 83.40 | 90.94 | 90.96 | 86.27 | 56.34 | 88.50 | 60.29 | 91.86 | 60.50 | 78.78 | 84.60 | |
| 5 | 79.69 | 89.43 | 87.81 | 84.80 | 57.75 | 86.49 | 60.65 | 89.45 | 47.56 | 75.96 | 82.62 | |
| 10 | 35.46 | 63.18 | 50.89 | 68.38 | 61.97 | 83.04 | 55.60 | 50.92 | 7.55 | 53.00 | 58.21 |
| N | MNLI | QQP | QNLI | MRPC | WNLI | STSB | RTE | SST2 | COLA | GLUE | GLUE |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 81.49±0.0 | 90.73±0.0 | 89.73±0.0 | 75.98±0.0 | 56.34±0.0 | 87.73±0.0 | 57.76±0.0 | 91.51±0.0 | 56.79±0.0 | 76.45 | 82.13 |
| 2 | 80.29±0.2 | 90.58±0.1 | 88.39±0.2 | 83.73±0.7 | 57.18±2.1 | 86.80±0.1 | 58.77±1.1 | 88.65±0.4 | 51.92±1.7 | 76.26 | 82.46 |
| 5 | 76.99±0.2 | 89.08±0.0 | 85.40±0.3 | 80.25±1.6 | 56.90±4.5 | 84.27±0.2 | 57.26±1.0 | 85.09±1.0 | 26.89±1.2 | 71.35 | 79.76 |
| 10 | 74.62±0.2 | 87.63±0.1 | 82.70±0.2 | 77.89±0.7 | 50.99±4.9 | 81.96±0.5 | 59.86±2.1 | 82.71±0.5 | 27.76±2.3 | 69.57 | 78.20 |
| N | Retreival Rate | MNLI | QQP | QNLI | MRPC | WNLI | STSB | RTE | SST2 | COLA | GLUE | GLUE |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2 | 0.0 | 83.23±0.2 | 90.85±0.1 | 90.66±0.2 | 84.90±0.8 | 56.34±0.0 | 88.22±0.2 | 59.21±0.9 | 91.38±0.4 | 57.89±1.5 | 78.08 | 84.06 |
| 0.1 | 83.55±0.3 | 90.90±0.1 | 90.58±0.2 | 85.49±1.1 | 56.34±0.0 | 88.28±0.2 | 57.76±1.4 | 90.69±0.8 | 59.36±1.4 | 78.11 | 83.89 | |
| 0.2 | 83.50±0.1 | 90.96±0.1 | 90.69±0.2 | 84.95±0.5 | 56.34±0.0 | 88.28±0.2 | 58.34±1.6 | 90.69±0.5 | 59.17±1.5 | 78.10 | 83.92 | |
| 0.5 | 83.41±0.2 | 90.91±0.0 | 90.47±0.1 | 85.25±0.5 | 56.34±0.0 | 88.02±0.1 | 59.35±1.6 | 89.52±0.6 | 59.41±2.0 | 78.08 | 83.85 | |
| 5 | 0.0 | 79.55±0.2 | 89.37±0.1 | 87.41±0.2 | 83.77±1.1 | 54.93±0.0 | 85.86±0.3 | 57.26±2.0 | 88.65±0.7 | 46.66±0.9 | 74.83 | 81.70 |
| 0.1 | 79.49±0.1 | 89.34±0.1 | 87.25±0.3 | 81.81±1.3 | 53.24±1.6 | 85.80±0.2 | 55.60±2.4 | 88.19±0.7 | 47.60±1.0 | 74.26 | 81.07 | |
| 0.2 | 79.37±0.1 | 89.42±0.1 | 87.23±0.3 | 82.40±1.1 | 54.93±0.0 | 85.85±0.2 | 55.38±2.6 | 87.84±0.8 | 43.58±1.2 | 74.00 | 81.07 | |
| 0.5 | 79.24±0.1 | 89.30±0.1 | 87.21±0.3 | 82.06±1.7 | 56.34±0.0 | 85.97±0.2 | 52.27±4.0 | 88.58±0.6 | 47.01±2.3 | 74.22 | 80.66 | |
| 10 | 0.0 | 35.45±0.0 | 63.18±0.0 | 50.54±0.0 | 68.38±0.0 | 56.90±5.2 | 82.81±0.2 | 52.13±1.9 | 50.92±0.0 | 1.87±4.6 | 51.35 | 57.63 |
| 0.1 | 35.45±0.0 | 63.18±0.0 | 50.65±0.2 | 68.38±0.0 | 54.93±5.0 | 4.45±1.5 | 51.48±2.4 | 50.92±0.0 | 1.34±1.8 | 42.31 | 46.36 | |
| 0.2 | 35.45±0.0 | 63.18±0.0 | 50.21±0.5 | 68.43±0.8 | 54.65±4.2 | 0.23±1.5 | 52.35±2.0 | 51.72±0.4 | 0.29±2.7 | 41.83 | 45.94 | |
| 0.5 | 35.45±0.0 | 63.18±0.0 | 50.43±0.4 | 68.38±0.0 | 56.06±0.6 | 82.01±0.6 | 52.71±0.0 | 50.92±0.0 | 1.51±1.7 | 51.18 | 57.58 |
| N | Mux Strategy | MNLI | QQP | QNLI | MRPC | WNLI | STSB | RTE | SST2 | COLA | GLUE | GLUE |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2 | MUX-BERT | 80.59±0.1 | 90.36±0.1 | 88.17±0.1 | 83.77±1.4 | 50.70±7.0 | 85.84±0.1 | 58.19±1.6 | 90.62±0.6 | 55.61±1.6 | 75.98 | 82.51 |
| DataMUX | 81.64±0.2 | 90.67±0.1 | 88.39±0.2 | 84.17±0.4 | 56.34±0.0 | 86.36±0.2 | 60.87±0.7 | 90.50±0.4 | 53.74±1.0 | 76.96 | 83.23 | |
| Attention | 81.32±0.2 | 90.65±0.0 | 88.77±0.1 | 80.88±0.6 | 56.34±0.0 | 86.25±0.1 | 56.90±1.2 | 91.06±0.2 | 47.15±1.1 | 75.48 | 82.26 | |
| 5 | MUX-BERT | 77.18±0.2 | 88.79±0.1 | 85.58±0.1 | 80.10±0.6 | 53.52±2.5 | 84.28±0.2 | 59.13±1.2 | 86.88±0.4 | 12.33±2.4 | 69.75 | 80.28 |
| DataMUX | 76.32±0.1 | 89.13±0.1 | 84.22±0.3 | 78.38±0.9 | 59.44±3.5 | 81.78±0.4 | 54.15±1.3 | 86.17±0.4 | 28.32±0.8 | 70.88 | 78.59 | |
| Attention | 77.16±0.1 | 88.71±0.0 | 84.33±0.1 | 70.49±0.6 | 54.08±3.2 | 80.37±0.3 | 54.44±2.5 | 81.95±0.3 | 34.67±1.2 | 69.58 | 76.78 | |
| 10 | MUX-BERT | 73.62±0.3 | 86.94±0.1 | 82.08±0.3 | 78.63±0.6 | 52.68±6.0 | 81.62±0.2 | 58.27±2.4 | 83.44±0.6 | 0.00±0.0 | 66.36 | 77.80 |
| DataMUX | 72.74±0.1 | 87.88±0.1 | 82.28±0.2 | 77.30±0.5 | 56.34±0.0 | 78.07±0.4 | 55.31±1.2 | 82.36±0.3 | 13.56±3.0 | 67.32 | 76.56 | |
| Attention | 71.83±0.2 | 88.00±0.0 | 81.46±0.2 | 73.53±0.5 | 53.24±5.4 | 82.95±0.2 | 52.71±0.0 | 81.28±0.4 | 32.84±0.6 | 68.65 | 75.97 |
| Model Size | N | MNLI | QQP | QNLI | MRPC | WNLI | STSB | RTE | SST2 | COLA | GLUE | GLUE |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Small | 2 | 61.48±0.2 | 80.33±0.0 | 60.05±0.2 | 68.43±0.5 | 56.34±0.0 | 15.02±0.4 | 51.12±0.6 | 79.75±0.3 | 8.22±0.7 | 53.42 | 59.45 |
| 5 | 58.35±0.2 | 77.50±0.1 | 57.17±0.3 | 68.38±0.0 | 56.34±0.0 | 11.31±0.3 | 51.70±1.3 | 77.78±0.3 | 6.02±0.7 | 51.62 | 57.46 | |
| 10 | 53.63±0.2 | 77.03±0.1 | 51.22±0.3 | 68.38±0.0 | 57.46±6.3 | 12.40±1.3 | 52.35±2.7 | 50.92±0.0 | 0.00±0.0 | 47.04 | 52.28 | |
| Base | 2 | 63.29±0.3 | 81.42±0.1 | 60.35±0.4 | 68.38±0.2 | 56.90±5.8 | 17.65±1.0 | 51.19±1.7 | 80.78±0.5 | 9.62±1.5 | 54.40 | 60.44 |
| 5 | 60.67±0.2 | 79.42±0.1 | 59.77±0.2 | 69.61±0.8 | 53.80±7.3 | 14.92±1.8 | 52.71±0.8 | 81.15±0.6 | 10.35±1.7 | 53.60 | 59.75 | |
| 10 | 59.07±0.2 | 78.22±0.1 | 57.99±0.5 | 68.38±0.0 | 60.28±3.0 | 11.83±0.6 | 53.07±1.1 | 78.35±1.1 | 7.40±1.7 | 52.73 | 58.13 | |
| Large | 2 | 64.64±0.2 | 82.10±0.1 | 60.21±0.2 | 69.95±0.9 | 56.34±0.0 | 21.62±0.4 | 52.71±0.0 | 80.34±0.9 | 8.72±2.1 | 55.18 | 61.65 |
| 5 | 60.78±0.3 | 78.56±0.1 | 60.19±0.3 | 69.51±0.5 | 56.34±0.0 | 17.33±1.1 | 52.71±0.0 | 78.28±0.8 | 10.63±2.7 | 53.81 | 59.62 | |
| 10 | 48.79±0.6 | 68.41±0.1 | 55.76±0.8 | 68.58±0.6 | 58.59±3.3 | 8.38±1.1 | 54.95±0.9 | 64.82±1.0 | 3.48±3.9 | 47.97 | 52.81 |
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsTopic Modeling · Natural Language Processing Techniques · Ferroelectric and Negative Capacitance Devices
MethodsMulti-Head Attention · Attention Is All You Need · Dense Connections · Linear Layer · Dropout · Weight Decay · Refunds@Expedia|||How do I get a full refund from Expedia? · WordPiece · Attention Dropout · Softmax
MUX-PLMs: Data Multiplexing for High-throughput Language Models
**Vishvak Murahari1 Ameet Deshpande1 Carlos E. Jimenez1 **
Izhak Shafran2 Mingqiu Wang2 Yuan Cao2 Karthik Narasimhan1
1Princeton University 2 Google Brain
Abstract
The widespread adoption of large language models such as ChatGPT and Bard has led to unprecedented demand for these technologies. The burgeoning cost of inference for ever-increasing model sizes coupled with hardware shortages has limited affordable access and poses a pressing need for efficiency approaches geared towards high throughput and performance. Multi-input multi-output (MIMO) algorithms such as data multiplexing, offer a promising solution with a many-fold increase in throughput by performing inference for multiple inputs at the cost of a single input. Yet these approaches are not currently performant enough to be deployed in modern systems. We change that by developing MUX-PLMs, a class of high throughput pre-trained language models (PLMs) trained with data multiplexing, that can be fine-tuned for any downstream task to yield high-throughput high-performance. Our novel multiplexing and demultiplexing modules proficiently entangle and disentangle inputs, and enable high-performance high throughput MUX-PLMs that are competitive with vanilla PLMs while achieving 2x/5x inference speedup with only a drop on a broad suite of tasks. 111Code + Models: https://github.com/princeton-nlp/datamux-pretraining/.
1 Introduction
Language models like ChatGPT OpenAI (2023), PaLM Chowdhery et al. (2022), T5 Raffel et al. (2020), and CM3 Aghajanyan et al. (2022), have seen unprecedented adoption in diverse sectors ranging from education and healthcare to manufacturing and marketing. The proficiency of these tools has led to unprecedented demand for these models, with users facing frequent outages and capacity limits. Additionally, ever-increasing model sizes and hardware shortages have constrained models’ ability to handle a very high load of requests, thus limiting large-scale affordable access to these models. These trends bring into focus the need for high-throughput, high-performance, efficient, and environmentally responsible models that can be deployed at scale to meet the quickly growing demand.
Multi-input Multi-output architectures (MIMO) Havasi et al. (2021); Ramé et al. (2021); Murahari et al. (2022) are a promising hardware-agnostic and architecture-agnostic paradigm that perform inference for multiple inputs simultaneously at the cost of a single input. This efficiency paradigm is natively geared towards yielding high-throughput models, in addition to being complementary in approach and motivation to current efficiency methods such as pruning, quantization, and distillation. Interestingly, MIMO approaches are partly inspired by the human brain’s extraordinary ability to process multiple inputs and propagate information at a high bandwidth with a few neural codes Blumhagen et al. (2011); Akam and Kullmann (2014); Pirschel and Kretzberg (2016); Hong et al. (2016); Friedrich et al. (2004).
Murahari et al. (2022) introduced data multiplexing, a MIMO technique that can enable a many-fold increase in throughput. The method compresses different instances into a single “multiplexed” hidden representation before decompressing it into independent predictions. While they show the plausibility of MIMO training, their method leads to a significant drop in performance ( points) compared to state-of-the-art models.
In this work, we introduce MUX-PLMs, a class of high-throughput pre-trained language models trained in a MIMO fashion with data multiplexing to process multiple inputs (2-10) simultaneously with a forward pass over a single instance. MUX-PLMs offer up to improvement in throughput over baseline pre-trained models while only being points and points worse than baseline pre-trained language models for text classification and token classification tasks, respectively. MUX-PLMs, like other pre-trained language models, provide general model initialization that can be fine-tuned for any downstream task. We demonstrate the effectiveness and generality of our MUX-PLMs class of pre-trained models by training MUX-BERT and MUX-ELECTRA models, which are trained with pre-trained objectives adapted from BERT Devlin et al. (2019) and ELECTRA Clark et al. (2020) respectively, although in a MIMO fashion with data multiplexing.
Our work is the first to introduce MIMO architectures to PLMs. To enable this, we first develop a new demultiplexing module, RSA-demux (Figure 2), that randomly initializes and learns private key vectors to recover the multiple outputs from a multiplexed representation. Secondly, we introduce a new Contextual Multiplexer module (Figure 3) that uses a cross-instance attention-based mechanism to aggregate context across the set of multiplexed instances, which seems to be particularly effective for token-level tasks. Thirdly, our three-stage training algorithm (Figure 1) enables stable and efficient training of MUX-PLMs.
Importantly, MUX-PLMs are complementary to existing state-of-the-art model compression techniques. We hope our work validates MIMO architectures as a promising complementary direction to existing efficiency techniques. Consequently, we hope future research develops MIMO architectures in tandem with other efficiency approaches, leveraging the best of both paradigms. We publicly release our models and code to promote open-source research on the next generation of MIMO architectures for large language models.
2 Related Work
Efficient Inference with Transformers
Recent methods in NLP rely heavily on transfer learning through Transformer-based Vaswani et al. (2017) language models trained on large text corpora using self-supervised objectives, such as autoregressive Radford and Narasimhan (2018) or masked language modeling Devlin et al. (2019). Prior analysis on pre-training language models has observed power-law scaling of model performance with respect to model size Kaplan et al. (2020), leading the community to develop ever-larger language models. It is also generally recognized that pre-trained language models are significantly over-parameterized; effectively learning a subnetwork that utilizes only a relatively small number of their total parameters Voita et al. (2019); Michel et al. (2019); Gordon et al. (2020); Prasanna et al. (2020).
The ubiquity of pre-trained language models, their growing size, and over-parameterization has inspired extensive research on improving inference efficiency. This includes methods such as structured pruning Liu et al. (2019); Wang et al. (2020); Lagunas et al. (2021); Xia et al. (2022); Yang et al. (2022), knowledge distillation Hinton et al. (2015); Sanh et al. (2019); Sun et al. (2020); Jiao et al. (2020); Yin et al. (2021), quantization Zafrir et al. (2019); Shen et al. (2020), and data multiplexing Murahari et al. (2022). These approaches assume that PLMs are highly over-parametrized and attempt to approximate a large function by learning a smaller, compressed, version of the original model. Past work has also focused on unstructured pruning for both task finetuning Chen et al. (2020); Sanh et al. (2020) and pre-trained Zafrir et al. (2021); Jiang et al. (2022) language model settings, but don’t increase model throughput due to hardware limits.
Multi-input Multi-output Models
While pruning, quantization, and distillation seek to reduce overparameterization by reducing models’ representational capacity, other lines of work seek to exploit overparameterization in other ways. Multi-input Multi-output (MIMO) architectures Havasi et al. (2021); Ramé et al. (2021); Murahari et al. (2022) train models using mixed-instance representations, i.e. Zhang et al. (2018), in order to obtain predictions for multiple instances simultaneously. Unlike efficiency methods, Havasi et al. (2021) and Ramé et al. (2021) try to obtain better performance by inducing multiple subnetworks in a single convolutional model to perform “ensembling for free” during inference. Data multiplexing, introduced in DataMUX Murahari et al. (2022), aims to improve model efficiency by training Transformer models with mixed-instance representations to perform simultaneous inference for language tasks, thereby improving inference throughput many-fold. Currently, MIMO architectures have only been used in a limited setting, achieving middling performance. Our work training PLMs with data multiplexing, dramatically improves inference throughput while preserving high accuracy for downstream tasks.
3 Methodology
We briefly introduce readers to the data multiplexing MIMO architecture Murahari et al. (2022), which we denote T-MUX. We then detail our novel approach to train MUX-PLMs to yield high-throughput and performant language models.
3.1 T-MUX: Data multiplexing with Transformer
Data multiplexing allows parallel processing of multiple sequences with a single forward or backward pass through the model () and requires two crucial components, multiplexer, and demultiplexer.
Multiplexer
The multiplexer module (MUX) combines an ordered set of multiple inputs – into a single superimposed representation (). If , the multiplexer is a transformation () such that \textrm{\mathbf{x}^{\textrm{{{MUX{}}}}}{}}=\textrm{MUX{}}\left(X^{1:N}\right).
To ensure MUX is an order-preserving transformation,T-MUX samples a vector () from a standard multivariate Gaussian and applies the Hadamard product (element-wise multiplication) with the corresponding input representation () before summing up vectors for all positions.
[TABLE]
The model processes the multiplexed representation and emits a multiplexed hidden state – \textrm{\mathbf{h}^{\textrm{{{MUX{}}}}}{}}=M\left(\textrm{\mathbf{x}^{\textrm{{{MUX{}}}}}{}}\right). To multiplex Transformers’ sequenced inputs of length , T-MUX applies the same to all positions of instance .
[TABLE]
Demultiplexer
A prediction needs to be made for each instance in , and T-MUX’s demultiplexer module (DeMUX) achieves this by separating the superimposed output () into output representations corresponding to the input ().
[TABLE]
The vector is dynamically generated for each instance () with the help of a prefix that is added to the input and re-used for all positions in the instance. They add a to , represented by the following pattern, where is a special token, and is set to be the output corresponding to token in the prefix.
[TABLE]
3.2 MUX-PLMs: Data multiplexing for high-throughput language models
We propose MUX-PLMs, a class of high-throughput pre-trained Transformer-based language models trained in a MIMO fashion with data multiplexing. To demonstrate the viability and the generality of this class of models, we pre-train Transformer models with objectives based on BERT and ELECTRA, to get MUX-BERT and MUX-ELECTRA respectively. MUX-PLMs are trained with our three stage training algorithm (Figure 1). Firstly, MUX-PLMs are trained with the token retrieval task in T-MUX, which is an auto-encoding setup to decode all the tokens in the multiplexed input. This simple auto-encoding task is critical to prime the model for MIMO-style data multiplexing. The MUX-PLMs are then pre-trained with standard pre-training objectives but adapted to MIMO-fashioned training with data multiplexing. MUX-PLMs show significant throughput improvement over standard pre-trained LMs while matching their downstream task accuracies. Finally, MUX-PLMs, like other pre-trained language models, provide general model initialization that can be fine-tuned for any downstream task.
Contextual multiplexer
T-MUX’s multiplexer multiplexes tokens independent of 1) tokens in the same position in other instances and 2) other tokens in the instance, which could lead to suboptimal representations. We, therefore, explore a contextual multiplexing scheme that aggregates context both from tokens in the same instance and tokens in the same position of other instances (Figure 3). We first use a single transformer layer to get contextual representations ) of length . We apply a hadamard product with a multivariate gaussian to all positions.
[TABLE]
We generate multiplexed representations, , by applying another transformer layer across encoded representations from instances at each position from to . This is done by transposing and applying .
[TABLE]
RSA demultiplexer
The demultiplexer in T-MUX requires a prefix whose length scales linearly with the number of instances (), thus reducing the effective context length during pre-training, which degrades performance Ainslie et al. (2020); Zaheer et al. (2020); Beltagy et al. (2020). Furthermore, it decreases throughput during inference for large because the model must process an extra prefix of length for each of the instances. To address these issues, we draw inspiration from the RSA cryptosystem Rivest et al. (1978) to randomly initialize and learn (private) key vectors , which are keys that can be used to demultiplex the output representation (Figure 2).
[TABLE]
Akin to RSA, and can be treated as the keys for multiplexing (encryption) and demultiplexing (decryption) while ensuring that, unlike T-MUX, the input sequence length does not change and thereby leading to an improvement in throughput. Importantly, this architecture ensures that the keys better transfer across the different stages of training as they are no longer conditioned on the input instances.
4 Experimental Setup
Datasets
We pre-train all models on Wikipedia Foundation and Bookscorpus Zhu et al. (2015). We evaluate on all datasets from the GLUE benchmark Wang et al. (2018), which are CoLA Warstadt et al. (2019), SST-2 Socher et al. (2013), MRPC Dolan and Brockett (2005), QQP qqp , STS-B Cer et al. (2017), MNLI Williams et al. (2018), QNLI Wang et al. (2018), RTE Wang et al. (2018), and WNLI Levesque et al. (2012). We also evaluate on token classification tasks – named entity recognition Sang and Meulder (2003) and POS tagging Grünewald et al. (2021). We don’t report average over the two smallest tasks in GLUE, WNLI and CoLA, as we observe high variance across seeds. All numbers are reported on the dev split. We report scores for all tasks in Appendix E.
Models
We experiment with ELECTRA and BERT pre-training objectives and present the pre-trained multiplexed models MUX-BERT and MUX-ELECTRA for . To simplify training, we use a random generator to train MUX-ELECTRA models, presented as an ablation in Clark et al. (2020), instead of using a smaller masked LM. Except where otherwise noted, we do not use the contextual MUX module, but instead, use the RSA demultiplexing module. Refer to Appendix B and C for implementation details.
Baselines
We run experiments to compare our models against T-MUX, from Murahari et al. (2022) and baseline PLMs - ELECTRA and BERT, across three different model configurations (small, base, and large). We also provide a comparison to results reported in recent PLM pruning and distillation papers in Table 5.2.
Multi-run evaluation
We evaluate all models across random seeds to reduce variance for smaller datasets which is caused by the randomized order in which we multiplex instances in the batch. We report both the average and maximum scores across seeds in Table 1 to understand the importance of ordering the multiplexed instances and report average across seeds for all other results.
5 Results
5.1 Comparing MUX-PLMs with PLMs and T-MUX
Table 1 shows that both MUX-BERT and MUX-ELECTRA outperform T-MUX at all levels of multiplexing (), with improvements between and points on GLUE and token-classification tasks respectively. Furthermore, MUX-PLMs’ efficient RSA-inspired demultiplexing method allows it to achieve faster throughput than T-MUX, increasing it by over for .
Moreover, MUX-PLMs provide a significant boost in throughput ( times faster) when compared to PLMs, without a significant loss in performance. For example, MUX-ELECTRA () is points better and only points worse than ELECTRA for GLUE and TOKEN tasks respectively, while being faster. Similarly, MUX-BERT () is within and points of BERT for GLUE and TOKEN tasks respectively, while being significantly faster. We also observe that as increases, MUX-PLMs’ throughput is significantly better, though performance compared to PLMs can degrade. This is because a large implies that MUX-PLMs must parallelly process more instances, thus having to share network parameters and activations with a larger number of instances, thus improving throughput and degrading performance. For example, the gap between ELECTRA and MUX-ELECTRA on TOKEN for is points and increases to points for , which shows that serves as a parameter to control the performance-throughput trade-off. We explore this further in Section 5.3 and Figure 4.
5.2 Comparing MUX-PLMs with recent model compression methods
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1(1) Quora. data.quora.com/First-Quora-Dataset-Release-Question-Pairs . Accessed: 2022-10-15.
- 2Aghajanyan et al. (2022) Armen Aghajanyan, Bernie Huang, Candace Ross, Vladimir Karpukhin, Hu Xu, Naman Goyal, Dmytro Okhonko, Mandar Joshi, Gargi Ghosh, Mike Lewis, et al. 2022. Cm 3: A causal masked multimodal model of the internet. ar Xiv preprint ar Xiv:2201.07520 .
- 3Ainslie et al. (2020) Joshua Ainslie, Santiago Ontañón, Chris Alberti, Philip Pham, Anirudh Ravula, and Sumit Sanghai. 2020. ETC: encoding long and structured data in transformers . Co RR , abs/2004.08483.
- 4Akam and Kullmann (2014) Thomas Akam and Dimitri M Kullmann. 2014. Oscillatory multiplexing of population codes for selective communication in the mammalian brain. Nature Reviews Neuroscience , 15(2):111–122.
- 5Beltagy et al. (2020) Iz Beltagy, Matthew E Peters, and Arman Cohan. 2020. Longformer: The long-document transformer. ar Xiv preprint ar Xiv:2004.05150 .
- 6Blumhagen et al. (2011) Francisca Blumhagen, Peixin Zhu, Jennifer Shum, Yan-Ping Zhang Schärer, Emre Yaksi, Karl Deisseroth, and Rainer W Friedrich. 2011. Neuronal filtering of multiplexed odour representations. Nature , 479(7374):493–498.
- 7Cer et al. (2017) Daniel Cer, Mona Diab, Eneko Agirre, Iñigo Lopez-Gazpio, and Lucia Specia. 2017. Semeval-2017 task 1: Semantic textual similarity multilingual and crosslingual focused evaluation. In Proceedings of the 11th International Workshop on Semantic Evaluation (Sem Eval-2017) , pages 1–14.
- 8Chen et al. (2020) Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Zhangyang Wang, and Michael Carbin. 2020. The lottery ticket hypothesis for pre-trained bert networks . In Advances in Neural Information Processing Systems , volume 33, pages 15834–15846. Curran Associates, Inc.
