Source code for training.hpo.study

"""Study creation, orchestration, and artifact saving."""

import json
import logging
import random
from pathlib import Path
from typing import Any

import optuna

from training.datasets import split_indices
from training.hpo.objective import make_objective
from training.hpo.search_space import validate_search_space
from training.hpo.visualize import save_study_plots
from training.runner import normalize_split_cfg, prepare_training, train

logger = logging.getLogger(__name__)


[docs] def create_study(hpo_cfg: dict) -> optuna.Study: """Create or resume an Optuna study from config.""" sampler_cfg = hpo_cfg.get("sampler", {}) sampler_cls = getattr(optuna.samplers, sampler_cfg.get("name", "TPESampler")) sampler = sampler_cls(**sampler_cfg.get("params", {})) pruner_cfg = hpo_cfg.get("pruner", {}) pruner_cls = getattr(optuna.pruners, pruner_cfg.get("name", "MedianPruner")) pruner = pruner_cls(**pruner_cfg.get("params", {})) storage = hpo_cfg.get("storage") if storage: storage_dir = Path(storage.replace("sqlite:///", "")).parent storage_dir.mkdir(parents=True, exist_ok=True) study_name = hpo_cfg.get("study_name", "hpo_study") load_if_exists = hpo_cfg.get("load_if_exists", True) try: return optuna.create_study( study_name=study_name, direction=hpo_cfg.get("direction", "minimize"), sampler=sampler, pruner=pruner, storage=storage, load_if_exists=load_if_exists, ) except optuna.exceptions.DuplicatedStudyError: db_path = storage.replace("sqlite:///", "") if storage else "<storage>" print( f"\n[HPO Error] Study '{study_name}' already exists in the database.\n" f" This happens when hpo.load_if_exists=false but the study name is already\n" f" registered in: {db_path}\n" f"\n" f" Fix options:\n" f" 1) Resume the existing study (recommended):\n" f" hpo.load_if_exists=true\n" f"\n" f" 2) Start fresh under a new name (old study preserved):\n" f" hpo.study_name={study_name}_v2\n" f"\n" f" 3) Delete the database and start over:\n" f" rm {db_path}\n" ) raise
[docs] def run_hpo(cfg_dict: dict) -> dict[str, Any]: """Run hyperparameter optimization. Parameters ---------- cfg_dict : dict Full Hydra config (resolved) with ``hpo`` section + base training config. Returns ------- dict Summary with best trial info, study stats, saved artifacts. """ hpo_cfg = dict(cfg_dict.get("hpo") or {}) base_cfg = {k: v for k, v in cfg_dict.items() if k != "hpo"} search_space = dict(hpo_cfg.get("search_space", {})) # --- Validate search space before doing any work --- validate_search_space(search_space, base_cfg) # --- Build dataset + adapter once (cached for all trials) --- prepared = prepare_training(base_cfg) dataset = prepared["dataset"] seed = prepared["seed"] # --- Outer train/test split (same as normal training) --- data_cfg = prepared["data_cfg"] split_cfg = normalize_split_cfg(dict(data_cfg.get("split") or {}), default_seed=seed) num_cases = len(dataset.sim_names) if hasattr(dataset, "sim_names") else len(dataset) train_idx, test_idx, train_sims, test_sims = split_indices( num_cases=num_cases, split_cfg=split_cfg, sim_names=dataset.sim_names, ) # --- Inner train/val split from the training pool --- val_cfg = hpo_cfg.get("validation", {}) val_ratio = float(val_cfg.get("split_ratio", 0.2)) val_seed = int(val_cfg.get("seed", seed)) rng = random.Random(val_seed) shuffled_train = list(train_idx) rng.shuffle(shuffled_train) n_val = max(1, int(len(shuffled_train) * val_ratio)) val_idx = sorted(shuffled_train[:n_val]) train_inner_idx = sorted(shuffled_train[n_val:]) # --- Guard: non-empty splits --- if not train_inner_idx: raise ValueError( f"Inner training split is empty after reserving {n_val} validation case(s) " f"from {len(train_idx)} training case(s). Reduce hpo.validation.split_ratio " f"(currently {val_ratio}) or provide more data." ) if not val_idx: raise ValueError( f"Validation split is empty. Only {len(train_idx)} training case(s) " f"available with val_ratio={val_ratio}. Provide more data." ) # Fit tabular normalization only on inner-train cases (prevents val/test leakage). if prepared["adapter_name"] == "pointwise" and bool(data_cfg.get("normalize", False)): normalized_data_cfg = dict(data_cfg) normalized_data_cfg["norm_from_case_indices"] = train_inner_idx normalized_dataset = prepared["adapter"].build_dataset(normalized_data_cfg) prepared = dict(prepared) prepared["dataset"] = normalized_dataset prepared["dataset_info"] = prepared["adapter"].dataset_info(normalized_dataset) dataset = normalized_dataset val_sims = [dataset.sim_names[i] for i in val_idx] train_inner_sims = [dataset.sim_names[i] for i in train_inner_idx] logger.info( "HPO splits: %d train_inner, %d val, %d test (held out)", len(train_inner_idx), len(val_idx), len(test_idx), ) # --- Create study and objective --- study = create_study(hpo_cfg) objective = make_objective( base_cfg=base_cfg, search_space=search_space, hpo_cfg=hpo_cfg, prepared=prepared, train_inner_idx=train_inner_idx, val_idx=val_idx, ) # --- Optimize --- n_trials = int(hpo_cfg.get("n_trials", 50)) timeout = hpo_cfg.get("timeout") show_progress_bar = bool(hpo_cfg.get("show_progress_bar", True)) study.optimize( objective, n_trials=n_trials, timeout=timeout, show_progress_bar=show_progress_bar, ) # --- Results --- completed = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE] pruned = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED] if not completed: logger.warning("No trials completed successfully.") return {"n_trials": len(study.trials), "n_complete": 0, "n_pruned": len(pruned)} best = study.best_trial # --- Save artifacts --- output_dir = Path(hpo_cfg.get("output_dir", "hpo_results")) output_dir.mkdir(parents=True, exist_ok=True) # best_params.json params_path = output_dir / "best_params.json" params_path.write_text(json.dumps(best.params, indent=2), encoding="utf-8") # best_config.yaml -- train-ready config (no hpo section) from training.hpo.search_space import apply_overrides best_config = apply_overrides(base_cfg, best.params) best_config_path = output_dir / "best_config.yaml" try: from omegaconf import OmegaConf OmegaConf.save(OmegaConf.create(best_config), str(best_config_path)) except ImportError: import yaml best_config_path.write_text( yaml.dump(best_config, default_flow_style=False), encoding="utf-8" ) # split_metadata.json split_meta = { "outer_train_sims": train_sims, "outer_test_sims": test_sims, "inner_train_sims": train_inner_sims, "inner_val_sims": val_sims, } split_path = output_dir / "split_metadata.json" split_path.write_text(json.dumps(split_meta, indent=2), encoding="utf-8") # Visualization (non-fatal) plot_files = save_study_plots(study, output_dir) results = { "study_name": study.study_name, "best_trial_number": best.number, "best_value": best.value, "best_params": best.params, "n_trials": len(study.trials), "n_complete": len(completed), "n_pruned": len(pruned), "output_dir": str(output_dir), "plots": plot_files, } # --- Optional: retrain best on the original outer train split --- if hpo_cfg.get("retrain_best", False): logger.info("Retraining best config on the original outer train split...") retrain_result = train(best_config) results["retrain"] = retrain_result return results