Revisiting the Shape Convention of Transformer Language Models
About
Dense Transformer language models have largely adhered to one consistent architectural shape: each layer consists of an attention module followed by a feed-forward network (FFN) with a narrow-wide-narrow MLP, allocating most parameters to the MLP at expansion ratios between 2 and 4. Motivated by recent results that residual wide-narrow-wide (hourglass) MLPs offer superior function approximation capabilities, we revisit the long-standing MLP shape convention in Transformer, challenging the necessity of the narrow-wide-narrow design. To study this, we develop a Transformer variant that replaces the conventional FFN with a deeper hourglass-shaped FFN, comprising a stack of hourglass sub-MLPs connected by residual pathways. We posit that a deeper but lighter hourglass FFN can serve as a competitive alternative to the conventional FFN, and that parameters saved by using a lighter hourglass FFN can be more effectively utilized, such as by enlarging model hidden dimensions under fixed budgets. We confirm these through empirical validations across model scales: hourglass FFNs outperform conventional FFNs up to 400M and achieve comparable performance at larger scales to 1B parameters; hourglass FFN variants with reduced FFN and increased attention parameters show consistent improvements over conventional configurations at matched budgets. Together, these findings shed new light on recent work and prompt a rethinking of the narrow-wide-narrow MLP convention and the balance between attention and FFN towards efficient and expressive modern language models.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Physical Interaction Question Answering | PIQA | Accuracy68.2 | 323 | |
| Sentence Completion | HellaSwag | Accuracy40.3 | 133 | |
| Multiple-choice Question Answering | ARC Easy | Accuracy57.7 | 122 | |
| Multiple-choice Question Answering | SciQ | Accuracy82.5 | 74 | |
| Multiple-choice Question Answering | CommonsenseQA (CSQA) | Accuracy36.1 | 21 | |
| Language Modeling | Pre-training (val) | PPL20.082 | 13 | |
| Question Answering | Natural Questions | PPL1.272 | 9 | |
| Question Answering | TriviaQA | PPL1.422 | 9 |