Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Association Graph Learning for Multi-Task Classification with Category Shifts

About

In this paper, we focus on multi-task classification, where related classification tasks share the same label space and are learned simultaneously. In particular, we tackle a new setting, which is more realistic than currently addressed in the literature, where categories shift from training to test data. Hence, individual tasks do not contain complete training data for the categories in the test set. To generalize to such test data, it is crucial for individual tasks to leverage knowledge from related tasks. To this end, we propose learning an association graph to transfer knowledge among tasks for missing classes. We construct the association graph with nodes representing tasks, classes and instances, and encode the relationships among the nodes in the edges to guide their mutual knowledge transfer. By message passing on the association graph, our model enhances the categorical information of each instance, making it more discriminative. To avoid spurious correlations between task and class nodes in the graph, we introduce an assignment entropy maximization that encourages each class node to balance its edge weights. This enables all tasks to fully utilize the categorical information from related tasks. An extensive evaluation on three general benchmarks and a medical dataset for skin lesion classification reveals that our method consistently performs better than representative baselines.

Jiayi Shen, Zehao Xiao, Xiantong Zhen, Cees G. M. Snoek, Marcel Worring• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationOffice-Home (test)--
199
Image ClassificationOffice-Caltech (test)--
35
Image ClassificationImageCLEF (test)--
33
Multi-task ClassificationOffice-Home
Metric Ao87.16
20
Multi-task ClassificationOffice-Caltech
Accuracy (Task o)98.51
20
Multi-task ClassificationImageCLEF
Ao87.08
20
Medical Image ClassificationSkin-Lesion gamma = 67% (test)
Metric Am10.82
5
Medical Image ClassificationSkin-Lesion gamma = 33% (test)
Sensitivity16.58
5
Medical Image ClassificationSkin-Lesion gamma = 0% (test)
Accuracy85.98
5
Multi-task ClassificationSkin-Lesion gamma = 67%
Metric Am10.82
5
Showing 10 of 12 rows

Other info

Code

Follow for update