JaxMARL: Multi-Agent RL Environments and Algorithms in JAX
Benchmarks are crucial in the development of machine learning algorithms, with available environments significantly influencing reinforcement learning (RL) research. Traditionally, RL environments run on the CPU, which limits their scalability with typical academic compute. However, recent advanceme...
Saved in:
| Main Authors | , , , , , , , , , , , , , , , , , , , , |
|---|---|
| Format | Journal Article |
| Language | English |
| Published |
16.11.2023
|
| Subjects | |
| Online Access | Get full text |
| DOI | 10.48550/arxiv.2311.10090 |
Cover
| Summary: | Benchmarks are crucial in the development of machine learning algorithms,
with available environments significantly influencing reinforcement learning
(RL) research. Traditionally, RL environments run on the CPU, which limits
their scalability with typical academic compute. However, recent advancements
in JAX have enabled the wider use of hardware acceleration, enabling massively
parallel RL training pipelines and environments. While this has been
successfully applied to single-agent RL, it has not yet been widely adopted for
multi-agent scenarios. In this paper, we present JaxMARL, the first
open-source, Python-based library that combines GPU-enabled efficiency with
support for a large number of commonly used MARL environments and popular
baseline algorithms. Our experiments show that, in terms of wall clock time,
our JAX-based training pipeline is around 14 times faster than existing
approaches, and up to 12500x when multiple training runs are vectorized. This
enables efficient and thorough evaluations, potentially alleviating the
evaluation crisis in the field. We also introduce and benchmark SMAX, a
JAX-based approximate reimplementation of the popular StarCraft Multi-Agent
Challenge, which removes the need to run the StarCraft II game engine. This not
only enables GPU acceleration, but also provides a more flexible MARL
environment, unlocking the potential for self-play, meta-learning, and other
future applications in MARL. The code is available at
https://github.com/flairox/jaxmarl. |
|---|---|
| DOI: | 10.48550/arxiv.2311.10090 |