Composer: A Search Framework for Hybrid Neural Architecture Design
About
Hybrid model architectures that combine computational primitives (e.g., Attention, MLP) in different ratios have shown promising performance beyond Transformers. Some studies have shown that different interleavings of primitives can affect model quality as well. However, prior works explore the hybrid model architecture design space manually. Due to the large design space and training costs, discovering hybrid models that combine key computational primitives for pre-training is challenging. In this work, we take a principled approach in designing a modular hybrid model architecture search framework -- Composer. Composer explores model architectures at a small scale and extrapolates the top-performing model architectures to a larger scale using our proposed scaling strategies. Using Composer, we discover new hybrid LLM architectures that outperform Llama 3.2. Compared to Llama 3.2 and previous state-of-the-art baselines, the new model architectures consistently reduce validation loss at parameter scales of 350M-3B and improve evaluation accuracy on the downstream tasks by up to 2.8-8.3% (1.1-3.1% on average) while improving both training and inference efficiency.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | HellaSwag | -- | 1896 | |
| Commonsense Reasoning | WinoGrande | Accuracy58.8 | 1442 | |
| Question Answering | ARC Easy | Accuracy64.73 | 597 | |
| Question Answering | ARC-C | Accuracy0.3 | 116 | |
| Question Answering | ARC Challenge | Normalized Accuracy32.25 | 105 | |
| Question Answering | ARC-E | Normalized Accuracy (ARC-E)61.6 | 59 | |
| Language Modeling | Pre-training (val) | Validation Loss2.724 | 55 | |
| Question Answering | SciQ | Normalized Accuracy87.9 | 14 | |
| Language Model Evaluation | DCLM Core | DCLM Core Score49.3 | 12 | |
| Multiple-choice Question Answering | 6 downstream tasks (ARC-Challenge, ARC-Easy, HellaSwag, Winogrande, SciQ, PIQA) | ARC-Challenge Accuracy43.6 | 12 |