Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Conflict-Averse Gradient Descent for Multi-task Learning

About

The goal of multi-task learning is to enable more efficient learning than single task learning by sharing model structures for a diverse set of tasks. A standard multi-task learning objective is to minimize the average loss across all tasks. While straightforward, using this objective often results in much worse final performance for each task than learning them independently. A major challenge in optimizing a multi-task model is the conflicting gradients, where gradients of different task objectives are not well aligned so that following the average gradient direction can be detrimental to specific tasks' performance. Previous work has proposed several heuristics to manipulate the task gradients for mitigating this problem. But most of them lack convergence guarantee and/or could converge to any Pareto-stationary point. In this paper, we introduce Conflict-Averse Gradient descent (CAGrad) which minimizes the average loss function, while leveraging the worst local improvement of individual tasks to regularize the algorithm trajectory. CAGrad balances the objectives automatically and still provably converges to a minimum over the average loss. It includes the regular gradient descent (GD) and the multiple gradient descent algorithm (MGDA) in the multi-objective optimization (MOO) literature as special cases. On a series of challenging multi-task supervised learning and reinforcement learning tasks, CAGrad achieves improved performance over prior state-of-the-art multi-objective gradient manipulation methods.

Bo Liu, Xingchao Liu, Xiaojie Jin, Peter Stone, Qiang Liu• 2021

Related benchmarks

TaskDatasetResultRank
Semantic segmentationCityscapes (test)
mIoU75.16
1145
Depth EstimationNYU v2 (test)--
423
Semantic segmentationNYU v2 (test)
mIoU52.04
248
Surface Normal EstimationNYU v2 (test)
Mean Angle Distance (MAD)23.39
206
Image ClassificationOffice-Home (test)--
199
Depth EstimationNYU Depth V2
RMSE0.595
177
Semantic segmentationNYU Depth V2 (test)
mIoU39.79
172
Surface Normal PredictionNYU V2
Mean Error20.38
100
Semantic segmentationNYUD v2
mIoU38.8
96
Multi-Label ClassificationChestX-Ray14 (test)--
88
Showing 10 of 42 rows

Other info

Code

Follow for update