Skip to content

Weights API

modern_yolonas.weights.load_pretrained(model, variant, strict=True, repo_id=HF_REPO_ID, revision=None)

Download safetensors checkpoint from HF Hub and load into model.

Parameters:

Name Type Description Default
model Module

A YoloNAS instance (or any nn.Module with matching keys).

required
variant str

One of "yolo_nas_s", "yolo_nas_m", "yolo_nas_l".

required
strict bool

Require exact key matching.

True
repo_id str

HF Hub repo (default :data:HF_REPO_ID, overridable via the YOLONAS_HF_REPO env var or this arg).

HF_REPO_ID
revision str | None

Git ref (branch / tag / commit) of the repo.

None

Returns:

Type Description
Module

The model with loaded weights.

Source code in src/modern_yolonas/weights.py
def load_pretrained(
    model: nn.Module,
    variant: str,
    strict: bool = True,
    repo_id: str = HF_REPO_ID,
    revision: str | None = None,
) -> nn.Module:
    """Download safetensors checkpoint from HF Hub and load into model.

    Args:
        model: A ``YoloNAS`` instance (or any nn.Module with matching keys).
        variant: One of ``"yolo_nas_s"``, ``"yolo_nas_m"``, ``"yolo_nas_l"``.
        strict: Require exact key matching.
        repo_id: HF Hub repo (default :data:`HF_REPO_ID`, overridable via the
            ``YOLONAS_HF_REPO`` env var or this arg).
        revision: Git ref (branch / tag / commit) of the repo.

    Returns:
        The model with loaded weights.
    """
    path = _download(variant, repo_id=repo_id, revision=revision)
    raw_sd = load_file(path)
    sd = remap_state_dict(raw_sd)

    model_keys = set(model.state_dict().keys())
    sd = {k: v for k, v in sd.items() if k in model_keys}

    model.load_state_dict(sd, strict=strict)
    return model

modern_yolonas.weights.remap_state_dict(raw_sd)

Remap super-gradients state_dict keys to our module hierarchy.

Super-gradients wraps the model in CustomizableDetector with::

backbone  → backbone.stem, backbone.stage1 … backbone.stage4, backbone.context_module
neck      → neck.neck1 … neck.neck4
heads     → heads.head1 … heads.head3

Our YoloNAS uses the same attribute names, so the only work is stripping DDP/EMA prefixes.

Source code in src/modern_yolonas/weights.py
def remap_state_dict(raw_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    """Remap super-gradients state_dict keys to our module hierarchy.

    Super-gradients wraps the model in ``CustomizableDetector`` with::

        backbone  → backbone.stem, backbone.stage1 … backbone.stage4, backbone.context_module
        neck      → neck.neck1 … neck.neck4
        heads     → heads.head1 … heads.head3

    Our ``YoloNAS`` uses the same attribute names, so the only work is
    stripping DDP/EMA prefixes.
    """
    return {_strip_prefix(k): v for k, v in raw_sd.items()}