DisARM: An Antithetic Gradient Estimator for Binary Latent Variables
About
Training models with discrete latent variables is challenging due to the difficulty of estimating the gradients accurately. Much of the recent progress has been achieved by taking advantage of continuous relaxations of the system, which are not always available or even possible. The Augment-REINFORCE-Merge (ARM) estimator provides an alternative that, instead of relaxation, uses continuous augmentation. Applying antithetic sampling over the augmenting variables yields a relatively low-variance and unbiased estimator applicable to any model with binary latent variables. However, while antithetic sampling reduces variance, the augmentation process increases variance. We show that ARM can be improved by analytically integrating out the randomness introduced by the augmentation process, guaranteeing substantial variance reduction. Our estimator, DisARM, is simple to implement and has the same computational cost as ARM. We evaluate DisARM on several generative modeling benchmarks and show that it consistently outperforms ARM and a strong independent sample baseline in terms of both variance and log-likelihood. Furthermore, we propose a local version of DisARM designed for optimizing the multi-sample variational bound, and show that it outperforms VIMCO, the current state-of-the-art method.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Log-likelihood estimation | MNIST dynamically binarized (test) | Log-Likelihood-101.6 | 48 | |
| Binary Latent VAE Training | MNIST (train) | Avg ELBO668 | 14 | |
| Binary Latent VAE Training | Fashion-MNIST (train) | Average ELBO182.7 | 14 | |
| Binary Latent VAE Training | Omniglot (train) | Average ELBO446.6 | 14 | |
| Generative Modeling | Dynamically binarized MNIST (test) | NELBO-97.56 | 13 | |
| Generative Modeling | Omniglot dynamically binarized (train) | Training ELBO-108.6 | 9 | |
| Generative Modeling | OMNIGLOT dynamically binarized (test) | Log-Likelihood Bound (100-point)-107.3 | 9 | |
| Generative Modeling | MNIST dynamically binarized (train) | Training ELBO-97.95 | 9 | |
| Generative Modeling | Fashion-MNIST dynamically binarized (train) | ELBO (Train)-234.4 | 9 | |
| Generative Modeling | Fashion-MNIST dynamically binarized (test) | Test Log-Likelihood Bound-234.5 | 9 |