Gradient Surgery for Multi-Task Learning
About
While deep learning and deep reinforcement learning (RL) systems have demonstrated impressive results in domains such as image classification, game playing, and robotic control, data efficiency remains a major challenge. Multi-task learning has emerged as a promising approach for sharing structure across multiple tasks to enable more efficient learning. However, the multi-task setting presents a number of optimization challenges, making it difficult to realize large efficiency gains compared to learning tasks independently. The reasons why multi-task learning is so challenging compared to single-task learning are not fully understood. In this work, we identify a set of three conditions of the multi-task optimization landscape that cause detrimental gradient interference, and develop a simple yet general approach for avoiding such interference between task gradients. We propose a form of gradient surgery that projects a task's gradient onto the normal plane of the gradient of any other task that has a conflicting gradient. On a series of challenging multi-task supervised and multi-task RL problems, this approach leads to substantial gains in efficiency and performance. Further, it is model-agnostic and can be combined with previously-proposed multi-task architectures for enhanced performance.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Semantic segmentation | Cityscapes (test) | mIoU75.13 | 1145 | |
| Depth Estimation | NYU v2 (test) | -- | 423 | |
| Semantic segmentation | NYU v2 (test) | mIoU51.77 | 248 | |
| Surface Normal Estimation | NYU v2 (test) | Mean Angle Distance (MAD)24.31 | 206 | |
| Image Classification | Office-Home (test) | -- | 199 | |
| Depth Estimation | NYU Depth V2 | RMSE0.596 | 177 | |
| Semantic segmentation | NYU Depth V2 (test) | mIoU38.06 | 172 | |
| Surface Normal Prediction | NYU V2 | Mean Error20.5 | 100 | |
| Semantic segmentation | NYUD v2 | mIoU38.61 | 96 | |
| Multi-Label Classification | ChestX-Ray14 (test) | -- | 88 |