Single-stage primitives.
build / train / evaluate are the atomic verbs between a
ResolvedConfig and a running Trainer. Each takes the resolved
config directly so callers don't have to unpack rendered /
validated / run_dir / ckpt_file into positional arguments at
every call site.
wire_file_exporters is called once by the caller per stage, not by
these primitives.
When resolved.run_dir is None (CLI smoke with no
default_root_dir), the primitives skip all filesystem side effects:
no markers, no file exporters, no ckpt_path hand-off to .test().
build
build(resolved: ResolvedConfig) -> InstantiatedRun
Instantiate trainer + model + datamodule from a resolved config.
GPU state is reset first so a prior stage's VRAM / compiled kernels
don't leak into this one.
Source code in graphids/orchestrate/stage.py
| def build(resolved: ResolvedConfig) -> InstantiatedRun:
"""Instantiate trainer + model + datamodule from a resolved config.
GPU state is reset first so a prior stage's VRAM / compiled kernels
don't leak into this one.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.compiler.reset()
return build_run(resolved.rendered, validated=resolved.validated)
|
evaluate
evaluate(artifacts: InstantiatedRun, resolved: ResolvedConfig) -> dict[str, Any]
Run the test phase and return metrics.
On success, writes the test-phase marker, the per-test-set prediction
sidecars, and the MLflow run row (params + final scalars + tags).
traces.jsonl remains authoritative for timeseries / spans; MLflow
is the queryable index for cross-run comparison.
Source code in graphids/orchestrate/stage.py
| def evaluate(
artifacts: InstantiatedRun,
resolved: ResolvedConfig,
) -> dict[str, Any]:
"""Run the test phase and return metrics.
On success, writes the test-phase marker, the per-test-set prediction
sidecars, and the MLflow run row (params + final scalars + tags).
``traces.jsonl`` remains authoritative for timeseries / spans; MLflow
is the queryable index for cross-run comparison.
"""
stage_name = resolved.stage_name
run_dir = resolved.run_dir
ckpt_file = resolved.ckpt_file
log.info("stage_test", stage=stage_name)
if ckpt_file is not None:
_check_ckpt_compat(ckpt_file, resolved)
metrics = artifacts.trainer.test(
artifacts.model,
datamodule=artifacts.datamodule,
ckpt_path=str(ckpt_file) if ckpt_file is not None else None,
)
if run_dir is not None:
from graphids._mlflow import log_test_run
touch_marker(run_dir / PHASE_MARKERS["test"])
_save_test_predictions(artifacts.model, run_dir / "predictions" / "test")
log_test_run(
run_dir,
resolved_config=resolved.validated.model_dump(),
metrics=metrics or {},
)
log.info("stage_complete", stage=stage_name)
return metrics or {}
|
train
train(artifacts: InstantiatedRun, resolved: ResolvedConfig, *, resume_from: str | None = None) -> Path | None
Fit the model and return the canonical checkpoint path.
Starts the MLflow run before fit so MLflowTrainingCallback has an
active run to log per-epoch metrics into; the callback closes it at
on_fit_end (or on_exception). end_training_run in finally
is a safety net for callback-raises-during-teardown.
Source code in graphids/orchestrate/stage.py
| def train(
artifacts: InstantiatedRun,
resolved: ResolvedConfig,
*,
resume_from: str | None = None,
) -> Path | None:
"""Fit the model and return the canonical checkpoint path.
Starts the MLflow run before fit so ``MLflowTrainingCallback`` has an
active run to log per-epoch metrics into; the callback closes it at
``on_fit_end`` (or ``on_exception``). ``end_training_run`` in finally
is a safety net for callback-raises-during-teardown.
"""
from graphids._mlflow import end_training_run, start_training_run
stage_name = resolved.stage_name
run_dir = resolved.run_dir
ckpt_file = resolved.ckpt_file
log.info("stage_train", stage=stage_name, run_dir=str(run_dir) if run_dir else "")
if resume_from is not None:
_check_ckpt_compat(resume_from, resolved)
if run_dir is not None:
start_training_run(run_dir, resolved.validated.model_dump())
try:
artifacts.trainer.fit(
artifacts.model,
datamodule=artifacts.datamodule,
ckpt_path=resume_from,
)
finally:
end_training_run()
if run_dir is not None:
touch_marker(run_dir / PHASE_MARKERS["train"])
pred_dir = run_dir / "predictions"
_save_split_predictions(artifacts, "train", pred_dir)
_save_split_predictions(artifacts, "val", pred_dir)
log.info(
"stage_train_complete",
stage=stage_name,
ckpt=str(ckpt_file) if ckpt_file else "",
)
return ckpt_file
|