Skip to content

Core: Callbacks

Lightning owns the training loop. graphids only ships callbacks that encode policy Lightning's stock callbacks don't:

  • Sha256ModelCheckpointlightning.pytorch.callbacks.ModelCheckpoint
  • a <ckpt>.sha256 sidecar so graphids._fs.atomic_load can verify bytes on read (GPFS truncates surprise us; sidecar is the load-time integrity check).
  • TauNormCallback — Kang ICLR 2020 τ-norm of the GAT classifier head at fit-end. Loads from the best ckpt, in-place rescales the final fc_layers[-1] nn.Linear weight by ‖w_c‖^τ, re-saves.
  • VRAMDriftCallback — warns once when free VRAM shrinks past threshold between epochs (probe baseline at fit-start).

pl.callbacks.EarlyStopping is wired straight from the libsonnet — graphids no longer ships its own.

MLflowTrainingCallback (in graphids._mlflow) forwards per-epoch metrics + run-config + LoggedModel registration; it is registered alongside but lives in the MLflow surface for discoverability.

The training loop, AMP autocast, gradient clipping, optimizer state, scheduler stepping, ckpt save/load schema, and the callback lifecycle all live in lightning.pytorch.Trainer. The core/trainer.py, core/_metric_acc.py, and core/_ckpt.py modules that previously re-implemented these were removed in the 2026-05-02 Lightning migration (commit c974185); see ~/plans/lightning-migration-spike.md for the inventory of what migrated and what was kept.

graphids.core.callbacks

callbacks

graphids-specific Lightning callbacks.

Lightning's stock ModelCheckpoint / EarlyStopping cover the universal trio (checkpoint + early-stop + MLflow forwarding); we only ship callbacks that encode graphids-specific policy:

  • Sha256ModelCheckpoint: ModelCheckpoint + sha256 sidecar so _fs.atomic_load can verify integrity at load time on GPFS.
  • TauNormCallback: Kang ICLR 2020 τ-norm of the GAT classifier head at fit-end (rescales final fc_layers[-1] row-wise by ‖w_c‖^τ).
  • VRAMDriftCallback: warn-once when free VRAM shrinks past threshold across epoch boundaries.

Sha256ModelCheckpoint

Bases: ModelCheckpoint

ModelCheckpoint + <ckpt>.sha256 sidecar after every save.

GPFS truncation surprises happen on OSC; the sidecar is the load-time integrity check used by _fs.atomic_load.

TauNormCallback dataclass

TauNormCallback(tau: float = 0.5)

Bases: Callback

Apply Kang τ-norm to GAT's classifier head at fit-end.

Loads the best ckpt, rescales the highest-indexed fc_layers.<N>.weight by ‖w_c‖^τ, atomic-saves. Hidden FC layers (fc_layers[:-1]) are encoder-side per Kang's framing — only the logit-producing matrix is normed.

VRAMDriftCallback dataclass

VRAMDriftCallback(threshold: float = 0.2)

Bases: Callback

Warn-once when free VRAM shrinks past threshold across epochs.

Budget probe captures free VRAM at build time. Over long runs the pool drifts (co-resident processes, PyG activation leaks). Baseline at fit-start, check at each epoch start. Warn-only — re-probing mid-run would race optimizer state; the researcher decides whether to abort.

apply_tau_norm

apply_tau_norm(weight: Tensor, tau: float) -> None

In-place row-wise τ-norm: w_c /= ‖w_c‖^τ.

Kang et al. ICLR 2020 §3.4 (arXiv 1910.09217). τ=0 identity, τ=1 unit-norm rows. Damps majority-class rows under imbalance.

Source code in graphids/core/callbacks.py
def apply_tau_norm(weight: torch.Tensor, tau: float) -> None:
    """In-place row-wise τ-norm: ``w_c /= ‖w_c‖^τ``.

    Kang et al. ICLR 2020 §3.4 (arXiv 1910.09217). τ=0 identity, τ=1
    unit-norm rows. Damps majority-class rows under imbalance.
    """
    if weight.ndim != 2:
        raise ValueError(f"τ-norm needs 2-D weight, got shape {tuple(weight.shape)}")
    norms = torch.linalg.vector_norm(weight, dim=1, keepdim=True).clamp_min(1e-12)
    weight.div_(norms.pow(tau))