Source code for training.losses

"""Loss registry for supervised one-step training."""

from __future__ import annotations

from collections.abc import Callable

import torch

LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]


[docs] def mse_loss( pred: torch.Tensor, target: torch.Tensor, weight: torch.Tensor | None = None, ) -> torch.Tensor: se = (pred - target) ** 2 if weight is not None: se = se * weight return se.mean()
[docs] def l1_loss( pred: torch.Tensor, target: torch.Tensor, weight: torch.Tensor | None = None, ) -> torch.Tensor: ae = torch.abs(pred - target) if weight is not None: ae = ae * weight return ae.mean()
[docs] def relative_l2_loss( pred: torch.Tensor, target: torch.Tensor, weight: torch.Tensor | None = None, ) -> torch.Tensor: diff = pred - target if weight is not None: diff = diff * weight.sqrt() numerator = torch.linalg.norm(diff) denominator = torch.linalg.norm(target).clamp_min(1e-12) return numerator / denominator
LOSS_REGISTRY: dict[str, LossFn] = { "mse": mse_loss, "l1": l1_loss, "relative_l2": relative_l2_loss, }
[docs] def get_loss_fn(name: str) -> LossFn: if name not in LOSS_REGISTRY: available = ", ".join(sorted(LOSS_REGISTRY)) raise ValueError(f"Unknown loss '{name}'. Available losses: {available}") return LOSS_REGISTRY[name]