Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

MADiff: Offline Multi-agent Learning with Diffusion Models

About

Offline reinforcement learning (RL) aims to learn policies from pre-existing datasets without further interactions, making it a challenging task. Q-learning algorithms struggle with extrapolation errors in offline settings, while supervised learning methods are constrained by model expressiveness. Recently, diffusion models (DMs) have shown promise in overcoming these limitations in single-agent learning, but their application in multi-agent scenarios remains unclear. Generating trajectories for each agent with independent DMs may impede coordination, while concatenating all agents' information can lead to low sample efficiency. Accordingly, we propose MADiff, which is realized with an attention-based diffusion model to model the complex coordination among behaviors of multiple agents. To our knowledge, MADiff is the first diffusion-based multi-agent learning framework, functioning as both a decentralized policy and a centralized controller. During decentralized executions, MADiff simultaneously performs teammate modeling, and the centralized controller can also be applied in multi-agent trajectory predictions. Our experiments demonstrate that MADiff outperforms baseline algorithms across various multi-agent learning tasks, highlighting its effectiveness in modeling complex multi-agent interactions. Our code is available at https://github.com/zbzhu99/madiff.

Zhengbang Zhu, Minghuan Liu, Liyuan Mao, Bingyi Kang, Minkai Xu, Yong Yu, Stefano Ermon, Weinan Zhang• 2023

Related benchmarks

TaskDatasetResultRank
Multi-agent Trajectory PredictionNBA dataset
ADE7.92
26
SpreadMPE Spread offline (Expert)
Average Score116.7
7
SpreadMPE Spread offline (Md-Replay)
Average Score42.2
7
TagMPE Tag offline (Expert)
Average Score1.676
7
3mSMAC 3m offline (Good)
Average Score19.9
6
8mSMAC 8m offline (Poor)
Average Score5.1
6
2halfcheetahMA Mujoco 2halfcheetah offline (Good)
Average Score8.51e+3
5
Showing 7 of 7 rows

Other info

Follow for update