Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Progressive Inference: Explaining Decoder-Only Sequence Classification Models Using Intermediate Predictions

About

This paper proposes Progressive Inference - a framework to compute input attributions to explain the predictions of decoder-only sequence classification models. Our work is based on the insight that the classification head of a decoder-only Transformer model can be used to make intermediate predictions by evaluating them at different points in the input sequence. Due to the causal attention mechanism, these intermediate predictions only depend on the tokens seen before the inference point, allowing us to obtain the model's prediction on a masked input sub-sequence, with negligible computational overheads. We develop two methods to provide sub-sequence level attributions using this insight. First, we propose Single Pass-Progressive Inference (SP-PI), which computes attributions by taking the difference between consecutive intermediate predictions. Second, we exploit a connection with Kernel SHAP to develop Multi Pass-Progressive Inference (MP-PI). MP-PI uses intermediate predictions from multiple masked versions of the input to compute higher quality attributions. Our studies on a diverse set of models trained on text classification tasks show that SP-PI and MP-PI provide significantly better attributions compared to prior work.

Sanjay Kariyappa, Freddy L\'ecu\'e, Saumitra Mishra, Christopher Pond, Daniele Magazzeni, Manuela Veloso• 2024

Related benchmarks

TaskDatasetResultRank
Faithfulness EvaluationTellMeWhy
AUC π-Soft-NS0.26
67
Faithfulness EvaluationWikiBio
AUC π-Soft-NS0.26
67
Attribution AlignmentCurated Attribution Dataset (NarrativeQA + SciQ)
DSA (Dependent Sentence Attribution)2.88
40
Attribution FaithfulnessLongRA
Soft-NC Score1.35
40
Causal AttributionCausal and Downstream Robustness Ablation Suite Averaged over LLaMA-3.1 70B, Phi-3 14B, GPT-J 6B, Qwen2.5 3B
Causal Pass@562
14
Fact CheckingCausal and Downstream Robustness Ablation Suite Averaged over 4 models
Fact EMΔ1.5
14
Span ExtractionCausal and Downstream Robustness Ablation Suite
Span F158
14
Tool UseCausal and Downstream Robustness Ablation Suite Averaged over 4 models
Tool Hit@1Δ1.6
14
Decoding StabilityCausal and Downstream Robustness Ablation Suite Averaged over 4 models
Decoding Δ%2.7
14
Showing 9 of 9 rows

Other info

Follow for update