Source code for training.hpo.visualize
"""Non-fatal Optuna visualization helpers."""
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
[docs]
def save_study_plots(study, output_dir: str | Path) -> list[str]:
"""Generate and save standard Optuna plots. Non-fatal on errors.
Returns list of saved file paths (may be empty if matplotlib or
Optuna visualization is not available).
"""
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
saved: list[str] = []
try:
from optuna.visualization.matplotlib import (
plot_optimization_history,
plot_parallel_coordinate,
plot_param_importances,
plot_slice,
)
except ImportError:
logger.warning(
"optuna.visualization.matplotlib not available. Install matplotlib for HPO plots."
)
return saved
plots = [
("optimization_history", plot_optimization_history),
("param_importances", plot_param_importances),
("parallel_coordinate", plot_parallel_coordinate),
("slice_plot", plot_slice),
]
for name, plot_fn in plots:
try:
ax = plot_fn(study)
fig = ax.figure if hasattr(ax, "figure") else ax
path = out / f"{name}.png"
fig.savefig(str(path), dpi=150, bbox_inches="tight")
saved.append(str(path))
import matplotlib.pyplot as plt
plt.close(fig)
except Exception as exc:
logger.warning("Could not generate %s: %s", name, exc)
# Export trials CSV
try:
df = study.trials_dataframe()
csv_path = out / "trials.csv"
df.to_csv(str(csv_path), index=False)
saved.append(str(csv_path))
except Exception as exc:
logger.warning("Could not export trials CSV: %s", exc)
return saved