Skip to content

Runtime

The experiment runtime is the live launch surface. It accepts a typed RunConfig, writes the manifest and event log, and dispatches the stage-specific work directly:

  • fit / test → Lightning trainer launch with config-driven data/model instantiation
  • cache → materialize the configured graph cache
  • extract → feature extraction over configured checkpoints and dataset
  • analyze → per-checkpoint artifact generation through graphids.core.artifacts.analyzer.Analyzer

The runtime module is intentionally narrow and replaces the old row/orchestrate chassis.

graphids.exp.runtime

runtime

Execution helpers for the new experiment seam.

Ray/Hydra can attach here later. For now this module gives us a single place to write manifests and events around any callable run body.

launch_run

launch_run(run: RunConfig) -> RunSummary

Run one launchable config with manifest/event tracking.

Source code in graphids/exp/runtime.py
def launch_run(
    run: RunConfig,
) -> RunSummary:
    """Run one launchable config with manifest/event tracking."""
    if run.stage in {"fit", "test"} and run.resources.accelerator == "gpu":
        from graphids.runtime_checks import assert_pyg_cuda_extensions_match

        assert_pyg_cuda_extensions_match()

    backend = run.resources.backend
    if backend == "ray":
        try:
            import ray  # noqa: F401
        except ImportError:
            backend = "local"
    logger = _make_run_logger(run)
    logger.log_hyperparams(run.mlflow_hparams(backend=backend))
    manifest = run.journal_manifest(status="running")
    write_manifest(run.outputs.run_dir, manifest, name=run.outputs.manifest_name)
    append_event(
        run.outputs.run_dir,
        EventRecord(
            status="running",
            stage=run.stage,
            message="launch_started",
            details={"backend": backend},
        ),
        name=run.outputs.events_name,
    )

    try:
        if backend == "ray":
            import ray

            ray.init(ignore_reinit_error=True, include_dashboard=False)
            if logger.run_id is None:
                raise RuntimeError("MLflow logger did not create a run id before Ray launch")
            result = ray.get(ray.remote(_run_stage_with_existing_mlflow_run).remote(run, logger.run_id))
        else:
            result = run_stage(run, logger=logger)
        if logger.run_id is not None:
            logger.experiment.set_terminated(logger.run_id, status="FINISHED")
        append_event(
            run.outputs.run_dir,
            EventRecord(status="finished", stage=run.stage, message="run_finished", details=_payload(result)),
            name=run.outputs.events_name,
        )
        write_manifest(
            run.outputs.run_dir,
            manifest.model_copy(update={"status": "finished"}),
            name=run.outputs.manifest_name,
        )
        return RunSummary(
            run_dir=str(run.outputs.run_dir),
            status="finished",
            stage=run.stage,
            name=run.name,
            last_event="run_finished",
        )
    except BaseException as exc:  # noqa: BLE001 - record all failures, then re-raise
        failure = f"{type(exc).__name__}: {exc}"
        if logger.run_id is not None:
            logger.experiment.set_terminated(logger.run_id, status="FAILED")
        append_event(
            run.outputs.run_dir,
            EventRecord(
                status="failed",
                stage=run.stage,
                message="run_failed",
                details={"failure": failure},
            ),
            name=run.outputs.events_name,
        )
        write_manifest(
            run.outputs.run_dir,
            manifest.model_copy(update={"status": "failed", "failure": failure}),
            name=run.outputs.manifest_name,
        )
        raise

run_stage

run_stage(run: RunConfig, logger: Any | None = None) -> dict[str, Any] | None

Default stage dispatcher for experiment launches.

Fit/test, extract, and analyze all run directly from the typed experiment config objects.

Source code in graphids/exp/runtime.py
def run_stage(run: RunConfig, logger: Any | None = None) -> dict[str, Any] | None:
    """Default stage dispatcher for experiment launches.

    Fit/test, extract, and analyze all run directly from the typed
    experiment config objects.
    """
    if run.stage in {"fit", "test"}:
        payload = run.payload.model_dump(mode="json")
        return _run_fit_or_test(
            run.stage,
            payload,
            ckpt_path=payload.get("ckpt_path"),
            logger=logger,
        )
    if run.stage == "cache":
        payload = run.payload.model_dump(mode="json")
        return _run_cache(payload)
    if run.stage == "extract":
        from graphids.core.data.extract import extract_states

        run_cfg = run.payload.model_dump(mode="json")
        checkpoints = run_cfg.get("checkpoints") or run_cfg.get("extractor_ckpts")
        if checkpoints is None:
            raise ValueError("extract requires checkpoints or extractor_ckpts")
        dataset = run_cfg.get("dataset")
        output_dir = run_cfg.get("output_dir")
        if not dataset or not output_dir:
            raise ValueError("extract requires dataset and output_dir")
        extract_states(
            checkpoints=checkpoints,
            dataset=dataset,
            output_dir=output_dir,
            max_samples=int(run_cfg.get("max_samples", 150_000)),
            max_val_samples=int(run_cfg.get("max_val_samples", 30_000)),
            batch_size=int(run_cfg.get("batch_size", 256)),
            seed=int(run_cfg.get("seed", run.seed)),
            val_fraction=float(run_cfg.get("val_fraction", 0.2)),
            representation_cfg=run.representation_cfg,
        )
        return {"stage": "extract", "output_dir": output_dir}
    if run.stage == "analyze":
        from graphids.core.artifacts.analyzer import AnalysisConfig, Analyzer

        run_cfg = run.payload.model_dump(mode="json")
        spec = AnalysisConfig(
            name=run_cfg.get("name", run.name),
            plan_id=run_cfg.get("plan_id", run.plan_id or run.name),
            ckpt_path=str(run_cfg.get("ckpt_path", "")),
            dataset=str(run_cfg.get("dataset", run.dataset or "")),
            model_type=str(run_cfg.get("model_type", "gat")),
            output_dir=str(run_cfg.get("output_dir", "")),
            lake_root=str(run_cfg.get("lake_root", "")),
            embeddings=bool(run_cfg.get("embeddings", True)),
            attention=bool(run_cfg.get("attention", False)),
            cka=bool(run_cfg.get("cka", False)),
            landscape=bool(run_cfg.get("landscape", False)),
            fusion_policy=bool(run_cfg.get("fusion_policy", False)),
            cka_teacher_ckpt=str(run_cfg.get("cka_teacher_ckpt", "")),
            cka_max_samples=int(run_cfg.get("cka_max_samples", 500)),
            landscape_resolution=int(run_cfg.get("landscape_resolution", 51)),
            landscape_scale=float(run_cfg.get("landscape_scale", 1.0)),
            landscape_max_graphs=int(run_cfg.get("landscape_max_graphs", 500)),
            embedding_max_samples=int(run_cfg.get("embedding_max_samples", 2000)),
            attention_max_samples=int(run_cfg.get("attention_max_samples", 50)),
            batch_size=int(run_cfg.get("batch_size", 256)),
            seed=int(run_cfg.get("seed", run.seed)),
            vocab_scope=str(run_cfg.get("vocab_scope", "train")),
            representation_cfg=run.representation_cfg,
            vgae_ckpt_path=str(run_cfg.get("vgae_ckpt_path", "")),
            gat_ckpt_path=str(run_cfg.get("gat_ckpt_path", "")),
        )
        Analyzer(spec).run()
        return {"stage": "analyze", "output_dir": spec.output_dir}
    if run.stage == "hf_push":
        raise NotImplementedError(f"stage {run.stage!r} is not wired yet")
    raise ValueError(f"unknown stage: {run.stage!r}")