Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Learning Longer-term Dependencies in RNNs with Auxiliary Losses

About

Despite recent advances in training recurrent neural networks (RNNs), capturing long-term dependencies in sequences remains a fundamental challenge. Most approaches use backpropagation through time (BPTT), which is difficult to scale to very long sequences. This paper proposes a simple method that improves the ability to capture long term dependencies in RNNs by adding an unsupervised auxiliary loss to the original objective. This auxiliary loss forces RNNs to either reconstruct previous events or predict next events in a sequence, making truncated backpropagation feasible for long sequences and also improving full BPTT. We evaluate our method on a variety of settings, including pixel-by-pixel image classification with sequence lengths up to 16\,000, and a real document classification benchmark. Our results highlight good performance and resource efficiency of this approach over competitive baselines, including other recurrent models and a comparable sized Transformer. Further analyses reveal beneficial effects of the auxiliary loss on optimization and regularization, as well as extreme cases where there is little to no backpropagation.

Trieu H. Trinh, Andrew M. Dai, Minh-Thang Luong, Quoc V. Le• 2018

Related benchmarks

TaskDatasetResultRank
Image ClassificationMNIST (test)
Accuracy98.4
882
Pixel-by-pixel Image ClassificationPermuted Sequential MNIST (pMNIST) (test)
Accuracy97.9
79
Sequential Image ClassificationPMNIST (test)
Accuracy (Test)97.9
77
Sequential Image ClassificationS-MNIST (test)
Accuracy98.9
70
Image Classificationpermuted MNIST (pMNIST) (test)
Accuracy95.2
63
Pixel-level 1-D image classificationSequential MNIST (test)
Accuracy98.4
53
Permuted Sequential Image ClassificationMNIST Permuted Sequential
Test Accuracy Mean95.2
50
Sequential Image ClassificationSequential CIFAR10
Accuracy72.2
48
1-D Pixel-level Image ClassificationsCIFAR (test)
Accuracy72.2
46
Pixel-by-pixel Image ClassificationCIFAR-10 sequential (test)
Accuracy72.2
37
Showing 10 of 28 rows

Other info

Follow for update