Skip to main content

The Training Conductor Behind Keras Models

How does Keras actually run your training under the hood? Think of a training conductor quietly coordinating every step behind model.fit().

Code Cracking
25m read
#Keras#DeepLearning#MachineLearning#NeuralNetworks
The Training Conductor Behind Keras Models - Featured blog post image

MENTORING

1:1 engineering mentorship.

Architecture, AI systems, career growth. Ongoing or one-off.

We're examining how Keras orchestrates model training behind the deceptively simple model.fit() API. Keras is the high-level deep learning interface built on top of TensorFlow, and its Model class is where training, evaluation, prediction, and checkpointing all come together. I'm Mahmoud Zalt, an AI software engineer, and we'll look at how this class acts as a training conductor — cleanly separating “what a single step does” from “how steps run at scale across devices, workers, and APIs.” That split is the core lesson, and we’ll see how it shapes extensibility, distribution, memory behavior, and operational concerns.

Model as a training conductor

The Keras training engine lives in keras_engine/engine/training.py. This file hosts the Model class, which intentionally centralizes almost everything related to training and I/O.

keras_engine/
  engine/
    base_layer.py
    training.py   <-- tf.keras.Model training & IO
    compile_utils.py
    data_adapter.py
    training_utils.py

Model.compile()
   |
   v
[LossesContainer, MetricsContainer, optimizer]
   |
   v
Model.fit()
   |
   +--> data_adapter.get_data_handler()
   |        (build Dataset / iterator)
   +--> Model.make_train_function()
            |
            +--> train_function(iterator)
                      |
                      +--> strategy.run(run_step)
                                 |
                                 +--> Model.train_step(data)
The Keras training engine as an orchestration layer around core TensorFlow primitives.

The Model class is a deliberate “god object” for training concerns. It owns:

  • Configuration: compile() wires optimizer, losses, metrics, and execution knobs like run_eagerly and steps_per_execution.
  • Loops: fit(), evaluate(), and predict() drive data handlers, callbacks, and distributed execution.
  • Steps: train_step(), test_step(), predict_step() define the per-batch math and are meant to be overridden.
  • Weights & persistence: save(), save_weights(), load_weights() coordinate SavedModel, checkpoints, and HDF5.

Think of Model as an orchestra conductor: it doesn’t implement the math of individual layers or optimizers, but it decides when and how everything plays together during training and inference.

The core pattern: step vs. loop

The dominant design move in this file is the strict separation between a step (one batch’s worth of work) and a loop (how those steps are executed across time and hardware). Nearly every advanced feature — custom training, distribution, and performance tuning — hangs off this split.

The step: one batch of semantics

The default train_step implementation is intentionally small and readable:

def train_step(self, data):
  """The logic for one training step."""
  # Normalize data structure.
  data = data_adapter.expand_1d(data)
  x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)

  # Forward pass.
  with backprop.GradientTape() as tape:
    y_pred = self(x, training=True)
    loss = self.compiled_loss(
        y, y_pred, sample_weight, regularization_losses=self.losses)

  # Backward pass.
  self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
  self.compiled_metrics.update_state(y, y_pred, sample_weight)

  # Package metrics.
  return_metrics = {}
  for metric in self.metrics:
    result = metric.result()
    if isinstance(result, dict):
      return_metrics.update(result)
    else:
      return_metrics[metric.name] = result
  return return_metrics
The default train_step: a single-batch contract, easy to override.

This is a textbook Template Method pattern: the base class defines the high-level algorithm, and subclasses override selected steps. Here, the contract is clear:

  • Input: a data object that has already been normalized by data_adapter.
  • Work: forward pass, loss computation, optimizer step, and metric updates.
  • Output: a metrics dictionary for callbacks and logging.

Within that contract you have freedom to implement multiple optimizers, gradient clipping, adversarial training, or custom logging — all without touching distribution, callbacks, or dataset handling.

The loop: how steps run at scale

make_train_function takes the pure per-batch train_step and turns it into an executable training loop that knows about distribution strategies, counters, summaries, and performance knobs like steps_per_execution:

