Prototype Transformer: Towards Language Model Architectures Interpretable by Design
About
While state-of-the-art language models (LMs) surpass most humans in certain domains, their reasoning remains largely opaque, reducing trust and increasing the risk of deception and hallucination. We introduce the Prototype Transformer (ProtoT), an autoregressive LM architecture that replaces the quadratic-cost self-attention module of the Transformer with a linear-cost module based on prototypes, which are learned parameter vectors. In ProtoT, prototypes create communication channels that aggregate contextual information at different time scales. We show that this structure leads prototypes to automatically capture nameable concepts, such as "woman", during training, offering a path toward interpreting model reasoning and making targeted edits to model behavior. Compared with baselines, ProtoT scales well with model and data size, is robust to input perturbations, and performs well on text generation and downstream tasks, including GLUE. These results suggest that ProtoT is a promising step toward autoregressive language models that are more interpretable by design.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Modeling | FineWeb-Edu (test) | Perplexity (Test)29.5 | 58 | |
| Robustness Evaluation | Lexical Variation (abbr.) | Jensen-Shannon Divergence0.0498 | 8 | |
| Open-ended Text Generation | Chatbot Arena inspired qualitative prompts (val) | ELO1.02e+3 | 4 | |
| Robustness Evaluation | Lexical Variation (punctuation) | Jensen-Shannon Divergence0.3982 | 4 | |
| Robustness Evaluation | Lexical Variation spelling | Jensen-Shannon Divergence0.026 | 4 | |
| Robustness Evaluation | Lexical Variation synonym | Jensen-Shannon Divergence0.1132 | 4 | |
| Robustness Evaluation | Lexical Variation typos | Jensen-Shannon Divergence0.2074 | 4 | |
| Natural Language Understanding | GLUE downstream fine-tuning | CoLA Score27.7 | 4 | |
| Robustness Evaluation | Lexical Variation contraction | Jensen-Shannon Divergence0.0823 | 4 |