Skip to content

Artifacts

Per-checkpoint artifact generation: embeddings, GAT attention weights, teacher↔student CKA, loss-landscape grids, fusion-policy traces. Driven by AnalysisConfig and dispatched directly through graphids.exp.runtime.run_stage to Analyzer(spec).run().

Distinct from graphids.analysis, which owns cross-run statistical comparison from the MLflow catalog (no torch, login-safe).

Layout

graphids/core/artifacts/
├── analyzer.py      orchestrates load → compute → save loop, writes manifest
├── _dispatch.py     ARTIFACTS table — each row is the only place compute + I/O meet
├── compute.py       pure compute fns + frozen result dataclasses (no fs)
├── io.py            every read (val data, teacher ckpt, fusion eval) + every write
└── __init__.py

The compute / I/O split is structural: every compute_* function takes pre-loaded models and pre-built val_data, returns a frozen dataclass, and never touches the filesystem. io.save_* consumes those dataclasses and writes; io.load_* reads. The dispatch table is the single seam between the two — adding an artifact means one new row in ARTIFACTS, one new compute fn, and one new save fn.

Adding an artifact

  1. Add compute_X(...) -> XResult (frozen dataclass) to compute.py.
  2. Add save_X(out, result) to io.py.
  3. Add an Artifact("X", "x.npz", frozenset({...}), _run_X) row to _dispatch.ARTIFACTS and a _run_X glue fn that wires load → compute → save.
  4. Add a toggle field on AnalysisConfig (default off, or default-on for the model types in applies_to via default_toggles_for).

expected_outputs(spec) and Analyzer.run() derive from ARTIFACTS automatically — no parallel declaration to update.

Reuse with training/eval

io.load_val_data goes through CANBusSourcestate.get_or_build — the same path GraphDataModule.setup takes during training. val_fraction, scaler strategy, and cache digest live on the source dataclass; the analyzer picks up changes there automatically with no parallel declaration.

io.load_teacher and the student ckpt load in Analyzer.run both go through safe_load_checkpoint — the canonical "ckpt → module" registry. io.load_fusion_eval wraps the same FusionDataModule training/eval uses.

Manifest sidecar

Analyzer.run() writes analysis_manifest.json next to the artifacts: the rendered analysis identity, expected outputs (derived from expected_outputs(spec)), and which actually exist on disk after the run. Useful as provenance when an analyze run was submitted via SLURM and the output dir is the only artifact left.

graphids.core.artifacts

artifacts

Per-checkpoint artifact generation.

ARTIFACTS module-attribute

ARTIFACTS: tuple[Artifact, ...] = (Artifact('embeddings', 'embeddings.npz', frozenset({'vgae', 'dgi', 'gat'}), _run_embeddings), Artifact('attention', 'attention_weights.npz', frozenset({'gat'}), _run_attention), Artifact('cka', 'cka.json', frozenset({'gat'}), _run_cka), Artifact('landscape', 'loss_landscape_{model_type}.parquet', frozenset({'vgae', 'dgi', 'gat'}), _run_landscape), Artifact('fusion_policy', 'dqn_policy.json', frozenset({'fusion'}), _run_fusion_policy))

MANIFEST_NAME module-attribute

MANIFEST_NAME = 'analysis_manifest.json'

Analyzer

Analyzer(spec: AnalysisConfig)

Generate analysis artifacts from a trained checkpoint.

Source code in graphids/core/artifacts/analyzer.py
def __init__(self, spec: AnalysisConfig):
    self.spec = spec
    if not Path(spec.ckpt_path).exists():
        raise FileNotFoundError(f"Checkpoint not found: {spec.ckpt_path}")
    if spec.cka and not Path(spec.cka_teacher_ckpt).exists():
        raise FileNotFoundError(f"Teacher checkpoint not found: {spec.cka_teacher_ckpt}")

expected_outputs

expected_outputs(spec: AnalysisConfig) -> tuple[str, ...]
Source code in graphids/core/artifacts/analyzer.py
def expected_outputs(spec: AnalysisConfig) -> tuple[str, ...]:
    out: list[str] = []
    for a in ARTIFACTS:
        if getattr(spec, a.name):
            out.append(a.output.format(model_type=spec.model_type))
    return tuple(out)

analyzer

Per-checkpoint artifact generation.

Analyzer

Analyzer(spec: AnalysisConfig)

Generate analysis artifacts from a trained checkpoint.

Source code in graphids/core/artifacts/analyzer.py
def __init__(self, spec: AnalysisConfig):
    self.spec = spec
    if not Path(spec.ckpt_path).exists():
        raise FileNotFoundError(f"Checkpoint not found: {spec.ckpt_path}")
    if spec.cka and not Path(spec.cka_teacher_ckpt).exists():
        raise FileNotFoundError(f"Teacher checkpoint not found: {spec.cka_teacher_ckpt}")

compute

Pure compute primitives — no filesystem, no MLflow, no logging side-effects.

