Source code for cases.alpha_d.metrics

"""Alpha-D-specific extended evaluation metrics.

The generic runner calls :meth:`Experiment.compute_extended_metrics`, which
``AlphaDExperiment`` (in ``cases/alpha_d/experiment.py``) overrides to invoke
the helpers in this module. Base experiments inherit a no-op and skip this
module entirely, so the runner stays alpha-D-agnostic.

Three functions live here:

- :func:`compute_pointwise_extended_metrics` — R², physical-space, per-region,
  per-case error breakdown for the TabularPairDataset adapter.
- :func:`compute_delta_p_metrics` — trapezoidal integration of the predicted
  α_D profile per case and comparison vs the ground-truth ``delta_p_case``.
- :func:`print_extended_metrics` — human-readable summary used by ``evaluate()``.
"""

from __future__ import annotations

import math
from typing import Any

import numpy as np
import torch

from cases.alpha_d.physics.targets import (
    alpha_d_values_to_bulk,
    field_values_to_physical,
    is_alpha_d_target,
)


[docs] def compute_delta_p_metrics( model: Any, eval_dataset, device: torch.device, *, alpha_d_target_name: str = "log_alpha_D", local_velocity_normalization: bool = False, ) -> dict[str, Any]: """Per-case Δp prediction error statistics. Integrates the predicted α_D profile via the trapezoidal rule to obtain ``delta_p_pred``, then compares with ``delta_p_case`` stored in the zarr metadata. Per-case geometry constants (``D_big``, ``outer_height_m``, ``buffer_diams``, ``rho``, ``V_bulk``) are read from each case's metadata and fall back to the historical AlphaD-ETL defaults when missing. """ if not is_alpha_d_target(alpha_d_target_name): return {} case_meta = getattr(eval_dataset, "_case_meta", None) case_names = getattr(eval_dataset, "_case_ids_unique", None) row_case_idx = getattr(eval_dataset, "_row_case_idx", None) raw_z_hat = getattr(eval_dataset, "_raw_z_hat", None) raw_d_over_D = getattr(eval_dataset, "_raw_d_local_over_D", None) if case_meta is None or case_names is None or row_case_idx is None: return {} if raw_z_hat is None or raw_d_over_D is None: return {} per_case: list[dict[str, Any]] = [] model.eval() # Profile-dataset wrappers expose _case_slices (per-case row indices # into the inner TabularPair, sorted by z_hat). When present, feed the # Conv1D model its expected [1, C, S] layout instead of the pointwise # [S, C] forward used by the MLP path. case_slices = getattr(eval_dataset, "_case_slices", None) profile_mode = case_slices is not None with torch.no_grad(): for ci, case_name in enumerate(case_names): cm = case_meta[ci] delta_p_gt = float(cm.get("delta_p_case", 0.0)) if delta_p_gt <= 0: continue D_big = float(cm.get("D_big", 0.2)) outer_height_m = float(cm.get("outer_height_m", 1.0)) buffer_diams = float(cm.get("buffer_diams", 1.0)) rho = float(cm.get("rho", 1.0)) V_bulk = float(cm.get("V_bulk", 1.0)) if profile_mode: rows = torch.as_tensor(case_slices[ci], dtype=torch.long) x_case = eval_dataset._x[rows].T.unsqueeze(0).to(device) # [1, C, S] z_hat = raw_z_hat[rows].to(device) d_over_D = raw_d_over_D[rows].to(device) pred_values = model(x_case)[0, 0] # [n_stations] pred_values = eval_dataset.add_baseline_to_encoded( pred_values, row_mask=rows, field_idx=0, ) else: mask = row_case_idx == ci x_full = eval_dataset._x[mask].to(device) z_hat = raw_z_hat[mask].to(device) d_over_D = raw_d_over_D[mask].to(device) pred_values = model(x_full).squeeze(-1) # [n_stations] pred_values = eval_dataset.add_baseline_to_encoded( pred_values, row_mask=mask, field_idx=0, ) alpha_D_bulk = alpha_d_values_to_bulk( pred_values, target_name=alpha_d_target_name, d_over_D=d_over_D, local_velocity_normalization=local_velocity_normalization, ) # dp/dz = alpha_D_bulk * rho * V_bulk^2 / (2 * D_h) D_h = d_over_D * D_big dp_dz = alpha_D_bulk * rho * V_bulk**2 / (2.0 * D_h) # Trapezoidal integration over physical z L_roi = cm["Lr"] * outer_height_m + 2.0 * buffer_diams * D_big z_physical = z_hat * L_roi delta_p_pred = float(torch.trapezoid(dp_dz, z_physical).cpu()) rel_err = abs(delta_p_pred - delta_p_gt) / abs(delta_p_gt) log_err = abs(math.log(max(delta_p_pred, 1e-8)) - math.log(max(delta_p_gt, 1e-8))) per_case.append( { "case": case_name, "delta_p_gt": delta_p_gt, "delta_p_pred": delta_p_pred, "relative_error": rel_err, "log_abs_error": log_err, "Dr": float(cm.get("Dr", 0.0)), "Re": float(cm.get("Re", 0.0)), } ) if not per_case: return {} rel_errors = [c["relative_error"] for c in per_case] log_errors = [c["log_abs_error"] for c in per_case] per_case.sort(key=lambda x: x["relative_error"], reverse=True) return { "n_cases": len(per_case), "relative_error_median": float(np.median(rel_errors)), "relative_error_mean": float(np.mean(rel_errors)), "relative_error_p90": float(np.quantile(rel_errors, 0.9)), "relative_error_max": float(max(rel_errors)), "log_abs_error_mean": float(np.mean(log_errors)), "log_abs_error_median": float(np.median(log_errors)), "worst_cases": per_case[:10], "best_cases": list(reversed(per_case[-10:])), "per_case": per_case, }
[docs] def compute_pointwise_extended_metrics( preds: torch.Tensor, targets: torch.Tensor, dataset, output_fields: list[str], ) -> dict[str, Any]: """R², physical-space, per-region, and per-case metrics. Only meaningful for the pointwise adapter with TabularPairDataset. """ metrics: dict[str, Any] = {} raw_d_over_D = getattr(dataset, "_raw_d_local_over_D", None) local_vel_norm = bool(getattr(dataset, "local_velocity_normalization", False)) def _has_physical_inverse(name: str) -> bool: return is_alpha_d_target(name) or name.startswith("log_") or name.startswith("log10_") per_field_extended = [] for i, name in enumerate(output_fields): p, t = preds[:, i], targets[:, i] ss_res = float(((p - t) ** 2).sum()) ss_tot = float(((t - t.mean()) ** 2).sum()) r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan") mae = float((p - t).abs().mean()) entry: dict[str, Any] = {"name": name, "r2": r2, "mae": mae} if _has_physical_inverse(name): d_over_D = raw_d_over_D if is_alpha_d_target(name) else None p_full = dataset.add_baseline_to_encoded(p, field_idx=i) t_full = dataset.add_baseline_to_encoded(t, field_idx=i) p_phys = field_values_to_physical( p_full, field_name=name, d_over_D=d_over_D, local_velocity_normalization=local_vel_norm, ) t_phys = field_values_to_physical( t_full, field_name=name, d_over_D=d_over_D, local_velocity_normalization=local_vel_norm, ) rel_err = ((p_phys - t_phys) / t_phys.abs().clamp(min=1e-8)).abs() entry["physical_median_relative_error"] = float(rel_err.median()) entry["physical_p90_relative_error"] = float(rel_err.quantile(0.9)) log_abs_err = ( p_phys.abs().clamp(min=1e-8).log() - t_phys.abs().clamp(min=1e-8).log() ).abs() entry["physical_log_abs_error_median"] = float(log_abs_err.median()) entry["physical_log_abs_error_mean"] = float(log_abs_err.mean()) entry["physical_log_abs_error_p90"] = float(log_abs_err.quantile(0.9)) per_field_extended.append(entry) metrics["per_field"] = per_field_extended input_columns = getattr(dataset, "input_columns", []) region_col_indices: dict[str, int] = {} for col_name in ("is_upstream", "is_throat", "is_downstream"): if col_name in input_columns: region_col_indices[col_name] = input_columns.index(col_name) if region_col_indices: raw_x = dataset._x.clone().cpu() norm_stats = getattr(dataset, "norm_stats", None) if getattr(dataset, "normalize", False) and norm_stats is not None: raw_x = raw_x * norm_stats["x_std"].cpu() + norm_stats["x_mean"].cpu() per_region: dict[str, Any] = {} for region_name, col_idx in region_col_indices.items(): mask = raw_x[:, col_idx] > 0.5 n_region = int(mask.sum()) if n_region == 0: continue region_entry: dict[str, Any] = {"n_samples": n_region} for i, field_name in enumerate(output_fields): p, t = preds[mask, i], targets[mask, i] ss_res = float(((p - t) ** 2).sum()) ss_tot = float(((t - t.mean()) ** 2).sum()) r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan") rmse = float(((p - t) ** 2).mean().sqrt()) field_metrics: dict[str, Any] = {"r2": r2, "rmse": rmse} if _has_physical_inverse(field_name): d_over_D = ( raw_d_over_D[mask] if (raw_d_over_D is not None and is_alpha_d_target(field_name)) else None ) p_full = dataset.add_baseline_to_encoded(p, row_mask=mask, field_idx=i) t_full = dataset.add_baseline_to_encoded(t, row_mask=mask, field_idx=i) p_phys = field_values_to_physical( p_full, field_name=field_name, d_over_D=d_over_D, local_velocity_normalization=local_vel_norm, ) t_phys = field_values_to_physical( t_full, field_name=field_name, d_over_D=d_over_D, local_velocity_normalization=local_vel_norm, ) rel_err = ((p_phys - t_phys) / t_phys.abs().clamp(min=1e-8)).abs() field_metrics["median_relative_error"] = float(rel_err.median()) region_entry[field_name] = field_metrics per_region[region_name] = region_entry metrics["per_region"] = per_region case_idx_arr = getattr(dataset, "_row_case_idx", None) case_names = getattr(dataset, "_case_ids_unique", None) if case_idx_arr is not None and case_names is not None: per_case: list[dict[str, Any]] = [] case_idx_t = torch.from_numpy(case_idx_arr) for ci, case_name in enumerate(case_names): mask = case_idx_t == ci if mask.sum() == 0: continue for i, field_name in enumerate(output_fields): p, t = preds[mask, i], targets[mask, i] case_rmse = float(((p - t) ** 2).mean().sqrt()) entry = {"case": case_name, "field": field_name, "rmse": case_rmse} if _has_physical_inverse(field_name): d_over_D = ( raw_d_over_D[mask] if (raw_d_over_D is not None and is_alpha_d_target(field_name)) else None ) p_full = dataset.add_baseline_to_encoded(p, row_mask=mask, field_idx=i) t_full = dataset.add_baseline_to_encoded(t, row_mask=mask, field_idx=i) p_phys = field_values_to_physical( p_full, field_name=field_name, d_over_D=d_over_D, local_velocity_normalization=local_vel_norm, ) t_phys = field_values_to_physical( t_full, field_name=field_name, d_over_D=d_over_D, local_velocity_normalization=local_vel_norm, ) rel_err = ((p_phys - t_phys) / t_phys.abs().clamp(min=1e-8)).abs() entry["median_relative_error"] = float(rel_err.median()) per_case.append(entry) per_case.sort(key=lambda x: x["rmse"], reverse=True) metrics["worst_cases"] = per_case[:10] metrics["best_cases"] = list(reversed(per_case[-10:])) return metrics