Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variables

About

To address the challenge of backpropagating the gradient through categorical variables, we propose the augment-REINFORCE-swap-merge (ARSM) gradient estimator that is unbiased and has low variance. ARSM first uses variable augmentation, REINFORCE, and Rao-Blackwellization to re-express the gradient as an expectation under the Dirichlet distribution, then uses variable swapping to construct differently expressed but equivalent expectations, and finally shares common random numbers between these expectations to achieve significant variance reduction. Experimental results show ARSM closely resembles the performance of the true gradient for optimization in univariate settings; outperforms existing estimators by a large margin when applied to categorical variational auto-encoders; and provides a "try-and-see self-critic" variance reduction method for discrete-action policy gradient, which removes the need of estimating baselines by generating a random number of pseudo actions and estimating their action-value functions.

Mingzhang Yin, Yuguang Yue, Mingyuan Zhou• 2019

Related benchmarks

TaskDatasetResultRank
Log-likelihood estimationMNIST dynamically binarized (test)
Log-Likelihood-98.73
48
Generative ModelingDynamic MNIST (train)
Log Likelihood-97.76
30
Generative ModelingFashion-MNIST (train)
Log Likelihood (100 samples)-235.9
30
VAE Log-Likelihood EstimationFashion MNIST (test)
Log-Likelihood-238.6
30
Generative ModelingOmniglot (train)
Log Likelihood-115.1
30
Variational InferenceOmniglot (test)
Test Log Likelihood-116.6
30
Conditional estimationDynamic MNIST (test)
Test Log Likelihood60.92
18
Conditional estimationDynamic MNIST (train)
Final Log Likelihood60.22
15
Conditional estimationFashion-MNIST (train)
Final Training Log Likelihood134.6
15
Conditional estimationOmniglot (train)
Final Training Log Likelihood68.35
15
Showing 10 of 12 rows

Other info

Follow for update