Neural Prototype Trees for Interpretable Fine-grained Image Recognition
About
Prototype-based methods use interpretable representations to address the black-box nature of deep learning models, in contrast to post-hoc explanation methods that only approximate such models. We propose the Neural Prototype Tree (ProtoTree), an intrinsically interpretable deep learning method for fine-grained image recognition. ProtoTree combines prototype learning with decision trees, and thus results in a globally interpretable model by design. Additionally, ProtoTree can locally explain a single prediction by outlining a decision path through the tree. Each node in our binary tree contains a trainable prototypical part. The presence or absence of this learned prototype in an image determines the routing through a node. Decision making is therefore similar to human reasoning: Does the bird have a red throat? And an elongated beak? Then it's a hummingbird! We tune the accuracy-interpretability trade-off using ensemble methods, pruning and binarizing. We apply pruning without sacrificing accuracy, resulting in a small tree with only 8 learned prototypes along a path to classify a bird from 200 species. An ensemble of 5 ProtoTrees achieves competitive accuracy on the CUB-200- 2011 and Stanford Cars data sets. Code is available at https://github.com/M-Nauta/ProtoTree
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | CUB-200-2011 (test) | Top-1 Acc77.22 | 276 | |
| Image Classification | CUB-200 2011 | Accuracy20.88 | 257 | |
| Image Classification | ImageNet-1k (val) | Accuracy9.07 | 189 | |
| Image Classification | Caltech101 | Base Accuracy87.72 | 129 | |
| Image Classification | Caltech101 (test) | Accuracy86.02 | 121 | |
| Image Classification | CUB-200 (test) | Accuracy82.2 | 62 | |
| Image Classification | CARS196 (test) | -- | 38 | |
| Classification | AWA2 (test) | -- | 22 | |
| Diagnostic Classification | TBX11K | F1 Score94 | 12 | |
| Image Classification | CUB200 | Accuracy82.2 | 9 |