def make_train_function(self):
  if self.train_function is not None:
    return self.train_function

  def step_function(model, iterator):
    def run_step(data):
      outputs = model.train_step(data)
      # Only increment if `train_step` succeeded.
      with ops.control_dependencies(_minimum_control_deps(outputs)):
        model._train_counter.assign_add(1)
      return outputs

    data = next(iterator)
    outputs = model.distribute_strategy.run(run_step, args=(data,))
    outputs = reduce_per_replica(outputs, self.distribute_strategy,
                                 reduction='first')
    write_scalar_summaries(outputs, step=model._train_counter)
    return outputs

  if self._steps_per_execution.numpy().item() == 1:
    def train_function(iterator):
      return step_function(self, iterator)
  else:
    def train_function(iterator):
      for _ in math_ops.range(self._steps_per_execution):
        outputs = step_function(self, iterator)
      return outputs

  if not self.run_eagerly:
    train_function = def_function.function(
        train_function, experimental_relax_shapes=True)
    self.train_tf_function = train_function

  self.train_function = train_function
  if self._cluster_coordinator:
    self.train_function = lambda it: self._cluster_coordinator.schedule(
        train_function, args=(it,))
  return self.train_function
The training loop wrapper: same train_step, different execution strategies.

Several design decisions show the value of the step/loop split:

  • Distribution-agnostic step: strategy.run(run_step, ...) executes train_step across replicas; the step itself is unaware of replica count or device type.
  • Configurable loop granularity: steps_per_execution lets you execute many steps inside one tf.function call, reducing Python overhead per batch.
  • Safe state updates: _minimum_control_deps ensures the training counter only advances if the step actually completed.
  • Caching and scheduling: the compiled train_function is cached and, when a ClusterCoordinator is present, scheduled onto workers.

The same pattern appears for evaluation and prediction: make_test_function and make_predict_function define loops that decide how to pull from iterators, how many steps to run per call, how to reduce or concatenate per-replica outputs, and whether to wrap everything in tf.function.

Distribution, reduction, and guardrails

Once step and loop are separated, distribution strategies can be layered on without contaminating per-batch logic. The remaining challenge is reconciling per-replica outputs and preventing unsupported usage patterns.

From per-replica outputs to normal tensors

Under a distribution strategy, strategy.run returns PerReplica objects: one tensor per replica, wrapped in a container. The helper reduce_per_replica converts these into regular tensors:

def reduce_per_replica(values, strategy, reduction='first'):
  """Reduce PerReplica objects."""
  def _reduce(v):
    if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy):
      return _multi_worker_concat(v, strategy)
    if not _is_per_replica_instance(v):
      return v
    elif reduction == 'first':
      return strategy.unwrap(v)[0]
    elif reduction == 'concat':
      if _is_tpu_multi_host(strategy):
        return _tpu_multi_host_concat(v, strategy)
      else:
        return concat(strategy.unwrap(v))
    else:
      raise ValueError('`reduction` must be "first" or "concat".')

  return nest.map_structure(_reduce, values)
Reducing distributed results: take the first replica or concatenate across all.

Two reduction modes matter in practice:

  • reduction='first' takes outputs from the first replica. This is enough for scalar logs and summaries during training.
  • reduction='concat' concatenates along the batch dimension, which is necessary for prediction outputs.

The implementation hides several infrastructure quirks:

  • Multi-worker all-reduce: _multi_worker_concat uses strategy.gather and stored shapes to keep cross-worker ordering consistent.
  • TPU multi-host layout: _tpu_multi_host_concat compensates for the difference between sharding order and unwrapping order on TPUs.
  • Data types: concat() knows how to combine SparseTensor, scalars, and dense tensors safely.

Guardrails against illegal combinations

The conductor also enforces guardrails to prevent confusing or undefined behavior. Two checks are particularly important:

  • _validate_compile: blocks TF1-style optimizers and enforces that model variables, metrics, and optimizer all live under the same strategy scope. This avoids subtle cross-scope bugs.
  • _disallow_inside_tf_function: prevents calling fit, evaluate, or predict inside a user-defined @tf.function.
Why fit() inside tf.function is rejected

High-level methods like fit() create and manage their own tf.function wrappers, dataset iterators, and callbacks. Nesting them inside another tf.function makes tracing and side-effects hard to reason about: retracing, callback invocation, and dataset exhaustion semantics can all become unpredictable. To avoid that, this file explicitly checks ops.inside_function() and raises with a clear error, nudging you to call the model directly inside tf.function instead.

