Auxiliary Losses for Learning Generalizable Concept-based Models
About
The increasing use of neural networks in various applications has lead to increasing apprehensions, underscoring the necessity to understand their operations beyond mere final predictions. As a solution to enhance model transparency, Concept Bottleneck Models (CBMs) have gained popularity since their introduction. CBMs essentially limit the latent space of a model to human-understandable high-level concepts. While beneficial, CBMs have been reported to often learn irrelevant concept representations that consecutively damage model performance. To overcome the performance trade-off, we propose cooperative-Concept Bottleneck Model (coop-CBM). The concept representation of our model is particularly meaningful when fine-grained concept labels are absent. Furthermore, we introduce the concept orthogonal loss (COL) to encourage the separation between the concept representations and to reduce the intra-concept distance. This paper presents extensive experiments on real-world datasets for image classification tasks, namely CUB, AwA2, CelebA and TIL. We also study the performance of coop-CBM models under various distributional shift settings. We show that our proposed method achieves higher accuracy in all distributional shift settings even compared to the black-box models with the highest concept accuracy.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | CIFAR100 | Accuracy84.66 | 331 | |
| Image Classification | CUB | Accuracy82.1 | 249 | |
| Image Classification | CUB | -- | 89 | |
| Classification | CUB | Accuracy79.154 | 85 | |
| Image Classification | ImageNet | Accuracy82.73 | 47 | |
| Image Classification | CUB (test) | Top-1 Accuracy84.1 | 31 | |
| Image Classification | Caltech-UCSD Birds (CUB-200-2011) (test) | Accuracy84.1 | 22 | |
| Classification | AWA2 (test) | -- | 22 | |
| Cancer Cell Classification | TIL (test) | Accuracy54.2 | 8 | |
| Fine-grained Image Classification | CUB Out-of-distribution Background Spurious Correlation (test) | Accuracy0.362 | 8 |