Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Episodic Multi-Task Learning with Heterogeneous Neural Processes

About

This paper focuses on the data-insufficiency problem in multi-task learning within an episodic training setup. Specifically, we explore the potential of heterogeneous information across tasks and meta-knowledge among episodes to effectively tackle each task with limited data. Existing meta-learning methods often fail to take advantage of crucial heterogeneous information in a single episode, while multi-task learning models neglect reusing experience from earlier episodes. To address the problem of insufficient data, we develop Heterogeneous Neural Processes (HNPs) for the episodic multi-task setup. Within the framework of hierarchical Bayes, HNPs effectively capitalize on prior experiences as meta-knowledge and capture task-relatedness among heterogeneous tasks, mitigating data-insufficiency. Meanwhile, transformer-structured inference modules are designed to enable efficient inferences toward meta-knowledge and task-relatedness. In this way, HNPs can learn more powerful functional priors for adapting to novel heterogeneous tasks in each meta-test episode. Experimental results show the superior performance of the proposed HNPs over typical baselines, and ablation studies verify the effectiveness of the designed inference modules.

Jiayi Shen, Xiantong Zhen, Qi (Cheems) Wang, Marcel Worring• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationOffice-Home (test)
Mean Accuracy70.9
199
Image ClassificationOffice-31 (test)
Avg Accuracy71.89
93
Episodic multi-task classificationOffice-Home meta (test)
Avg Accuracy80.8
36
Episodic multi-task classificationDomainNet meta (test)
Accuracy69.38
36
Image ClassificationOffice-Caltech (test)
Average Accuracy95.8
35
Image ClassificationImageCLEF (test)
Accuracy80.9
33
Multi-Domain ClassificationOffice-Home (test)--
20
Rotation angle estimationRotated MNIST (test)
Average NMSE0.103
7
Episodic Multi-Task Regression1-D Gaussian Process Regression target points from all tasks toy
Avg. NLL-0.5207
4
Showing 9 of 9 rows

Other info

Code

Follow for update