Prediction and the memory cliff

The same conductor pattern is used for inference, but prediction exposes a scalability trade-off that many teams only discover in production: predict() is convenient but accumulates outputs in memory.

How predict() accumulates outputs

The prediction loop mirrors training at a high level: build a dataset, get a per-step function, and iterate. The key difference is how outputs are handled:

def predict(self, x, batch_size=None, ...):
  ...
  outputs = None
  with self.distribute_strategy.scope():
    data_handler = data_adapter.get_data_handler(...)
    ...
    self.predict_function = self.make_predict_function()
    self._predict_counter.assign(0)
    callbacks.on_predict_begin()
    batch_outputs = None
    for _, iterator in data_handler.enumerate_epochs():  # Single epoch.
      with data_handler.catch_stop_iteration():
        for step in data_handler.steps():
          callbacks.on_predict_batch_begin(step)
          tmp_batch_outputs = self.predict_function(iterator)
          if data_handler.should_sync:
            context.async_wait()
          batch_outputs = tmp_batch_outputs
          if outputs is None:
            outputs = nest.map_structure(
                lambda batch_output: [batch_output], batch_outputs)
          else:
            nest.map_structure_up_to(
                batch_outputs,
                lambda output, batch_output: output.append(batch_output),
                outputs, batch_outputs)
          callbacks.on_predict_batch_end(...)
    if batch_outputs is None:
      raise ValueError('Expect x to be a non-empty array or dataset.')
    callbacks.on_predict_end()
  all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
  return tf_utils.sync_to_numpy_or_python_type(all_outputs)
predict() collects every batch output in Python lists, then concatenates at the end.

Each batch output is appended to a Python list; only after the loop finishes does concat() run to produce final arrays. This design is ergonomic — callers get a single NumPy array — but creates a clear memory profile:

  • Memory grows roughly as O(N * O), where N is the number of samples and O is the per-sample output size.
  • Training and evaluation hold only one batch (plus metrics/optimizer state), so they scale mostly with batch size; prediction scales with dataset size.
Phase Data retained in memory Risk
fit() / evaluate() One batch + metrics/optimizer state Low, scales with batch size
predict() All batch outputs in lists, then concatenated High for very large datasets

A pattern for streaming predictions

The file itself hints at an alternative: call model(x) directly for small inputs. For large-scale inference, the same idea becomes a pattern where you reuse the step but own the loop:

# Illustrative example: streaming prediction
for batch_x in dataset:  # e.g., a tf.data.Dataset
  batch_y = model(batch_x, training=False)
  write_batch_to_disk_or_socket(batch_y)  # your custom sink
  # Do NOT accumulate all batch_y in a list.
Drive your own loop around model(x) when you need streaming or chunked outputs.

You’re still using the same forward-pass “step,” but your loop streams results to disk, a database, or a queue instead of building a single monolithic array. The report even suggests a potential future API surface — a streaming predict_* variant that yields batches — but the underlying idea is the same: keep the semantic step small and let callers choose their loop semantics.

Persistence and operational concerns

The conductor doesn’t just coordinate steps; it also controls how models are persisted and how the training loop behaves under load. Both aspects are wired through the same step/loop design.

Saving and loading: choosing the right format

Weight loading is centralized around a small helper, _detect_save_format, which decides how to interpret a filepath:

def _detect_save_format(filepath):
  filepath = path_to_string(filepath)
  if saving_utils.is_hdf5_filepath(filepath):
    return filepath, 'h5'

  if _is_readable_tf_checkpoint(filepath):
    save_format = 'tf'
  elif sm_loader.contains_saved_model(filepath):
    ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
                             sm_constants.VARIABLES_FILENAME)
    if _is_readable_tf_checkpoint(ckpt_path):
      filepath = ckpt_path
      save_format = 'tf'
    else:
      raise ValueError('Unable to load weights ...')
  else:
    save_format = 'h5'
  return filepath, save_format
Weight loading format detection: choose between HDF5 and TF checkpoint based on the path.

