New Article

The Transformations Engine Behind JAX

By Mahmoud Zalt

We’re examining how JAX wires together its core transformations – gradients, JIT compilation, vectorization, and device movement – through a single public entry point: jax/_src/api.py. This file is the facade that turns a deep stack of interpreters and XLA backends into the familiar jax.jit, jax.grad, jax.vmap, and jax.device_put functions.

JAX itself is a numerical computing library that lets you take pure Python+NumPy style code and transform it: compile it, differentiate it, batch it, shard it. I’m Mahmoud Zalt, an AI solutions architect, and we’ll treat api.py as our lab specimen to understand one central idea: how a single, carefully designed “transformations engine” – built around pytrees and a shared IR (jaxpr) – lets you freely compose transformations while hiding enormous complexity.

We’ll map where api.py sits in the architecture, see how core transformations are layered on top of each other, look at vectorization as axis bookkeeping, study how data movement is exposed without surprises, and close with what this design implies for performance, observability, and your own APIs.

Where api.py Fits

api.py is the public door into JAX’s transformation system. Everything users call as jax.jit, jax.grad, jax.vmap, jax.device_put, jax.eval_shape, and similar flows through this file, then fans out into lower-level interpreters and backends.

jax/ (project root)
  |_ _src/
      |_ core.py          # JAX IR (jaxpr), avals, primitives
      |_ interpreters/
      |    |_ ad.py       # Autodiff interpreter
      |    |_ batching.py # vmap/batching interpreter
      |    |_ partial_eval.py
      |    |_ pxla.py     # Pjit / multi-device
      |_ dispatch.py      # Compilation & execution interface
      |_ tree_util.py     # PyTree definitions & helpers
      |_ api.py           # <== Public transformations & device APIs
      |_ sharding_impls.py
      |_ xla_bridge.py
api.py as facade: user calls enter here, then dispatch to interpreters and backends.

Architecturally, api.py is a Facade: it exposes friendly, documented functions and delegates to internal components like autodiff (ad), batching (batching), partial evaluation (pe), and compilation/dispatch (dispatch, pjit, xb, xc). It also owns user ergonomics: pytrees, configuration-driven behavior, and most error messages.

Once we see api.py as a facade over a shared transformations engine, it becomes clear why composability, error messages, and performance all have to be coordinated here, around one intermediate representation: jaxpr.

How Transformations Stack

The interesting part of api.py isn’t that jit or grad exist; it’s how they are implemented as small, predictable layers on top of jaxpr, and how they deliberately reuse one another. That’s what makes compositions like jit(vmap(grad(f))) behave sensibly.

jit: A Thin Public Shell

JIT compilation is surfaced as jax.jit, but the wrapper in api.py is intentionally thin. It normalizes options (static args, sharding, device/backend selection, donation) and hands everything to pjit.make_jit:

def jit(
  fun: Callable | NotSpecified = NotSpecified(), /, *,
  in_shardings: Any = sharding_impls.UNSPECIFIED,
  out_shardings: Any = sharding_impls.UNSPECIFIED,
  static_argnums: int | Sequence[int] | None = None,
  static_argnames: str | Iterable[str] | None = None,
  donate_argnums: int | Sequence[int] | None = None,
  donate_argnames: str | Iterable[str] | None = None,
  keep_unused: bool = False,
  device: xc.Device | None = None,
  backend: str | None = None,
  inline: bool = False,
  compiler_options: dict[str, Any] | None = None,
) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]:
  kwds = dict(
      in_shardings=in_shardings, out_shardings=out_shardings,
      static_argnums=static_argnums, static_argnames=static_argnames,
      donate_argnums=donate_argnums, donate_argnames=donate_argnames,
      keep_unused=keep_unused, device=device, backend=backend, inline=inline,
      compiler_options=compiler_options, use_resource_env=False)
  if isinstance(fun, NotSpecified):
    return lambda fun: pjit.make_jit(fun, **kwds)
  else:
    return pjit.make_jit(fun, **kwds)
jit focuses on signature and options; compilation logic lives in pjit and backends.

The design choice here is restraint: the public wrapper stays “dumb”. It owns user-facing semantics (decorator behavior, argument interpretation), then forwards to a focused implementation. That separation keeps the JIT contract stable even as compilation internals evolve.

grad Built on value_and_grad Built on vjp

Differentiation is implemented once, then reused. grad delegates to value_and_grad, which in turn builds on vjp (reverse-mode autodiff). The outer layer looks like this:

