Source code for training.models

"""Model registry and entrypoint resolution for training workflows."""

import importlib
import inspect
from collections.abc import Callable
from dataclasses import dataclass


[docs] @dataclass class ModelEntry: build_fn: Callable adapter: str
MODEL_REGISTRY: dict[str, ModelEntry] = {} _VALID_ADAPTERS = ("grid", "graph", "pointwise", "profile")
[docs] def register_model(name: str, build_fn: Callable, adapter: str) -> None: if adapter not in _VALID_ADAPTERS: raise ValueError(f"adapter must be one of {_VALID_ADAPTERS}, got '{adapter}'.") if name in MODEL_REGISTRY: raise ValueError(f"Model '{name}' is already registered.") MODEL_REGISTRY[name] = ModelEntry(build_fn=build_fn, adapter=adapter)
[docs] def resolve_entrypoint(entrypoint: str) -> Callable: if ":" not in entrypoint: raise ValueError( f"Invalid entrypoint '{entrypoint}'. Expected format 'module.path:callable'." ) module_path, callable_name = entrypoint.rsplit(":", 1) module = importlib.import_module(module_path) if not hasattr(module, callable_name): raise AttributeError( f"Entrypoint callable '{callable_name}' not found in module '{module_path}'." ) build_fn = getattr(module, callable_name) if not callable(build_fn): raise TypeError(f"Entrypoint '{entrypoint}' is not callable.") return build_fn
def _validate_build_signature(build_fn: Callable, entrypoint: str) -> None: """Ensure a build function can accept `(model_cfg, dataset_info)` arguments.""" signature = inspect.signature(build_fn) parameters = list(signature.parameters.values()) positional_count = sum( 1 for param in parameters if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) ) has_varargs = any(param.kind == inspect.Parameter.VAR_POSITIONAL for param in parameters) if positional_count < 2 and not has_varargs: raise TypeError( f"Model build function '{entrypoint}' must accept (model_cfg, dataset_info), " f"but has signature: {signature}" )
[docs] def get_build_fn_and_adapter(model_cfg: dict) -> tuple[Callable, str]: """Resolve model build function and adapter. Built-in models derive adapter from registry. Custom entrypoints must specify `model.adapter` explicitly. """ entrypoint = model_cfg.get("entrypoint") if entrypoint: build_fn = resolve_entrypoint(str(entrypoint)) _validate_build_signature(build_fn, str(entrypoint)) adapter_name = model_cfg.get("adapter") if not adapter_name: raise ValueError( "model.adapter is required when using model.entrypoint. " f"Set model.adapter to one of {_VALID_ADAPTERS}." ) if adapter_name not in _VALID_ADAPTERS: raise ValueError( f"model.adapter must be one of {_VALID_ADAPTERS}, got '{adapter_name}'." ) return build_fn, adapter_name name = model_cfg.get("name") if not name: raise ValueError("model.name is required when model.entrypoint is not set.") if name not in MODEL_REGISTRY: raise ValueError(f"Unknown model '{name}'. Registered models: {sorted(MODEL_REGISTRY)}") entry = MODEL_REGISTRY[name] user_adapter = model_cfg.get("adapter") if user_adapter and user_adapter != entry.adapter: raise ValueError( f"model.adapter='{user_adapter}' conflicts with registered adapter " f"'{entry.adapter}' for model '{name}'. Remove model.adapter or fix the mismatch." ) return entry.build_fn, entry.adapter
[docs] def model_entrypoint_string(model_cfg: dict, build_fn: Callable) -> str: entrypoint = model_cfg.get("entrypoint") if entrypoint: return str(entrypoint) return f"{build_fn.__module__}:{build_fn.__name__}"
def _load_builtins() -> None: builtin_modules = ( "training.models.afno", "training.models.fno", "training.models.meshgraphnet", "training.models.mlp", "training.models.pix2pix", "training.models.conv1d_profile", ) for module_name in builtin_modules: importlib.import_module(module_name) _load_builtins()