Embedding Trajectory for Out-of-Distribution Detection in Mathematical Reasoning
About
Real-world data deviating from the independent and identically distributed (i.i.d.) assumption of in-distribution training data poses security threats to deep networks, thus advancing out-of-distribution (OOD) detection algorithms. Detection methods in generative language models (GLMs) mainly focus on uncertainty estimation and embedding distance measurement, with the latter proven to be most effective in traditional linguistic tasks like summarization and translation. However, another complex generative scenario mathematical reasoning poses significant challenges to embedding-based methods due to its high-density feature of output spaces, but this feature causes larger discrepancies in the embedding shift trajectory between different samples in latent spaces. Hence, we propose a trajectory-based method TV score, which uses trajectory volatility for OOD detection in mathematical reasoning. Experiments show that our method outperforms all traditional algorithms on GLMs under mathematical reasoning scenarios and can be extended to more applications with high-density features in output spaces, such as multiple-choice questions.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Offline OOD Detection | Mathematical Reasoning Near-shift OOD | AUROC98.76 | 14 | |
| Offline OOD Detection | Mathematical Reasoning Far-shift OOD | AUROC96.54 | 14 | |
| OOD Quality Estimation | Mathematical Reasoning Far-shift OOD | Kendall's Tau0.161 | 12 | |
| OOD Quality Estimation | Mathematical Reasoning OOD (Near-shift) | Kendall Tau0.159 | 12 | |
| Online Out-of-Distribution Detection | Algebra Far-shift OOD | Accuracy93.88 | 3 | |
| Online Out-of-Distribution Detection | Geometry Far-shift OOD | Accuracy0.9447 | 3 | |
| Online Out-of-Distribution Detection | Cnt.&Prob (Far-shift OOD) | Accuracy93.74 | 3 | |
| Online Out-of-Distribution Detection | Num. Theory (Far-shift OOD) | Accuracy92.08 | 3 | |
| Online Out-of-Distribution Detection | Precalculus Far-shift OOD | Accuracy99.28 | 3 | |
| Online Out-of-Distribution Detection | GSM8K Near-shift OOD | Accuracy93.39 | 3 |