Episodic Training for Domain Generalization
About
Domain generalization (DG) is the challenging and topical problem of learning models that generalize to novel testing domains with different statistics than a set of known training domains. The simple approach of aggregating data from all source domains and training a single deep neural network end-to-end on all the data provides a surprisingly strong baseline that surpasses many prior published methods. In this paper, we build on this strong baseline by designing an episodic training procedure that trains a single deep network in a way that exposes it to the domain shift that characterises a novel domain at runtime. Specifically, we decompose a deep network into feature extractor and classifier components, and then train each component by simulating it interacting with a partner who is badly tuned for the current domain. This makes both components more robust, ultimately leading to our networks producing state-of-the-art performance on three DG benchmarks. Furthermore, we consider the pervasive workflow of using an ImageNet trained CNN as a fixed feature extractor for downstream recognition tasks. Using the Visual Decathlon benchmark, we demonstrate that our episodic-DG training improves the performance of such a general-purpose feature extractor by explicitly training a feature for robustness to novel problems. This shows that DG training can benefit standard practice in computer vision.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | PACS (test) | Average Accuracy81.5 | 254 | |
| Image Classification | PACS | Overall Average Accuracy81.5 | 230 | |
| Domain Generalization | PACS (test) | Average Accuracy81.5 | 225 | |
| Domain Generalization | PACS | Accuracy (Art)82.1 | 221 | |
| Image Classification | DomainNet (test) | Average Accuracy63.85 | 209 | |
| Domain Generalization | PACS (leave-one-domain-out) | Art Accuracy82.1 | 146 | |
| object recognition | PACS (leave-one-domain-out) | Acc (Art painting)82.1 | 112 | |
| Image Classification | PACS v1 (test) | Average Accuracy81.5 | 92 | |
| Image Classification | VLCS (test) | Average Accuracy72.9 | 65 | |
| Image Classification | PACS (out-of-domain) | Overall Accuracy81.5 | 63 |