@partial(api_boundary, repro_api_name="jax.grad")
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
         has_aux: bool = False, holomorphic: bool = False,
         allow_int: bool = False,
         reduce_axes: Sequence[AxisName] = ()) -> Callable:
  if reduce_axes:
    raise NotImplementedError("reduce_axes argument to grad is deprecated")
  del reduce_axes
  value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
                                    holomorphic=holomorphic,
                                    allow_int=allow_int)

  @wraps(fun, docstr=docstr, argnums=argnums)
  @api_boundary
  def grad_f(*args, **kwargs):
    _, g = value_and_grad_f(*args, **kwargs)
    return g

  @wraps(fun, docstr=docstr, argnums=argnums)
  @api_boundary
  def grad_f_aux(*args, **kwargs):
    (_, aux), g = value_and_grad_f(*args, **kwargs)
    return g, aux

  return grad_f_aux if has_aux else grad_f
grad is “just” packaging; value_and_grad and vjp hold the real logic.

value_and_grad validates dtypes, enforces scalar outputs, calls vjp to build a backward pass, and returns both primal values and gradients. grad wraps that to either drop or expose auxiliary outputs.

The core decision is to treat vjp as the fundamental primitive and implement higher-level conveniences as thin, consistent layers. That keeps the implementation DRY and makes behavior across grad-family APIs easier to reason about.

vjp, linearize, and the Shared IR

vjp and linearize show the “transformations engine” idea most directly. Both:

  • Flatten pytrees into simple lists of leaves.
  • Call ad.linearize, which traces the Python function into a jaxpr – a compact IR for the computation – plus residuals.
  • Return closures you can reuse multiple times without re-tracing.

Conceptually, they turn a function into an explicit “forward tape + backward player” form over jaxpr. Other transformations don’t need to know how vjp works; they just consume or produce jaxpr. That’s the heart of the engine: pick one IR and make every transformation operate by reading and rewriting that IR.

The primary lesson from this section: centralize your transformations on a single, explicit intermediate representation. JAX uses jaxpr; once that is in place, jit, grad, vmap, and friends become composable layers instead of independent features.

Vectorization as Axis Bookkeeping

With gradients and JIT defined over jaxpr, vectorization (vmap) is where the abstraction is stress‑tested. vmap has to understand pytrees, axis semantics, and even distributed meshes, but present itself as “just batching” to users.

The vmap Contract

Semantically, vmap takes a function f and returns a new function that applies f in parallel across a batch axis. Implementing that over nested containers and sharded devices requires careful normalization of axis specs and shapes before delegating to the batching interpreter.

@partial(api_boundary, repro_api_name="jax.vmap")
def vmap(fun: F,
         in_axes: int | None | Sequence[Any] = 0,
         out_axes: Any = 0,
         axis_name: AxisName | None = None,
         axis_size: int | None = None,
         spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
         sum_match: bool = False
         ) -> F:
  if isinstance(in_axes, list):
    in_axes = tuple(in_axes)

  from jax._src import hijax
  if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}
          or isinstance(in_axes, hijax.MappingSpec)):
    raise TypeError("vmap in_axes must be an int, None, or a tuple ...")
  if not all(type(l) in {int, *batching.spec_types} or isinstance(l, hijax.MappingSpec)
             for l in tree_leaves(in_axes)):
    raise TypeError("vmap in_axes must be an int, None, or (nested) container ...")
  if not all(type(l) in {int, *batching.spec_types} or isinstance(l, hijax.MappingSpec)
             for l in tree_leaves(out_axes)):
    raise TypeError("vmap out_axes must be an int, None, or (nested) container ...")

  @wraps(fun, docstr=docstr)
  @api_boundary
  def vmap_f(*args, **kwargs):
    nonlocal spmd_axis_name
    if isinstance(in_axes, tuple) and len(in_axes) != len(args):
      raise ValueError("vmap in_axes must be an int, None, or a tuple ...")

    args_flat, in_tree  = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
    dbg = debug_info("vmap", fun, args, kwargs)
    api_util.check_no_transformed_refs_args(lambda: dbg, args_flat)
    f = lu.wrap_init(fun, debug_info=dbg)
    flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
    in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)

    if config.mutable_array_checks.value:
      avals = [None if d is None or batching.is_vmappable(x) else core.typeof(x)
               for x, d in zip(args_flat, in_axes_flat)]
      api_util.check_no_aliased_ref_args(lambda: dbg, avals, args_flat)

    axis_size_ = _mapped_axis_size(
        fun, in_tree, args_flat, in_axes_flat, "vmap", axis_size=axis_size)
    explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
    _check_ema_unmapped_args(explicit_mesh_axis, args_flat, in_axes_flat)

    axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
                                  explicit_mesh_axis)
    out_flat, inferred_out_axes = batching.batch(
        flat_fun, axis_data, in_axes_flat,
        lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
        sum_match=sum_match
    ).call_wrapped(*args_flat)

    return tree_unflatten(out_tree(), out_flat)
