Learning to Branch for Multi-Task Learning
About
Training multiple tasks jointly in one deep network yields reduced latency during inference and better performance over the single-task counterpart by sharing certain layers of a network. However, over-sharing a network could erroneously enforce over-generalization, causing negative knowledge transfer across tasks. Prior works rely on human intuition or pre-computed task relatedness scores for ad hoc branching structures. They provide sub-optimal end results and often require huge efforts for the trial-and-error process. In this work, we present an automated multi-task learning algorithm that learns where to share or branch within a network, designing an effective network topology that is directly optimized for multiple objectives across tasks. Specifically, we propose a novel tree-structured design space that casts a tree branching operation as a gumbel-softmax sampling procedure. This enables differentiable network splitting that is end-to-end trainable. We validate the proposed method on controlled synthetic data, CelebA, and Taskonomy.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | Office-Home (test) | Mean Accuracy62.2 | 199 | |
| Semantic segmentation | Pascal Context | mIoU61.84 | 111 | |
| Multi-Task Adaptation | Pascal Context (test) | -- | 70 | |
| Image Classification | Office-Caltech (test) | Average Accuracy89.9 | 35 | |
| Image Classification | ImageCLEF (test) | Accuracy71.6 | 33 | |
| Depth Estimation | Taskonomy (test) | Depth Estimation Error0.023 | 21 | |
| Semantic segmentation | Taskonomy (test) | mIoU52.1 | 16 | |
| Keypoint Detection | Taskonomy (test) | Keypoint Detection20.2 | 10 | |
| Edge Detection | Taskonomy (test) | Edge Det.21.7 | 10 | |
| Surface Normal Prediction | Taskonomy (test) | Surface Normal Accuracy85 | 10 |