Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

NeuronMLP: Efficient LLM Inference via Singular Value Decomposition Compression and Tiling on AWS Trainium

About

Emerging AI accelerators have started to gain attention and offer new opportunities for efficient inference of large language models (LLMs). Trainium, an AI accelerator recently developed by Amazon Web Services (AWS), provides an attractive option for LLM inference through its heterogeneous architecture. However, leveraging Trainium architecture for high performance can be challenging because of its systolic array architecture and special requirement on data layout. In this paper, we propose NeuronMLP, an efficient LLM inference method based on Singular Value Decomposition (SVD) compression and tiling on AWS Trainium. We introduce a series of techniques customized to Trainium based on kernel fusion and novel caching strategies to reduce data movement across the software-managed memory hierarchy, maximize SRAM bandwidth, and avoid expensive matrix transpose. The proposed method is specifically optimized for multi-layer perceptron (MLP) layers in LLMs, which serve as a critical computational kernel for inference on Trainium. Evaluating on nine datasets and six recent LLMs, we show that NeuronMLP significantly outperforms the state-of-the-art Neuron Kernel Interface (NKI)-based matrix multiplication (matmul) kernel implemented by AWS on Trainium: at the kernel level, it achieves an average 1.35x speedup, which translates to an average 1.21x speedup for end-to-end LLM inference, under a compression ratio of 0.05.

Dinghong Song, Jierui Xu, Weichu Yang, Pengfei Su, Dong Li• 2025

Related benchmarks

TaskDatasetResultRank
Language ModelingC4
Perplexity13.67
1688
Language ModelingPTB
Perplexity11.43
1234
Question AnsweringARC Challenge
Accuracy (ARC)52
598
Language ModelingWiki2
PPL8.56
326
Question AnsweringOpenBookQA
Accuracy40
305
Multiple-choice Question AnsweringHellaSwag
Accuracy54
196
Commonsense Question AnsweringWinoGrande
Accuracy72
73
Question AnsweringMathQA
Accuracy45
36
Inference Efficiency and ReasoningNine Datasets (Aggregate)
mAcc56
24
Matrix MultiplicationSimulated Matrix Multiplication Sequence Lengths 1K-32K
Latency (ms)27.63
3
Showing 10 of 10 rows

Other info

Follow for update