Skip to main content
المدونة

Zalt Blog

Deep Dives into Code & Architecture

AT SCALE

When Your Trainer Becomes an Orchestrator

By محمود الزلط
Code Cracking
20m read
<

When does a simple ML training loop stop being “just training” and start acting like an orchestrator for your whole system? This post digs into that shift.

/>
When Your Trainer Becomes an Orchestrator - Featured blog post image

CONSULTING

Learning to build ML systems?

From training pipelines to distributed inference — 1:1 mentoring on the patterns that actually matter in production ML.

Most of us start with a tiny training loop: a for over a DataLoader, a loss, an optimizer.step(), and we ship it. Then reality shows up with multi-GPU runs, out-of-memory errors, NaNs, resume logic, and time-limited jobs. Suddenly that cute loop wants to be an entire system.

We're examining how Ultralytics' BaseTrainer turns that simple loop into a robust training orchestrator. Ultralytics is the engine behind the YOLO family of vision models, where training has to work reliably across tasks, hardware setups, and production constraints. At the center of that engine is BaseTrainer, the class that owns the full training lifecycle.

I'm Mahmoud Zalt, an AI solutions architect. We’ll walk through how this trainer coordinates models, data, distributed runtimes, optimizers, and recovery logic, and how you can structure your own trainer to act as an orchestrator instead of a fragile loop.

Trainer as Orchestrator, Not Just a Loop

BaseTrainer is not a monolithic training script; it's an orchestration layer. It coordinates models, datasets, distributed training, optimizers, schedulers, EMA, and error recovery. The model, optimizer, and dataloader each know how to "play"; the trainer decides when and how they play together.

Architecturally, it follows the Template Method pattern: a base class defines the lifecycle, and subclasses fill in task-specific details. BaseTrainer owns the overall algorithm, while detection, segmentation, or classification trainers override hooks like get_model(), get_dataloader(), and preprocess_batch().

ultralytics/
  engine/
    trainer.py   <-- BaseTrainer (orchestration layer)
  data/
    utils.py     (dataset checks)
  nn/
    tasks.py     (load_checkpoint, model creation)
  optim/
    __init__.py  (MuSGD)
  utils/
    cfg.py       (get_cfg, get_save_dir)
    dist.py      (ddp_cleanup, generate_ddp_command)
    torch_utils.py (ModelEMA, attempt_compile, EarlyStopping, unwrap_model)
    plotting.py  (plot_results)
The trainer sits in the engine and delegates work to lower-level modules.

Wiring the Training World Together

The orchestration becomes clear when we follow the main call graph. All public callers go through train(), which either spawns DDP processes or runs the core routine _do_train().

BaseTrainer.train()
  ├─ if ddp: generate_ddp_command() → subprocess.run() → ddp_cleanup()
  └─ else: _do_train()
       ├─ _setup_ddp()           # multi-GPU
       ├─ _setup_train()
       │    ├─ setup_model() → get_model()
       │    ├─ attempt_compile()
       │    ├─ _build_train_pipeline()
       │    │    ├─ get_dataloader()
       │    │    └─ build_optimizer()
       │    ├─ get_validator()
       │    └─ resume_training()
       ├─ per-epoch loop
       │    ├─ scheduler.step()
       │    ├─ _model_train()
       │    ├─ per-batch loop
       │    │    ├─ preprocess_batch()
       │    │    ├─ model(...) / unwrap_model(model).loss(...)
       │    │    └─ optimizer_step()
       │    ├─ validate()
       │    ├─ _handle_nan_recovery()
       │    └─ save_model()
       └─ final_eval()
One public train(), many coordinated subsystems behind it.

Inside _setup_train(), the trainer normalizes configuration with get_cfg(), sets up devices and distributed training, builds or loads the model via setup_model(), and wraps it with EMA, AMP, and optional compilation. Then it builds the data and optimization pipeline.

The pipeline builder shows the orchestration style well:

