Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Hyperparameter Trajectory Inference with Conditional Lagrangian Optimal Transport

About

Neural networks (NNs) often have critical behavioural trade-offs that are set at design time with hyperparameters-such as reward weights in reinforcement learning or quantile targets in regression. Post-deployment, however, user preferences can evolve, making initial settings undesirable, necessitating potentially expensive retraining. To circumvent this, we introduce the task of Hyperparameter Trajectory Inference (HTI): to learn, from observed data, how a NN's conditional output distribution changes with its hyperparameters, and construct a surrogate model that approximates the NN at unobserved hyperparameter settings. HTI requires extending existing trajectory inference approaches to incorporate conditions, exacerbating the challenge of ensuring inferred paths are feasible. We propose an approach based on conditional Lagrangian optimal transport, jointly learning the Lagrangian function governing hyperparameter-induced dynamics along with the associated optimal transport maps and geodesics between observed marginals, which form the surrogate model. We incorporate inductive biases based on the manifold hypothesis and least-action principles into the learned Lagrangian, improving surrogate model feasibility. We empirically demonstrate that our approach reconstructs NN outputs across various hyperparameter spectra better than other alternatives.

Harry Amad, Mihaela van der Schaar• 2026

Related benchmarks

TaskDatasetResultRank
Continuous ControlOpenAI Gym Reacher (held-out λc ∈ {2, 3, 4})
Reward-6.093
8
Generative Modellingsklearn two moons dropout settings p ∈ {0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9} (held-out)
WD0.06
8
Non-Linear Reward ScalarizationCancer_nl (held-out settings)
Reward91.94
8
Reward Weighting AdaptationCancer (held-out λnk ∈ {1, 2, 3, 4, 6, 7, 8, 9})
Reward102.5
8
Time Series ForecastingETTm2 held-out quantiles τ ∈ {0.1, 0.25, 0.5, 0.75, 0.9}
MSE on ETTm20.608
7
Showing 5 of 5 rows

Other info

Follow for update