Source code for training.split_io
"""Helpers for exporting reusable train/test split files."""
import json
from pathlib import Path
from typing import Any
def _clean_sim_names(sim_names: list[str] | tuple[str, ...]) -> list[str]:
cleaned = [str(name).strip().removesuffix(".zarr") for name in sim_names if str(name).strip()]
if not cleaned:
raise ValueError("Expected at least one simulation name.")
return cleaned
[docs]
def write_sim_name_list(path: str | Path, sim_names: list[str] | tuple[str, ...]) -> Path:
"""Write one simulation name per line and return the resolved output path."""
cleaned = _clean_sim_names(sim_names)
output_path = Path(path).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text("\n".join(cleaned) + "\n", encoding="utf-8")
return output_path
[docs]
def export_split_files(
train_sims: list[str] | tuple[str, ...],
test_sims: list[str] | tuple[str, ...],
output_dir: str | Path,
*,
train_filename: str = "train.txt",
test_filename: str = "test.txt",
) -> dict[str, str]:
"""Export train/test simulation-name lists into a directory."""
output_root = Path(output_dir).expanduser().resolve()
train_path = write_sim_name_list(output_root / train_filename, train_sims)
test_path = write_sim_name_list(output_root / test_filename, test_sims)
return {
"output_dir": str(output_root),
"train_file": str(train_path),
"test_file": str(test_path),
}