Skip to content

Core: Trainer

Pure-PyTorch training loop — Lightning was removed. Single-GPU only (project targets 1× V100), handles AMP via GradScaler, gradient clipping, AMP-safe scheduler skipping on inf/nan scale-warmup batches, and callback lifecycle using the same hook names as Lightning so the OTel + curriculum callbacks ported over without change.

graphids.core.trainer

trainer

Pure-PyTorch training loop for GraphIDS.

Single-GPU only (project uses 1x V100). Handles: - AMP via torch.amp.autocast + GradScaler(enabled=...) (no-op when disabled) - Gradient clipping via clip_grad_norm_ - automatic_optimization=False for RL fusion models - Metric accumulation and logger dispatch - Callback lifecycle (same hook names as Lightning) - Checkpoint resume

MetricAccumulator

MetricAccumulator(nan_strategy: str = 'error')

Dynamic-keyed batch-weighted mean.

Plain dict[str, (sum, count)] — NOT an nn.Module. These are transient per-phase accumulators; storing them in a ModuleDict both pollutes the parent's state_dict and rejects keys with "." (add_module's attribute-name check), breaking metric names like "test/precision@0.95recall".

NaN detection hard-fails the run — under precision: 16-mixed a silent NaN in callback_metrics fools EarlyStopping (NaN < inf is False) and wastes the full patience window.

Source code in graphids/core/trainer.py
def __init__(self, nan_strategy: str = "error") -> None:
    self._nan_strategy = nan_strategy
    self._sums: dict[str, float] = {}
    self._counts: dict[str, float] = {}

Trainer

Trainer(config: TrainerConfig, callbacks: list | None = None, logger: list | bool | None = None)

Single-GPU training loop with AMP, gradient clipping, and callbacks.

Source code in graphids/core/trainer.py
def __init__(
    self,
    config: TrainerConfig,
    callbacks: list | None = None,
    logger: list | bool | None = None,
) -> None:
    self.config = config
    self.callbacks: list[CallbackBase] = list(callbacks or [])
    # build_loggers returns a list or a falsy (None/False); normalize once.
    self.loggers: list = list(logger) if isinstance(logger, list) else []

    # Public state (read by callbacks + model code)
    self.max_epochs: int = config.max_epochs
    self.current_epoch: int = 0
    self.global_step: int = 0
    self.callback_metrics: dict[str, float] = {}
    self.default_root_dir: str = config.default_root_dir
    self.should_stop: bool = False
    self.datamodule: Any = None

    # Populated during fit()
    self._optimizers: list[torch.optim.Optimizer] = []
    self._schedulers: list[Any] = []

    # Resolve device
    self._device = self._resolve_device(config.accelerator)

    # Find well-known callbacks
    self.checkpoint_callback: ModelCheckpoint | None = None
    self.early_stopping_callback: EarlyStopping | None = None
    for cb in self.callbacks:
        if isinstance(cb, ModelCheckpoint):
            self.checkpoint_callback = cb
        elif isinstance(cb, EarlyStopping):
            self.early_stopping_callback = cb

fit

fit(model: Module, datamodule: Any, ckpt_path: str | None = None) -> None

Fit the model.

Wires datamodule → device → model.setup → device.to(), then runs the train/val loop up to max_epochs or until a callback flips trainer.should_stop. ckpt_path resumes weights + optimizer + scheduler + AMP scaler state; on_exception fires on any raise so callbacks can close MLflow runs cleanly before re-raising.

Source code in graphids/core/trainer.py
def fit(
    self,
    model: nn.Module,
    datamodule: Any,
    ckpt_path: str | None = None,
) -> None:
    """Fit the model.

    Wires datamodule → device → model.setup → device.to(), then runs
    the train/val loop up to ``max_epochs`` or until a callback flips
    ``trainer.should_stop``. ``ckpt_path`` resumes weights +
    optimizer + scheduler + AMP scaler state; ``on_exception`` fires
    on any raise so callbacks can close MLflow runs cleanly before
    re-raising.
    """
    self.datamodule = datamodule
    self._wire_datamodule(datamodule, model)

    datamodule.setup("fit")
    model.setup(datamodule)
    # Must move AFTER setup(): _build() creates self.model and may wrap it
    # in torch.compile — moving before setup leaves the inner module on CPU.
    model.to(self._device)

    opt, sched = model.build_optimizers(self.max_epochs)
    self._optimizers = [opt] if opt else []
    self._schedulers = [sched] if sched else []

    # GradScaler(enabled=False) is a complete no-op passthrough —
    # all methods become identity. No branching needed in the loop.
    use_amp = "16" in str(self.config.precision) and self._device.type == "cuda"
    scaler = torch.amp.GradScaler(enabled=use_amp)

    if ckpt_path:
        self._resume_fit(ckpt_path, model, opt, sched, scaler)

    self._dispatch("on_fit_start", model)
    self._log_hyperparams(model)

    try:
        for epoch in range(self.current_epoch, self.max_epochs):
            self.current_epoch = epoch
            self._train_one_epoch(model, datamodule, opt, scaler, use_amp)
            self._validate_one_epoch(model, datamodule, use_amp)

            self._dispatch("on_train_epoch_end", model)

            # Skip scheduler.step() when the optimizer hasn't stepped
            # this run — GradScaler skips opt.step() on inf/nan grads,
            # which is common on early fp16 batches while the scale warms
            # up. Stepping the scheduler anyway trips PyTorch's
            # "lr_scheduler.step() before optimizer.step()" warning and
            # silently burns the first LR value.
            opt_stepped = any(getattr(o, "_opt_called", False) for o in self._optimizers)
            if opt_stepped:
                for s in self._schedulers:
                    if s is not None:
                        s.step()

            if self.should_stop:
                _log.info("early_stopping", epoch=epoch)
                break

    except BaseException as exc:
        self._dispatch("on_exception", model, exc)
        raise

    self._dispatch("on_fit_end", model)

