MapFormer: Self-Supervised Learning of Cognitive Maps with Input-Dependent Positional Embeddings
About
A cognitive map is an internal model which encodes the abstract relationships among entities in the world, giving humans and animals the flexibility to adapt to new situations, with a strong out-of-distribution (OOD) generalization that current AI systems still do not possess. To bridge this gap, we introduce $\textit{MapFormers}$, new Transformer-based architectures, which can learn cognitive maps from observational data and perform path-integration without supervision. Cognitive maps are learned in the model by disentangling structural relationships in the inputs from their specific content, a property that can be achieved by updating position encodings with input-dependent matrices, built as exponentials of learned combinations of Lie-algebra generators. We developed two variants of $\textit{MapFormers}$ that unify absolute and relative positional encoding to model episodic (EM) and working memory (WM), respectively. We tested $\textit{MapFormers}$ on several formal tasks targeting distinct cognitive capacities, including gating, 2D navigation and nested hierarchies (Dyck Languages). Our results demonstrate that $\textit{MapFormers}$ significantly outperform current AI architectures, achieving near-perfect OOD generalization where standard models fail. Furthermore, we show that $\textit{MapFormers}$ are scalable; evaluations on naturalistic data yield perplexity improvements over baselines, suggesting that these principles extend to large-scale, real-world domains. These results are obtained through efficient parallel computation on commutative maps, though our models can also learn non-commutative cognitive maps via sequential path-integration. Overall, these results suggest that input-dependent matrices provide a critical structural bias, by disentangling abstract relations from content in order to drive robust OOD generalization.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Grid Navigation | 1D Grid Navigation Sequence length 128, grid width 64 (IID) | Accuracy100 | 12 | |
| Grid Navigation | 1D Grid Navigation 64/32/0.2 (OOD-dense D) | Accuracy100 | 12 | |
| Grid Navigation | 1D Grid Navigation OOD-sparse S 256/128/0.8 | Accuracy100 | 12 | |
| Grid Navigation | 2D Grid Navigation Sequence length 128, grid width 64 (IID) | Accuracy100 | 12 | |
| Grid Navigation | 2D Grid Navigation D 64/32/0.2 (OOD-dense) | Accuracy100 | 12 | |
| Grid Navigation | 2D Grid Navigation S 256/128/0.8 (OOD-sparse) | Accuracy100 | 12 | |
| Selective-Copy Task | Selective-Copy OOD Dense split: 64/128 blank/non-blank | Accuracy100 | 8 | |
| Selective-Copy Task | Selective-Copy OOD Sparse 256 128 blank non-blank | Accuracy100 | 8 | |
| Selective-Copy Task | Selective-Copy IID 128/128 blank/non-blank | Accuracy100 | 8 | |
| Family tree navigation | Family tree navigation L=64, G=5 (IID) | Accuracy100 | 7 |