vmap normalizes axis specs, flattens pytrees, infers axis size, then delegates to the batching interpreter.

The responsibilities here are tightly interleaved:

  • Canonicalizing in_axes/out_axes across simple values and nested containers.
  • Flattening the argument tree so batching operates over lists of leaves.
  • Inferring the batch axis size and producing clear errors when shapes disagree.
  • Reconciling vectorization with sharding meshes via spmd_axis_name and explicit mesh axes.

Helpful Errors via _mapped_axis_size

When batch dimensions don’t line up, vmap doesn’t just say “sizes don’t match”. _mapped_axis_size walks every argument, determines the implied axis size, and then produces a narrative error that points to specific arguments, shapes, and sizes.

This reflects a broader pattern in api.py: error handling is treated as part of the design, not an afterthought. The logic that explains “what went wrong” is pulled into helpers so the core algorithm can stay focused on transformation semantics.

If you build your own transformation APIs, investing in helpers like _mapped_axis_size pays off. They keep the main path readable and give users precise feedback in the failure cases that matter most.

Data Movement Without Surprises

Transformations are only half the story; at scale, host↔device data movement is often the real bottleneck. api.py owns the public device_put, device_put_sharded, device_put_replicated, and device_get APIs, balancing ergonomics with precise control over sharding and copy semantics.

device_put: Sharding, Donation, Aliasing

You can think of device_put as a shipping dock: values arrive as pytrees on the host; the function flattens them, assigns shardings or devices, decides whether buffers may be reused or must be copied, and hands everything to the dispatcher.

def device_put(
    x,
    device: None | xc.Device | Sharding | P | Format | Any = None,
    *, src: None | xc.Device | Sharding | P | Format | Any = None,
    donate: bool | Any = False, may_alias: bool | None | Any = None):
  with config.explicit_device_put_scope():
    x_flat, treedef = tree_flatten(x)
    x_avals = [shaped_abstractify(x) for x in x_flat]
    ...
    device_flat = map(partial(pspec_to_sharding, 'device_put'), device_flat)
    src_flat = map(partial(pspec_to_sharding, 'device_put'), src_flat)

    if isinstance(donate, bool):
      donate_flat = [donate] * len(x_flat)
    else:
      donate_flat = flatten_axes("device_put donate", treedef, donate)

    if isinstance(may_alias, bool):
      may_alias_flat = [may_alias] * len(x_flat)
    else:
      may_alias_flat = flatten_axes("device_put may_alias", treedef, may_alias)

    copy_semantics = []
    for m, d in zip(may_alias_flat, donate_flat):
      if m and d:
        raise ValueError('may_alias and donate cannot be True at the same time.')
      if m is None:
        m = not d
      if m and not d:
        copy_semantics.append(dispatch.ArrayCopySemantics.REUSE_INPUT)
      elif not m and d:
        copy_semantics.append(dispatch.ArrayCopySemantics.DONATE_INPUT)
      else:
        copy_semantics.append(dispatch.ArrayCopySemantics.ALWAYS_COPY)

    dst_avals = []
    for x_aval, d in zip(x_avals, device_flat):
      aval = dispatch.update_dp_aval(x_aval, d)
      dst_avals.append(aval)
      _check_sharding(aval, d)

    if core.trace_state_clean():
      out_flat = dispatch._batched_device_put_impl(
          *x_flat, devices=device_flat, srcs=src_flat,
          copy_semantics=copy_semantics, dst_avals=dst_avals)
    else:
      out_flat = dispatch.device_put_p.bind(
          *x_flat, devices=tuple(device_flat), srcs=tuple(src_flat),
          copy_semantics=tuple(copy_semantics))

    return tree_unflatten(treedef, out_flat)
device_put flattens pytrees, infers shardings, enforces copy semantics, and then delegates to dispatch.

Subtleties handled here so users don’t have to think about them on every call include:

  • Pytree matching: device, src, donate, and may_alias can mirror the structure of x; flatten_axes keeps those shapes consistent.
  • Donation vs aliasing: donation means “you may destroy this buffer”; aliasing means “you may reuse this buffer, but don’t rely on copies”. They are mutually exclusive and get encoded as explicit ArrayCopySemantics values.
  • Sharding validation: _check_sharding ensures that shardings are compatible with value shapes and device types (e.g. string arrays pinned to CPU).

Sharded and Replicated Placement

device_put_sharded and device_put_replicated are higher‑level helpers built on the same ideas:

  • device_put_sharded takes one shard per device, checks compatibility, builds a Mesh and NamedSharding, and uses pxla.batched_device_put underneath.
  • device_put_replicated computes an “unmapped” abstract value for a replica axis, then broadcasts a single host buffer across devices with batched_device_put.

