ReBaPL: Repulsive Bayesian Prompt Learning
About
Prompt learning has emerged as an effective technique for fine-tuning large-scale foundation models for downstream tasks. However, conventional prompt learning methods are prone to overfitting and can struggle with out-of-distribution generalization. To address these limitations, Bayesian prompt learning has been proposed, which frames prompt optimization as a Bayesian inference problem to enhance robustness. This paper introduces Repulsive Bayesian Prompt Learning (ReBaPL), a novel method for Bayesian prompt learning, designed to efficiently explore the complex and often multimodal posterior landscape of prompts. Our method integrates a cyclical step-size schedule with a stochastic gradient Hamiltonian Monte Carlo (SGHMC) algorithm, enabling alternating phases of exploration to discover new modes, and exploitation to refine existing modes. Furthermore, we introduce a repulsive force derived from a potential function over probability metrics (including Maximum Mean Discrepancy and Wasserstein distance) computed on the distributions of representations produced by different prompts. This representation-space repulsion diversifies exploration and prevents premature collapse to a single mode. Our approach allows for a more comprehensive characterization of the prompt posterior distribution, leading to improved generalization. In contrast to prior Bayesian prompt learning methods, our method provides a modular plug-and-play Bayesian extension of any existing prompt learning method based on maximum likelihood estimation. We demonstrate the efficacy of ReBaPL on several benchmark datasets, showing superior performance over state-of-the-art prompt learning methods.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | ImageNet 1k (test) | Top-1 Accuracy71 | 848 | |
| Image Classification | FGVC-Aircraft (test) | -- | 305 | |
| Image Classification | ImageNet V2 (test) | Top-1 Accuracy62.5 | 216 | |
| Image Classification | ImageNet-A (test) | -- | 175 | |
| Image Classification | Caltech101 (test) | -- | 159 | |
| Image Classification | EuroSAT (test) | -- | 141 | |
| Image Classification | ImageNet-R (test) | Accuracy75.63 | 118 | |
| Image Classification | UCF101 | Base Classes Acc88.33 | 100 | |
| Image Classification | Food101 (test) | -- | 91 | |
| Image Classification | StanfordCars (test) | Base Accuracy81.2 | 20 |