Skip to home
المدونة

Zalt Blog

Deep Dives into Code & Architecture at Scale

How JAX Turns Ordinary Python Into a Transformation Machine

By محمود الزلط
Code Cracking
35m read
<

What does it really mean for JAX to turn ordinary Python into a transformation machine? This walks through how that shift changes how you think about code.

/>
How JAX Turns Ordinary Python Into a Transformation Machine - Featured blog post image

Most of us meet JAX through a few magical functions: jit, grad, vmap, pmap. They feel like small decorators you sprinkle on top of plain Python. But in reality, they form a carefully engineered transformation machine that reshapes your functions for differentiation, vectorization, and parallel execution.

In this article, we'll walk through the core API module of JAX and see how it builds that machine. I'm Mahmoud Zalt, and we'll focus on one central idea: you can design a powerful transformation layer by consistently wrapping, flattening, and validating user functions before they ever reach your runtime.

The Scene: One File, Many Transformations

Before we zoom into individual functions, we need to understand the terrain. The file in question, jax/_src/api.py, is the main facade that backs the public symbols you import as jax.jit, jax.grad, jax.vmap, and friends. It doesn't implement autodiff rules or GPU kernels; instead, it orchestrates a stack of interpreters and backends.

jax/_src/
├── core.py            (jaxpr, ShapedArray, Tracer abstractions)
├── interpreters/
│   ├── ad.py          (autodiff rules and JVP/VJP machinery)
│   ├── batching.py    (vmap batching rules)
│   ├── partial_eval.py (pe; linearize, jaxpr tracing)
│   └── pxla.py        (pmap/sharding lowering)
├── pjit.py            (jit/sharding implementation)
├── dispatch.py        (device_put, runtime tokens, primitives)
├── xla_bridge.py      (backend and device clients)
└── api.py             (this file: user-facing jit/grad/vmap/pmap/... facade)

User code
   |
   v
jax.jit / jax.grad / jax.vmap / jax.pmap / ...
   |
   v
jax._src.api (this module)
   |
   +--> wraps fun with lu.wrap_init, debug_info
   +--> flattens PyTrees via tree_util
   +--> selects interpreter: ad / batching / pxla / pjit / dispatch
            |
            v
        XLA backends (CPU/GPU/TPU via xla_client/xb)
jax._src.api as a facade layer between user code and the interpreter/backends stack.

So this one module is doing a lot: autodiff entrypoints, vectorization/parallelism (vmap, pmap), device movement (device_put, device_get), runtime utilities, and even NaN/Inf debug hooks. That sounds like a recipe for a ball of mud, yet the file stays surprisingly navigable.

The Pattern: Wrap, Flatten, Dispatch

Once we start looking at individual APIs, we see the same skeleton repeated with small variations. That skeleton is the real star of this file. It looks like this:

  1. Validate the callable and options.
  2. Flatten Python containers into PyTrees (nested lists/tuples/dicts with arrays at the leaves) and flatten any axis/device specs to match.
  3. Wrap the user function with metadata (name stack, debug info, static args) into a lu.WrappedFun.
  4. Pick the right interpreter (autodiff, batching, pmap, pjit, etc.).
  5. Post-process back to the original PyTree structure and enforce invariants.

The pay‑off of this pattern is enormous: new transformations can be added by reusing the same wrapping/flattening infrastructure, and users get consistent semantics and error messages across everything.

Example: jit as a Thin Front-End

JIT compilation feels like a heavy operation, but the Python wrapper in api.py is intentionally thin. It normalizes the options and hands everything to pjit:

