Episodic Multi-Task Learning with Heterogeneous Neural Processes
About
This paper focuses on the data-insufficiency problem in multi-task learning within an episodic training setup. Specifically, we explore the potential of heterogeneous information across tasks and meta-knowledge among episodes to effectively tackle each task with limited data. Existing meta-learning methods often fail to take advantage of crucial heterogeneous information in a single episode, while multi-task learning models neglect reusing experience from earlier episodes. To address the problem of insufficient data, we develop Heterogeneous Neural Processes (HNPs) for the episodic multi-task setup. Within the framework of hierarchical Bayes, HNPs effectively capitalize on prior experiences as meta-knowledge and capture task-relatedness among heterogeneous tasks, mitigating data-insufficiency. Meanwhile, transformer-structured inference modules are designed to enable efficient inferences toward meta-knowledge and task-relatedness. In this way, HNPs can learn more powerful functional priors for adapting to novel heterogeneous tasks in each meta-test episode. Experimental results show the superior performance of the proposed HNPs over typical baselines, and ablation studies verify the effectiveness of the designed inference modules.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | Office-Home (test) | Mean Accuracy70.9 | 199 | |
| Image Classification | Office-31 (test) | Avg Accuracy71.89 | 93 | |
| Episodic multi-task classification | Office-Home meta (test) | Avg Accuracy80.8 | 36 | |
| Episodic multi-task classification | DomainNet meta (test) | Accuracy69.38 | 36 | |
| Image Classification | Office-Caltech (test) | Average Accuracy95.8 | 35 | |
| Image Classification | ImageCLEF (test) | Accuracy80.9 | 33 | |
| Multi-Domain Classification | Office-Home (test) | -- | 20 | |
| Rotation angle estimation | Rotated MNIST (test) | Average NMSE0.103 | 7 | |
| Episodic Multi-Task Regression | 1-D Gaussian Process Regression target points from all tasks toy | Avg. NLL-0.5207 | 4 |