def _build_train_pipeline(self):
    batch_size = self.batch_size // max(self.world_size, 1)

    self.train_loader = self.get_dataloader(
        self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
    )

    self.test_loader = self.get_dataloader(
        self.data.get("val") or self.data.get("test"),
        batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
        rank=LOCAL_RANK,
        mode="val",
    )

    self.accumulate = max(round(self.args.nbs / self.batch_size), 1)
    weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs

    iterations = math.ceil(
        len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)
    ) * self.epochs

    self.optimizer = self.build_optimizer(
        model=self.model,
        name=self.args.optimizer,
        lr=self.args.lr0,
        momentum=self.args.momentum,
        decay=weight_decay,
        iterations=iterations,
    )

    self._setup_scheduler()

Rather than burying decisions inside the model or dataset, the trainer glues them together using a few derived quantities: effective batch size, gradient accumulation, scaled weight decay, and a rough iteration budget. That makes the same orchestration logic reusable across very different tasks.

Resilience Built into the Loop

Once the wiring is solid, the next step is keeping long-running jobs alive under real-world failures: OOMs, NaNs, and wall-clock limits. This is where BaseTrainer stops being a control loop and becomes an operational system.

Automatic OOM Recovery by Tuning Batch Size

Out-of-memory errors on the first epoch are common when probing new models or hardware. Here, OOM is treated as a configuration problem (batch too big), not a fatal runtime error. The trainer shrinks the batch size and rebuilds the pipeline.

