"""Optuna objective function for hyperparameter optimization."""
from collections.abc import Callable
from typing import Any
import optuna
import torch
from torch.utils.data import DataLoader
from training.hpo.search_space import apply_overrides, sample_from_search_space
from training.losses import get_loss_fn
from training.models import get_build_fn_and_adapter
from training.runner import (
build_experiment,
compute_val_loss,
set_seed,
train_one_epoch,
)
[docs]
def make_objective(
base_cfg: dict,
search_space: dict[str, dict],
hpo_cfg: dict,
prepared: dict[str, Any],
train_inner_idx: list[int],
val_idx: list[int],
) -> Callable[[optuna.Trial], float]:
"""Create an Optuna objective function.
The returned closure rebuilds only the model, optimizer, and loss per
trial. The dataset, adapter, and dataset_info are cached from
*prepared* (the output of ``runner.prepare_training``).
Parameters
----------
base_cfg : dict
Full training config **without** the ``hpo`` section.
search_space : dict
YAML search-space definition (dot-path -> spec).
hpo_cfg : dict
The ``hpo`` section of the config.
prepared : dict
Output of ``runner.prepare_training(base_cfg)``.
train_inner_idx : list[int]
Case indices (into ``dataset.sim_names``) for inner training.
val_idx : list[int]
Case indices for validation.
"""
dataset = prepared["dataset"]
adapter = prepared["adapter"]
dataset_info = prepared["dataset_info"]
device = prepared["device"]
# Build train/val subsets once (shared across trials)
if hasattr(dataset, "subset_by_case_indices"):
train_ds = dataset.subset_by_case_indices(train_inner_idx)
val_ds = dataset.subset_by_case_indices(val_idx)
else:
from torch.utils.data import Subset
train_ds = Subset(dataset, train_inner_idx)
val_ds = Subset(dataset, val_idx)
def objective(trial: optuna.Trial) -> float:
# 1. Sample hyperparameters and apply to config
overrides = sample_from_search_space(trial, search_space)
trial_cfg = apply_overrides(base_cfg, overrides)
model_cfg = dict(trial_cfg.get("model") or {})
training_cfg = dict(trial_cfg.get("training") or {})
seed = int(training_cfg.get("seed", 42))
set_seed(seed)
# 2. Build model (cheap per trial)
build_fn, _ = get_build_fn_and_adapter(model_cfg)
model_params = dict(model_cfg.get("params") or {})
model = build_fn(model_params, dataset_info).to(device)
# 3. Build optimizer and loss
lr = float(training_cfg.get("lr", 1e-3))
weight_decay = float(training_cfg.get("weight_decay", 0.0))
if weight_decay > 0:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = get_loss_fn(str(training_cfg.get("loss", "mse")))
# 4. Build experiment (respects training.experiment entrypoint)
experiment_kwargs: dict[str, Any] = {}
experiment = build_experiment(
experiment_entrypoint=training_cfg.get("experiment"),
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
adapter=adapter,
device=device,
**experiment_kwargs,
)
experiment.prepare_for_training(train_ds, val_ds, device)
# 5. Build DataLoaders
epochs = int(training_cfg.get("epochs", 20))
batch_size = int(training_cfg.get("batch_size", 4))
num_workers = int(training_cfg.get("num_workers", 0))
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=device.type == "cuda",
collate_fn=adapter.collate_fn(),
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=device.type == "cuda",
collate_fn=adapter.collate_fn(),
)
# 6. LR scheduler (match retrain behavior)
scheduler_name = str(training_cfg.get("lr_scheduler") or "")
scheduler = None
if scheduler_name == "cosine":
warmup_epochs = int(training_cfg.get("lr_warmup_epochs", 0))
if warmup_epochs > 0 and warmup_epochs < epochs:
warmup_sched = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1e-3, end_factor=1.0, total_iters=warmup_epochs
)
cosine_sched = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs - warmup_epochs, eta_min=1e-7
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_sched, cosine_sched], milestones=[warmup_epochs]
)
else:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=1e-7
)
# 7. Training loop with pruning
val_loss = float("nan")
for epoch in range(1, epochs + 1):
train_one_epoch(experiment, train_loader)
experiment.on_epoch_end_extra_step()
if scheduler is not None:
scheduler.step()
val_loss = compute_val_loss(experiment, val_loader)
trial.report(val_loss, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
trial.set_user_attr("val_loss", float(val_loss))
return float(val_loss)
return objective