Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

SE(3) Equivariant Augmented Coupling Flows

About

Coupling normalizing flows allow for fast sampling and density evaluation, making them the tool of choice for probabilistic modeling of physical systems. However, the standard coupling architecture precludes endowing flows that operate on the Cartesian coordinates of atoms with the SE(3) and permutation invariances of physical systems. This work proposes a coupling flow that preserves SE(3) and permutation equivariance by performing coordinate splits along additional augmented dimensions. At each layer, the flow maps atoms' positions into learned SE(3) invariant bases, where we apply standard flow transformations, such as monotonic rational-quadratic splines, before returning to the original basis. Crucially, our flow preserves fast sampling and density evaluation, and may be used to produce unbiased estimates of expectations with respect to the target distribution via importance sampling. When trained on the DW4, LJ13, and QM9-positional datasets, our flow is competitive with equivariant continuous normalizing flows and diffusion models, while allowing sampling more than an order of magnitude faster. Moreover, to the best of our knowledge, we are the first to learn the full Boltzmann distribution of alanine dipeptide by only modeling the Cartesian positions of its atoms. Lastly, we demonstrate that our flow can be trained to approximately sample from the Boltzmann distribution of the DW4 and LJ13 particle systems using only their energy functions.

Laurence I. Midgley, Vincent Stimper, Javier Antor\'an, Emile Mathieu, Bernhard Sch\"olkopf, Jos\'e Miguel Hern\'andez-Lobato• 2023

Related benchmarks

TaskDatasetResultRank
Density EstimationLJ13
Negative Log-Likelihood30.19
7
Density EstimationQM9 positional
NLL-165.7
7
Boltzmann distribution approximationAlanine dipeptide (test)
KLD0.0026
7
Density EstimationDW4
Negative Log-Likelihood8.61
7
Boltzmann distribution modelingDW4 (test)
Rev ESS84.29
4
Boltzmann distribution modelingLJ13 (test)
Rev ESS0.6209
4
Showing 6 of 6 rows

Other info

Code

Follow for update