Skip to content

Training API

modern_yolonas.training.trainer.Trainer

YOLO-NAS training loop.

Parameters:

Name Type Description Default
model Module

YoloNAS model.

required
train_loader DataLoader

Training DataLoader.

required
val_loader DataLoader | None

Optional validation DataLoader.

None
num_classes int

Number of classes.

80
epochs int

Total training epochs.

300
lr float

Learning rate.

0.0002
optimizer_name str

Optimizer name ('adamw' or 'sgd').

'adamw'
weight_decay float

Weight decay.

1e-05
warmup_steps int

LR warmup steps.

1000
use_amp bool

Enable automatic mixed precision.

True
use_ema bool

Enable exponential moving average.

True
output_dir str | Path

Directory for checkpoints.

'runs/train'
device str | device

Training device.

'cuda'
local_rank int

Local rank for DDP (-1 for single GPU).

-1
class_names list[str] | None

Optional list of class names used when drawing validation images. When None the visualiser falls back to COCO names.

None
val_freq int

Run validation every n epochs (default: 10). Set to 1 to validate after every epoch, or to a larger value to reduce overhead.

10
val_vis_images int

Number of sample images to annotate and log to WandB / TensorBoard during each validation run. Set to 0 to disable image logging.

8
Source code in src/modern_yolonas/training/trainer.py
class Trainer:
    """YOLO-NAS training loop.

    Args:
        model: YoloNAS model.
        train_loader: Training DataLoader.
        val_loader: Optional validation DataLoader.
        num_classes: Number of classes.
        epochs: Total training epochs.
        lr: Learning rate.
        optimizer_name: Optimizer name ('adamw' or 'sgd').
        weight_decay: Weight decay.
        warmup_steps: LR warmup steps.
        use_amp: Enable automatic mixed precision.
        use_ema: Enable exponential moving average.
        output_dir: Directory for checkpoints.
        device: Training device.
        local_rank: Local rank for DDP (-1 for single GPU).
        class_names: Optional list of class names used when drawing validation
            images.  When ``None`` the visualiser falls back to COCO names.
        val_freq: Run validation every *n* epochs (default: 10).  Set to 1 to
            validate after every epoch, or to a larger value to reduce overhead.
        val_vis_images: Number of sample images to annotate and log to
            WandB / TensorBoard during each validation run.  Set to 0 to
            disable image logging.
    """

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader | None = None,
        num_classes: int = 80,
        epochs: int = 300,
        lr: float = 2e-4,
        optimizer_name: str = "adamw",
        weight_decay: float = 1e-5,
        warmup_steps: int = 1000,
        use_amp: bool = True,
        use_ema: bool = True,
        output_dir: str | Path = "runs/train",
        device: str | torch.device = "cuda",
        local_rank: int = -1,
        callbacks: list[Callback] | None = None,
        class_names: list[str] | None = None,
        val_freq: int = 10,
        val_vis_images: int = 8,
        gradient_accum: int = 1,
    ):
        self.epochs = epochs
        self.callbacks = callbacks or []
        self._stop_training = False
        self.device = torch.device(device)
        self.local_rank = local_rank
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.is_main = local_rank <= 0
        self.gradient_accum = max(1, gradient_accum)

        # DDP setup
        if local_rank >= 0:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = model.to(self.device)
            model = DDP(model, device_ids=[local_rank])
        else:
            model = model.to(self.device)

        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Loss
        self.criterion = PPYoloELoss(num_classes=num_classes)

        # Optimizer
        raw_model = model.module if isinstance(model, DDP) else model
        self.optimizer = create_optimizer(raw_model, optimizer_name, lr, weight_decay)

        # Scheduler
        total_steps = epochs * len(train_loader)
        self.scheduler = cosine_with_warmup(
            self.optimizer,
            warmup_steps=warmup_steps,
            total_steps=total_steps,
        )

        # AMP
        self.use_amp = use_amp
        self.scaler = torch.amp.GradScaler("cuda", enabled=use_amp) if use_amp else None

        # Speed: let cuDNN auto-select fastest convolution algorithms for
        # the fixed input size used throughout training.
        # if self.device.type == "cuda":
        #     torch.backends.cudnn.benchmark = True

        # EMA
        self.ema = ModelEMA(raw_model) if use_ema else None

        self.start_epoch = 0
        self.best_map = 0.0
        # Global optimisation step counter (incremented once per batch).
        # Persisted through checkpoints so that loggers keep a monotonic x-axis
        # when training is resumed.
        self.global_step = 0
        self.class_names = class_names
        self.val_freq = val_freq
        self.val_vis_images = val_vis_images

    def _fire(self, hook: str, *args, **kwargs):
        for cb in self.callbacks:
            getattr(cb, hook)(self, *args, **kwargs)

    def train(self):
        """Run training loop."""
        self._fire("on_train_start")
        for epoch in range(self.start_epoch, self.epochs):
            if self._stop_training:
                break
            self.model.train()

            if isinstance(self.train_loader.sampler, DistributedSampler):
                self.train_loader.sampler.set_epoch(epoch)

            epoch_loss = 0.0
            num_batches = 0

            self._fire("on_epoch_start", epoch)
            if self.is_main:
                console.print(f"\n[bold]Epoch {epoch + 1}/{self.epochs}[/bold]")

            for batch_idx, (images, targets) in enumerate(self.train_loader):
                is_last_batch = (batch_idx + 1 == len(self.train_loader))
                should_step = (batch_idx + 1) % self.gradient_accum == 0 or is_last_batch

                if batch_idx % self.gradient_accum == 0:
                    self.optimizer.zero_grad(set_to_none=True)

                images = images.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)

                with torch.amp.autocast(str(self.device), enabled=self.use_amp):
                    predictions = self.model(images)
                    loss, loss_dict = self.criterion(
                        predictions, targets,
                        input_size=(images.shape[2], images.shape[3]),
                        epoch=epoch,
                    )
                    loss = loss / self.gradient_accum

                if self.scaler is not None:
                    self.scaler.scale(loss).backward()
                    if should_step:
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                else:
                    loss.backward()
                    if should_step:
                        self.optimizer.step()

                self.scheduler.step()

                if should_step and self.ema is not None:
                    raw_model = self.model.module if isinstance(self.model, DDP) else self.model
                    self.ema.update(raw_model)

                unscaled_loss = loss.item() * self.gradient_accum
                epoch_loss += unscaled_loss
                num_batches += 1
                self.global_step += 1

                self._fire("on_batch_end", epoch, batch_idx, unscaled_loss, loss_dict)

                if self.is_main and (batch_idx + 1) % 50 == 0:
                    avg_loss = epoch_loss / num_batches
                    lr = self.optimizer.param_groups[0]["lr"]
                    console.print(
                        f"  [{batch_idx + 1}/{len(self.train_loader)}] "
                        f"loss={avg_loss:.4f} "
                        f"cls={loss_dict['cls_loss']:.4f} "
                        f"iou={loss_dict['iou_loss']:.4f} "
                        f"dfl={loss_dict['dfl_loss']:.4f} "
                        f"lr={lr:.6f}"
                    )

            avg_loss = epoch_loss / max(num_batches, 1)
            if self.is_main:
                console.print(f"  Epoch {epoch + 1} avg loss: {avg_loss:.4f}")

            self._fire("on_epoch_end", epoch, avg_loss)

            # Validation
            if self.val_loader is not None and self.is_main and (epoch + 1) % self.val_freq == 0:
                self._validate(epoch)

            # Save checkpoint
            if self.is_main:
                self._save_checkpoint(epoch)

        self._fire("on_train_end")

    @torch.no_grad()
    def _validate(self, epoch: int) -> dict[str, float]:
        """Run one pass over the validation set, computing losses and detection metrics.

        The model is called in training mode (so raw predictions are available for
        the loss) but inside :func:`torch.no_grad`, so no gradients are accumulated.
        Predicted boxes decoded in eval mode are used separately for mAP/mAR.

        Args:
            epoch: Zero-based epoch index (used for logging only).

        Returns:
            Dict with keys ``"val/loss"``, ``"val/cls_loss"``, ``"val/iou_loss"``,
            ``"val/dfl_loss"`` (averaged over batches) plus ``"val metrics/mAP"``,
            ``"val metrics/mAP_50"``, ``"val metrics/mAR_100"``.
        """
        raw_model = self.model.module if isinstance(self.model, DDP) else self.model
        eval_model = self.ema.ema if self.ema is not None else raw_model

        # Training mode gives raw predictions needed for the loss; no gradients
        # are computed because we are inside torch.no_grad().
        eval_model.train()

        det_metrics = DetectionMetrics(device=self.device)
        num_batches = 0
        vis_images: list = []

        loss_sum: dict[str, float] = {"total": 0.0, "cls": 0.0, "iou": 0.0, "dfl": 0.0}

        for images, targets in self.val_loader:
            images  = images.to(self.device, non_blocking=True)
            targets = targets.to(self.device, non_blocking=True)

            with torch.amp.autocast(str(self.device), enabled=self.use_amp):
                predictions = eval_model(images)
                loss, loss_dict = self.criterion(
                    predictions, targets,
                    input_size=(images.shape[2], images.shape[3]),
                    epoch=epoch,
                )

            loss_sum["total"] += loss.item()
            loss_sum["cls"]   += loss_dict.get("cls_loss", 0.0)
            loss_sum["iou"]   += loss_dict.get("iou_loss", 0.0)
            loss_sum["dfl"]   += loss_dict.get("dfl_loss", 0.0)

            # Decoded predictions (pred_bboxes, pred_scores) are the first element
            pred_bboxes, pred_scores = predictions[0]
            detections = postprocess(pred_bboxes, pred_scores)
            preds = [
                {
                    "boxes":  boxes.float(),
                    "scores": scores.float(),
                    "labels": labels.int(),
                }
                for boxes, scores, labels in detections
            ]

            # Convert collated targets [sum_N, 6] → per-image xyxy dicts
            batch_size = images.shape[0]
            img_h, img_w = images.shape[2], images.shape[3]
            target_list: list[dict] = []
            for i in range(batch_size):
                mask = targets[:, 0] == i
                t = targets[mask]
                if t.shape[0] == 0:
                    target_list.append({
                        "boxes":  torch.zeros(0, 4, device=self.device),
                        "labels": torch.zeros(0, dtype=torch.int, device=self.device),
                    })
                else:
                    cx = t[:, 2] * img_w
                    cy = t[:, 3] * img_h
                    bw = t[:, 4] * img_w
                    bh = t[:, 5] * img_h
                    boxes = torch.stack(
                        [cx - bw / 2, cy - bh / 2, cx + bw / 2, cy + bh / 2], dim=-1
                    )
                    target_list.append({
                        "boxes":  boxes,
                        "labels": t[:, 1].int(),
                    })

            det_metrics.update(preds, target_list)
            num_batches += 1

            # Collect visual samples from the first batch only
            if num_batches == 1 and self.val_vis_images > 0:
                n_vis = min(self.val_vis_images, images.shape[0])
                imgs_np = images[:n_vis].cpu().numpy()
                for idx in range(n_vis):
                    p_boxes  = preds[idx]["boxes"].cpu().numpy()
                    p_scores = preds[idx]["scores"].cpu().numpy()
                    p_labels = preds[idx]["labels"].cpu().numpy()
                    g_boxes  = target_list[idx]["boxes"].cpu().numpy()
                    g_labels = target_list[idx]["labels"].cpu().numpy()
                    vis_images.append(
                        annotate_validation_sample(
                            imgs_np[idx], p_boxes, p_scores, p_labels, g_boxes, g_labels,
                            class_names=self.class_names,
                        )
                    )

        # Restore eval mode for any subsequent inference
        eval_model.eval()

        n = max(num_batches, 1)
        map_results = det_metrics.compute()

        results = {
            # Losses — averaged across all validation batches
            "val/loss":     loss_sum["total"] / n,
            "val/cls_loss": loss_sum["cls"]   / n,
            "val/iou_loss": loss_sum["iou"]   / n,
            "val/dfl_loss": loss_sum["dfl"]   / n,
            # Detection metrics — in their own group
            "val_metrics/mAP":     map_results["mAP"],
            "val_metrics/mAP_50":  map_results["mAP_50"],
            "val_metrics/mAR_100": map_results["mAR_100"],
        }

        if self.is_main:
            console.print(
                f"  Val [{num_batches} batches] "
                f"loss={results['val/loss']:.4f}  "
                f"cls={results['val/cls_loss']:.4f}  "
                f"iou={results['val/iou_loss']:.4f}  "
                f"dfl={results['val/dfl_loss']:.4f}  "
                f"mAP={results['val_metrics/mAP']:.4f}  "
                f"mAP_50={results['val_metrics/mAP_50']:.4f}  "
                f"mAR_100={results['val_metrics/mAR_100']:.4f}"
            )

        # Persist the best checkpoint when mAP_50 improves
        if results["val_metrics/mAP_50"] > self.best_map:
            self.best_map = results["val_metrics/mAP_50"]
            self._save_checkpoint(epoch, is_best=True)

        self._fire("on_validation_end", epoch, results)
        if vis_images:
            self._fire("on_validation_images", epoch, vis_images)
        return results

    def _save_checkpoint(self, epoch: int, *, is_best: bool = False):
        raw_model = self.model.module if isinstance(self.model, DDP) else self.model
        state = {
            "epoch": epoch + 1,
            "model_state_dict": raw_model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "best_map": self.best_map,
            "global_step": self.global_step,
        }
        if self.ema is not None:
            state["ema"] = self.ema.state_dict()
        if self.scaler is not None:
            state["scaler_state_dict"] = self.scaler.state_dict()

        torch.save(state, self.output_dir / "last.pt")
        if is_best:
            torch.save(state, self.output_dir / "best.pt")
        if (epoch + 1) % 50 == 0:
            torch.save(state, self.output_dir / f"epoch_{epoch + 1}.pt")

    def resume(self, checkpoint_path: str | Path):
        """Resume training from checkpoint."""
        ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=True)

        raw_model = self.model.module if isinstance(self.model, DDP) else self.model
        raw_model.load_state_dict(ckpt["model_state_dict"])
        self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])
        self.start_epoch = ckpt["epoch"]
        self.best_map = ckpt.get("best_map", 0.0)
        self.global_step = ckpt.get("global_step", self.start_epoch * len(self.train_loader))

        if self.ema is not None and "ema" in ckpt:
            self.ema.load_state_dict(ckpt["ema"])
        if self.scaler is not None and "scaler_state_dict" in ckpt:
            self.scaler.load_state_dict(ckpt["scaler_state_dict"])

        console.print(f"Resumed from epoch {self.start_epoch}")

