A Gentle Introduction to Deep Reinforcement Learning in JAX

Recent progress in Reinforcement Learning (RL), such as Waymo's autonomous taxis or DeepMind's superhuman chess-playing agents, complement classical RL with Deep Learning components such as Neural Networks and Gradient Optimization methods.
Building on the foundations and coding principles introduced in one of my previous stories, we'll discover and learn to implement Deep Q-Networks (DQN) and replay buffers to solve OpenAI's CartPole environment. All of that in under a second using JAX!
For an introduction to Jax, vectorized environments, and Q-learning, please refer to the content of this story:
Vectorize and Parallelize RL Environments with JAX: Q-learning at the Speed of Light⚡
Our framework of choice for deep learning will be DeepMind's Haiku library, which I recently introduced in the context of Transformers:
Implementing a Transformer Encoder from Scratch with JAX and Haiku