def jit(
  fun: Callable | NotSpecified = NotSpecified(), /, *,
  in_shardings: Any = sharding_impls.UNSPECIFIED,
  out_shardings: Any = sharding_impls.UNSPECIFIED,
  static_argnums: int | Sequence[int] | None = None,
  static_argnames: str | Iterable[str] | None = None,
  donate_argnums: int | Sequence[int] | None = None,
  donate_argnames: str | Iterable[str] | None = None,
  keep_unused: bool = False,
  device: xc.Device | None = None,
  backend: str | None = None,
  inline: bool = False,
  abstracted_axes: Any | None = None,
  compiler_options: dict[str, Any] | None = None,
) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]:
  ...
  kwds = dict(
      in_shardings=in_shardings, out_shardings=out_shardings,
      static_argnums=static_argnums, static_argnames=static_argnames,
      donate_argnums=donate_argnums, donate_argnames=donate_argnames,
      keep_unused=keep_unused, device=device, backend=backend, inline=inline,
      abstracted_axes=abstracted_axes, compiler_options=compiler_options,
      use_resource_env=False)
  if isinstance(fun, NotSpecified):
    return lambda fun: pjit.make_jit(fun, **kwds)
  else:
    return pjit.make_jit(fun, **kwds)
jax.jit focuses on signature and ergonomics; pjit handles the heavy lifting.

The transformation we care about isn't encoded here at all; it's encoded in pjit and eventually in compiled XLA. This wrapper's job is to define how humans talk to JIT: decorator factory semantics, static/donated args, sharding hints, and consistent boundary tracing via @api_boundary.

Autodiff as a First-Class Facade

Nowhere is the transformation-machine idea clearer than in autodiff. Functions like grad, value_and_grad, jacfwd, jacrev, and hessian all build on the same underlying AD interpreters, but the public APIs each express a particular “view” on differentiation.

grad as a Thin View on value_and_grad

grad is often the first thing we call in JAX. It's a perfect example of how this module avoids duplicating logic by composing a more general transformation:

@partial(api_boundary, repro_api_name="jax.grad")
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
         has_aux: bool = False, holomorphic: bool = False,
         allow_int: bool = False,
         reduce_axes: Sequence[AxisName] = ()) -> Callable:
  if reduce_axes:
    raise NotImplementedError("reduce_axes argument to grad is deprecated")
  del reduce_axes
  value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
                                    holomorphic=holomorphic,
                                    allow_int=allow_int)

  @wraps(fun, docstr=docstr, argnums=argnums)
  @api_boundary
  def grad_f(*args, **kwargs):
    _, g = value_and_grad_f(*args, **kwargs)
    return g

  @wraps(fun, docstr=docstr, argnums=argnums)
  @api_boundary
  def grad_f_aux(*args, **kwargs):
    (_, aux), g = value_and_grad_f(*args, **kwargs)
    return g, aux

  return grad_f_aux if has_aux else grad_f
grad doesn't implement differentiation; it reuses value_and_grad and chooses the surface shape of the API.

The interesting work is in value_and_grad. It flattens the arguments, performs detailed dtype validation (holomorphic vs real-valued, integer handling), calls into reverse-mode AD via a helper _vjp, and then reassembles gradients, optionally with auxiliary data.

Error Messages as Part of the API

A recurring theme across autodiff helpers is that validation errors are written as teaching moments. For example, input dtype checks for reverse-mode ( _check_input_dtype_revderiv ) don't just say “wrong dtype” — they tell you what to do instead:

Reverse-mode input dtype validation snippet
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
  dispatch.check_arg(x)
  aval = core.get_aval(x)
  if holomorphic:
    if not dtypes.issubdtype(aval.dtype, np.complexfloating):
      raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
                      f"but got {aval.dtype.name}.")
  if isinstance(aval, ShapedArray):
    if (dtypes.issubdtype(aval.dtype, dtypes.extended) or
        dtypes.issubdtype(aval.dtype, np.integer) or
        dtypes.issubdtype(aval.dtype, np.bool_)):
      if not allow_int:
        raise TypeError(f"{name} requires real- or complex-valued inputs ... "
                        "If you want to use Boolean- or integer-valued inputs, use vjp "
                        "or set allow_int to True.")

The pattern is always the same:

  • Check invariants early (scalar outputs for grad, dtype compatibility, PyTree structure).
  • Point to alternative APIs when the invariant doesn’t hold (vjp, jvp, or flags like holomorphic=True, allow_int=True).

Jacobian and Hessian: Composition over Cleverness

jacfwd and jacrev are forward- and reverse-mode Jacobian builders. Rather than inventing custom machinery, they assemble existing parts:

  • Wrap the function with debug metadata.
  • Partially apply over argnums.
  • Use vmap over jvp or vjp on basis vectors produced by _std_basis.
  • Unravel the dense Jacobian back into the PyTree block structure.