train()

Run training loop.

Source code in src/modern_yolonas/training/trainer.py
def train(self):
    """Run training loop."""
    self._fire("on_train_start")
    for epoch in range(self.start_epoch, self.epochs):
        if self._stop_training:
            break
        self.model.train()

        if isinstance(self.train_loader.sampler, DistributedSampler):
            self.train_loader.sampler.set_epoch(epoch)

        epoch_loss = 0.0
        num_batches = 0

        self._fire("on_epoch_start", epoch)
        if self.is_main:
            console.print(f"\n[bold]Epoch {epoch + 1}/{self.epochs}[/bold]")

        for batch_idx, (images, targets) in enumerate(self.train_loader):
            is_last_batch = (batch_idx + 1 == len(self.train_loader))
            should_step = (batch_idx + 1) % self.gradient_accum == 0 or is_last_batch

            if batch_idx % self.gradient_accum == 0:
                self.optimizer.zero_grad(set_to_none=True)

            images = images.to(self.device, non_blocking=True)
            targets = targets.to(self.device, non_blocking=True)

            with torch.amp.autocast(str(self.device), enabled=self.use_amp):
                predictions = self.model(images)
                loss, loss_dict = self.criterion(
                    predictions, targets,
                    input_size=(images.shape[2], images.shape[3]),
                    epoch=epoch,
                )
                loss = loss / self.gradient_accum

            if self.scaler is not None:
                self.scaler.scale(loss).backward()
                if should_step:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
            else:
                loss.backward()
                if should_step:
                    self.optimizer.step()

            self.scheduler.step()

            if should_step and self.ema is not None:
                raw_model = self.model.module if isinstance(self.model, DDP) else self.model
                self.ema.update(raw_model)

            unscaled_loss = loss.item() * self.gradient_accum
            epoch_loss += unscaled_loss
            num_batches += 1
            self.global_step += 1

            self._fire("on_batch_end", epoch, batch_idx, unscaled_loss, loss_dict)

            if self.is_main and (batch_idx + 1) % 50 == 0:
                avg_loss = epoch_loss / num_batches
                lr = self.optimizer.param_groups[0]["lr"]
                console.print(
                    f"  [{batch_idx + 1}/{len(self.train_loader)}] "
                    f"loss={avg_loss:.4f} "
                    f"cls={loss_dict['cls_loss']:.4f} "
                    f"iou={loss_dict['iou_loss']:.4f} "
                    f"dfl={loss_dict['dfl_loss']:.4f} "
                    f"lr={lr:.6f}"
                )

        avg_loss = epoch_loss / max(num_batches, 1)
        if self.is_main:
            console.print(f"  Epoch {epoch + 1} avg loss: {avg_loss:.4f}")

        self._fire("on_epoch_end", epoch, avg_loss)

        # Validation
        if self.val_loader is not None and self.is_main and (epoch + 1) % self.val_freq == 0:
            self._validate(epoch)

        # Save checkpoint
        if self.is_main:
            self._save_checkpoint(epoch)

    self._fire("on_train_end")

