Backend 2
JAX
JAX는 NumPy와 비슷한 array 문법에 automatic differentiation과 compilation 도구를 더합니다.
JAX를 쓸 때
설치 안내, API detail, 예제를 확인해야 할 때는 JAX 공식 문서를 사용합니다.
공식 문서같은 array workflow에서 gradient, just-in-time compilation, accelerator 기반 계산이 필요할 때 JAX를 사용합니다.
Optimization, differentiable simulation, gradient가 방법론의 일부인 연구 코드에 유용합니다.
최소 예제
import jax.numpy as jnp
from jax import grad
def loss(values):
return jnp.sum((values - 1.0) ** 2)
values = jnp.array([0.0, 1.5, 2.0])
gradient = grad(loss)(values)
print(gradient)
작업 습관
먼저 평범한 array 코드로 시작하고, 실제로 필요한 지점에서만 JAX 기능을 추가합니다. 이렇게 해야 디버깅이 쉬워집니다.