hessian goes one step further and defines itself as jacfwd(jacrev(...)). Algorithmically, that’s expensive — and the docstring is very explicit about the O(n²) memory — but architecturally, it's beautifully simple. The transformation machine stays composable.

Vectorization and Parallelism Without Losing Your Mind

So far we've focused on scalar-like transformations over function behavior (differentiate, linearize). JAX also needs to transform how functions map over data: vectorization with vmap and SPMD parallelism with pmap. The same skeleton—wrap, flatten, dispatch—shows up again, but the interesting story here is how axis and shape validation is handled.

vmap: Axis Specs as a Contract

The core vmap implementation starts by aggressively validating in_axes and out_axes:

@partial(api_boundary, repro_api_name="jax.vmap")
def vmap(fun: F,
         in_axes: int | None | Sequence[Any] = 0,
         out_axes: Any = 0,
         axis_name: AxisName | None = None,
         axis_size: int | None = None,
         spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None) -> F:
  check_callable(fun)
  ...
  if isinstance(in_axes, list):
    in_axes = tuple(in_axes)

  if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}):
    raise TypeError("vmap in_axes must be an int, None, or a tuple ...")
  if not all(type(l) in {int, *batching.spec_types} for l in tree_leaves(in_axes)):
    raise TypeError("vmap in_axes must be an int, None, or (nested) container ...")
  if not all(type(l) in {int, *batching.spec_types} for l in tree_leaves(out_axes)):
    raise TypeError("vmap out_axes must be an int, None, or (nested) container ...")
vmap establishes a strict, but well‑documented, contract for axis specs.

Inside the actual vmap_f closure, we see the familiar routine: flatten arguments into a PyTree, wrap the function, flatten it again for vmap, and broadcast/flatten the axis specifications to match the tree. One particularly instructive helper is _mapped_axis_size, used both by vmap and pmap to infer the batch size and to craft detailed mismatch errors.

def _mapped_axis_size(fn, tree, vals, dims, name):
  if not vals:
    args, kwargs = tree_unflatten(tree, vals)
    raise ValueError(
        f"{name} wrapped function must be passed at least one argument "
        f"containing an array, got empty *args={args} and **kwargs={kwargs}")
  ...
  sizes = core.dedup_referents(_get_axis_size(name, np.shape(x), d)
                               for x, d in zip(vals, dims) if d is not None)
  if len(sizes) == 1:
    sz, = sizes
    return sz
  if not sizes:
    raise ValueError(f"{name} must have at least one non-None value in in_axes")

  # Build a multi-line, structured mismatch explanation
  ...
  raise ValueError(''.join(msg)[:-2])
_mapped_axis_size separates what went wrong (sizes differ) from a detailed explanation of where and how.

Notice how the core computation (deduplicating axis sizes) is relatively simple, but a large chunk of the function is dedicated to constructing a human-readable error that points to argument names and paths. This is deliberate: vmap failures can be maddening without good diagnostics.

pmap: Orchestrating Devices Without Owning Them

pmap adds another dimension: actual hardware devices and potentially multiple hosts. The semantics are similar to vmap (“map a function over an axis”), but the implementation has to reason about axis sizes, device lists, backends, and even migration between old and new implementations.

The public pmap function itself follows the same facade philosophy as jit:

  • Reject deprecated options (global_arg_shapes).
  • Optionally delegate to a newer implementation in jax._src.pmap based on a feature flag (config.pmap_shmap_merge).
  • Otherwise, route to the legacy C++ fastpath via _cpp_pmap.

The heavy logic lives in helpers like _prepare_pmap, _shared_code_pmap, and the interaction with pxla and pmap_lib. What's notable from a design perspective is how the API function itself remains readable: you can grasp what pmap promises without understanding every caching and fastpath detail.

The report calls out this area as one of the most complex parts of the file, and suggests pushing the preparation/fastpath decision behind a single _pmap_impl helper. That kind of encapsulation is what keeps a central API file from collapsing under its own weight as features evolve.