resume(checkpoint_path)

Resume training from checkpoint.

Source code in src/modern_yolonas/training/trainer.py
def resume(self, checkpoint_path: str | Path):
    """Resume training from checkpoint."""
    ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=True)

    raw_model = self.model.module if isinstance(self.model, DDP) else self.model
    raw_model.load_state_dict(ckpt["model_state_dict"])
    self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])
    self.start_epoch = ckpt["epoch"]
    self.best_map = ckpt.get("best_map", 0.0)
    self.global_step = ckpt.get("global_step", self.start_epoch * len(self.train_loader))

    if self.ema is not None and "ema" in ckpt:
        self.ema.load_state_dict(ckpt["ema"])
    if self.scaler is not None and "scaler_state_dict" in ckpt:
        self.scaler.load_state_dict(ckpt["scaler_state_dict"])

    console.print(f"Resumed from epoch {self.start_epoch}")

modern_yolonas.training.loss.PPYoloELoss

Bases: Module

Combined loss for YOLO-NAS training.

Components: - VarifocalLoss (classification) - GIoULoss (box regression, weighted by assigned scores) - DFLLoss (distribution focal loss, weighted by assigned scores)

Weighted sum: cls_weight * vfl + iou_weight * giou + dfl_weight * dfl All terms normalized by assigned_scores_sum (matching super-gradients).

