Generalizable Reasoning through Compositional Energy Minimization
About
Generalization is a key challenge in machine learning, specifically in reasoning tasks, where models are expected to solve problems more complex than those encountered during training. Existing approaches typically train reasoning models in an end-to-end fashion, directly mapping input instances to solutions. While this allows models to learn useful heuristics from data, it often results in limited generalization beyond the training distribution. In this work, we propose a novel approach to reasoning generalization by learning energy landscapes over the solution spaces of smaller, more tractable subproblems. At test time, we construct a global energy landscape for a given problem by combining the energy functions of multiple subproblems. This compositional approach enables the incorporation of additional constraints during inference, allowing the construction of energy landscapes for problems of increasing difficulty. To improve the sample quality from this newly constructed energy landscape, we introduce Parallel Energy Minimization (PEM). We evaluate our approach on a wide set of reasoning problems. Our method outperforms existing state-of-the-art methods, demonstrating its ability to generalize to larger and more complex problems. Project website can be found at: https://alexoarga.github.io/compositional_reasoning/
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| 3-SAT Solving | 3-SAT 100 sampled formulas 20 variables, 91 clauses (phase transition 4.258 * n) (test) | Satisfied Clauses Rate99.85 | 11 | |
| 8-Queens Puzzle Solving | 8-Queens (100 sampled boards) | Correct Solutions97 | 7 | |
| Graph Coloring | Erdos Renyi 100 sampled tasks | Average Violations8.6 | 6 | |
| Graph Coloring | Erdos Renyi 2 (100 sampled tasks) | Average Violations29.2 | 6 | |
| Graph Coloring | Holme Kim 100 sampled tasks | Average Violations10.6 | 6 | |
| Graph Coloring | Regular Expander 100 sampled tasks | Average Violations11 | 6 | |
| Graph Coloring | Paley 100 sampled tasks | Average Violations34.8 | 6 | |
| Graph Coloring | Complete (100 sampled tasks) | Average Violations3.4 | 6 | |
| Graph Coloring | Holme Kim 2 (100 sampled tasks) | Average Violations59 | 6 | |
| Graph Coloring | Regular Expander 2 (100 sampled tasks) | Average Violations37.2 | 5 |