Most of us meet JAX through a few magical functions: jit, grad, vmap, pmap. They feel like small decorators you sprinkle on top of plain Python. But in reality, they form a carefully engineered transformation machine that reshapes your functions for differentiation, vectorization, and parallel execution.
In this article, we'll walk through the core API module of JAX and see how it builds that machine. I'm Mahmoud Zalt, and we'll focus on one central idea: you can design a powerful transformation layer by consistently wrapping, flattening, and validating user functions before they ever reach your runtime.
The Scene: One File, Many Transformations
Before we zoom into individual functions, we need to understand the terrain. The file in question, jax/_src/api.py, is the main facade that backs the public symbols you import as jax.jit, jax.grad, jax.vmap, and friends. It doesn't implement autodiff rules or GPU kernels; instead, it orchestrates a stack of interpreters and backends.
jax/_src/
├── core.py (jaxpr, ShapedArray, Tracer abstractions)
├── interpreters/
│ ├── ad.py (autodiff rules and JVP/VJP machinery)
│ ├── batching.py (vmap batching rules)
│ ├── partial_eval.py (pe; linearize, jaxpr tracing)
│ └── pxla.py (pmap/sharding lowering)
├── pjit.py (jit/sharding implementation)
├── dispatch.py (device_put, runtime tokens, primitives)
├── xla_bridge.py (backend and device clients)
└── api.py (this file: user-facing jit/grad/vmap/pmap/... facade)
User code
|
v
jax.jit / jax.grad / jax.vmap / jax.pmap / ...
|
v
jax._src.api (this module)
|
+--> wraps fun with lu.wrap_init, debug_info
+--> flattens PyTrees via tree_util
+--> selects interpreter: ad / batching / pxla / pjit / dispatch
|
v
XLA backends (CPU/GPU/TPU via xla_client/xb)
jax._src.api as a facade layer between user code and the interpreter/backends stack.
So this one module is doing a lot: autodiff entrypoints, vectorization/parallelism (vmap, pmap), device movement (device_put, device_get), runtime utilities, and even NaN/Inf debug hooks. That sounds like a recipe for a ball of mud, yet the file stays surprisingly navigable.
The Pattern: Wrap, Flatten, Dispatch
Once we start looking at individual APIs, we see the same skeleton repeated with small variations. That skeleton is the real star of this file. It looks like this:
- Validate the callable and options.
- Flatten Python containers into PyTrees (nested lists/tuples/dicts with arrays at the leaves) and flatten any axis/device specs to match.
- Wrap the user function with metadata (name stack, debug info, static args) into a
lu.WrappedFun. - Pick the right interpreter (autodiff, batching, pmap, pjit, etc.).
- Post-process back to the original PyTree structure and enforce invariants.
The pay‑off of this pattern is enormous: new transformations can be added by reusing the same wrapping/flattening infrastructure, and users get consistent semantics and error messages across everything.
Example: jit as a Thin Front-End
JIT compilation feels like a heavy operation, but the Python wrapper in api.py is intentionally thin. It normalizes the options and hands everything to pjit:
def jit(
fun: Callable | NotSpecified = NotSpecified(), /, *,
in_shardings: Any = sharding_impls.UNSPECIFIED,
out_shardings: Any = sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
static_argnames: str | Iterable[str] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
compiler_options: dict[str, Any] | None = None,
) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]:
...
kwds = dict(
in_shardings=in_shardings, out_shardings=out_shardings,
static_argnums=static_argnums, static_argnames=static_argnames,
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
abstracted_axes=abstracted_axes, compiler_options=compiler_options,
use_resource_env=False)
if isinstance(fun, NotSpecified):
return lambda fun: pjit.make_jit(fun, **kwds)
else:
return pjit.make_jit(fun, **kwds)
jax.jit focuses on signature and ergonomics; pjit handles the heavy lifting.
The transformation we care about isn't encoded here at all; it's encoded in pjit and eventually in compiled XLA. This wrapper's job is to define how humans talk to JIT: decorator factory semantics, static/donated args, sharding hints, and consistent boundary tracing via @api_boundary.
Autodiff as a First-Class Facade
Nowhere is the transformation-machine idea clearer than in autodiff. Functions like grad, value_and_grad, jacfwd, jacrev, and hessian all build on the same underlying AD interpreters, but the public APIs each express a particular “view” on differentiation.
grad as a Thin View on value_and_grad
grad is often the first thing we call in JAX. It's a perfect example of how this module avoids duplicating logic by composing a more general transformation:
@partial(api_boundary, repro_api_name="jax.grad")
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = ()) -> Callable:
if reduce_axes:
raise NotImplementedError("reduce_axes argument to grad is deprecated")
del reduce_axes
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int)
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f(*args, **kwargs):
_, g = value_and_grad_f(*args, **kwargs)
return g
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f_aux(*args, **kwargs):
(_, aux), g = value_and_grad_f(*args, **kwargs)
return g, aux
return grad_f_aux if has_aux else grad_f
grad doesn't implement differentiation; it reuses value_and_grad and chooses the surface shape of the API.
The interesting work is in value_and_grad. It flattens the arguments, performs detailed dtype validation (holomorphic vs real-valued, integer handling), calls into reverse-mode AD via a helper _vjp, and then reassembles gradients, optionally with auxiliary data.
Error Messages as Part of the API
A recurring theme across autodiff helpers is that validation errors are written as teaching moments. For example, input dtype checks for reverse-mode (
_check_input_dtype_revderiv
) don't just say “wrong dtype” — they tell you what to do instead:
Reverse-mode input dtype validation snippet
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
dispatch.check_arg(x)
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
if isinstance(aval, ShapedArray):
if (dtypes.issubdtype(aval.dtype, dtypes.extended) or
dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
raise TypeError(f"{name} requires real- or complex-valued inputs ... "
"If you want to use Boolean- or integer-valued inputs, use vjp "
"or set allow_int to True.")
The pattern is always the same:
- Check invariants early (scalar outputs for
grad, dtype compatibility, PyTree structure). - Point to alternative APIs when the invariant doesn’t hold (
vjp,jvp, or flags likeholomorphic=True,allow_int=True).
Jacobian and Hessian: Composition over Cleverness
jacfwd and jacrev are forward- and reverse-mode Jacobian builders. Rather than inventing custom machinery, they assemble existing parts:
- Wrap the function with debug metadata.
- Partially apply over
argnums. - Use
vmapoverjvporvjpon basis vectors produced by_std_basis. - Unravel the dense Jacobian back into the PyTree block structure.
hessian goes one step further and defines itself as jacfwd(jacrev(...)). Algorithmically, that’s expensive — and the docstring is very explicit about the O(n²) memory — but architecturally, it's beautifully simple. The transformation machine stays composable.
Vectorization and Parallelism Without Losing Your Mind
So far we've focused on scalar-like transformations over function behavior (differentiate, linearize). JAX also needs to transform how functions map over data: vectorization with vmap and SPMD parallelism with pmap. The same skeleton—wrap, flatten, dispatch—shows up again, but the interesting story here is how axis and shape validation is handled.
vmap: Axis Specs as a Contract
The core vmap implementation starts by aggressively validating in_axes and out_axes:
@partial(api_boundary, repro_api_name="jax.vmap")
def vmap(fun: F,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None) -> F:
check_callable(fun)
...
if isinstance(in_axes, list):
in_axes = tuple(in_axes)
if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}):
raise TypeError("vmap in_axes must be an int, None, or a tuple ...")
if not all(type(l) in {int, *batching.spec_types} for l in tree_leaves(in_axes)):
raise TypeError("vmap in_axes must be an int, None, or (nested) container ...")
if not all(type(l) in {int, *batching.spec_types} for l in tree_leaves(out_axes)):
raise TypeError("vmap out_axes must be an int, None, or (nested) container ...")
vmap establishes a strict, but well‑documented, contract for axis specs.
Inside the actual vmap_f closure, we see the familiar routine: flatten arguments into a PyTree, wrap the function, flatten it again for vmap, and broadcast/flatten the axis specifications to match the tree. One particularly instructive helper is _mapped_axis_size, used both by vmap and pmap to infer the batch size and to craft detailed mismatch errors.
def _mapped_axis_size(fn, tree, vals, dims, name):
if not vals:
args, kwargs = tree_unflatten(tree, vals)
raise ValueError(
f"{name} wrapped function must be passed at least one argument "
f"containing an array, got empty *args={args} and **kwargs={kwargs}")
...
sizes = core.dedup_referents(_get_axis_size(name, np.shape(x), d)
for x, d in zip(vals, dims) if d is not None)
if len(sizes) == 1:
sz, = sizes
return sz
if not sizes:
raise ValueError(f"{name} must have at least one non-None value in in_axes")
# Build a multi-line, structured mismatch explanation
...
raise ValueError(''.join(msg)[:-2])
_mapped_axis_size separates what went wrong (sizes differ) from a detailed explanation of where and how.
Notice how the core computation (deduplicating axis sizes) is relatively simple, but a large chunk of the function is dedicated to constructing a human-readable error that points to argument names and paths. This is deliberate: vmap failures can be maddening without good diagnostics.
pmap: Orchestrating Devices Without Owning Them
pmap adds another dimension: actual hardware devices and potentially multiple hosts. The semantics are similar to vmap (“map a function over an axis”), but the implementation has to reason about axis sizes, device lists, backends, and even migration between old and new implementations.
The public pmap function itself follows the same facade philosophy as jit:
- Reject deprecated options (
global_arg_shapes). - Optionally delegate to a newer implementation in
jax._src.pmapbased on a feature flag (config.pmap_shmap_merge). - Otherwise, route to the legacy C++ fastpath via
_cpp_pmap.
The heavy logic lives in helpers like _prepare_pmap, _shared_code_pmap, and the interaction with pxla and pmap_lib. What's notable from a design perspective is how the API function itself remains readable: you can grasp what pmap promises without understanding every caching and fastpath detail.
The report calls out this area as one of the most complex parts of the file, and suggests pushing the preparation/fastpath decision behind a single _pmap_impl helper. That kind of encapsulation is what keeps a central API file from collapsing under its own weight as features evolve.
Owning Device Placement Without Owning Devices
Beyond transformations, api.py also defines how users move data between host and devices. Again, it doesn't actually implement transports; it shapes and validates the contracts around them.
device_put: Sharding, Donation, and Aliasing
The core device_put helper is a great example of balancing flexibility with strict safety. It lets you specify, in PyTree form, target devices/shardings, source shardings, and copy semantics (donation vs aliasing) and then enforces invariants before delegating to dispatch.
def device_put(
x,
device: None | xc.Device | Sharding | P | Format | Any = None,
*, src: None | xc.Device | Sharding | P | Format | Any = None,
donate: bool | Any = False, may_alias: bool | None | Any = None):
with config.explicit_device_put_scope():
x_flat, treedef = tree_flatten(x)
...
if isinstance(donate, bool):
donate_flat = [donate] * len(x_flat)
else:
donate_flat = flatten_axes("device_put donate", treedef, donate)
if isinstance(may_alias, bool):
may_alias_flat = [may_alias] * len(x_flat)
else:
may_alias_flat = flatten_axes("device_put may_alias", treedef, may_alias)
copy_semantics = []
for m, d in zip(may_alias_flat, donate_flat):
if m and d:
raise ValueError('may_alias and donate cannot be True at the same time.')
if m is None:
m = not d
if m and not d:
copy_semantics.append(dispatch.ArrayCopySemantics.REUSE_INPUT)
elif not m and d:
copy_semantics.append(dispatch.ArrayCopySemantics.DONATE_INPUT)
else:
copy_semantics.append(dispatch.ArrayCopySemantics.ALWAYS_COPY)
dst_avals = []
for xf, d in zip(x_flat, device_flat):
aval = shaped_abstractify(xf)
aval = dispatch.update_dp_aval(aval, d)
dst_avals.append(aval)
_check_sharding(aval, d)
if core.trace_state_clean():
out_flat = dispatch._batched_device_put_impl(...)
else:
out_flat = dispatch.device_put_p.bind(...)
return tree_unflatten(treedef, out_flat)
device_put normalizes PyTrees and copy semantics before delegating to runtime primitives.
A few design lessons emerge here:
- Tree-prefix semantics: many arguments (
device,src,donate,may_alias) are allowed to be either scalars or PyTrees that form a prefix ofx. The helperflatten_axesenforces this, with good error messages. - Copy semantics as an explicit enum: instead of encoding semantics in booleans alone, JAX builds an explicit
ArrayCopySemanticslist. That makes downstream dispatch simpler and easier to extend. - Validation before tracing: the function checks sharding compatibility, string-dtype rules, and device kinds (
_check_string_compatible_sharding) before actually binding primitives when possible.
device_get and Friends
The inverse operation, device_get, follows the same PyTree-first thinking. It optionally kicks off asynchronous copy_to_host_async calls and then uses tree_map to visit leaves, delegating either to extended dtypes or to __array__ implementations.
Helpers like device_put_sharded and device_put_replicated further specialize the semantics (“stack shards across devices” vs “replicate across devices”), but they still adhere to the same basic pattern: validate tree structure and consistency, construct an abstract aval + sharding spec, and then call into pxla.batched_device_put.
Introspection: Seeing the Program JAX Sees
Transformations are powerful, but debugging them can be opaque. api.py also provides introspection tools like make_jaxpr and eval_shape that let you inspect the traced form of your functions or compute output shapes without doing FLOPs.
make_jaxpr: A JAXIR Inspector
The implementation of make_jaxpr is a nice case study in reusing existing building blocks while maintaining user semantics:
@partial(api_boundary, repro_api_name="jax.make_japr")
def make_jaxpr(
fun: Callable,
static_argnums: int | Iterable[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
return_shape: bool = False,
abstracted_axes: Any | None = None,
) -> Callable[...]:
try:
hash(fun)
weakref.ref(fun)
except TypeError:
fun = partial(fun)
@wraps(fun)
@api_boundary
def make_jaxpr_f(*args, **kwargs):
with core.extend_axis_env_nd(axis_env or []):
traced = jit(fun, static_argnums=static_argnums,
abstracted_axes=abstracted_axes).trace(*args, **kwargs)
num_consts = traced._num_consts
if num_consts:
jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr, num_consts)
jaxpr = core.ClosedJaxpr(jaxpr_, traced._consts)
else:
jaxpr = traced.jaxpr
if return_shape:
out = [ShapeDtypeStruct(o.shape, o.dtype) for o in jaxpr.out_avals]
return jaxpr, tree_unflatten(tree_structure(traced.out_info), out)
return jaxpr
...
return make_jaxpr_f
make_jaxpr uses jit(...).trace() under the hood, then repairs const handling to match user expectations.
A few noteworthy touches:
- If the function isn't hashable/weakref-able, it's wrapped in
functools.partialto still serve as a cache key. - The function uses an axis environment so it can correctly model collectives (
pmapaxes) when building the jaxpr. - It corrects for a subtle behavior of
jit(moving consts into args) because users ofmake_jaxprexpect true consts.
eval_shape: Abstract Execution Without FLOPs
eval_shape is conceptually very simple: “run my function, but in a mode where values are abstract ShapeDtypeStruct objects instead of real arrays.” In implementation, it reuses jit(...).trace() in the general case, and fast-paths PjitFunction objects.
The key takeaway is that both introspection functions are thin adapters: they don't duplicate tracing logic; they control how that logic is exposed.
Operational Lessons: Caches, NaNs, and Metrics
A transformation machine is only useful in production if it can be observed and controlled. This file also exposes runtime utilities and hooks that are easy to overlook but important operationally.
NaN/Inf Debug Hooks: Global but Scoped
At the top of the file we find _nan_check_posthook, a hook that the C++ JIT and PMAP paths can call to check for NaNs/Infs in buffers after a computation. It's wired to config flags debug_nans and debug_infs through a Config object:
@api_boundary
def _nan_check_posthook(fun, args, kwargs, output):
buffers = []
for leaf in tree_leaves(output):
if hasattr(leaf, "addressable_shards"):
buffers.extend([shard.data for shard in leaf.addressable_shards])
try:
dispatch.check_special(pjit.jit_p.name, buffers)
except api_util.InternalFloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value
if hasattr(fun, '_fun'):
f = fun._fun
if getattr(f, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}")
api_util.maybe_recursive_nan_check(e, f, args, kwargs)
raise AssertionError("Unreachable") from e
else:
raise
Configuration hooks update the global or thread-local post-hook whenever debug flags change. The code report flags this as a coupling smell: NaN/Inf handling is mixed into the main API module and uses mutable global state that can be tricky in multithreaded contexts.
The suggested improvement is to extract this into a dedicated, well-documented debug module and keep api.py free from these concerns. The broader lesson: central facades should be very careful about owning global state; it's hard to reason about and test.
Caches and Cleanup
JAX compilation is expensive, and this file offers utilities to manage the lifecycle of compiled artifacts:
clear_caches()clears Python-level staging caches, C++ compiled executable caches forpjitandpmap, and the internalPjitFunctionCache.clear_backends()resets backend clients and caches so new backends can be created later.- An
@atexit-registeredclean_up()function calls both, then shuts down the distributed system if present.
From an operator’s perspective, these are escape hatches for long-lived processes (servers, notebooks) that might otherwise accumulate compiled programs and device memory. From a design perspective, they illustrate another pattern: surface global effects behind tiny, explicit functions rather than sprinkling them through the codebase.
What to Measure in the Transformation Layer
Even though this module doesn't emit metrics itself, the analysis suggests a few concrete metrics that align well with the responsibilities we've seen:
jit_compilation_time_seconds– to catch slow or regressing compilation of JIT/PMAP/PJIT paths.num_compilations_per_callable– to detect shape polymorphism or static-arg issues that cause repeated recompilation.device_to_host_bytes_per_second– to monitor data transfer throughput whendevice_put/device_getare used heavily.live_arrays_count_by_platform– usinglive_arrays()to spot potential leaks in device memory.pmap_global_axis_size_mismatch_errors– to flag misconfigurations in distributed pmap usage.
None of these require changes to api.py; they can be layered on externally by wrapping jit/pmap in your own observability hooks. But they align tightly with the transformation-machine responsibilities we've been exploring.
What We Can Steal for Our Own Code
Walking through jax/_src/api.py as a whole, we see a single, strong narrative: build a transformation machine around user functions by consistently wrapping, flattening, validating, and delegating. Even if you're not building an autodiff library, there are several concrete patterns worth copying.
1. Separate Contracts From Implementations
Functions like jit, grad, vmap, and pmap focus on:
- Signatures and overloads.
- Rich, example-filled docstrings.
- Front-loaded validation with educational error messages.
The actual algorithms live in interpreters like ad, batching, pxla, and pjit. This decoupling makes it easier to change the guts (e.g., migrate pmap to shard_map) without breaking user expectations.
2. Make Complex Structures First-Class (PyTrees, Axes, Shardings)
Instead of fighting the complexity of nested containers and axis specs, JAX embraces them as a first-class abstraction: PyTrees, flatten_axes, tree_flatten_with_path, etc. That lets every transformation share a common vocabulary and behavior for structured inputs and outputs.
In our own systems, we can define and standardize on such “structured value” abstractions instead of handling dicts/lists ad hoc in each function.
3. Treat Error Messages as Design Artefacts
Whether it's _mapped_axis_size describing axis mismatches, or autodiff dtype checks suggesting alternate APIs, this file treats errors as an opportunity to teach. The outcome is a much smoother developer experience for very sophisticated features.
4. Keep Global State at the Edges
Where global state is unavoidable (config flags, caches, NaN hooks), the API exposes tiny, explicit helpers (clear_caches, clear_backends) or uses scoped contexts (disable_jit, explicit_device_put_scope). The report suggests going even further by extracting some of these concerns into separate modules—a good reminder to keep central facades small and focused.
5. Design for Composition
Autodiff and vectorization in JAX build on each other: hessian as jacfwd(jacrev(...)), Jacobians using vmap over jvp/vjp, linearize reusing ad.linearize. That composability is only possible because APIs consistently adhere to the wrap/flatten/dispatch pattern and preserve PyTree contracts.
When we design transformation-like layers in our own code—whether that's caching, authorization, or multi-tenant routing—we can aim for the same compositional story: each layer should accept and return the same shape of function, plus metadata, so it can be stacked with others.
JAX's core API module is big, yes, and the report rightly calls out some monolithic smells and refactor opportunities. But underneath the size is a remarkably consistent architecture: a user-facing facade that treats functions as data, reshapes them through a series of predictable steps, and delegates the heavy work to well-defined interpreters and backends.
If we take just one lesson away, let it be this: transformation power comes from disciplined boundaries, not from magic. Once we start wrapping, flattening, validating, and dispatching in a consistent way, we can add surprisingly sophisticated capabilities without losing our minds—or our users.



