We're examining how Keras orchestrates model training behind the deceptively simple model.fit() API. Keras is the high-level deep learning interface built on top of TensorFlow, and its Model class is where training, evaluation, prediction, and checkpointing all come together. I'm Mahmoud Zalt, an AI software engineer, and we'll look at how this class acts as a training conductor — cleanly separating “what a single step does” from “how steps run at scale across devices, workers, and APIs.” That split is the core lesson, and we’ll see how it shapes extensibility, distribution, memory behavior, and operational concerns.
Model as a training conductor
The Keras training engine lives in keras_engine/engine/training.py. This file hosts the Model class, which intentionally centralizes almost everything related to training and I/O.
keras_engine/
engine/
base_layer.py
training.py <-- tf.keras.Model training & IO
compile_utils.py
data_adapter.py
training_utils.py
Model.compile()
|
v
[LossesContainer, MetricsContainer, optimizer]
|
v
Model.fit()
|
+--> data_adapter.get_data_handler()
| (build Dataset / iterator)
+--> Model.make_train_function()
|
+--> train_function(iterator)
|
+--> strategy.run(run_step)
|
+--> Model.train_step(data)
The Model class is a deliberate “god object” for training concerns. It owns:
- Configuration:
compile()wires optimizer, losses, metrics, and execution knobs likerun_eagerlyandsteps_per_execution. - Loops:
fit(),evaluate(), andpredict()drive data handlers, callbacks, and distributed execution. - Steps:
train_step(),test_step(),predict_step()define the per-batch math and are meant to be overridden. - Weights & persistence:
save(),save_weights(),load_weights()coordinate SavedModel, checkpoints, and HDF5.
Think of Model as an orchestra conductor: it doesn’t implement the math of individual layers or optimizers, but it decides when and how everything plays together during training and inference.
The core pattern: step vs. loop
The dominant design move in this file is the strict separation between a step (one batch’s worth of work) and a loop (how those steps are executed across time and hardware). Nearly every advanced feature — custom training, distribution, and performance tuning — hangs off this split.
The step: one batch of semantics
The default train_step implementation is intentionally small and readable:
def train_step(self, data):
"""The logic for one training step."""
# Normalize data structure.
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# Forward pass.
with backprop.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Backward pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Package metrics.
return_metrics = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
return return_metrics
train_step: a single-batch contract, easy to override.This is a textbook Template Method pattern: the base class defines the high-level algorithm, and subclasses override selected steps. Here, the contract is clear:
- Input: a
dataobject that has already been normalized bydata_adapter. - Work: forward pass, loss computation, optimizer step, and metric updates.
- Output: a metrics dictionary for callbacks and logging.
Within that contract you have freedom to implement multiple optimizers, gradient clipping, adversarial training, or custom logging — all without touching distribution, callbacks, or dataset handling.
The loop: how steps run at scale
make_train_function takes the pure per-batch train_step and turns it into an executable training loop that knows about distribution strategies, counters, summaries, and performance knobs like steps_per_execution:
def make_train_function(self):
if self.train_function is not None:
return self.train_function
def step_function(model, iterator):
def run_step(data):
outputs = model.train_step(data)
# Only increment if `train_step` succeeded.
with ops.control_dependencies(_minimum_control_deps(outputs)):
model._train_counter.assign_add(1)
return outputs
data = next(iterator)
outputs = model.distribute_strategy.run(run_step, args=(data,))
outputs = reduce_per_replica(outputs, self.distribute_strategy,
reduction='first')
write_scalar_summaries(outputs, step=model._train_counter)
return outputs
if self._steps_per_execution.numpy().item() == 1:
def train_function(iterator):
return step_function(self, iterator)
else:
def train_function(iterator):
for _ in math_ops.range(self._steps_per_execution):
outputs = step_function(self, iterator)
return outputs
if not self.run_eagerly:
train_function = def_function.function(
train_function, experimental_relax_shapes=True)
self.train_tf_function = train_function
self.train_function = train_function
if self._cluster_coordinator:
self.train_function = lambda it: self._cluster_coordinator.schedule(
train_function, args=(it,))
return self.train_function
train_step, different execution strategies.Several design decisions show the value of the step/loop split:
- Distribution-agnostic step:
strategy.run(run_step, ...)executestrain_stepacross replicas; the step itself is unaware of replica count or device type. - Configurable loop granularity:
steps_per_executionlets you execute many steps inside onetf.functioncall, reducing Python overhead per batch. - Safe state updates:
_minimum_control_depsensures the training counter only advances if the step actually completed. - Caching and scheduling: the compiled
train_functionis cached and, when aClusterCoordinatoris present, scheduled onto workers.
The same pattern appears for evaluation and prediction: make_test_function and make_predict_function define loops that decide how to pull from iterators, how many steps to run per call, how to reduce or concatenate per-replica outputs, and whether to wrap everything in tf.function.
Distribution, reduction, and guardrails
Once step and loop are separated, distribution strategies can be layered on without contaminating per-batch logic. The remaining challenge is reconciling per-replica outputs and preventing unsupported usage patterns.
From per-replica outputs to normal tensors
Under a distribution strategy, strategy.run returns PerReplica objects: one tensor per replica, wrapped in a container. The helper reduce_per_replica converts these into regular tensors:
def reduce_per_replica(values, strategy, reduction='first'):
"""Reduce PerReplica objects."""
def _reduce(v):
if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy):
return _multi_worker_concat(v, strategy)
if not _is_per_replica_instance(v):
return v
elif reduction == 'first':
return strategy.unwrap(v)[0]
elif reduction == 'concat':
if _is_tpu_multi_host(strategy):
return _tpu_multi_host_concat(v, strategy)
else:
return concat(strategy.unwrap(v))
else:
raise ValueError('`reduction` must be "first" or "concat".')
return nest.map_structure(_reduce, values)
Two reduction modes matter in practice:
reduction='first'takes outputs from the first replica. This is enough for scalar logs and summaries during training.reduction='concat'concatenates along the batch dimension, which is necessary for prediction outputs.
The implementation hides several infrastructure quirks:
- Multi-worker all-reduce:
_multi_worker_concatusesstrategy.gatherand stored shapes to keep cross-worker ordering consistent. - TPU multi-host layout:
_tpu_multi_host_concatcompensates for the difference between sharding order and unwrapping order on TPUs. - Data types:
concat()knows how to combineSparseTensor, scalars, and dense tensors safely.
Guardrails against illegal combinations
The conductor also enforces guardrails to prevent confusing or undefined behavior. Two checks are particularly important:
_validate_compile: blocks TF1-style optimizers and enforces that model variables, metrics, and optimizer all live under the same strategy scope. This avoids subtle cross-scope bugs._disallow_inside_tf_function: prevents callingfit,evaluate, orpredictinside a user-defined@tf.function.
Why fit() inside tf.function is rejected
High-level methods like fit() create and manage their own tf.function wrappers, dataset iterators, and callbacks. Nesting them inside another tf.function makes tracing and side-effects hard to reason about: retracing, callback invocation, and dataset exhaustion semantics can all become unpredictable. To avoid that, this file explicitly checks ops.inside_function() and raises with a clear error, nudging you to call the model directly inside tf.function instead.
Prediction and the memory cliff
The same conductor pattern is used for inference, but prediction exposes a scalability trade-off that many teams only discover in production: predict() is convenient but accumulates outputs in memory.
How predict() accumulates outputs
The prediction loop mirrors training at a high level: build a dataset, get a per-step function, and iterate. The key difference is how outputs are handled:
def predict(self, x, batch_size=None, ...):
...
outputs = None
with self.distribute_strategy.scope():
data_handler = data_adapter.get_data_handler(...)
...
self.predict_function = self.make_predict_function()
self._predict_counter.assign(0)
callbacks.on_predict_begin()
batch_outputs = None
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
with data_handler.catch_stop_iteration():
for step in data_handler.steps():
callbacks.on_predict_batch_begin(step)
tmp_batch_outputs = self.predict_function(iterator)
if data_handler.should_sync:
context.async_wait()
batch_outputs = tmp_batch_outputs
if outputs is None:
outputs = nest.map_structure(
lambda batch_output: [batch_output], batch_outputs)
else:
nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs, batch_outputs)
callbacks.on_predict_batch_end(...)
if batch_outputs is None:
raise ValueError('Expect x to be a non-empty array or dataset.')
callbacks.on_predict_end()
all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
return tf_utils.sync_to_numpy_or_python_type(all_outputs)
predict() collects every batch output in Python lists, then concatenates at the end.Each batch output is appended to a Python list; only after the loop finishes does concat() run to produce final arrays. This design is ergonomic — callers get a single NumPy array — but creates a clear memory profile:
- Memory grows roughly as
O(N * O), whereNis the number of samples andOis the per-sample output size. - Training and evaluation hold only one batch (plus metrics/optimizer state), so they scale mostly with batch size; prediction scales with dataset size.
| Phase | Data retained in memory | Risk |
|---|---|---|
fit() / evaluate() |
One batch + metrics/optimizer state | Low, scales with batch size |
predict() |
All batch outputs in lists, then concatenated | High for very large datasets |
A pattern for streaming predictions
The file itself hints at an alternative: call model(x) directly for small inputs. For large-scale inference, the same idea becomes a pattern where you reuse the step but own the loop:
# Illustrative example: streaming prediction
for batch_x in dataset: # e.g., a tf.data.Dataset
batch_y = model(batch_x, training=False)
write_batch_to_disk_or_socket(batch_y) # your custom sink
# Do NOT accumulate all batch_y in a list.
model(x) when you need streaming or chunked outputs.You’re still using the same forward-pass “step,” but your loop streams results to disk, a database, or a queue instead of building a single monolithic array. The report even suggests a potential future API surface — a streaming predict_* variant that yields batches — but the underlying idea is the same: keep the semantic step small and let callers choose their loop semantics.
Persistence and operational concerns
The conductor doesn’t just coordinate steps; it also controls how models are persisted and how the training loop behaves under load. Both aspects are wired through the same step/loop design.
Saving and loading: choosing the right format
Weight loading is centralized around a small helper, _detect_save_format, which decides how to interpret a filepath:
def _detect_save_format(filepath):
filepath = path_to_string(filepath)
if saving_utils.is_hdf5_filepath(filepath):
return filepath, 'h5'
if _is_readable_tf_checkpoint(filepath):
save_format = 'tf'
elif sm_loader.contains_saved_model(filepath):
ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
sm_constants.VARIABLES_FILENAME)
if _is_readable_tf_checkpoint(ckpt_path):
filepath = ckpt_path
save_format = 'tf'
else:
raise ValueError('Unable to load weights ...')
else:
save_format = 'h5'
return filepath, save_format
Two big ideas sit behind this helper:
- HDF5 vs TensorFlow checkpoints: HDF5 uses a flat list of weights; TensorFlow checkpoints use the object graph (attributes on the model and sublayers). That’s why TF checkpoints are stricter about architecture compatibility, while HDF5 can load by name into different but compatible topologies.
- Safety checks: certain strategies and settings are blocked from incompatible load paths, and loading HDF5 into an unbuilt subclassed model raises an explicit
ValueErrorinstead of failing later.
Again, the conductor owns the orchestration: when to snapshot the orchestra, how to restore it, and which incompatible combinations must be rejected up front.
Throughput, overhead, and useful metrics
For each batch, most wall time is spent in your model’s forward and backward passes, but the orchestration still matters at scale:
- Python overhead: data handler iteration, callbacks, and
tf.functionentry/exit add a fixed cost per step. - Distribution overhead:
strategy.run,strategy.gather, and cross-worker concatenation add cost proportional to replica count. - Logging and summaries:
write_scalar_summarieswrites metrics each step; too-frequent logging can noticeably reduce throughput.
steps_per_execution exists precisely to amortize that overhead by looping inside the compiled function. From an operational perspective, several metrics naturally map onto this design and help you reason about performance and behavior:
- Training step latency: time per call to
train_function(median and tail percentiles) to catch regressions in either model compute or orchestration. - Training throughput: samples per second processed by
fit(), which implicitly includes distribution and callback overhead. - Prediction memory usage: tracking memory while
predict()runs to surface the accumulation behavior before it causes out-of-memory errors. - Checkpoint write time: the duration of
save()orsave_weights(), especially important for large models where saving can eat into epoch time. - Replica synchronization time: time spent in synchronization primitives like
reduce_per_replicaandstrategy.gather, to see whether scaling out actually helps.
Practical design takeaways
Viewed as a whole, training.py is a case study in using a training conductor to separate semantics from orchestration. That pattern is applicable far beyond Keras.
1. Keep “what a step means” separate from “how steps run”
- Define a small, override-friendly step method that expresses a single unit of work (one training batch, one job, one request).
- Keep retries, distribution, counters, logging, and
tf.functionin a separate orchestration layer that calls that step. - Avoid mixing loops and business logic if you care about testability and extensibility.
2. Isolate infrastructure quirks behind narrow helpers
- Multi-worker ordering rules, TPU host layouts, and other platform details all live behind helpers like
reduce_per_replicaand_tpu_multi_host_concat. - Do the same in your systems: when you must handle platform-specific weirdness, hide it behind a tiny API with a clear contract.
3. Fail fast on unsupported combinations
- Checks like
_disallow_inside_tf_functionand_validate_compilereject invalid usage with explicit errors instead of allowing subtle bugs. - Be explicit about which combinations your APIs support, and enforce those constraints at the conductor level.
4. Design clear extension points
train_step,test_step,predict_step, and themake_*_functionfamily are documented as extension points, while lower-level helpers are kept internal.- In your own code, mark which methods are meant to be overridden and keep orchestration logic reusable across those customizations.
5. Offer streaming alternatives to “return everything” APIs
predict()is convenient but accumulates outputs in memory; the design naturally suggests a streaming alternative where callers own the loop.- Whenever you design a “give me all the results” API, consider also providing a batched or streaming variant that reuses the same step semantics.
The primary lesson from Keras’ training engine is straightforward: treat your core model as a training conductor. Keep step logic small and semantic, put orchestration in its own layer, fence off invalid combinations, and expose clear extension points. Once you adopt that pattern, complexity from distribution, persistence, and scaling has a place to live that doesn’t pollute the heart of your model logic.