Owning Device Placement Without Owning Devices

Beyond transformations, api.py also defines how users move data between host and devices. Again, it doesn't actually implement transports; it shapes and validates the contracts around them.

device_put: Sharding, Donation, and Aliasing

The core device_put helper is a great example of balancing flexibility with strict safety. It lets you specify, in PyTree form, target devices/shardings, source shardings, and copy semantics (donation vs aliasing) and then enforces invariants before delegating to dispatch.

def device_put(
    x,
    device: None | xc.Device | Sharding | P | Format | Any = None,
    *, src: None | xc.Device | Sharding | P | Format | Any = None,
    donate: bool | Any = False, may_alias: bool | None | Any = None):
  with config.explicit_device_put_scope():
    x_flat, treedef = tree_flatten(x)
    ...
    if isinstance(donate, bool):
      donate_flat = [donate] * len(x_flat)
    else:
      donate_flat = flatten_axes("device_put donate", treedef, donate)

    if isinstance(may_alias, bool):
      may_alias_flat = [may_alias] * len(x_flat)
    else:
      may_alias_flat = flatten_axes("device_put may_alias", treedef, may_alias)

    copy_semantics = []
    for m, d in zip(may_alias_flat, donate_flat):
      if m and d:
        raise ValueError('may_alias and donate cannot be True at the same time.')
      if m is None:
        m = not d
      if m and not d:
        copy_semantics.append(dispatch.ArrayCopySemantics.REUSE_INPUT)
      elif not m and d:
        copy_semantics.append(dispatch.ArrayCopySemantics.DONATE_INPUT)
      else:
        copy_semantics.append(dispatch.ArrayCopySemantics.ALWAYS_COPY)

    dst_avals = []
    for xf, d in zip(x_flat, device_flat):
      aval = shaped_abstractify(xf)
      aval = dispatch.update_dp_aval(aval, d)
      dst_avals.append(aval)
      _check_sharding(aval, d)
    if core.trace_state_clean():
      out_flat = dispatch._batched_device_put_impl(...)
    else:
      out_flat = dispatch.device_put_p.bind(...)
    return tree_unflatten(treedef, out_flat)
device_put normalizes PyTrees and copy semantics before delegating to runtime primitives.

A few design lessons emerge here:

  • Tree-prefix semantics: many arguments (device, src, donate, may_alias) are allowed to be either scalars or PyTrees that form a prefix of x. The helper flatten_axes enforces this, with good error messages.
  • Copy semantics as an explicit enum: instead of encoding semantics in booleans alone, JAX builds an explicit ArrayCopySemantics list. That makes downstream dispatch simpler and easier to extend.
  • Validation before tracing: the function checks sharding compatibility, string-dtype rules, and device kinds (_check_string_compatible_sharding) before actually binding primitives when possible.

device_get and Friends

The inverse operation, device_get, follows the same PyTree-first thinking. It optionally kicks off asynchronous copy_to_host_async calls and then uses tree_map to visit leaves, delegating either to extended dtypes or to __array__ implementations.

Helpers like device_put_sharded and device_put_replicated further specialize the semantics (“stack shards across devices” vs “replicate across devices”), but they still adhere to the same basic pattern: validate tree structure and consistency, construct an abstract aval + sharding spec, and then call into pxla.batched_device_put.

Introspection: Seeing the Program JAX Sees

Transformations are powerful, but debugging them can be opaque. api.py also provides introspection tools like make_jaxpr and eval_shape that let you inspect the traced form of your functions or compute output shapes without doing FLOPs.

make_jaxpr: A JAXIR Inspector

The implementation of make_jaxpr is a nice case study in reusing existing building blocks while maintaining user semantics:

