Skip to content

Orchestrate: Stage

Single-stage primitives: build resets GPU state and instantiates; train fits and writes the .train_complete marker + split predictions; evaluate runs .test(), writes the .test_complete marker + per-test-set prediction sidecars, and persists the test-phase MLflow run row. _check_ckpt_compat guards resume/test against silent topology drift (wrong Module class, wrong IdEncoder).

graphids.orchestrate.stage

stage

Single-stage primitives.

build / train / evaluate are the atomic verbs between a ResolvedConfig and a running Trainer. Each takes the resolved config directly so callers don't have to unpack rendered / validated / run_dir / ckpt_file into positional arguments at every call site.

wire_file_exporters is called once by the caller per stage, not by these primitives.

When resolved.run_dir is None (CLI smoke with no default_root_dir), the primitives skip all filesystem side effects: no markers, no file exporters, no ckpt_path hand-off to .test().

build

build(resolved: ResolvedConfig) -> InstantiatedRun

Instantiate trainer + model + datamodule from a resolved config.

GPU state is reset first so a prior stage's VRAM / compiled kernels don't leak into this one.

Source code in graphids/orchestrate/stage.py
def build(resolved: ResolvedConfig) -> InstantiatedRun:
    """Instantiate trainer + model + datamodule from a resolved config.

    GPU state is reset first so a prior stage's VRAM / compiled kernels
    don't leak into this one.
    """
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    torch.compiler.reset()
    return build_run(resolved.rendered, validated=resolved.validated)

evaluate

evaluate(artifacts: InstantiatedRun, resolved: ResolvedConfig) -> dict[str, Any]

Run the test phase and return metrics.

On success, writes the test-phase marker, the per-test-set prediction sidecars, and the MLflow run row (params + final scalars + tags). traces.jsonl remains authoritative for timeseries / spans; MLflow is the queryable index for cross-run comparison.

Source code in graphids/orchestrate/stage.py
def evaluate(
    artifacts: InstantiatedRun,
    resolved: ResolvedConfig,
) -> dict[str, Any]:
    """Run the test phase and return metrics.

    On success, writes the test-phase marker, the per-test-set prediction
    sidecars, and the MLflow run row (params + final scalars + tags).
    ``traces.jsonl`` remains authoritative for timeseries / spans; MLflow
    is the queryable index for cross-run comparison.
    """
    stage_name = resolved.stage_name
    run_dir = resolved.run_dir
    ckpt_file = resolved.ckpt_file
    log.info("stage_test", stage=stage_name)
    if ckpt_file is not None:
        _check_ckpt_compat(ckpt_file, resolved)
    metrics = artifacts.trainer.test(
        artifacts.model,
        datamodule=artifacts.datamodule,
        ckpt_path=str(ckpt_file) if ckpt_file is not None else None,
    )
    if run_dir is not None:
        from graphids._mlflow import log_test_run

        touch_marker(run_dir / PHASE_MARKERS["test"])
        _save_test_predictions(artifacts.model, run_dir / "predictions" / "test")
        log_test_run(
            run_dir,
            resolved_config=resolved.validated.model_dump(),
            metrics=metrics or {},
        )
    log.info("stage_complete", stage=stage_name)
    return metrics or {}

train

train(artifacts: InstantiatedRun, resolved: ResolvedConfig, *, resume_from: str | None = None) -> Path | None

Fit the model and return the canonical checkpoint path.

Starts the MLflow run before fit so MLflowTrainingCallback has an active run to log per-epoch metrics into; the callback closes it at on_fit_end (or on_exception). end_training_run in finally is a safety net for callback-raises-during-teardown.

Source code in graphids/orchestrate/stage.py
def train(
    artifacts: InstantiatedRun,
    resolved: ResolvedConfig,
    *,
    resume_from: str | None = None,
) -> Path | None:
    """Fit the model and return the canonical checkpoint path.

    Starts the MLflow run before fit so ``MLflowTrainingCallback`` has an
    active run to log per-epoch metrics into; the callback closes it at
    ``on_fit_end`` (or ``on_exception``). ``end_training_run`` in finally
    is a safety net for callback-raises-during-teardown.
    """
    from graphids._mlflow import end_training_run, start_training_run

    stage_name = resolved.stage_name
    run_dir = resolved.run_dir
    ckpt_file = resolved.ckpt_file
    log.info("stage_train", stage=stage_name, run_dir=str(run_dir) if run_dir else "")
    if resume_from is not None:
        _check_ckpt_compat(resume_from, resolved)
    if run_dir is not None:
        start_training_run(run_dir, resolved.validated.model_dump())
    try:
        artifacts.trainer.fit(
            artifacts.model,
            datamodule=artifacts.datamodule,
            ckpt_path=resume_from,
        )
    finally:
        end_training_run()
    if run_dir is not None:
        touch_marker(run_dir / PHASE_MARKERS["train"])
        pred_dir = run_dir / "predictions"
        _save_split_predictions(artifacts, "train", pred_dir)
        _save_split_predictions(artifacts, "val", pred_dir)
    log.info(
        "stage_train_complete",
        stage=stage_name,
        ckpt=str(ckpt_file) if ckpt_file else "",
    )
    return ckpt_file