Parameters:

Name Type Description Default
num_classes int

Number of object classes.

80
reg_max int

Distribution regression maximum.

16
cls_weight float

Classification loss weight.

1.0
iou_weight float

Box regression loss weight.

2.5
dfl_weight float

Distribution focal loss weight.

0.5
static_assigner_epochs int

Use ATSS for the first N epochs (0 to disable).

4
Source code in src/modern_yolonas/training/loss.py
class PPYoloELoss(nn.Module):
    """Combined loss for YOLO-NAS training.

    Components:
    - VarifocalLoss (classification)
    - GIoULoss (box regression, weighted by assigned scores)
    - DFLLoss (distribution focal loss, weighted by assigned scores)

    Weighted sum: cls_weight * vfl + iou_weight * giou + dfl_weight * dfl
    All terms normalized by ``assigned_scores_sum`` (matching super-gradients).

    Args:
        num_classes: Number of object classes.
        reg_max: Distribution regression maximum.
        cls_weight: Classification loss weight.
        iou_weight: Box regression loss weight.
        dfl_weight: Distribution focal loss weight.
        static_assigner_epochs: Use ATSS for the first N epochs (0 to disable).
    """

    def __init__(
        self,
        num_classes: int = 80,
        reg_max: int = 16,
        cls_weight: float = 1.0,
        iou_weight: float = 2.5,
        dfl_weight: float = 0.5,
        static_assigner_epochs: int = 4,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.reg_max = reg_max
        self.cls_weight = cls_weight
        self.iou_weight = iou_weight
        self.dfl_weight = dfl_weight
        self.static_assigner_epochs = static_assigner_epochs

        self.assigner = TaskAlignedAssigner()
        self.static_assigner = ATSSAssigner(topk=9) if static_assigner_epochs > 0 else None
        self.vfl = VarifocalLoss()
        self.giou_loss = GIoULoss()
        self.dfl_loss = DFLLoss()

    def _bbox2dist(self, anchor_points: Tensor, gt_bboxes: Tensor) -> Tensor:
        """Convert bounding boxes to distances from anchor points."""
        x1y1 = anchor_points - gt_bboxes[..., :2]
        x2y2 = gt_bboxes[..., 2:] - anchor_points
        dist = torch.cat([x1y1, x2y2], dim=-1)
        return dist.clamp(0, self.reg_max - 0.01)

    def forward(
        self,
        predictions: tuple,
        targets: Tensor,
        input_size: tuple[int, int] | None = None,
        epoch: int | None = None,
    ) -> tuple[Tensor, dict[str, float]]:
        """Compute loss.

        Args:
            predictions: ``(decoded_predictions, raw_predictions)`` from NDFLHeads in training mode.
                decoded_predictions: ``(pred_bboxes [B,N,4], pred_scores [B,N,C])``
                raw_predictions: ``(cls_logits [B,N,C], reg_distri [B,N,4*(reg_max+1)],
                                   anchors, anchor_points, num_anchors_list, stride_tensor)``
            targets: ``[sum(N_i), 6]`` with ``[batch_idx, class_id, x, y, w, h]`` (normalized xywh).
            input_size: ``(H, W)`` of the input image. If None, inferred from anchor grid.
            epoch: Current training epoch (used for ATSS warmup).

        Returns:
            (total_loss, loss_dict)
        """
        (pred_bboxes_decoded, pred_scores_decoded), (
            cls_logits,
            reg_distri,
            anchors,
            anchor_points,
            num_anchors_list,
            stride_tensor,
        ) = predictions

        batch_size = cls_logits.shape[0]
        device = cls_logits.device

        # Determine input image size for scaling normalized GT to pixel coords
        if input_size is not None:
            img_h, img_w = input_size[0], input_size[1]
        else:
            inferred = (anchor_points.max(dim=0).values + stride_tensor.min() / 2).clamp(min=1)
            img_w, img_h = inferred[0], inferred[1]

        # Validate GT class labels
        if targets.numel() > 0:
            class_ids = targets[:, 1]
            if (class_ids < 0).any() or (class_ids >= self.num_classes).any():
                logger.warning(
                    "GT class labels out of range [0, %d): min=%d, max=%d. "
                    "Check your dataset labels.",
                    self.num_classes,
                    int(class_ids.min().item()),
                    int(class_ids.max().item()),
                )

        # Prepare GT in format expected by assigner
        gt_labels_list = []
        gt_bboxes_list = []
        for b in range(batch_size):
            mask = targets[:, 0] == b
            if mask.any():
                t = targets[mask]
                gt_labels_list.append(t[:, 1:2])
                xc, yc, w, h = t[:, 2], t[:, 3], t[:, 4], t[:, 5]
                xc, w = xc * img_w, w * img_w
                yc, h = yc * img_h, h * img_h
                gt_bboxes_list.append(torch.stack([
                    xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
                ], dim=-1))
            else:
                gt_labels_list.append(torch.zeros(0, 1, device=device))
                gt_bboxes_list.append(torch.zeros(0, 4, device=device))

        max_gt = max(len(g) for g in gt_labels_list)
        if max_gt == 0:
            zero_loss = cls_logits.sum() * 0.0
            return zero_loss, {"cls_loss": 0.0, "iou_loss": 0.0, "dfl_loss": 0.0, "total_loss": 0.0}

        gt_labels = torch.zeros(batch_size, max_gt, 1, device=device)
        gt_bboxes = torch.zeros(batch_size, max_gt, 4, device=device)
        mask_gt = torch.zeros(batch_size, max_gt, 1, device=device)

        for b in range(batch_size):
            n = len(gt_labels_list[b])
            if n > 0:
                gt_labels[b, :n] = gt_labels_list[b]
                gt_bboxes[b, :n] = gt_bboxes_list[b]
                mask_gt[b, :n] = 1.0

        # Select assigner: ATSS for warmup, TAL after
        use_static = (
            self.static_assigner is not None
            and epoch is not None
            and epoch < self.static_assigner_epochs
        )

        if use_static:
            assigned_labels, assigned_bboxes, assigned_scores, fg_mask = self.static_assigner.assign(
                anchors, num_anchors_list, gt_labels, gt_bboxes, mask_gt, self.num_classes,
                pred_bboxes=pred_bboxes_decoded.detach(),
            )
        else:
            assigned_labels, assigned_bboxes, assigned_scores, fg_mask = self.assigner.assign(
                pred_scores_decoded, pred_bboxes_decoded, anchor_points,
                gt_labels, gt_bboxes, mask_gt,
            )

        # Normalization: sum of assigned soft scores (matches super-gradients)
        assigned_scores_sum = assigned_scores.sum().clamp(min=1)

        # Classification loss (VFL)
        cls_loss = self.vfl(cls_logits, assigned_scores, (assigned_scores > 0).float()) / assigned_scores_sum

        # Box regression loss (GIoU) — weighted by per-anchor assigned scores
        if fg_mask.any():
            pos_pred_bboxes = pred_bboxes_decoded[fg_mask]
            pos_target_bboxes = assigned_bboxes[fg_mask]
            bbox_weight = assigned_scores[fg_mask].sum(-1)  # [num_pos]
            iou_loss_per_anchor = self.giou_loss(pos_pred_bboxes, pos_target_bboxes)  # [num_pos]
            iou_loss = (iou_loss_per_anchor * bbox_weight).sum() / assigned_scores_sum
        else:
            iou_loss = torch.tensor(0.0, device=device)

        # DFL loss — weighted by per-anchor assigned scores
        if fg_mask.any():
            pos_reg_distri = reg_distri[fg_mask]
            pos_anchor_points = anchor_points.unsqueeze(0).expand(batch_size, -1, -1)[fg_mask]
            pos_stride = stride_tensor.unsqueeze(0).expand(batch_size, -1, -1)[fg_mask]
            pos_target_bboxes = assigned_bboxes[fg_mask] / pos_stride
            pos_anchor_points_scaled = pos_anchor_points / pos_stride
            target_dist = self._bbox2dist(pos_anchor_points_scaled, pos_target_bboxes)
            dfl_loss_per_anchor = self.dfl_loss(pos_reg_distri, target_dist)  # [num_pos]
            dfl_loss = (dfl_loss_per_anchor * bbox_weight).sum() / assigned_scores_sum
        else:
            dfl_loss = torch.tensor(0.0, device=device)

        total_loss = self.cls_weight * cls_loss + self.iou_weight * iou_loss + self.dfl_weight * dfl_loss

        loss_dict = {
            "cls_loss": cls_loss.item(),
            "iou_loss": iou_loss.item(),
            "dfl_loss": dfl_loss.item(),
            "total_loss": total_loss.item(),
        }

        return total_loss, loss_dict

forward(predictions, targets, input_size=None, epoch=None)

Compute loss.

Parameters:

Name Type Description Default
predictions tuple

(decoded_predictions, raw_predictions) from NDFLHeads in training mode. decoded_predictions: (pred_bboxes [B,N,4], pred_scores [B,N,C]) raw_predictions: (cls_logits [B,N,C], reg_distri [B,N,4*(reg_max+1)], anchors, anchor_points, num_anchors_list, stride_tensor)

required
targets Tensor

[sum(N_i), 6] with [batch_idx, class_id, x, y, w, h] (normalized xywh).

required
input_size tuple[int, int] | None

(H, W) of the input image. If None, inferred from anchor grid.

None
epoch int | None

Current training epoch (used for ATSS warmup).

None

Returns:

Type Description
tuple[Tensor, dict[str, float]]

(total_loss, loss_dict)

Source code in src/modern_yolonas/training/loss.py
def forward(
    self,
    predictions: tuple,
    targets: Tensor,
    input_size: tuple[int, int] | None = None,
    epoch: int | None = None,
) -> tuple[Tensor, dict[str, float]]:
    """Compute loss.

    Args:
        predictions: ``(decoded_predictions, raw_predictions)`` from NDFLHeads in training mode.
            decoded_predictions: ``(pred_bboxes [B,N,4], pred_scores [B,N,C])``
            raw_predictions: ``(cls_logits [B,N,C], reg_distri [B,N,4*(reg_max+1)],
                               anchors, anchor_points, num_anchors_list, stride_tensor)``
        targets: ``[sum(N_i), 6]`` with ``[batch_idx, class_id, x, y, w, h]`` (normalized xywh).
        input_size: ``(H, W)`` of the input image. If None, inferred from anchor grid.
        epoch: Current training epoch (used for ATSS warmup).

    Returns:
        (total_loss, loss_dict)
    """
    (pred_bboxes_decoded, pred_scores_decoded), (
        cls_logits,
        reg_distri,
        anchors,
        anchor_points,
        num_anchors_list,
        stride_tensor,
    ) = predictions

    batch_size = cls_logits.shape[0]
    device = cls_logits.device

    # Determine input image size for scaling normalized GT to pixel coords
    if input_size is not None:
        img_h, img_w = input_size[0], input_size[1]
    else:
        inferred = (anchor_points.max(dim=0).values + stride_tensor.min() / 2).clamp(min=1)
        img_w, img_h = inferred[0], inferred[1]

    # Validate GT class labels
    if targets.numel() > 0:
        class_ids = targets[:, 1]
        if (class_ids < 0).any() or (class_ids >= self.num_classes).any():
            logger.warning(
                "GT class labels out of range [0, %d): min=%d, max=%d. "
                "Check your dataset labels.",
                self.num_classes,
                int(class_ids.min().item()),
                int(class_ids.max().item()),
            )

    # Prepare GT in format expected by assigner
    gt_labels_list = []
    gt_bboxes_list = []
    for b in range(batch_size):
        mask = targets[:, 0] == b
        if mask.any():
            t = targets[mask]
            gt_labels_list.append(t[:, 1:2])
            xc, yc, w, h = t[:, 2], t[:, 3], t[:, 4], t[:, 5]
            xc, w = xc * img_w, w * img_w
            yc, h = yc * img_h, h * img_h
            gt_bboxes_list.append(torch.stack([
                xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
            ], dim=-1))
        else:
            gt_labels_list.append(torch.zeros(0, 1, device=device))
            gt_bboxes_list.append(torch.zeros(0, 4, device=device))

    max_gt = max(len(g) for g in gt_labels_list)
    if max_gt == 0:
        zero_loss = cls_logits.sum() * 0.0
        return zero_loss, {"cls_loss": 0.0, "iou_loss": 0.0, "dfl_loss": 0.0, "total_loss": 0.0}

    gt_labels = torch.zeros(batch_size, max_gt, 1, device=device)
    gt_bboxes = torch.zeros(batch_size, max_gt, 4, device=device)
    mask_gt = torch.zeros(batch_size, max_gt, 1, device=device)

    for b in range(batch_size):
        n = len(gt_labels_list[b])
        if n > 0:
            gt_labels[b, :n] = gt_labels_list[b]
            gt_bboxes[b, :n] = gt_bboxes_list[b]
            mask_gt[b, :n] = 1.0

    # Select assigner: ATSS for warmup, TAL after
    use_static = (
        self.static_assigner is not None
        and epoch is not None
        and epoch < self.static_assigner_epochs
    )

    if use_static:
        assigned_labels, assigned_bboxes, assigned_scores, fg_mask = self.static_assigner.assign(
            anchors, num_anchors_list, gt_labels, gt_bboxes, mask_gt, self.num_classes,
            pred_bboxes=pred_bboxes_decoded.detach(),
        )
    else:
        assigned_labels, assigned_bboxes, assigned_scores, fg_mask = self.assigner.assign(
            pred_scores_decoded, pred_bboxes_decoded, anchor_points,
            gt_labels, gt_bboxes, mask_gt,
        )

    # Normalization: sum of assigned soft scores (matches super-gradients)
    assigned_scores_sum = assigned_scores.sum().clamp(min=1)

    # Classification loss (VFL)
    cls_loss = self.vfl(cls_logits, assigned_scores, (assigned_scores > 0).float()) / assigned_scores_sum

    # Box regression loss (GIoU) — weighted by per-anchor assigned scores
    if fg_mask.any():
        pos_pred_bboxes = pred_bboxes_decoded[fg_mask]
        pos_target_bboxes = assigned_bboxes[fg_mask]
        bbox_weight = assigned_scores[fg_mask].sum(-1)  # [num_pos]
        iou_loss_per_anchor = self.giou_loss(pos_pred_bboxes, pos_target_bboxes)  # [num_pos]
        iou_loss = (iou_loss_per_anchor * bbox_weight).sum() / assigned_scores_sum
    else:
        iou_loss = torch.tensor(0.0, device=device)

    # DFL loss — weighted by per-anchor assigned scores
    if fg_mask.any():
        pos_reg_distri = reg_distri[fg_mask]
        pos_anchor_points = anchor_points.unsqueeze(0).expand(batch_size, -1, -1)[fg_mask]
        pos_stride = stride_tensor.unsqueeze(0).expand(batch_size, -1, -1)[fg_mask]
        pos_target_bboxes = assigned_bboxes[fg_mask] / pos_stride
        pos_anchor_points_scaled = pos_anchor_points / pos_stride
        target_dist = self._bbox2dist(pos_anchor_points_scaled, pos_target_bboxes)
        dfl_loss_per_anchor = self.dfl_loss(pos_reg_distri, target_dist)  # [num_pos]
        dfl_loss = (dfl_loss_per_anchor * bbox_weight).sum() / assigned_scores_sum
    else:
        dfl_loss = torch.tensor(0.0, device=device)

    total_loss = self.cls_weight * cls_loss + self.iou_weight * iou_loss + self.dfl_weight * dfl_loss

    loss_dict = {
        "cls_loss": cls_loss.item(),
        "iou_loss": iou_loss.item(),
        "dfl_loss": dfl_loss.item(),
        "total_loss": total_loss.item(),
    }

    return total_loss, loss_dict

modern_yolonas.training.ema.ModelEMA

Maintains an exponential moving average of model parameters.

Parameters:

Name Type Description Default
model Module

The model to track.

required
decay float

EMA decay factor.

0.9997
warmup_steps int

Steps before reaching full decay.

2000
Source code in src/modern_yolonas/training/ema.py
class ModelEMA:
    """Maintains an exponential moving average of model parameters.

    Args:
        model: The model to track.
        decay: EMA decay factor.
        warmup_steps: Steps before reaching full decay.
    """

    def __init__(self, model: nn.Module, decay: float = 0.9997, warmup_steps: int = 2000):
        self.ema = deepcopy(model).eval()
        for p in self.ema.parameters():
            p.requires_grad_(False)
        self.decay = decay
        self.warmup_steps = warmup_steps
        self.updates = 0

    def update(self, model: nn.Module):
        self.updates += 1
        d = self.decay * (1 - math.exp(-self.updates / self.warmup_steps))
        model_sd = model.state_dict()
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:
                v.mul_(d).add_(model_sd[k], alpha=1.0 - d)

    @torch.no_grad()
    def state_dict(self) -> dict:
        return {
            "ema_state_dict": self.ema.state_dict(),
            "decay": self.decay,
            "updates": self.updates,
        }

    def load_state_dict(self, state: dict):
        self.ema.load_state_dict(state["ema_state_dict"])
        self.decay = state["decay"]
        self.updates = state["updates"]