Source code for training.hpo.search_space

"""Parse YAML search space definitions into Optuna trial suggestions."""

import copy
from typing import Any

import optuna

# Dot-path prefixes that must never be overridden during HPO because they
# change the dataset or model identity rather than tuning hyperparameters.
UNSAFE_PREFIXES = (
    "data.zarr_dir",
    "data.input_columns",
    "data.output_columns",
    "data.split",
    "model.name",
    "model.entrypoint",
    "model.adapter",
)


[docs] def validate_search_space(search_space: dict[str, dict], base_cfg: dict) -> None: """Check that all search-space keys are safe and exist in base_cfg. Raises ``ValueError`` for unsafe prefixes and ``KeyError`` for dot-paths that do not exist in *base_cfg* (catches typos). """ for dot_path in search_space: for prefix in UNSAFE_PREFIXES: if dot_path == prefix or dot_path.startswith(prefix + "."): raise ValueError( f"Search-space key '{dot_path}' is not allowed. " f"Overriding '{prefix}' changes the dataset or model identity. " "Only training.* and model.params.* paths are supported." ) keys = dot_path.split(".") node = base_cfg for i, key in enumerate(keys): if not isinstance(node, dict) or key not in node: partial = ".".join(keys[: i + 1]) raise KeyError( f"Search-space path '{dot_path}' is invalid: " f"'{partial}' does not exist in the base config. " "Check for typos." ) node = node[key]
[docs] def sample_from_search_space( trial: optuna.Trial, search_space: dict[str, dict], ) -> dict[str, Any]: """Sample hyperparameters from a YAML search-space definition. Each entry in *search_space* maps a dot-path config key to a spec dict:: training.lr: type: float low: 1e-5 high: 1e-2 log: true Supported types: ``float``, ``int``, ``categorical``. Returns a dict mapping dot-path keys to sampled values. """ sampled: dict[str, Any] = {} for param_path, spec in search_space.items(): param_type = spec["type"] if param_type == "float": sampled[param_path] = trial.suggest_float( param_path, low=float(spec["low"]), high=float(spec["high"]), log=bool(spec.get("log", False)), ) elif param_type == "int": sampled[param_path] = trial.suggest_int( param_path, low=int(spec["low"]), high=int(spec["high"]), log=bool(spec.get("log", False)), ) elif param_type == "categorical": sampled[param_path] = trial.suggest_categorical( param_path, spec["choices"], ) else: raise ValueError( f"Unknown search-space type '{param_type}' for '{param_path}'. " "Must be 'float', 'int', or 'categorical'." ) return sampled
[docs] def apply_overrides(base_cfg: dict, overrides: dict[str, Any]) -> dict: """Deep-copy *base_cfg* and set values at dot-paths. Raises ``KeyError`` if any intermediate or leaf key does not already exist in the base config. This prevents silent creation of new keys from typos. """ cfg = copy.deepcopy(base_cfg) for dot_path, value in overrides.items(): keys = dot_path.split(".") node = cfg for i, key in enumerate(keys[:-1]): if not isinstance(node, dict) or key not in node: partial = ".".join(keys[: i + 1]) raise KeyError( f"Cannot apply override '{dot_path}': '{partial}' does not exist in config." ) node = node[key] leaf = keys[-1] if leaf not in node: raise KeyError( f"Cannot apply override '{dot_path}': " f"key '{leaf}' does not exist in config at '{'.'.join(keys[:-1])}'." ) node[leaf] = value return cfg