predict

predict(model: Module, datamodule: Any, ckpt_path: str | None = None) -> list

Run predict_step over every test loader and return the concatenated list. Setups with "predict" so datamodules can swap in a predict-specific loader.

Source code in graphids/core/trainer.py
def predict(
    self,
    model: nn.Module,
    datamodule: Any,
    ckpt_path: str | None = None,
) -> list:
    """Run ``predict_step`` over every test loader and return the
    concatenated list. Setups with ``"predict"`` so datamodules can
    swap in a predict-specific loader.
    """
    self.datamodule = datamodule
    self._wire_datamodule(datamodule, model)

    datamodule.setup("predict")
    model.setup(datamodule)
    model.to(self._device)

    if ckpt_path:
        self._load_model_weights(ckpt_path, model)

    loaders = datamodule.test_dataloader()
    if not isinstance(loaders, list):
        loaders = [loaders]
    results: list = []
    for loader in loaders:
        results.extend(self.predict_on(model, loader))
    return results

predict_on

predict_on(model: Module, loader: Any) -> list

Run predict_step over a single loader. Assumes model/dm set up.

Source code in graphids/core/trainer.py
def predict_on(self, model: nn.Module, loader: Any) -> list:
    """Run ``predict_step`` over a single loader. Assumes model/dm set up."""
    model.eval()
    results: list = []
    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            out = model.predict_step(batch, batch_idx)
            if out is not None:
                results.append(out)
    return results

test

test(model: Module, datamodule: Any, ckpt_path: str | None = None) -> dict[str, float]

Evaluate on all test dataloaders, return aggregated metrics.

Multiple test loaders (e.g. one per attack subdir) are dispatched with a dataloader_idx so test_step can name metrics per subdir.

Source code in graphids/core/trainer.py
def test(
    self,
    model: nn.Module,
    datamodule: Any,
    ckpt_path: str | None = None,
) -> dict[str, float]:
    """Evaluate on all test dataloaders, return aggregated metrics.

    Multiple test loaders (e.g. one per attack subdir) are dispatched
    with a ``dataloader_idx`` so ``test_step`` can name metrics per
    subdir.
    """
    self.datamodule = datamodule
    self._wire_datamodule(datamodule, model)

    datamodule.setup("test")
    model.setup(datamodule)
    model.to(self._device)

    if ckpt_path:
        self._load_model_weights(ckpt_path, model)

    # OCGIN centroid fit: deterministic statistic of (trained encoder,
    # benign train data), so re-fit at test-start rather than persisting
    # through state_dict (which deadlocked on callback/ckpt-save ordering
    # and shipped uncalibrated ckpts — Cardinal jid 8772115). No-op for
    # models that don't expose calibrate_svdd_center.
    calibrate = getattr(model, "calibrate_svdd_center", None)
    if calibrate is not None:
        datamodule.setup("fit")
        calibrate(datamodule.train_dataloader(), self._device)

    test_loaders = datamodule.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    model.eval()
    model.on_test_epoch_start()

    with torch.no_grad():
        for dl_idx, loader in enumerate(test_loaders):
            for batch_idx, batch in enumerate(loader):
                model.test_step(batch, batch_idx, dataloader_idx=dl_idx)

    model.on_test_epoch_end()

    self.callback_metrics.update(model._metric_acc.compute())
    model._metric_acc.reset()

    return dict(self.callback_metrics)

validate

validate(model: Module, datamodule: Any, ckpt_path: str | None = None) -> dict[str, float]

Run one validation pass, return aggregated metrics.

Setups with "fit" (not "validate") because the val loader is built there; the train loader is allocated but never iterated.

Source code in graphids/core/trainer.py
def validate(
    self,
    model: nn.Module,
    datamodule: Any,
    ckpt_path: str | None = None,
) -> dict[str, float]:
    """Run one validation pass, return aggregated metrics.

    Setups with ``"fit"`` (not ``"validate"``) because the val
    loader is built there; the train loader is allocated but never
    iterated.
    """
    self.datamodule = datamodule
    self._wire_datamodule(datamodule, model)

    datamodule.setup("fit")
    model.setup(datamodule)
    model.to(self._device)

    if ckpt_path:
        self._load_model_weights(ckpt_path, model)

    use_amp = "16" in str(self.config.precision) and self._device.type == "cuda"
    self._validate_one_epoch(model, datamodule, use_amp)
    return dict(self.callback_metrics)

TrainerConfig dataclass

TrainerConfig(max_epochs: int = 300, precision: str = '16-mixed', gradient_clip_val: float = 1.0, log_every_n_steps: int = 50, accelerator: str = 'auto', devices: str | int = 'auto', default_root_dir: str = '')

Flat config matching the jsonnet trainer section keys.

seed_everything

seed_everything(seed: int) -> None

Seed Python, NumPy, and PyTorch RNGs. torch.manual_seed covers CPU + CUDA.

Source code in graphids/core/trainer.py
def seed_everything(seed: int) -> None:
    """Seed Python, NumPy, and PyTorch RNGs. ``torch.manual_seed`` covers CPU + CUDA."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)