Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

CARMS: Categorical-Antithetic-REINFORCE Multi-Sample Gradient Estimator

About

Accurately backpropagating the gradient through categorical variables is a challenging task that arises in various domains, such as training discrete latent variable models. To this end, we propose CARMS, an unbiased estimator for categorical random variables based on multiple mutually negatively correlated (jointly antithetic) samples. CARMS combines REINFORCE with copula based sampling to avoid duplicate samples and reduce its variance, while keeping the estimator unbiased using importance sampling. It generalizes both the ARMS antithetic estimator for binary variables, which is CARMS for two categories, as well as LOORF/VarGrad, the leave-one-out REINFORCE estimator, which is CARMS with independent samples. We evaluate CARMS on several benchmark datasets on a generative modeling task, as well as a structured output prediction task, and find it to outperform competing methods including a strong self-control baseline. The code is publicly available.

Alek Dimitriev, Mingyuan Zhou• 2021

Related benchmarks

TaskDatasetResultRank
Log-likelihood estimationMNIST dynamically binarized (test)
Log-Likelihood-92.97
48
Generative ModelingDynamic MNIST (train)
Log Likelihood-92.13
30
Generative ModelingFashion-MNIST (train)
Log Likelihood (100 samples)-230.8
30
Generative ModelingOmniglot (train)
Log Likelihood-108.6
30
VAE Log-Likelihood EstimationFashion MNIST (test)
Log-Likelihood-233.4
30
Variational InferenceOmniglot (test)
Test Log Likelihood-112.7
30
Conditional estimationDynamic MNIST (test)
Test Log Likelihood60.01
18
Conditional estimationDynamic MNIST (train)
Final Log Likelihood58.35
15
Conditional estimationOmniglot (train)
Final Training Log Likelihood66.94
15
Conditional estimationOmniglot (test)
Test Log Likelihood72.88
15
Showing 10 of 12 rows

Other info

Code

Follow for update