Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Stochastic Multiple Target Sampling Gradient Descent

About

Sampling from an unnormalized target distribution is an essential problem with many applications in probabilistic inference. Stein Variational Gradient Descent (SVGD) has been shown to be a powerful method that iteratively updates a set of particles to approximate the distribution of interest. Furthermore, when analysing its asymptotic properties, SVGD reduces exactly to a single-objective optimization problem and can be viewed as a probabilistic version of this single-objective optimization problem. A natural question then arises: "Can we derive a probabilistic version of the multi-objective optimization?". To answer this question, we propose Stochastic Multiple Target Sampling Gradient Descent (MT-SGD), enabling us to sample from multiple unnormalized target distributions. Specifically, our MT-SGD conducts a flow of intermediate distributions gradually orienting to multiple target distributions, which allows the sampled particles to move to the joint high-likelihood region of the target distributions. Interestingly, the asymptotic analysis shows that our approach reduces exactly to the multiple-gradient descent algorithm for multi-objective optimization, as expected. Finally, we conduct comprehensive experiments to demonstrate the merit of our approach to multi-task learning.

Hoang Phan, Ngoc Tran, Trung Le, Toan Tran, Nhat Ho, Dinh Phung• 2022

Related benchmarks

TaskDatasetResultRank
Bottom-right digit classificationMulti-MNIST
ECE (%)4
5
Bottom-right fashion item classificationMulti-Fashion
Expected Calibration Error4.47
5
Bottom-right item classificationMulti-Fashion+MNIST
ECE (%)3.17
5
Top-left digit classificationMulti-MNIST
ECE3.28
5
Top-left fashion item classificationMulti-Fashion
ECE (%)3.8
5
Top-left item classificationMulti-Fashion+MNIST
Expected Calibration Error (ECE)4.65
5
Binary ClassificationCelebA subset of 40k images
Accuracy (5S)92.6
4
Multivariate RegressionSARCOS (val)
Task 1 Score0.0173
3
Multivariate RegressionSARCOS (test)
Task 1 Error0.0037
3
Showing 9 of 9 rows

Other info

Code

Follow for update