Each compute_* returns a frozen dataclass (or plain dict, for CKA's single layer→score mapping) that io.save_* knows how to serialize. The analyzer wraps the whole batch in :func:eval_mode, so no compute function re-enters it.

compute_attention

compute_attention(model: Module, val_data: list, device: device, *, max_samples: int = 50) -> AttentionResult | None

Per-sample per-layer GAT attention weights. None if model lacks them.

Source code in graphids/core/artifacts/compute.py
@torch.no_grad()
def compute_attention(
    model: torch.nn.Module,
    val_data: list,
    device: torch.device,
    *,
    max_samples: int = 50,
) -> AttentionResult | None:
    """Per-sample per-layer GAT attention weights. ``None`` if model lacks them."""
    if getattr(model, "conv_type", None) != "gat":
        return None

    loader = PyGDataLoader(val_data[:max_samples], batch_size=1)
    out: dict[str, np.ndarray] = {}
    sample_idx = 0
    for batch in loader:
        batch = batch.clone().to(device)
        _xs, attention_weights = model(batch, return_attention_weights=True)
        prefix = f"sample_{sample_idx}"
        out[f"{prefix}_label"] = batch.y[0].cpu().numpy()
        for layer_idx, alpha in enumerate(attention_weights):
            out[f"{prefix}_layer_{layer_idx}_alpha"] = alpha.cpu().numpy()
        sample_idx += 1
    return AttentionResult(weights=out, n_samples=sample_idx)

compute_cka

compute_cka(student: Module, teacher: Module, val_data: list, device: device, *, max_samples: int = 500) -> dict[str, float]

Full cross-matrix linear CKA between all teacher and student layers.

Returns keys teacher_{i}student for every combination, giving an n_teacher × n_student matrix. Previously used min(n_teacher, n_student) and only compared corresponding pairs — this silently dropped teacher layers that had no student counterpart (bug: 3-layer teacher × 2-layer student only produced 2 values instead of 6).

Source code in graphids/core/artifacts/compute.py
def compute_cka(
    student: torch.nn.Module,
    teacher: torch.nn.Module,
    val_data: list,
    device: torch.device,
    *,
    max_samples: int = 500,
) -> dict[str, float]:
    """Full cross-matrix linear CKA between all teacher and student layers.

    Returns keys teacher_{i}_student_{j} for every combination, giving an
    n_teacher × n_student matrix. Previously used min(n_teacher, n_student)
    and only compared corresponding pairs — this silently dropped teacher
    layers that had no student counterpart (bug: 3-layer teacher × 2-layer
    student only produced 2 values instead of 6).
    """
    student_reps = _collect_reps(student, val_data, device, max_samples)
    teacher_reps = _collect_reps(teacher, val_data, device, max_samples)
    return {
        f"teacher_{i}_student_{j}": _linear_cka(teacher_reps[i], student_reps[j])
        for i in range(len(teacher_reps))
        for j in range(len(student_reps))
    }

compute_embeddings

compute_embeddings(model: Module, val_data: list, device: device, *, model_type: str, max_samples: int = 2000, batch_size: int = 256) -> EmbeddingsResult

Pool per-graph embeddings + labels for val_data[:max_samples].

Source code in graphids/core/artifacts/compute.py
@torch.no_grad()
def compute_embeddings(
    model: torch.nn.Module,
    val_data: list,
    device: torch.device,
    *,
    model_type: str,
    max_samples: int = 2000,
    batch_size: int = 256,
) -> EmbeddingsResult:
    """Pool per-graph embeddings + labels for ``val_data[:max_samples]``."""
    loader = PyGDataLoader(val_data[:max_samples], batch_size=batch_size)
    all_emb, all_labels = [], []
    for batch in loader:
        batch = batch.clone().to(device)
        edge_attr = getattr(batch, "edge_attr", None)
        if model_type == "vgae":
            z, *_ = model.encode(batch.x, batch.edge_index, edge_attr, batch.batch, batch.node_id)
            emb = scatter(z, batch.batch, dim=0, reduce="mean")
        elif model_type == "dgi":
            z = model.encode(batch.x, batch.edge_index, edge_attr, batch.batch, batch.node_id)
            emb = scatter(z, batch.batch, dim=0, reduce="mean")
        else:
            # GATWithJK: forward(data, return_embedding=True) → (logits, emb_pooled)
            _, emb = model(batch, return_embedding=True)
        all_emb.append(emb.cpu().numpy())
        all_labels.append(batch.y.cpu().numpy())
    return EmbeddingsResult(
        embeddings=np.concatenate(all_emb),
        labels=np.concatenate(all_labels),
        model_type=model_type,
    )

compute_fusion_policy

compute_fusion_policy(module, td: TensorDict, labels: Tensor) -> PolicyResult

Run fusion module on pre-extracted val states; return alphas + Q-values + labels.

Source code in graphids/core/artifacts/compute.py
def compute_fusion_policy(module, td: TensorDict, labels: torch.Tensor) -> PolicyResult:
    """Run fusion module on pre-extracted val states; return alphas + Q-values + labels."""
    from graphids.core.models.fusion.base import flatten_features

    result = module.predict(td)
    q_vals: np.ndarray | None = None
    if hasattr(module, "q_values"):
        flat_obs = flatten_features(result["td_norm"])
        q_vals = module.q_values(flat_obs).cpu().numpy()
    # RL models (Bandit/DQN) return alphas (mixing weights); non-RL models
    # (MLP/MoE/WeightedAvg) return only fused_scores — use those as the signal.
    alpha_tensor = result.get("alphas", result["fused_scores"])
    return PolicyResult(
        alphas=alpha_tensor.detach().cpu().numpy(),
        labels=labels.cpu().numpy(),
        q_values=q_vals,
    )

compute_landscape

compute_landscape(model: Module, model_type: str, val_data: list, device: device, hparams, *, resolution: int = 51, scale: float = 1.0, seed: int = 42, max_graphs: int = 500, dataset: str = '') -> LandscapeResult

Loss on a resolution × resolution grid of filter-normalized perturbations.

KeyError on unknown model_type — dispatch's applies_to should filter callers; reaching this with an unsupported type is a routing bug.

Source code in graphids/core/artifacts/compute.py
def compute_landscape(
    model: torch.nn.Module,
    model_type: str,
    val_data: list,
    device: torch.device,
    hparams,
    *,
    resolution: int = 51,
    scale: float = 1.0,
    seed: int = 42,
    max_graphs: int = 500,
    dataset: str = "",
) -> LandscapeResult:
    """Loss on a ``resolution × resolution`` grid of filter-normalized perturbations.

    ``KeyError`` on unknown ``model_type`` — dispatch's ``applies_to`` should
    filter callers; reaching this with an unsupported type is a routing bug.
    """
    loss_fn = _LOSS_FN[model_type]
    if len(val_data) > max_graphs:
        rng = np.random.default_rng(seed)
        idx = rng.choice(len(val_data), max_graphs, replace=False)
        data = [val_data[i] for i in idx]
    else:
        data = val_data
    dataloader = PyGDataLoader(data, batch_size=min(256, len(data)))

    dir1 = _random_direction(model, seed)
    dir2 = _random_direction(model, seed + 1)
    base = [p.data.clone() for p in model.parameters()]
    alphas = np.linspace(-scale, scale, resolution)
    betas = np.linspace(-scale, scale, resolution)

    xs, ys, losses = [], [], []
    for a in alphas:
        for b in betas:
            _perturb_model(model, base, dir1, dir2, a, b)
            losses.append(loss_fn(model, dataloader, device, hparams))
            xs.append(float(a))
            ys.append(float(b))
    _perturb_model(model, base, dir1, dir2, 0.0, 0.0)  # restore

    return LandscapeResult(x=xs, y=ys, loss=losses, model_type=model_type, dataset=dataset)

io

Filesystem I/O for the artifact pipeline.

load_fusion_eval

load_fusion_eval(*, dataset: str, seed: int, device: device) -> tuple

Load pre-extracted fusion validation tensors.

Source code in graphids/core/artifacts/io.py
def load_fusion_eval(
    *,
    dataset: str,
    seed: int,
    device: torch.device,
) -> tuple:
    """Load pre-extracted fusion validation tensors."""
    from graphids.core.data.datamodule.fusion import FusionDataModule
    from graphids.paths import trial_dir

    dm = FusionDataModule(
        cached_states_dir=trial_dir() / "cached_states" / dataset / "default" / f"seed_{int(seed)}"
    )
    dm.setup("test")
    labels = dm.val_td["labels"].clone()
    td = dm.val_td.exclude("labels").to(device)
    return td, labels

load_teacher

load_teacher(model_type: str, ckpt_path: str, device: device) -> torch.nn.Module

Load a teacher checkpoint for analysis.

Source code in graphids/core/artifacts/io.py
def load_teacher(model_type: str, ckpt_path: str, device: torch.device) -> torch.nn.Module:
    """Load a teacher checkpoint for analysis."""
    teacher = safe_load_checkpoint(model_type, ckpt_path, map_location=device)
    teacher.eval()
    return teacher

load_val_data

load_val_data(*, lake_root: str, dataset: str, vocab_scope: str, seed: int, representation_cfg: GraphRepresentationCfg) -> list

Load the val split through the same source/cache path as training.

Source code in graphids/core/artifacts/io.py
def load_val_data(
    *,
    lake_root: str,
    dataset: str,
    vocab_scope: str,
    seed: int,
    representation_cfg: GraphRepresentationCfg,
) -> list:
    """Load the val split through the same source/cache path as training."""
    from graphids.core.data.datasets.can_bus import CANBusSource
    from graphids.core.data.state import get_or_build

    state = get_or_build(
        CANBusSource(
            name=dataset,
            lake_root=lake_root,
            seed=seed,
            vocab_scope=vocab_scope,
            representation_cfg=representation_cfg,
        )
    )
    val = list(state.val)
    log.info("data_loaded", n_val=len(val))
    return val