We’re examining how JAX wires together its core transformations – gradients, JIT compilation, vectorization, and device movement – through a single public entry point: jax/_src/api.py. This file is the facade that turns a deep stack of interpreters and XLA backends into the familiar jax.jit, jax.grad, jax.vmap, and jax.device_put functions.
JAX itself is a numerical computing library that lets you take pure Python+NumPy style code and transform it: compile it, differentiate it, batch it, shard it. I’m Mahmoud Zalt, an AI solutions architect, and we’ll treat api.py as our lab specimen to understand one central idea: how a single, carefully designed “transformations engine” – built around pytrees and a shared IR (jaxpr) – lets you freely compose transformations while hiding enormous complexity.
We’ll map where api.py sits in the architecture, see how core transformations are layered on top of each other, look at vectorization as axis bookkeeping, study how data movement is exposed without surprises, and close with what this design implies for performance, observability, and your own APIs.
Where api.py Fits
api.py is the public door into JAX’s transformation system. Everything users call as jax.jit, jax.grad, jax.vmap, jax.device_put, jax.eval_shape, and similar flows through this file, then fans out into lower-level interpreters and backends.
jax/ (project root)
|_ _src/
|_ core.py # JAX IR (jaxpr), avals, primitives
|_ interpreters/
| |_ ad.py # Autodiff interpreter
| |_ batching.py # vmap/batching interpreter
| |_ partial_eval.py
| |_ pxla.py # Pjit / multi-device
|_ dispatch.py # Compilation & execution interface
|_ tree_util.py # PyTree definitions & helpers
|_ api.py # <== Public transformations & device APIs
|_ sharding_impls.py
|_ xla_bridge.py
api.py as facade: user calls enter here, then dispatch to interpreters and backends.
Architecturally, api.py is a Facade: it exposes friendly, documented functions and delegates to internal components like autodiff (ad), batching (batching), partial evaluation (pe), and compilation/dispatch (dispatch, pjit, xb, xc). It also owns user ergonomics: pytrees, configuration-driven behavior, and most error messages.
Once we see api.py as a facade over a shared transformations engine, it becomes clear why composability, error messages, and performance all have to be coordinated here, around one intermediate representation: jaxpr.
How Transformations Stack
The interesting part of api.py isn’t that jit or grad exist; it’s how they are implemented as small, predictable layers on top of jaxpr, and how they deliberately reuse one another. That’s what makes compositions like jit(vmap(grad(f))) behave sensibly.
jit: A Thin Public Shell
JIT compilation is surfaced as jax.jit, but the wrapper in api.py is intentionally thin. It normalizes options (static args, sharding, device/backend selection, donation) and hands everything to pjit.make_jit:
def jit(
fun: Callable | NotSpecified = NotSpecified(), /, *,
in_shardings: Any = sharding_impls.UNSPECIFIED,
out_shardings: Any = sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
static_argnames: str | Iterable[str] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
compiler_options: dict[str, Any] | None = None,
) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]:
kwds = dict(
in_shardings=in_shardings, out_shardings=out_shardings,
static_argnums=static_argnums, static_argnames=static_argnames,
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
compiler_options=compiler_options, use_resource_env=False)
if isinstance(fun, NotSpecified):
return lambda fun: pjit.make_jit(fun, **kwds)
else:
return pjit.make_jit(fun, **kwds)
jit focuses on signature and options; compilation logic lives in pjit and backends.
The design choice here is restraint: the public wrapper stays “dumb”. It owns user-facing semantics (decorator behavior, argument interpretation), then forwards to a focused implementation. That separation keeps the JIT contract stable even as compilation internals evolve.
grad Built on value_and_grad Built on vjp
Differentiation is implemented once, then reused. grad delegates to value_and_grad, which in turn builds on vjp (reverse-mode autodiff). The outer layer looks like this:
@partial(api_boundary, repro_api_name="jax.grad")
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = ()) -> Callable:
if reduce_axes:
raise NotImplementedError("reduce_axes argument to grad is deprecated")
del reduce_axes
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int)
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f(*args, **kwargs):
_, g = value_and_grad_f(*args, **kwargs)
return g
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f_aux(*args, **kwargs):
(_, aux), g = value_and_grad_f(*args, **kwargs)
return g, aux
return grad_f_aux if has_aux else grad_f
grad is “just” packaging; value_and_grad and vjp hold the real logic.
value_and_grad validates dtypes, enforces scalar outputs, calls vjp to build a backward pass, and returns both primal values and gradients. grad wraps that to either drop or expose auxiliary outputs.
The core decision is to treat vjp as the fundamental primitive and implement higher-level conveniences as thin, consistent layers. That keeps the implementation DRY and makes behavior across grad-family APIs easier to reason about.
vjp, linearize, and the Shared IR
vjp and linearize show the “transformations engine” idea most directly. Both:
- Flatten pytrees into simple lists of leaves.
- Call
ad.linearize, which traces the Python function into ajaxpr– a compact IR for the computation – plus residuals. - Return closures you can reuse multiple times without re-tracing.
Conceptually, they turn a function into an explicit “forward tape + backward player” form over jaxpr. Other transformations don’t need to know how vjp works; they just consume or produce jaxpr. That’s the heart of the engine: pick one IR and make every transformation operate by reading and rewriting that IR.
The primary lesson from this section: centralize your transformations on a single, explicit intermediate representation. JAX uses jaxpr; once that is in place, jit, grad, vmap, and friends become composable layers instead of independent features.
Vectorization as Axis Bookkeeping
With gradients and JIT defined over jaxpr, vectorization (vmap) is where the abstraction is stress‑tested. vmap has to understand pytrees, axis semantics, and even distributed meshes, but present itself as “just batching” to users.
The vmap Contract
Semantically, vmap takes a function f and returns a new function that applies f in parallel across a batch axis. Implementing that over nested containers and sharded devices requires careful normalization of axis specs and shapes before delegating to the batching interpreter.
@partial(api_boundary, repro_api_name="jax.vmap")
def vmap(fun: F,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
sum_match: bool = False
) -> F:
if isinstance(in_axes, list):
in_axes = tuple(in_axes)
from jax._src import hijax
if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}
or isinstance(in_axes, hijax.MappingSpec)):
raise TypeError("vmap in_axes must be an int, None, or a tuple ...")
if not all(type(l) in {int, *batching.spec_types} or isinstance(l, hijax.MappingSpec)
for l in tree_leaves(in_axes)):
raise TypeError("vmap in_axes must be an int, None, or (nested) container ...")
if not all(type(l) in {int, *batching.spec_types} or isinstance(l, hijax.MappingSpec)
for l in tree_leaves(out_axes)):
raise TypeError("vmap out_axes must be an int, None, or (nested) container ...")
@wraps(fun, docstr=docstr)
@api_boundary
def vmap_f(*args, **kwargs):
nonlocal spmd_axis_name
if isinstance(in_axes, tuple) and len(in_axes) != len(args):
raise ValueError("vmap in_axes must be an int, None, or a tuple ...")
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
dbg = debug_info("vmap", fun, args, kwargs)
api_util.check_no_transformed_refs_args(lambda: dbg, args_flat)
f = lu.wrap_init(fun, debug_info=dbg)
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
if config.mutable_array_checks.value:
avals = [None if d is None or batching.is_vmappable(x) else core.typeof(x)
for x, d in zip(args_flat, in_axes_flat)]
api_util.check_no_aliased_ref_args(lambda: dbg, avals, args_flat)
axis_size_ = _mapped_axis_size(
fun, in_tree, args_flat, in_axes_flat, "vmap", axis_size=axis_size)
explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
_check_ema_unmapped_args(explicit_mesh_axis, args_flat, in_axes_flat)
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
explicit_mesh_axis)
out_flat, inferred_out_axes = batching.batch(
flat_fun, axis_data, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
sum_match=sum_match
).call_wrapped(*args_flat)
return tree_unflatten(out_tree(), out_flat)
vmap normalizes axis specs, flattens pytrees, infers axis size, then delegates to the batching interpreter.
The responsibilities here are tightly interleaved:
- Canonicalizing
in_axes/out_axesacross simple values and nested containers. - Flattening the argument tree so batching operates over lists of leaves.
- Inferring the batch axis size and producing clear errors when shapes disagree.
- Reconciling vectorization with sharding meshes via
spmd_axis_nameand explicit mesh axes.
Helpful Errors via _mapped_axis_size
When batch dimensions don’t line up, vmap doesn’t just say “sizes don’t match”. _mapped_axis_size walks every argument, determines the implied axis size, and then produces a narrative error that points to specific arguments, shapes, and sizes.
This reflects a broader pattern in api.py: error handling is treated as part of the design, not an afterthought. The logic that explains “what went wrong” is pulled into helpers so the core algorithm can stay focused on transformation semantics.
If you build your own transformation APIs, investing in helpers like _mapped_axis_size pays off. They keep the main path readable and give users precise feedback in the failure cases that matter most.
Data Movement Without Surprises
Transformations are only half the story; at scale, host↔device data movement is often the real bottleneck. api.py owns the public device_put, device_put_sharded, device_put_replicated, and device_get APIs, balancing ergonomics with precise control over sharding and copy semantics.
device_put: Sharding, Donation, Aliasing
You can think of device_put as a shipping dock: values arrive as pytrees on the host; the function flattens them, assigns shardings or devices, decides whether buffers may be reused or must be copied, and hands everything to the dispatcher.
def device_put(
x,
device: None | xc.Device | Sharding | P | Format | Any = None,
*, src: None | xc.Device | Sharding | P | Format | Any = None,
donate: bool | Any = False, may_alias: bool | None | Any = None):
with config.explicit_device_put_scope():
x_flat, treedef = tree_flatten(x)
x_avals = [shaped_abstractify(x) for x in x_flat]
...
device_flat = map(partial(pspec_to_sharding, 'device_put'), device_flat)
src_flat = map(partial(pspec_to_sharding, 'device_put'), src_flat)
if isinstance(donate, bool):
donate_flat = [donate] * len(x_flat)
else:
donate_flat = flatten_axes("device_put donate", treedef, donate)
if isinstance(may_alias, bool):
may_alias_flat = [may_alias] * len(x_flat)
else:
may_alias_flat = flatten_axes("device_put may_alias", treedef, may_alias)
copy_semantics = []
for m, d in zip(may_alias_flat, donate_flat):
if m and d:
raise ValueError('may_alias and donate cannot be True at the same time.')
if m is None:
m = not d
if m and not d:
copy_semantics.append(dispatch.ArrayCopySemantics.REUSE_INPUT)
elif not m and d:
copy_semantics.append(dispatch.ArrayCopySemantics.DONATE_INPUT)
else:
copy_semantics.append(dispatch.ArrayCopySemantics.ALWAYS_COPY)
dst_avals = []
for x_aval, d in zip(x_avals, device_flat):
aval = dispatch.update_dp_aval(x_aval, d)
dst_avals.append(aval)
_check_sharding(aval, d)
if core.trace_state_clean():
out_flat = dispatch._batched_device_put_impl(
*x_flat, devices=device_flat, srcs=src_flat,
copy_semantics=copy_semantics, dst_avals=dst_avals)
else:
out_flat = dispatch.device_put_p.bind(
*x_flat, devices=tuple(device_flat), srcs=tuple(src_flat),
copy_semantics=tuple(copy_semantics))
return tree_unflatten(treedef, out_flat)
device_put flattens pytrees, infers shardings, enforces copy semantics, and then delegates to dispatch.
Subtleties handled here so users don’t have to think about them on every call include:
- Pytree matching:
device,src,donate, andmay_aliascan mirror the structure ofx;flatten_axeskeeps those shapes consistent. - Donation vs aliasing: donation means “you may destroy this buffer”; aliasing means “you may reuse this buffer, but don’t rely on copies”. They are mutually exclusive and get encoded as explicit
ArrayCopySemanticsvalues. - Sharding validation:
_check_shardingensures that shardings are compatible with value shapes and device types (e.g. string arrays pinned to CPU).
Sharded and Replicated Placement
device_put_sharded and device_put_replicated are higher‑level helpers built on the same ideas:
device_put_shardedtakes one shard per device, checks compatibility, builds aMeshandNamedSharding, and usespxla.batched_device_putunderneath.device_put_replicatedcomputes an “unmapped” abstract value for a replica axis, then broadcasts a single host buffer across devices withbatched_device_put.
Extended dtypes (custom element types) plug into these paths via specialized hooks. The report recommends centralizing those hooks in dedicated helpers so that support for new dtypes stays consistent across all data‑movement APIs.
device_get: Async Then Sync
On the way back to the host, device_get first starts host copies asynchronously on all leaves (via copy_to_host_async when available), then walks the tree again to materialize each value using __array__() or extended dtype rules.
The same pattern shows up in block_until_ready and effects_barrier: kick off asynchronous work across the tree, then provide explicit synchronization points. From the facade’s perspective, this is where you’d add tracing, logging, or metrics for all host↔device traffic.
Data movement is often the main IO cost in accelerator workloads. By routing all host↔device transfers through a small set of functions in api.py, JAX makes those crossings explicit and observable without leaking backend details.
Performance, Debugging, and Observability
So far we’ve focused on how transformations and data movement are exposed. The same facade also shapes how JAX behaves at scale: how much Python overhead hot paths incur, and where you can attach operational insight.
Hot Paths and Python Overheads
The primary hot entry points described in the report are:
jit-wrapped functions for training and inference.grad/value_and_grad/vjpfor backprop.vmapfor batched execution.device_put*anddevice_getfor host↔device transfers.eval_shapeandmake_jaxprfor meta‑programming and debugging.
On the Python side, these functions mostly perform pytree flattening/unflattening, argument validation, and small allocations, with complexity proportional to the number of leaves or arguments. The heavy work – FLOPs, compilation, large allocations – is delegated to compiled interpreters and XLA.
vmapcosts roughlyO(n_leaves + n_args)per call in Python to normalize axes and pytrees, then defers to the batching interpreter.device_putisO(n_leaves)to build abstract values, shardings, and copy semantics, plus the actual transfer.make_jaxprandeval_shapeare dominated by the cost of tracing the function but avoid real numeric computation.
Metrics That Matter at Scale
For production workloads, the report highlights several concrete metrics that naturally attach at the api.py layer:
| Metric | What it tells you | Where it hooks |
|---|---|---|
jax_compilation_time_seconds |
Time spent tracing and compiling transformed functions. | jit, make_jaxpr, eval_shape. |
jax_traced_jaxpr_size_nodes |
Approximate size of generated IR; reveals oversized traces. | make_jaxpr (via ClosedJaxpr structure). |
device_transfer_bytes_total |
Total host↔device bytes moved. | device_put*, device_get. |
jax_array_block_until_ready_calls |
Frequency of synchronization barriers. | block_until_ready, effects_barrier. |
live_arrays_count |
Number of live device buffers. | live_arrays() from the backend. |
Because all user-visible transformations and device APIs enter through api.py, you can instrument these metrics by wrapping the public functions – no need to modify interpreters or backends.
Debugging Hooks and Global Config
api.py also wires in runtime configuration for debugging and safety: NaN/Inf checking, JIT disabling, and cache/backends clearing. For example, _nan_check_posthook can be attached to the JIT runtime when config.debug_nans or config.debug_infs is enabled, inspecting buffers after execution and raising detailed floating‑point errors.
Similarly, disable_jit() toggles JIT behavior via global config while leaving primitive‑level compilation intact. That gives you an escape hatch for debugging shape or control‑flow issues without discarding the transformation engine entirely.
The overarching operational lesson is to keep the public API thin but give it enough hooks – metrics, debug flags, cache controls – so performance and correctness issues can be understood and influenced at the facade layer.
Conclusion: Building Your Own Transformations Engine
Walking through jax/_src/api.py shows more than a list of functions. It shows how a transformation‑centric design – built around pytrees, a single IR (jaxpr), and carefully layered wrappers – lets JAX expose powerful capabilities as simple, composable APIs.
Key Lessons You Can Reuse
-
Center everything on one intermediate representation.
In JAX,
jaxpris the small, explicit language that all major transformations produce or consume. Adopting a similar IR in your own systems prevents each feature from inventing its own ad‑hoc representation and makes stacking transformations feasible. -
Keep public wrappers thin and ergonomic.
Functions like
jit,grad,vmap, anddevice_puthandle signatures, pytrees, and rich error messages, then defer to focused interpreters and backends. This separation keeps user contracts clear while allowing internals to evolve. - Design for observability from the facade. By routing compilation, transformation, and data movement through a small set of public APIs, JAX gains natural choke‑points for metrics, logging, and debugging controls. Thinking about these from day one makes scale and operations far less painful.
If we treat api.py as JAX’s transformations engine, the central lesson is simple: choose where complexity lives. JAX concentrates it in a shared IR and a handful of interpreters, and keeps the public surface feather‑light but precise. That pattern is broadly reusable, whether you’re building ML libraries, data platforms, or internal tooling that needs to transform user code without overwhelming your users – or your future self.