@partial(api_boundary, repro_api_name="jax.make_japr")
def make_jaxpr(
    fun: Callable,
    static_argnums: int | Iterable[int] = (),
    axis_env: Sequence[tuple[AxisName, int]] | None = None,
    return_shape: bool = False,
    abstracted_axes: Any | None = None,
) -> Callable[...]:
  try:
    hash(fun)
    weakref.ref(fun)
  except TypeError:
    fun = partial(fun)

  @wraps(fun)
  @api_boundary
  def make_jaxpr_f(*args, **kwargs):
    with core.extend_axis_env_nd(axis_env or []):
      traced = jit(fun, static_argnums=static_argnums,
                   abstracted_axes=abstracted_axes).trace(*args, **kwargs)
    num_consts = traced._num_consts
    if num_consts:
      jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr, num_consts)
      jaxpr = core.ClosedJaxpr(jaxpr_, traced._consts)
    else:
      jaxpr = traced.jaxpr
    if return_shape:
      out = [ShapeDtypeStruct(o.shape, o.dtype) for o in jaxpr.out_avals]
      return jaxpr, tree_unflatten(tree_structure(traced.out_info), out)
    return jaxpr
  ...
  return make_jaxpr_f
make_jaxpr uses jit(...).trace() under the hood, then repairs const handling to match user expectations.

A few noteworthy touches:

  • If the function isn't hashable/weakref-able, it's wrapped in functools.partial to still serve as a cache key.
  • The function uses an axis environment so it can correctly model collectives (pmap axes) when building the jaxpr.
  • It corrects for a subtle behavior of jit (moving consts into args) because users of make_jaxpr expect true consts.

eval_shape: Abstract Execution Without FLOPs

eval_shape is conceptually very simple: “run my function, but in a mode where values are abstract ShapeDtypeStruct objects instead of real arrays.” In implementation, it reuses jit(...).trace() in the general case, and fast-paths PjitFunction objects.

The key takeaway is that both introspection functions are thin adapters: they don't duplicate tracing logic; they control how that logic is exposed.

Operational Lessons: Caches, NaNs, and Metrics

A transformation machine is only useful in production if it can be observed and controlled. This file also exposes runtime utilities and hooks that are easy to overlook but important operationally.

NaN/Inf Debug Hooks: Global but Scoped

At the top of the file we find _nan_check_posthook, a hook that the C++ JIT and PMAP paths can call to check for NaNs/Infs in buffers after a computation. It's wired to config flags debug_nans and debug_infs through a Config object:

@api_boundary
def _nan_check_posthook(fun, args, kwargs, output):
  buffers = []
  for leaf in tree_leaves(output):
    if hasattr(leaf, "addressable_shards"):
      buffers.extend([shard.data for shard in leaf.addressable_shards])

  try:
    dispatch.check_special(pjit.jit_p.name, buffers)
  except api_util.InternalFloatingPointError as e:
    assert config.debug_nans.value or config.debug_infs.value
    if hasattr(fun, '_fun'):
      f = fun._fun
      if getattr(f, '_apply_primitive', False):
        raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}")
      api_util.maybe_recursive_nan_check(e, f, args, kwargs)
      raise AssertionError("Unreachable") from e
    else:
      raise
The NaN/Inf posthook inspects shards of the output and raises rich errors tied back to the original Python function.

Configuration hooks update the global or thread-local post-hook whenever debug flags change. The code report flags this as a coupling smell: NaN/Inf handling is mixed into the main API module and uses mutable global state that can be tricky in multithreaded contexts.

The suggested improvement is to extract this into a dedicated, well-documented debug module and keep api.py free from these concerns. The broader lesson: central facades should be very careful about owning global state; it's hard to reason about and test.

Caches and Cleanup

JAX compilation is expensive, and this file offers utilities to manage the lifecycle of compiled artifacts:

  • clear_caches() clears Python-level staging caches, C++ compiled executable caches for pjit and pmap, and the internal PjitFunctionCache.
  • clear_backends() resets backend clients and caches so new backends can be created later.
  • An @atexit-registered clean_up() function calls both, then shuts down the distributed system if present.

From an operator’s perspective, these are escape hatches for long-lived processes (servers, notebooks) that might otherwise accumulate compiled programs and device memory. From a design perspective, they illustrate another pattern: surface global effects behind tiny, explicit functions rather than sprinkling them through the codebase.

What to Measure in the Transformation Layer