Two big ideas sit behind this helper:

  • HDF5 vs TensorFlow checkpoints: HDF5 uses a flat list of weights; TensorFlow checkpoints use the object graph (attributes on the model and sublayers). That’s why TF checkpoints are stricter about architecture compatibility, while HDF5 can load by name into different but compatible topologies.
  • Safety checks: certain strategies and settings are blocked from incompatible load paths, and loading HDF5 into an unbuilt subclassed model raises an explicit ValueError instead of failing later.

Again, the conductor owns the orchestration: when to snapshot the orchestra, how to restore it, and which incompatible combinations must be rejected up front.

Throughput, overhead, and useful metrics

For each batch, most wall time is spent in your model’s forward and backward passes, but the orchestration still matters at scale:

  • Python overhead: data handler iteration, callbacks, and tf.function entry/exit add a fixed cost per step.
  • Distribution overhead: strategy.run, strategy.gather, and cross-worker concatenation add cost proportional to replica count.
  • Logging and summaries: write_scalar_summaries writes metrics each step; too-frequent logging can noticeably reduce throughput.

steps_per_execution exists precisely to amortize that overhead by looping inside the compiled function. From an operational perspective, several metrics naturally map onto this design and help you reason about performance and behavior:

  • Training step latency: time per call to train_function (median and tail percentiles) to catch regressions in either model compute or orchestration.
  • Training throughput: samples per second processed by fit(), which implicitly includes distribution and callback overhead.
  • Prediction memory usage: tracking memory while predict() runs to surface the accumulation behavior before it causes out-of-memory errors.
  • Checkpoint write time: the duration of save() or save_weights(), especially important for large models where saving can eat into epoch time.
  • Replica synchronization time: time spent in synchronization primitives like reduce_per_replica and strategy.gather, to see whether scaling out actually helps.

Practical design takeaways

Viewed as a whole, training.py is a case study in using a training conductor to separate semantics from orchestration. That pattern is applicable far beyond Keras.

1. Keep “what a step means” separate from “how steps run”

  • Define a small, override-friendly step method that expresses a single unit of work (one training batch, one job, one request).
  • Keep retries, distribution, counters, logging, and tf.function in a separate orchestration layer that calls that step.
  • Avoid mixing loops and business logic if you care about testability and extensibility.

2. Isolate infrastructure quirks behind narrow helpers

  • Multi-worker ordering rules, TPU host layouts, and other platform details all live behind helpers like reduce_per_replica and _tpu_multi_host_concat.
  • Do the same in your systems: when you must handle platform-specific weirdness, hide it behind a tiny API with a clear contract.

3. Fail fast on unsupported combinations

  • Checks like _disallow_inside_tf_function and _validate_compile reject invalid usage with explicit errors instead of allowing subtle bugs.
  • Be explicit about which combinations your APIs support, and enforce those constraints at the conductor level.

4. Design clear extension points

  • train_step, test_step, predict_step, and the make_*_function family are documented as extension points, while lower-level helpers are kept internal.
  • In your own code, mark which methods are meant to be overridden and keep orchestration logic reusable across those customizations.

5. Offer streaming alternatives to “return everything” APIs

  • predict() is convenient but accumulates outputs in memory; the design naturally suggests a streaming alternative where callers own the loop.
  • Whenever you design a “give me all the results” API, consider also providing a batched or streaming variant that reuses the same step semantics.

The primary lesson from Keras’ training engine is straightforward: treat your core model as a training conductor. Keep step logic small and semantic, put orchestration in its own layer, fence off invalid combinations, and expose clear extension points. Once you adopt that pattern, complexity from distribution, persistence, and scaling has a place to live that doesn’t pollute the heart of your model logic.

CONSULTING

AI consulting. Strategy to production.

Architecture, implementation, team guidance.

Full Source Code

Here's the full source code of the file that inspired this article.
Read on GitHub

Thanks for reading! I hope this was useful. If you have questions or thoughts, feel free to reach out.

Content Creation Process: This article was generated via a semi-automated workflow using AI tools. I prepared the strategic framework, including specific prompts and data sources. From there, the automation system conducted the research, analysis, and writing. The content passed through automated verification steps before being finalized and published without manual intervention.

Mahmoud Zalt

About the Author

I’m Zalt, a technologist with 16+ years of experience, passionate about designing and building AI systems that move us closer to a world where machines handle everything and humans reclaim wonder.

Let's connect if you're working on interesting AI projects, looking for technical advice or want to discuss anything.

Support this content

Share this article