Source code for training.models.pix2pix
"""Built-in Pix2Pix model definition."""
from training import import_physicsnemo_attr
from training.models import register_model
[docs]
def build(model_cfg: dict, dataset_info: dict):
pix2pix_cls = import_physicsnemo_attr("physicsnemo.models.pix2pix.pix2pix", "Pix2Pix")
resolved = {
"in_channels": dataset_info["in_channels"],
"out_channels": dataset_info["out_channels"],
"dimension": int(model_cfg.get("dimension", 2)),
"conv_layer_size": int(model_cfg.get("conv_layer_size", 64)),
"n_downsampling": int(model_cfg.get("n_downsampling", 3)),
"n_upsampling": int(model_cfg.get("n_upsampling", 3)),
"n_blocks": int(model_cfg.get("n_blocks", 3)),
}
model = pix2pix_cls(**resolved)
model._resolved_model_params = dict(resolved)
return model
register_model("pix2pix", build_fn=build, adapter="grid")