Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Bridging Discrete and Backpropagation: Straight-Through and Beyond

About

Backpropagation, the cornerstone of deep learning, is limited to computing gradients for continuous variables. This limitation poses challenges for problems involving discrete latent variables. To address this issue, we propose a novel approach to approximate the gradient of parameters involved in generating discrete latent variables. First, we examine the widely used Straight-Through (ST) heuristic and demonstrate that it works as a first-order approximation of the gradient. Guided by our findings, we propose ReinMax, which achieves second-order accuracy by integrating Heun's method, a second-order numerical method for solving ODEs. ReinMax does not require Hessian or other second-order derivatives, thus having negligible computation overheads. Extensive experimental results on various tasks demonstrate the superiority of ReinMax over the state of the art. Implementations are released at https://github.com/microsoft/ReinMax.

Liyuan Liu, Chengyu Dong, Xiaodong Liu, Bin Yu, Jianfeng Gao• 2023

Related benchmarks

TaskDatasetResultRank
Generative ModelingMNIST (train)
Neg. ELBO93.44
42
Generative ModelingMNIST (test)
N-ELBO100.6
35
Generative ModelingFashion-MNIST (train)--
30
Generative ModelingOmniglot (train)--
30
Unsupervised ParsingListOps (val)
Accuracy67.65
5
Unsupervised ParsingListOps (test)
Accuracy68.07
5
Image ClassificationNATS-Bench CIFAR-10 1.0 (val)
Accuracy90.01
2
Image ClassificationNATS-Bench CIFAR-10 1.0 (test)
Accuracy93.44
2
Image ClassificationNATS-Bench CIFAR-100 1.0 (val)
Accuracy69.29
2
Image ClassificationNATS-Bench CIFAR-100 1.0 (test)
Accuracy69.41
2
Showing 10 of 12 rows

Other info

Code

Follow for update