for i, batch in pbar:
    try:
        with autocast(self.amp):
            batch = self.preprocess_batch(batch)
            if self.args.compile:
                preds = self.model(batch["img"])
                loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
            else:
                loss, self.loss_items = self.model(batch)
            self.loss = loss.sum()
            if RANK != -1:
                self.loss *= self.world_size
            self.tloss = (
                self.loss_items if self.tloss is None else (self.tloss * i + self.loss_items) / (i + 1)
            )

        self.scaler.scale(self.loss).backward()

    except torch.cuda.OutOfMemoryError:
        if epoch > self.start_epoch or self._oom_retries >= 3 or RANK != -1:
            raise
        self._oom_retries += 1
        old_batch = self.batch_size
        self.args.batch = self.batch_size = max(self.batch_size // 2, 1)
        LOGGER.warning(
            f"CUDA out of memory with batch={old_batch}. "
            f"Reducing to batch={self.batch_size} and retrying ({self._oom_retries}/3)."
        )
        self._clear_memory()
        self._build_train_pipeline()
        self.scheduler.last_epoch = self.start_epoch - 1
        self.optimizer.zero_grad()
        break

The policy is simple:

  • Only first-epoch OOMs on single GPU are auto-handled; others are raised immediately.
  • Batch size is halved on each retry (down to 1), with at most three retries.
  • The trainer clears memory, rebuilds the pipeline, and restarts the epoch with a consistent scheduler state.

NaN Recovery as a First-Class Feature

Numerical problems are subtler than OOMs. A NaN can signal unstable loss, broken data, or a bug in augmentation. Here, the trainer again prefers resilience, but with stricter safeguards and clear failure modes.

def _handle_nan_recovery(self, epoch):
    loss_nan = self.loss is not None and not self.loss.isfinite()
    fitness_nan = self.fitness is not None and not np.isfinite(self.fitness)
    fitness_collapse = self.best_fitness and self.best_fitness > 0 and self.fitness == 0

    corrupted = RANK in {-1, 0} and loss_nan and (fitness_nan or fitness_collapse)
    reason = "Loss NaN/Inf" if loss_nan else "Fitness NaN/Inf" if fitness_nan else "Fitness collapse"

    if RANK != -1:  # DDP: broadcast decision
        broadcast_list = [corrupted if RANK == 0 else None]
        dist.broadcast_object_list(broadcast_list, 0)
        corrupted = broadcast_list[0]

    if not corrupted:
        return False

    if epoch == self.start_epoch or not self.last.exists():
        LOGGER.warning(f"{reason} detected but can not recover from last.pt...")
        return False

    self.nan_recovery_attempts += 1
    if self.nan_recovery_attempts > 3:
        raise RuntimeError(
            f"Training failed: NaN persisted for {self.nan_recovery_attempts} epochs"
        )

    LOGGER.warning(
        f"{reason} detected (attempt {self.nan_recovery_attempts}/3), recovering from last.pt..."
    )

    self._model_train()
    _, ckpt = load_checkpoint(self.last)
    ema_state = ckpt["ema"].float().state_dict()
    if not all(torch.isfinite(v).all() for v in ema_state.values() if isinstance(v, torch.Tensor)):
        raise RuntimeError(f"Checkpoint {self.last} is corrupted with NaN/Inf weights")

    unwrap_model(self.model).load_state_dict(ema_state)
    self._load_checkpoint_state(ckpt)
    self.scheduler.last_epoch = epoch - 1
    return True

Design decisions embedded here:

  • NaNs are detected both on raw loss and on derived fitness, catching both direct and indirect instability.
  • In DDP, rank 0 decides whether the run is corrupted and broadcasts that decision, so all workers stay in sync.
  • The last checkpoint is treated as the "known good" state, but it's validated for finite weights before reuse.
  • Recovery is limited to three attempts; beyond that, the trainer fails loudly with a clear exception.

Time-Based Stopping

Many production runs are constrained by wall-clock time, not epochs. BaseTrainer supports a time budget (in hours) and monitors progress inside the loop. With args.time set, it estimates epoch duration from observed timings, adjusts self.epochs and the scheduler to fit within the remaining budget, and checks for budget exhaustion on optimizer steps and at epoch boundaries.

The effect is that jobs end gracefully within their time window: you still get validation, checkpoints, and consistent scheduler state, instead of an abrupt kill from the outside.

Smart Optimizer and Config Choices

The trainer also encodes operational experience into its defaults. Instead of asking users to specify every hyperparameter, it uses simple heuristics to choose reasonable optimizers and schedules from the training budget and dataset.

Auto-Choosing an Optimizer from Iteration Budget

The build_optimizer() method supports explicit choices, but optimizer="auto" delegates the decision to the trainer. It looks at the expected number of iterations and picks between AdamW and a custom MuSGD variant.

def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9,
                   decay=1e-5, iterations=1e5):
    g = [{}, {}, {}, {}]  # parameter groups
    bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)

    if name == "auto":
        LOGGER.info(
            f"{colorstr('optimizer:')} 'optimizer=auto' found, "
            f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
        )
        nc = self.data.get("nc", 10)
        lr_fit = round(0.002 * 5 / (4 + nc), 6)
        name, lr, momentum = ("MuSGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
        self.args.warmup_bias_lr = 0.0

    use_muon = name == "MuSGD"

    for module_name, module in unwrap_model(model).named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            fullname = f"{module_name}.{param_name}" if module_name else param_name
            if param.ndim >= 2 and use_muon:
                g[3][fullname] = param       # MuON params
            elif "bias" in fullname:
                g[2][fullname] = param       # biases
            elif isinstance(module, bn) or "logit_scale" in fullname:
                g[1][fullname] = param       # non-decayed params
            else:
                g[0][fullname] = param       # decayed weights

    if not use_muon:
        g = [x.values() for x in g[:3]]

    optimizer = getattr(optim, name, partial(MuSGD, muon=muon, sgd=sgd))(params=g)
    return optimizer

Parameters are split into groups (decayed weights, non-decayed weights, biases, optional MuON group). The trainer can then apply appropriate decay and learning rates per group, centralizing optimization strategy so that individual models don't need to know about it.

Checkpoint Content and Trade-Offs

Checkpointing is another place where orchestration decisions matter. The trainer doesn't just save weights; it captures enough context to reconstruct and audit a run.

def save_model(self):
    import io
    buffer = io.BytesIO()

    torch.save(
        {
            "epoch": self.epoch,
            "best_fitness": self.best_fitness,
            "model": None,
            "ema": deepcopy(unwrap_model(self.ema.ema)).half(),
            "updates": self.ema.updates,
            "optimizer": convert_optimizer_state_dict_to_fp16(
                deepcopy(self.optimizer.state_dict())
            ),
            "scaler": self.scaler.state_dict(),
            "train_args": vars(self.args),
            "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
            "train_results": self.read_results_csv(),
            "date": datetime.now().isoformat(),
            "version": __version__,
            "git": {
                "root": str(GIT.root),
                "branch": GIT.branch,
                "commit": GIT.commit,
                "origin": GIT.origin,
            },
            "license": "AGPL-3.0 (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        },
        buffer,
    )

    serialized_ckpt = buffer.getvalue()
    self.wdir.mkdir(parents=True, exist_ok=True)
    self.last.write_bytes(serialized_ckpt)

    if self.best_fitness == self.fitness:
        self.best.write_bytes(serialized_ckpt)

    if (self.save_period > 0) and (self.epoch % self.save_period == 0):
        (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt)

Alongside EMA weights and optimizer state, checkpoints include training arguments, metrics, Git metadata, license info, and a parsed copy of results.csv. This makes checkpoints self-contained experiment artifacts, but it also increases size and I/O cost as the CSV grows. The obvious refinement is to make history embedding configurable or store only a compact summary.

Practical Lessons You Can Steal

Stepping back, the pattern is consistent: BaseTrainer treats training as a system to orchestrate, not a tight inner loop to micro-optimize. That mindset shows up in how it centralizes lifecycle, encodes default strategies, and bakes resilience into the core flow.

There are a few concrete design moves you can apply directly:

  1. Centralize the lifecycle behind a trainer. Create a single object that owns configuration, setup, training, validation, checkpointing, and teardown. Expose abstract hooks like get_dataloader(), get_model(), and preprocess_batch() for task-specific behavior instead of duplicating loops across entrypoints.
  2. Handle instability as part of the design. OOM, NaN, and time limits are normal, not edge cases. Treat "too big" errors as opportunities to auto-tune (e.g., halve batch size on first-epoch OOM), and treat NaNs as triggers to roll back to the last known good checkpoint with a bounded number of retries.
  3. Encode optimization strategy once. Compute a rough iteration budget and use it to select optimizers and schedules. Group parameters for decay and learning rate inside the trainer. Let advanced users override, but make the default path informed by the training regime, not arbitrary constants.
  4. Make checkpoints useful, not just small. Save enough state to reproduce and audit a run: arguments, metrics, optimizer state, and some training history. Then watch size and frequency, and make the heavier pieces (like full CSV history) opt-in.
  5. Think in terms of orchestration. Once you view your trainer as the component that coordinates hardware, data, models, optimization, and failure recovery, features like EMA, DDP setup, auto-batch sizing, and time-based stopping stop feeling like extras. They become the core of a reliable training engine.

As your own projects move from experiments to production systems, shaping your trainer as an orchestrator like this will matter far more than the specific model you plug into it. The orchestration layer is what turns "a training loop" into an asset you can run, monitor, and trust.

Full Source Code

Here's the full source code of the file that inspired this article.
Read on GitHub

Thanks for reading! I hope this was useful. If you have questions or thoughts, feel free to reach out.

Content Creation Process: This article was generated via a semi-automated workflow using AI tools. I prepared the strategic framework, including specific prompts and data sources. From there, the automation system conducted the research, analysis, and writing. The content passed through automated verification steps before being finalized and published without manual intervention.

Mahmoud Zalt

About the Author

I’m Zalt, a technologist with 16+ years of experience, passionate about designing and building AI systems that move us closer to a world where machines handle everything and humans reclaim wonder.

Let's connect if you're working on interesting AI projects, looking for technical advice or want to discuss anything.

Support this content

Share this article