Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Gradient Surgery for Multi-Task Learning

About

While deep learning and deep reinforcement learning (RL) systems have demonstrated impressive results in domains such as image classification, game playing, and robotic control, data efficiency remains a major challenge. Multi-task learning has emerged as a promising approach for sharing structure across multiple tasks to enable more efficient learning. However, the multi-task setting presents a number of optimization challenges, making it difficult to realize large efficiency gains compared to learning tasks independently. The reasons why multi-task learning is so challenging compared to single-task learning are not fully understood. In this work, we identify a set of three conditions of the multi-task optimization landscape that cause detrimental gradient interference, and develop a simple yet general approach for avoiding such interference between task gradients. We propose a form of gradient surgery that projects a task's gradient onto the normal plane of the gradient of any other task that has a conflicting gradient. On a series of challenging multi-task supervised and multi-task RL problems, this approach leads to substantial gains in efficiency and performance. Further, it is model-agnostic and can be combined with previously-proposed multi-task architectures for enhanced performance.

Tianhe Yu, Saurabh Kumar, Abhishek Gupta, Sergey Levine, Karol Hausman, Chelsea Finn• 2020

Related benchmarks

TaskDatasetResultRank
Semantic segmentationCityscapes (test)
mIoU75.13
1145
Depth EstimationNYU v2 (test)--
423
Semantic segmentationNYU v2 (test)
mIoU51.77
248
Surface Normal EstimationNYU v2 (test)
Mean Angle Distance (MAD)24.31
206
Image ClassificationOffice-Home (test)--
199
Depth EstimationNYU Depth V2
RMSE0.596
177
Semantic segmentationNYU Depth V2 (test)
mIoU38.06
172
Surface Normal PredictionNYU V2
Mean Error20.5
100
Semantic segmentationNYUD v2
mIoU38.61
96
Multi-Label ClassificationChestX-Ray14 (test)--
88
Showing 10 of 78 rows
...

Other info

Follow for update