Extended dtypes (custom element types) plug into these paths via specialized hooks. The report recommends centralizing those hooks in dedicated helpers so that support for new dtypes stays consistent across all data‑movement APIs.

device_get: Async Then Sync

On the way back to the host, device_get first starts host copies asynchronously on all leaves (via copy_to_host_async when available), then walks the tree again to materialize each value using __array__() or extended dtype rules.

The same pattern shows up in block_until_ready and effects_barrier: kick off asynchronous work across the tree, then provide explicit synchronization points. From the facade’s perspective, this is where you’d add tracing, logging, or metrics for all host↔device traffic.

Data movement is often the main IO cost in accelerator workloads. By routing all host↔device transfers through a small set of functions in api.py, JAX makes those crossings explicit and observable without leaking backend details.

Performance, Debugging, and Observability

So far we’ve focused on how transformations and data movement are exposed. The same facade also shapes how JAX behaves at scale: how much Python overhead hot paths incur, and where you can attach operational insight.

Hot Paths and Python Overheads

The primary hot entry points described in the report are:

  • jit-wrapped functions for training and inference.
  • grad/value_and_grad/vjp for backprop.
  • vmap for batched execution.
  • device_put* and device_get for host↔device transfers.
  • eval_shape and make_jaxpr for meta‑programming and debugging.

On the Python side, these functions mostly perform pytree flattening/unflattening, argument validation, and small allocations, with complexity proportional to the number of leaves or arguments. The heavy work – FLOPs, compilation, large allocations – is delegated to compiled interpreters and XLA.

  • vmap costs roughly O(n_leaves + n_args) per call in Python to normalize axes and pytrees, then defers to the batching interpreter.
  • device_put is O(n_leaves) to build abstract values, shardings, and copy semantics, plus the actual transfer.
  • make_jaxpr and eval_shape are dominated by the cost of tracing the function but avoid real numeric computation.

Metrics That Matter at Scale

For production workloads, the report highlights several concrete metrics that naturally attach at the api.py layer:

Metric What it tells you Where it hooks
jax_compilation_time_seconds Time spent tracing and compiling transformed functions. jit, make_jaxpr, eval_shape.
jax_traced_jaxpr_size_nodes Approximate size of generated IR; reveals oversized traces. make_jaxpr (via ClosedJaxpr structure).
device_transfer_bytes_total Total host↔device bytes moved. device_put*, device_get.
jax_array_block_until_ready_calls Frequency of synchronization barriers. block_until_ready, effects_barrier.
live_arrays_count Number of live device buffers. live_arrays() from the backend.

Because all user-visible transformations and device APIs enter through api.py, you can instrument these metrics by wrapping the public functions – no need to modify interpreters or backends.

Debugging Hooks and Global Config

api.py also wires in runtime configuration for debugging and safety: NaN/Inf checking, JIT disabling, and cache/backends clearing. For example, _nan_check_posthook can be attached to the JIT runtime when config.debug_nans or config.debug_infs is enabled, inspecting buffers after execution and raising detailed floating‑point errors.

Similarly, disable_jit() toggles JIT behavior via global config while leaving primitive‑level compilation intact. That gives you an escape hatch for debugging shape or control‑flow issues without discarding the transformation engine entirely.

The overarching operational lesson is to keep the public API thin but give it enough hooks – metrics, debug flags, cache controls – so performance and correctness issues can be understood and influenced at the facade layer.

Conclusion: Building Your Own Transformations Engine

Walking through jax/_src/api.py shows more than a list of functions. It shows how a transformation‑centric design – built around pytrees, a single IR (jaxpr), and carefully layered wrappers – lets JAX expose powerful capabilities as simple, composable APIs.

Key Lessons You Can Reuse

  1. Center everything on one intermediate representation. In JAX, jaxpr is the small, explicit language that all major transformations produce or consume. Adopting a similar IR in your own systems prevents each feature from inventing its own ad‑hoc representation and makes stacking transformations feasible.
  2. Keep public wrappers thin and ergonomic. Functions like jit, grad, vmap, and device_put handle signatures, pytrees, and rich error messages, then defer to focused interpreters and backends. This separation keeps user contracts clear while allowing internals to evolve.
  3. Design for observability from the facade. By routing compilation, transformation, and data movement through a small set of public APIs, JAX gains natural choke‑points for metrics, logging, and debugging controls. Thinking about these from day one makes scale and operations far less painful.

If we treat api.py as JAX’s transformations engine, the central lesson is simple: choose where complexity lives. JAX concentrates it in a shared IR and a handful of interpreters, and keeps the public surface feather‑light but precise. That pattern is broadly reusable, whether you’re building ML libraries, data platforms, or internal tooling that needs to transform user code without overwhelming your users – or your future self.


View Permalink