GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding
About
Neural network scaling has been critical for improving the model quality in many real-world machine learning applications with vast amounts of training data and compute. Although this trend of scaling is affirmed to be a sure-fire approach for better model quality, there are challenges on the path such as the computation cost, ease of programming, and efficient implementation on parallel devices. GShard is a module composed of a set of lightweight annotation APIs and an extension to the XLA compiler. It provides an elegant way to express a wide range of parallel computation patterns with minimal changes to the existing model code. GShard enabled us to scale up multilingual neural machine translation Transformer model with Sparsely-Gated Mixture-of-Experts beyond 600 billion parameters using automatic sharding. We demonstrate that such a giant model can efficiently be trained on 2048 TPU v3 accelerators in 4 days to achieve far superior quality for translation from 100 languages to English compared to the prior art.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | HellaSwag | Accuracy72 | 1460 | |
| Mathematical Reasoning | GSM8K | Accuracy6.4 | 983 | |
| Code Generation | HumanEval | Pass@117.7 | 850 | |
| Multi-task Language Understanding | MMLU | Accuracy26.3 | 842 | |
| Commonsense Reasoning | WinoGrande | Accuracy67.6 | 776 | |
| Language Understanding | MMLU | Accuracy36.7 | 756 | |
| Question Answering | ARC Challenge | Accuracy45.8 | 749 | |
| Commonsense Reasoning | PIQA | Accuracy77.6 | 647 | |
| Question Answering | ARC Easy | Normalized Acc64 | 385 | |
| Reading Comprehension | RACE high | Accuracy43.5 | 295 |