Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Gradient Surgery for Multi-Task Learning

About

While deep learning and deep reinforcement learning (RL) systems have demonstrated impressive results in domains such as image classification, game playing, and robotic control, data efficiency remains a major challenge. Multi-task learning has emerged as a promising approach for sharing structure across multiple tasks to enable more efficient learning. However, the multi-task setting presents a number of optimization challenges, making it difficult to realize large efficiency gains compared to learning tasks independently. The reasons why multi-task learning is so challenging compared to single-task learning are not fully understood. In this work, we identify a set of three conditions of the multi-task optimization landscape that cause detrimental gradient interference, and develop a simple yet general approach for avoiding such interference between task gradients. We propose a form of gradient surgery that projects a task's gradient onto the normal plane of the gradient of any other task that has a conflicting gradient. On a series of challenging multi-task supervised and multi-task RL problems, this approach leads to substantial gains in efficiency and performance. Further, it is model-agnostic and can be combined with previously-proposed multi-task architectures for enhanced performance.

Tianhe Yu, Saurabh Kumar, Abhishek Gupta, Sergey Levine, Karol Hausman, Chelsea Finn• 2020

Related benchmarks

TaskDatasetResultRank
Semantic segmentationCityscapes (test)
mIoU75.13
1154
Semantic segmentationCityscapes
mIoU75.13
658
Depth EstimationNYU v2 (test)--
432
Semantic segmentationNYU v2 (test)
mIoU51.77
282
Surface Normal EstimationNYU v2 (test)
Mean Angle Distance (MAD)24.31
224
Depth EstimationNYU Depth V2
RMSE0.596
209
Image ClassificationOffice-Home (test)--
199
Semantic segmentationNYU Depth V2 (test)
mIoU38.06
183
Semantic segmentationNYUD v2
mIoU38.61
125
Surface Normal PredictionNYU V2
Mean Error20.5
118
Showing 10 of 115 rows
...

Other info

Follow for update