Snippet: JAX / Flax Best Practices¶
Domain Context¶
Framework-specific rules for JAX-based ML development (Flax, Optax, Orbax). JAX is functionally pure — embrace the paradigm, don't fight it.
Core JAX Principles¶
- Pure functions: JAX transformations (jit, grad, vmap) require functions with no side effects
- Explicit state: model parameters are explicit pytrees, not hidden in objects
- PRNG management: always split keys explicitly — never reuse a PRNG key
- Prefer
jnpovernpfor any computation that should be JIT-compilable
Model Definition (Flax)¶
- Use
nn.Module(Flax Linen) with@nn.compactfor inline submodule definition - Init params explicitly:
params = model.init(rng_key, dummy_input) - Separate model definition from training state — params, optimizer state, and model are distinct
- Use
TrainStatefrom Flax for clean state management:
JIT Compilation¶
@jax.jitall training and evaluation functions — non-JIT code is significantly slower- Mark non-static arguments that change shape as
donate_argnumsto save memory - Use
static_argnumsfor arguments that trigger recompilation (config flags, not data) - Trace JIT compilation time: first call compiles, subsequent calls should be fast — profile both
- Avoid Python control flow inside JIT — use
jax.lax.condandjax.lax.scaninstead
Training Patterns¶
- Use
optaxfor optimizers: chain transformations (e.g.,optax.chain(optax.clip_by_global_norm(1.0), optax.adam(lr))) - Gradient computation:
jax.gradreturns a function — apply it, don't call it in a loop jax.value_and_gradwhen you need both loss value and gradients (one forward pass)jax.vmapfor batched operations — cleaner and faster than manual batch dimensions- Scan for sequential operations:
jax.lax.scaninstead of Python for-loops
Multi-Device Training¶
jax.pmapfor data parallelism across devices (legacy but stable)jax.experimental.shard_mapor named sharding for modern multi-device (preferred for new code)- Replicate params:
jax.device_put_replicated(params, jax.devices()) - All-reduce gradients:
jax.lax.pmean(grads, axis_name='batch') - Check device count:
jax.device_count()andjax.local_device_count()— log at startup
Checkpointing¶
- Use Orbax for checkpointing: supports async saving, sharded checkpoints, and atomic operations
- Save complete state: params, optimizer state, step count, config
- Restore with shape checking: validate parameter shapes match the current model definition
- For large models: use sharded checkpointing to avoid OOM during save/load
Debugging¶
- Disable JIT during debugging:
with jax.disable_jit():orJAX_DISABLE_JIT=1 - NaN checking:
jax.config.update("jax_debug_nans", True)during development - Shape errors: use
jax.eval_shapeto check output shapes without computation - Memory profiling:
jax.profiler.device_memory_profile()for TPU/GPU memory analysis - Print inside JIT: use
jax.debug.print()notprint()— regular print only runs at trace time
Common Pitfalls¶
- Mutating arrays: JAX arrays are immutable — use
x.at[i].set(v)notx[i] = v - PRNG key reuse: produces correlated random numbers — always split before each use
- JIT recompilation: changing array shapes causes costly recompilation — pad to fixed shapes
- Forgetting
jax.tree_utiloperations: params are nested dicts — use tree_map, not manual traversal - NumPy/JAX mixing:
np.arrayandjnp.arrayhave different behaviors — be explicit about which you use