Backend 2
JAX
JAX keeps a NumPy-like style while adding automatic differentiation and compilation tools.
When to use JAX
Use the official JAX documentation when you need installation notes, API details, or examples.
Official documentationUse JAX when the same array workflow needs gradients, just-in-time compilation, or accelerator-backed computation.
It is useful for optimization, differentiable simulations, and research code where gradients are part of the method.
Minimal example
import jax.numpy as jnp
from jax import grad
def loss(values):
return jnp.sum((values - 1.0) ** 2)
values = jnp.array([0.0, 1.5, 2.0])
gradient = grad(loss)(values)
print(gradient)
Working habit
Start from plain array code, then add JAX features only when they solve a concrete problem. This keeps debugging easier.