Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Studying Large Language Model Generalization with Influence Functions

About

When trying to gain better visibility into a machine learning model in order to understand and mitigate the associated risks, a potentially valuable source of evidence is: which training examples most contribute to a given behavior? Influence functions aim to answer a counterfactual: how would the model's parameters (and hence its outputs) change if a given sequence were added to the training set? While influence functions have produced insights for small models, they are difficult to scale to large language models (LLMs) due to the difficulty of computing an inverse-Hessian-vector product (IHVP). We use the Eigenvalue-corrected Kronecker-Factored Approximate Curvature (EK-FAC) approximation to scale influence functions up to LLMs with up to 52 billion parameters. In our experiments, EK-FAC achieves similar accuracy to traditional influence function estimators despite the IHVP computation being orders of magnitude faster. We investigate two algorithmic techniques to reduce the cost of computing gradients of candidate training sequences: TF-IDF filtering and query batching. We use influence functions to investigate the generalization patterns of LLMs, including the sparsity of the influence patterns, increasing abstraction with scale, math and programming abilities, cross-lingual generalization, and role-playing behavior. Despite many apparently sophisticated forms of generalization, we identify a surprising limitation: influences decay to near-zero when the order of key phrases is flipped. Overall, influence functions give us a powerful new tool for studying the generalization properties of LLMs.

Roger Grosse, Juhan Bae, Cem Anil, Nelson Elhage, Alex Tamkin, Amirhossein Tajdini, Benoit Steiner, Dustin Li, Esin Durmus, Ethan Perez, Evan Hubinger, Kamil\.e Luko\v{s}i\=ut\.e, Karina Nguyen, Nicholas Joseph, Sam McCandlish, Jared Kaplan, Samuel R. Bowman• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationANIMAL-10N
Accuracy0.8089
43
Image ClassificationCIFAR-100-N
Accuracy59.91
41
Image ClassificationCIFAR-10N (test)
Accuracy91.76
19
Defense against adaptive adversarial attacksCelebA
Accuracy77.36
18
Defense against adaptive adversarial attacksBank
Accuracy87.38
18
Defense against adaptive adversarial attacksJigsawToxicity
Accuracy70.05
18
Image ClassificationCIFAR-10N-r
Accuracy90.47
11
Image ClassificationCIFAR-10N w
Accuracy83.25
11
Training Data AttributionGPT2-small
LDS Score0.3936
10
Backdoor Attribution RetrievalCIFAR-10 poisoned (train)
Recall@505.8
8
Showing 10 of 11 rows

Other info

Follow for update