Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Auxiliary Losses for Learning Generalizable Concept-based Models

About

The increasing use of neural networks in various applications has lead to increasing apprehensions, underscoring the necessity to understand their operations beyond mere final predictions. As a solution to enhance model transparency, Concept Bottleneck Models (CBMs) have gained popularity since their introduction. CBMs essentially limit the latent space of a model to human-understandable high-level concepts. While beneficial, CBMs have been reported to often learn irrelevant concept representations that consecutively damage model performance. To overcome the performance trade-off, we propose cooperative-Concept Bottleneck Model (coop-CBM). The concept representation of our model is particularly meaningful when fine-grained concept labels are absent. Furthermore, we introduce the concept orthogonal loss (COL) to encourage the separation between the concept representations and to reduce the intra-concept distance. This paper presents extensive experiments on real-world datasets for image classification tasks, namely CUB, AwA2, CelebA and TIL. We also study the performance of coop-CBM models under various distributional shift settings. We show that our proposed method achieves higher accuracy in all distributional shift settings even compared to the black-box models with the highest concept accuracy.

Ivaxi Sheth, Samira Ebrahimi Kahou• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR100
Accuracy84.66
331
Image ClassificationCUB
Accuracy82.1
249
Image ClassificationCUB--
89
ClassificationCUB
Accuracy79.154
85
Image ClassificationImageNet
Accuracy82.73
47
Image ClassificationCUB (test)
Top-1 Accuracy84.1
31
Image ClassificationCaltech-UCSD Birds (CUB-200-2011) (test)
Accuracy84.1
22
ClassificationAWA2 (test)--
22
Cancer Cell ClassificationTIL (test)
Accuracy54.2
8
Fine-grained Image ClassificationCUB Out-of-distribution Background Spurious Correlation (test)
Accuracy0.362
8
Showing 10 of 24 rows

Other info

Code

Follow for update