Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Order Matters: Sequence to sequence for sets

About

Sequences have become first class citizens in supervised learning thanks to the resurgence of recurrent neural networks. Many complex tasks that require mapping from or to a sequence of observations can now be formulated with the sequence-to-sequence (seq2seq) framework which employs the chain rule to efficiently represent the joint probability of sequences. In many cases, however, variable sized inputs and/or outputs might not be naturally expressed as sequences. For instance, it is not clear how to input a set of numbers into a model where the task is to sort them; similarly, we do not know how to organize outputs when they correspond to random variables and the task is to model their unknown joint probability. In this paper, we first show using various examples that the order in which we organize input and/or output data matters significantly when learning an underlying model. We then discuss an extension of the seq2seq framework that goes beyond sequences and handles input sets in a principled way. In addition, we propose a loss which, by searching over possible orders during training, deals with the lack of structure of output sets. We show empirical evidence of our claims regarding ordering, and on the modifications to the seq2seq framework on benchmark language modeling and parsing tasks, as well as two artificial tasks -- sorting numbers and estimating the joint probability of unknown graphical models.

Oriol Vinyals, Samy Bengio, Manjunath Kudlur• 2015

Related benchmarks

TaskDatasetResultRank
Graph ClassificationPROTEINS
Accuracy74.29
742
Graph ClassificationMUTAG
Accuracy73.7
697
Graph ClassificationNCI1
Accuracy66.97
460
Graph ClassificationCOLLAB
Accuracy79.6
329
Graph ClassificationIMDB-B
Accuracy72.2
322
Graph ClassificationENZYMES
Accuracy60.15
305
Graph ClassificationNCI109
Accuracy61.04
223
Graph ClassificationMutag (test)
Accuracy80.84
217
Graph ClassificationMUTAG (10-fold cross-validation)
Accuracy73
206
Graph ClassificationPROTEINS (10-fold cross-validation)
Accuracy72
197
Showing 10 of 27 rows

Other info

Follow for update