Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Cosine-Gated Adam-Decay: Drop-In Staleness-Aware Outer Optimization for Decoupled DiLoCo

About

Asynchronous DiLoCo systems may receive pseudo-gradients computed several outer rounds earlier, yet the standard Nesterov outer optimizer does not explicitly condition its update on per-update age. This can make the outer momentum buffer brittle under large controlled delays. We propose Cosine Gated Adam Decay (CGAD), a simple, drop-in, age-aware outer optimizer that scales each incoming pseudo-gradient by $\sigma(\tau) = \gamma(\tau) e^{-\alpha\tau}$ before it enters Adam's first- and second-moment buffers; the exponential models information decay and the cosine gate $\gamma(\tau)$ smoothly zeroes contributions past a chosen cutoff. CGAD reduces to plain Adam at $\tau=0$, adds two hyperparameters whose defaults transfer across scales, and extends to partial-sync schedulers via a per-fragment age-aware variant (PA-CGAD). For an idealized gated-adaptive update on smooth non convex objectives, we prove a non-asymptotic convergence bound whose staleness-bias term depends on $\alpha$ alone, rather than on the realized maximum delay $\tau_{\max}$; standard analyses of asynchronous momentum-SGD instead carry a $\tau_{\max}^2$ factor. Empirically, on Llama style language model pretraining at 25M, 1B, and 7B parameters, CGAD trains stably across the controlled delays we sweep. The cosine cutoff acts as scale insurance: the closest baseline, Adam Decay (CGAD without the cutoff), is competitive at 25M but its seed-to-seed $\sigma$ at $\tau=8$ grows 27x from 25M to 7B, pushing its single-shot risk (mean + $\sigma$) above the chance-level loss while CGAD's stays well below. The published Nesterov recipe is the least stable method on the full sweep.

Vatsal Shah, Jiahao Sun• 2026

Related benchmarks

TaskDatasetResultRank
Language ModelingLlama decoder 10M
Mean Final Loss6.79
36
Language ModelingLanguage Modeling Dataset
Cross-Entropy Loss6.79
26
Language Modeling25 M decoder
Mean Final Loss6.79
20
Showing 3 of 3 rows

Other info

Follow for update