Flash Multi-Head Feed-Forward Network
Minshen Zhang, Xiang Hu, Jianguo Li, Wei Wu, Kewei Tu

TL;DR
This paper introduces FlashMHF, an efficient multi-head feed-forward network for Transformers that improves scalability, reduces memory usage, and enhances performance by innovative kernel design and dynamic sub-network weighting.
Contribution
It proposes FlashMHF, a novel multi-head FFN architecture with fused kernel computation and dynamic sub-networks, addressing scalability and efficiency challenges in Transformer models.
Findings
Reduces peak memory usage by 3-5x
Improves perplexity and task accuracy over SwiGLU FFNs
Accelerates inference by up to 1.08x
Abstract
We explore Multi-Head FFN (MH-FFN) as a replacement of FFN in the Transformer architecture, motivated by the structural similarity between single-head attention and FFN. While multi-head mechanisms enhance expressivity in attention, naively applying them to FFNs faces two challenges: memory consumption scaling with the head count, and an imbalanced ratio between the growing intermediate size and the fixed head dimension as models scale, which degrades scalability and expressive power. To address these challenges, we propose Flash Multi-Head FFN (FlashMHF), with two key innovations: an I/O-aware fused kernel computing outputs online in SRAM akin to FlashAttention, and a design using dynamically weighted parallel sub-networks to maintain a balanced ratio between intermediate and head dimensions. Validated on models from 128M to 1.3B parameters, FlashMHF consistently improves perplexity…
Peer Reviews
Decision·Submitted to ICLR 2026
The motivation of the paper is well justified with two problems in naive multi-head attention. There are proper ablations such as head dimensions and model scales, and downstream task evaluations are standard. The idea is straightforward by using sub-networks to group different heads to solve the problems, yet results are pretty impressive.
1. In section 3.2.1 the authors say their FlashMHF functions Luke a dense MoE, however, there is no direct comparison against dense MoE architecture. 2. There is no ablations for “Flash”, so it’s hard to isolate memory savings from the architectural change and the kernel optimization. 3. Lack of large scale experiments to verify the scaling effect - largest model size is 1.3B. 4. About presentation, Figure 3a doesn’t show multihead which is confusing. Also, the biggest innovation of it seems t
1. The paper is well-motivated and clearly written. It identifies two key challenges of multi-head FFNs and proposes corresponding solutions, which are empirically validated. 2. FlashMHF achieves lower PPL and better downstream performance than SwiGLU and other baselines. The architectural design choices are well-supported by effective ablation studies, including the multi-head mechanism, SwiGLU component, and subnetwork structure. 3. lt is implemented with a kernel design analogous to FlashAtte
1. **Source of subnetwork advantages.** The authors claim that the benefit of the subnetwork design mainly arises from a more balanced expansion ratio. However, for a given head, the parallel subnetwork computation essentially differs from a dense FFN only by an additional **blockwise gating** applied to intermediate activations. When concatenated, this does not effectively control the expansion ratio and finally increases by $d_{model}/d_h$ compared to a standard SwiGLU. I suspect the improveme
- The main motivations behind the choice of architecture modifications are justified reasonably well - The analysis is convincing, and the experiments conducted overall complete (although some results could be presented better)
- Novelty is limited: both core methodologies (mirroring MH Attention and improving kernel application via tiling) have already been proposed
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsParallel Computing and Optimization Techniques · Ferroelectric and Negative Capacitance Devices · Low-power high-performance VLSI design
