Backend 2

JAX

JAX keeps a NumPy-like style while adding automatic differentiation and compilation tools.

When to use JAX

Use the official JAX documentation when you need installation notes, API details, or examples.

Official documentation

Use JAX when the same array workflow needs gradients, just-in-time compilation, or accelerator-backed computation.

It is useful for optimization, differentiable simulations, and research code where gradients are part of the method.

Minimal example

import jax.numpy as jnp
from jax import grad

def loss(values):
    return jnp.sum((values - 1.0) ** 2)

values = jnp.array([0.0, 1.5, 2.0])
gradient = grad(loss)(values)

print(gradient)

Working habit

Start from plain array code, then add JAX features only when they solve a concrete problem. This keeps debugging easier.