Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

DisARM: An Antithetic Gradient Estimator for Binary Latent Variables

About

Training models with discrete latent variables is challenging due to the difficulty of estimating the gradients accurately. Much of the recent progress has been achieved by taking advantage of continuous relaxations of the system, which are not always available or even possible. The Augment-REINFORCE-Merge (ARM) estimator provides an alternative that, instead of relaxation, uses continuous augmentation. Applying antithetic sampling over the augmenting variables yields a relatively low-variance and unbiased estimator applicable to any model with binary latent variables. However, while antithetic sampling reduces variance, the augmentation process increases variance. We show that ARM can be improved by analytically integrating out the randomness introduced by the augmentation process, guaranteeing substantial variance reduction. Our estimator, DisARM, is simple to implement and has the same computational cost as ARM. We evaluate DisARM on several generative modeling benchmarks and show that it consistently outperforms ARM and a strong independent sample baseline in terms of both variance and log-likelihood. Furthermore, we propose a local version of DisARM designed for optimizing the multi-sample variational bound, and show that it outperforms VIMCO, the current state-of-the-art method.

Zhe Dong, Andriy Mnih, George Tucker• 2020

Related benchmarks

TaskDatasetResultRank
Log-likelihood estimationMNIST dynamically binarized (test)
Log-Likelihood-101.6
48
Binary Latent VAE TrainingMNIST (train)
Avg ELBO668
14
Binary Latent VAE TrainingFashion-MNIST (train)
Average ELBO182.7
14
Binary Latent VAE TrainingOmniglot (train)
Average ELBO446.6
14
Generative ModelingDynamically binarized MNIST (test)
NELBO-97.56
13
Generative ModelingOmniglot dynamically binarized (train)
Training ELBO-108.6
9
Generative ModelingOMNIGLOT dynamically binarized (test)
Log-Likelihood Bound (100-point)-107.3
9
Generative ModelingMNIST dynamically binarized (train)
Training ELBO-97.95
9
Generative ModelingFashion-MNIST dynamically binarized (train)
ELBO (Train)-234.4
9
Generative ModelingFashion-MNIST dynamically binarized (test)
Test Log-Likelihood Bound-234.5
9
Showing 10 of 15 rows

Other info

Follow for update