Even though this module doesn't emit metrics itself, the analysis suggests a few concrete metrics that align well with the responsibilities we've seen:

  • jit_compilation_time_seconds – to catch slow or regressing compilation of JIT/PMAP/PJIT paths.
  • num_compilations_per_callable – to detect shape polymorphism or static-arg issues that cause repeated recompilation.
  • device_to_host_bytes_per_second – to monitor data transfer throughput when device_put/device_get are used heavily.
  • live_arrays_count_by_platform – using live_arrays() to spot potential leaks in device memory.
  • pmap_global_axis_size_mismatch_errors – to flag misconfigurations in distributed pmap usage.

None of these require changes to api.py; they can be layered on externally by wrapping jit/pmap in your own observability hooks. But they align tightly with the transformation-machine responsibilities we've been exploring.

What We Can Steal for Our Own Code

Walking through jax/_src/api.py as a whole, we see a single, strong narrative: build a transformation machine around user functions by consistently wrapping, flattening, validating, and delegating. Even if you're not building an autodiff library, there are several concrete patterns worth copying.

1. Separate Contracts From Implementations

Functions like jit, grad, vmap, and pmap focus on:

  • Signatures and overloads.
  • Rich, example-filled docstrings.
  • Front-loaded validation with educational error messages.

The actual algorithms live in interpreters like ad, batching, pxla, and pjit. This decoupling makes it easier to change the guts (e.g., migrate pmap to shard_map) without breaking user expectations.

2. Make Complex Structures First-Class (PyTrees, Axes, Shardings)

Instead of fighting the complexity of nested containers and axis specs, JAX embraces them as a first-class abstraction: PyTrees, flatten_axes, tree_flatten_with_path, etc. That lets every transformation share a common vocabulary and behavior for structured inputs and outputs.

In our own systems, we can define and standardize on such “structured value” abstractions instead of handling dicts/lists ad hoc in each function.

3. Treat Error Messages as Design Artefacts

Whether it's _mapped_axis_size describing axis mismatches, or autodiff dtype checks suggesting alternate APIs, this file treats errors as an opportunity to teach. The outcome is a much smoother developer experience for very sophisticated features.

4. Keep Global State at the Edges

Where global state is unavoidable (config flags, caches, NaN hooks), the API exposes tiny, explicit helpers (clear_caches, clear_backends) or uses scoped contexts (disable_jit, explicit_device_put_scope). The report suggests going even further by extracting some of these concerns into separate modules—a good reminder to keep central facades small and focused.

5. Design for Composition

Autodiff and vectorization in JAX build on each other: hessian as jacfwd(jacrev(...)), Jacobians using vmap over jvp/vjp, linearize reusing ad.linearize. That composability is only possible because APIs consistently adhere to the wrap/flatten/dispatch pattern and preserve PyTree contracts.

When we design transformation-like layers in our own code—whether that's caching, authorization, or multi-tenant routing—we can aim for the same compositional story: each layer should accept and return the same shape of function, plus metadata, so it can be stacked with others.

JAX's core API module is big, yes, and the report rightly calls out some monolithic smells and refactor opportunities. But underneath the size is a remarkably consistent architecture: a user-facing facade that treats functions as data, reshapes them through a series of predictable steps, and delegates the heavy work to well-defined interpreters and backends.

If we take just one lesson away, let it be this: transformation power comes from disciplined boundaries, not from magic. Once we start wrapping, flattening, validating, and dispatching in a consistent way, we can add surprisingly sophisticated capabilities without losing our minds—or our users.

Full Source Code

Here's the full source code of the file that inspired this article.
Read on GitHub

Unable to load source code

Thanks for reading! I hope this was useful. If you have questions or thoughts, feel free to reach out.

Content Creation Process: This article was generated via a semi-automated workflow using AI tools. I prepared the strategic framework, including specific prompts and data sources. From there, the automation system conducted the research, analysis, and writing. The content passed through automated verification steps before being finalized and published without manual intervention.

Mahmoud Zalt

About the Author

I’m Zalt, a technologist with 15+ years of experience, passionate about designing and building AI systems that move us closer to a world where machines handle everything and humans reclaim wonder.

Let's connect if you're working on interesting AI projects, looking for technical advice or want to discuss your career.

Support this content

Share this article