Skip to content

Core: Models

Model families used as ablation rows. All inherit from GraphModuleBase (base.py), which owns the VRAM probe (compute_budget) plus the _store_init_kwargs / _build_id_encoder mixins.

  • autoencoder/ — VGAE family (unsupervised reconstruction). Stage 1 of the KD chain.
  • supervised/ — GAT family (supervised classification). Stage 2.
  • fusion/ — fusion modules dispatching on fusion_method TLA over the method libsonnets. Stage 3.
  • id_encoding/ — categorical-ID encoders (embedding tables with reserved UNK at index 0).

graphids.core.models

models

Core model families and shared model base classes.

BanditFusionModule

BanditFusionModule(state_dim: int = 18, alpha_steps: int = 21, ucb_alpha: float = 1.0, lambda_reg: float = 1.0, hidden_dim: int = 128, num_layers: int = 3, backbone_lr: float = 0.001, backbone_retrain_freq: int = 50, backbone_epochs: int = 5, buffer_size: int = 100000, batch_size: int = 128, decision_threshold: float = 0.5, reward_kwargs: dict | None = None)

Bases: RLFusionBase

Neural-LinUCB: backbone + per-arm ridge with Sherman-Morrison online updates, and a frequency-gated backbone refit.

Source code in graphids/core/models/fusion/bandit.py
def __init__(
    self,
    state_dim: int = 18,
    alpha_steps: int = 21,
    ucb_alpha: float = 1.0,
    lambda_reg: float = 1.0,
    hidden_dim: int = 128,
    num_layers: int = 3,
    backbone_lr: float = 1e-3,
    backbone_retrain_freq: int = 50,
    backbone_epochs: int = 5,
    buffer_size: int = 100_000,
    batch_size: int = 128,
    decision_threshold: float = 0.5,
    reward_kwargs: dict | None = None,
):
    super().__init__(
        buffer_size=buffer_size,
        batch_size=batch_size,
        state_dim=state_dim,
        alpha_steps=alpha_steps,
        decision_threshold=decision_threshold,
        reward_kwargs=reward_kwargs,
    )
    self._store_init_kwargs(locals())

    self.backbone = _Backbone(state_dim, hidden_dim, num_layers)
    d = self.backbone.out_dim

    self.register_buffer(
        "A_inv", torch.eye(d).unsqueeze(0).repeat(alpha_steps, 1, 1) / lambda_reg
    )
    self.register_buffer("b", torch.zeros(alpha_steps, d))
    self.register_buffer("theta", torch.zeros(alpha_steps, d))

    # gpu_training_steps drives RLFusionBase._learn_step's inner loop.
    self.gpu_training_steps = backbone_epochs

    self._optimizer = optim.AdamW(self.backbone.parameters(), lr=backbone_lr)
    self._episode = 0
    self._ucb_widths: list[float] = []

GraphModuleBase

Bases: _ModelBase

Shared base for VGAE, GAT, DGI — lazy setup, threshold metrics.

Subclasses must implement _build() using self.hparams.

compute_budget

compute_budget(train_dataset, dataset_name: str, min_steps: int | None = None) -> Any

Probe-once VRAM budget cached on the model.

Source code in graphids/core/models/base.py
def compute_budget(self, train_dataset, dataset_name: str, min_steps: int | None = None) -> Any:
    """Probe-once VRAM budget cached on the model."""
    if self._budget_cache is None:
        from graphids.core.budget import node_budget

        self._budget_cache = node_budget(
            dataset_name, model=self, train_dataset=train_dataset, min_steps=min_steps
        )
    return self._budget_cache

configure_optimizers

configure_optimizers()

Adam over all params using self.hparams.lr / weight_decay.

Source code in graphids/core/models/base.py
def configure_optimizers(self):
    """Adam over all params using ``self.hparams.lr`` / ``weight_decay``."""
    lr = getattr(self.hparams, "lr", 1e-3)
    wd = getattr(self.hparams, "weight_decay", 0.0)
    return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)

on_validation_epoch_end

on_validation_epoch_end() -> None

Override in subclasses to compute epoch-level val metrics.

Source code in graphids/core/models/base.py
def on_validation_epoch_end(self) -> None:
    """Override in subclasses to compute epoch-level val metrics."""

prepare_from_datamodule

prepare_from_datamodule(dm) -> None

Lazy-build with DM-supplied sizes, then capture test-set names.

Source code in graphids/core/models/base.py
def prepare_from_datamodule(self, dm) -> None:
    """Lazy-build with DM-supplied sizes, then capture test-set names."""
    already_built = getattr(self, "_built", False) or (
        getattr(self, "model", "_sentinel") not in (None, "_sentinel")
    )
    if not already_built:
        # Mirror onto self and into hparams so _build() sees DM values.
        for k in ("num_ids", "in_channels", "num_classes"):
            v = getattr(dm, k)
            setattr(self, k, v)
            self.hparams[k] = v
        self._build()
        self._built = True
    super().prepare_from_datamodule(dm)

autoencoder

Autoencoder/self-supervised model family exports.

DGI

DGI(conv_type: str = 'gatv2', hidden_dims: list[int] | None = None, latent_dim: int | None = None, heads: int | None = None, embedding_dim: int | None = None, dropout: float = 0.15, edge_dim: int = 11, proj_dim: int = 0, gradient_checkpointing: bool = True, compile_model: bool = False, batch_norm: bool = True, id_encoder_cfg: IdEncodingCfg | None = None, id_encoder_class_path: str = 'graphids.core.models.id_encoding.LookupIdEncoder', id_encoder_kwargs: dict | None = None, lr: float = 0.001, weight_decay: float = 0.0001, scale: str = 'small', model_type: ModelType = 'dgi', dataset: str = '', seed: int = 42, num_ids: int = 0, in_channels: int = 0, num_classes: int = 2)

Bases: ScoreBasedDetectorMixin

Collapsed DGI — arch + trainer-bridge in one nn.Module.

No loss_fn kwarg: the contrastive MI loss is intrinsic to the architecture (built into the discriminator).

Source code in graphids/core/models/autoencoder/dgi.py
def __init__(
    self,
    # --- architecture (latent_dim/embedding_dim/heads=None → scale) ---
    conv_type: str = "gatv2",
    hidden_dims: list[int] | None = None,
    latent_dim: int | None = None,
    heads: int | None = None,
    embedding_dim: int | None = None,
    dropout: float = 0.15,
    edge_dim: int = 11,
    proj_dim: int = 0,
    gradient_checkpointing: bool = True,
    compile_model: bool = False,
    batch_norm: bool = True,
    id_encoder_cfg: IdEncodingCfg | None = None,
    id_encoder_class_path: str = "graphids.core.models.id_encoding.LookupIdEncoder",
    id_encoder_kwargs: dict | None = None,
    # --- training ---
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    # --- identity / dynamic ---
    scale: str = "small",
    model_type: ModelType = "dgi",
    dataset: str = "",
    seed: int = 42,
    num_ids: int = 0,
    in_channels: int = 0,
    num_classes: int = 2,
):
    s = self._SCALES.get(scale, {})
    if latent_dim is None:
        latent_dim = s.get("latent_dim", 48)
    if embedding_dim is None:
        embedding_dim = s.get("embedding_dim", 32)
    if heads is None:
        heads = s.get("heads", 4)
    super().__init__()
    # OCGIN scoring head: centroid of training-normal pooled embeddings.
    # Re-fit at test-start by ``Trainer.test`` via ``on_test_setup`` —
    # the centroid is a deterministic statistic of (encoder weights,
    # benign train data). Zero init means an uncalibrated forward pass
    # raises in ``score`` rather than returning bogus scores.
    self.register_buffer("svdd_center", torch.zeros(latent_dim))
    self._init_post(locals())
dgi_loss
dgi_loss(pos_z, neg_z, summary, batch_idx)

Contrastive MI loss: maximize real node–summary agreement.

Source code in graphids/core/models/autoencoder/dgi.py
def dgi_loss(self, pos_z, neg_z, summary, batch_idx):
    """Contrastive MI loss: maximize real node–summary agreement."""
    EPS = 1e-6
    pos_score = self.discriminate(pos_z, summary, batch_idx)
    neg_score = self.discriminate(neg_z, summary, batch_idx)
    return -torch.log(pos_score + EPS).mean() - torch.log(1 - neg_score + EPS).mean()
discriminate
discriminate(z, summary, batch)

Bilinear scoring: sigmoid(z^T W s) per node.

Source code in graphids/core/models/autoencoder/dgi.py
def discriminate(self, z, summary, batch):
    """Bilinear scoring: sigmoid(z^T W s) per node."""
    s = summary[batch]
    return torch.sigmoid((z @ self.discriminator_weight * s).sum(dim=1))
encode
encode(x, edge_index, edge_attr=None, batch=None, node_id=None)

Encode nodes to latent embeddings (same contract as VGAE minus KL).

Source code in graphids/core/models/autoencoder/dgi.py
def encode(self, x, edge_index, edge_attr=None, batch=None, node_id=None):
    """Encode nodes to latent embeddings (same contract as VGAE minus KL)."""
    x = self.input_encoder(x, node_id)
    ea = edge_attr if self._uses_edge_attr else None

    for i, conv in enumerate(self.encoder_layers):
        bn = self.encoder_bns[i] if self.batch_norm else None
        x = conv_forward(
            conv,
            x,
            edge_index,
            ea,
            bn=bn,
            batch=batch,
            dropout_p=self.dropout_rate,
            training=self.training,
            use_checkpointing=self.use_checkpointing,
        )
    return self.z_proj(x)
extract_features
extract_features(batch, device: device) -> dict[str, torch.Tensor]

Per-graph fusion features as named tensors (symmetric to VGAE/GAT).

  • pos_stats [N, 3] — anomaly, pos_mean, pos_spread (discriminator-derived)
  • conf [N, 1] — 1 / (1 + anomaly)
  • z_stats [N, 4] — z_mean, z_std, z_max, z_min (latent-pooled)
Source code in graphids/core/models/autoencoder/dgi.py
def extract_features(self, batch, device: torch.device) -> dict[str, torch.Tensor]:
    """Per-graph fusion features as named tensors (symmetric to VGAE/GAT).

    - ``pos_stats`` [N, 3] — anomaly, pos_mean, pos_spread (discriminator-derived)
    - ``conf``      [N, 1] — 1 / (1 + anomaly)
    - ``z_stats``   [N, 4] — z_mean, z_std, z_max, z_min (latent-pooled)
    """
    from torch_geometric.utils import scatter

    edge_attr = getattr(batch, "edge_attr", None)
    z = self.encode(
        batch.x,
        batch.edge_index,
        edge_attr,
        batch.batch,
        batch.node_id,
    )
    summary = self.summarize(z, batch.batch)
    pos_scores = self.discriminate(z, summary, batch.batch)

    b = batch.batch
    pos_mean = scatter(pos_scores, b, dim=0, reduce="mean")
    pos_sq_mean = scatter(pos_scores.pow(2), b, dim=0, reduce="mean")
    pos_spread = (pos_sq_mean - pos_mean.pow(2)).clamp(min=0).sqrt()
    anomaly = 1.0 - pos_mean

    z_mean = scatter(z.mean(1), b, dim=0, reduce="mean")
    z_std = scatter(z.std(1), b, dim=0, reduce="mean")
    z_max = scatter(z.max(1).values, b, dim=0, reduce="max")
    z_min = scatter(z.min(1).values, b, dim=0, reduce="min")
    return {
        "pos_stats": torch.stack([anomaly, pos_mean, pos_spread], dim=1),
        "conf": (1.0 / (1.0 + anomaly)).unsqueeze(-1),
        "z_stats": torch.stack([z_mean, z_std, z_max, z_min], dim=1),
    }
on_test_setup
on_test_setup(datamodule, device) -> None

Fit SVDD center from training-normal graphs at test-start. Always re-fits (no idempotence flag — center isn't persisted in state_dict; see Cardinal jid 8772115 for ckpt-ordering rationale).

Source code in graphids/core/models/autoencoder/dgi.py
def on_test_setup(self, datamodule, device) -> None:
    """Fit SVDD center from training-normal graphs at test-start.
    Always re-fits (no idempotence flag — center isn't persisted in
    state_dict; see Cardinal jid 8772115 for ckpt-ordering rationale)."""
    self._calibrate_svdd_center(datamodule.train_eval_dataloader(), device)
score
score(batch) -> torch.Tensor

OCGIN score: L2 distance from SVDD centroid in pooled-latent space.

Source code in graphids/core/models/autoencoder/dgi.py
def score(self, batch) -> torch.Tensor:
    """OCGIN score: L2 distance from SVDD centroid in pooled-latent space."""
    if not torch.any(self.svdd_center):
        raise RuntimeError(
            "DGI.svdd_center is zero. Call "
            "on_test_setup(datamodule, device) before scoring "
            "(Trainer.test does this automatically for the test phase)."
        )
    pooled = self._pooled_latent(batch)
    return (pooled - self.svdd_center).pow(2).sum(dim=1)
summarize
summarize(z, batch)

Graph-level summary: sigmoid(mean_pool(z)).

Source code in graphids/core/models/autoencoder/dgi.py
def summarize(self, z, batch):
    """Graph-level summary: sigmoid(mean_pool(z))."""
    return torch.sigmoid(global_mean_pool(z, batch))

VGAE

VGAE(*, loss_fn: Module | None = None, conv_type: str = 'gatv2', hidden_dims: list[int] | None = None, latent_dim: int | None = None, heads: int = 4, embedding_dim: int = 32, dropout: float = 0.1, edge_dim: int = 11, proj_dim: int = 0, gradient_checkpointing: bool = True, compile_model: bool = False, batch_norm: bool = True, mlp_hidden: int | None = None, id_encoder_cfg: IdEncodingCfg | None = None, id_encoder_class_path: str = 'graphids.core.models.id_encoding.LookupIdEncoder', id_encoder_kwargs: dict | None = None, lr: float = 0.003, weight_decay: float = 0.0001, mask_rate: float = 0.15, score_recon_weight: float = 1.0, score_mahal_weight: float = 1.0, score_kl_weight: float = 1.0, scale: str = 'small', model_type: ModelType = 'vgae', dataset: str = '', seed: int = 42, num_ids: int = 0, in_channels: int = 0, num_classes: int = 2)

Bases: ScoreBasedDetectorMixin

Collapsed VGAE — arch + trainer-bridge in one nn.Module.

Loss selection is decoupled: loss_fn is an nn.Module built from experiment config.

Anomaly score = max-σ over four components (masked recon mean, masked recon max, TAM affinity, Rayleigh quotient). Calibration buffers are filled by :meth:on_test_setup at test-start.

Source code in graphids/core/models/autoencoder/vgae.py
def __init__(
    self,
    *,
    loss_fn: nn.Module | None = None,
    # --- architecture (latent_dim/hidden_dims=None → resolve from scale) ---
    conv_type: str = "gatv2",
    hidden_dims: list[int] | None = None,
    latent_dim: int | None = None,
    heads: int = 4,
    embedding_dim: int = 32,
    dropout: float = 0.1,
    edge_dim: int = 11,
    proj_dim: int = 0,
    gradient_checkpointing: bool = True,
    compile_model: bool = False,
    batch_norm: bool = True,
    mlp_hidden: int | None = None,
    id_encoder_cfg: IdEncodingCfg | None = None,
    id_encoder_class_path: str = "graphids.core.models.id_encoding.LookupIdEncoder",
    id_encoder_kwargs: dict | None = None,
    # --- training ---
    lr: float = 0.003,
    weight_decay: float = 0.0001,
    mask_rate: float = 0.15,
    # --- anomaly scoring (config-schema stability; calibrated max-σ
    # path doesn't read these). ---
    score_recon_weight: float = 1.0,
    score_mahal_weight: float = 1.0,
    score_kl_weight: float = 1.0,
    # --- identity / dynamic ---
    scale: str = "small",
    model_type: ModelType = "vgae",
    dataset: str = "",
    seed: int = 42,
    num_ids: int = 0,
    in_channels: int = 0,
    num_classes: int = 2,
):
    s = self._SCALES.get(scale, {})
    if latent_dim is None:
        latent_dim = s.get("latent_dim", 48)
    if hidden_dims is None:
        hidden_dims = s.get("hidden_dims")
    super().__init__()
    self._register_score_norm_buffers(latent_dim)
    self._init_post(locals())
encode
encode(x, edge_index, edge_attr=None, batch=None, node_id=None)

Returns (z, kl_per_node, mu).

Source code in graphids/core/models/autoencoder/vgae.py
def encode(self, x, edge_index, edge_attr=None, batch=None, node_id=None):
    """Returns ``(z, kl_per_node, mu)``."""
    x = self.input_encoder(x, node_id)
    for i, conv in enumerate(self.encoder_layers):
        bn = self.encoder_bns[i] if self.batch_norm else None
        x = conv_forward(
            conv,
            x,
            edge_index,
            edge_attr,
            bn=bn,
            batch=batch,
            dropout_p=self.dropout_rate,
            training=self.training,
            use_checkpointing=self.use_checkpointing,
        )
    mu = self.z_mean(x)
    logvar = self.z_logvar(x).clamp(-10, 10)
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std
    kl_per_node = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean(dim=-1)
    return z, kl_per_node, mu
extract_features
extract_features(batch, device: device) -> dict[str, torch.Tensor]

Per-graph fusion features as named tensors.

  • errors [N, 3] — recon, mahal, kl (anomaly evidence)
  • conf [N, 1] — 1 / (1 + recon)
  • z_stats [N, 4] — z_mean, z_std, z_max, z_min
  • spike [N, 1] — recon_max (per-graph max masked-node MSE)
  • affinity [N, 1] — TAM per-graph mean affinity
  • rq [N, 1] — Rayleigh quotient (input-space spectral smoothness)
Source code in graphids/core/models/autoencoder/vgae.py
def extract_features(self, batch, device: torch.device) -> dict[str, torch.Tensor]:
    """Per-graph fusion features as named tensors.

    - ``errors``   [N, 3] — recon, mahal, kl (anomaly evidence)
    - ``conf``     [N, 1] — 1 / (1 + recon)
    - ``z_stats``  [N, 4] — z_mean, z_std, z_max, z_min
    - ``spike``    [N, 1] — recon_max (per-graph max masked-node MSE)
    - ``affinity`` [N, 1] — TAM per-graph mean affinity
    - ``rq``       [N, 1] — Rayleigh quotient (input-space spectral smoothness)
    """
    from torch_geometric.utils import scatter

    recon, recon_max, affinity, rq, mahal, kl, z = self._score(batch)
    b = batch.batch
    z_mean = scatter(z.mean(1), b, dim=0, reduce="mean")
    z_std = scatter(z.std(1), b, dim=0, reduce="mean")
    z_max = scatter(z.max(1).values, b, dim=0, reduce="max")
    z_min = scatter(z.min(1).values, b, dim=0, reduce="min")
    return {
        "errors": torch.stack([recon, mahal, kl], dim=1),
        "conf": (1.0 / (1.0 + recon)).unsqueeze(-1),
        "z_stats": torch.stack([z_mean, z_std, z_max, z_min], dim=1),
        "spike": recon_max.unsqueeze(-1),
        "affinity": affinity.unsqueeze(-1),
        "rq": rq.unsqueeze(-1),
    }
on_test_setup
on_test_setup(datamodule, device) -> None

Fit z-norm calibration buffers from benign val if not already populated. Idempotent: skips if a calibrated ckpt was reloaded.

Source code in graphids/core/models/autoencoder/vgae.py
def on_test_setup(self, datamodule, device) -> None:
    """Fit z-norm calibration buffers from benign val if not already
    populated. Idempotent: skips if a calibrated ckpt was reloaded."""
    if not bool(self.score_norm_fitted):
        self._fit_score_norm(datamodule.val_dataloader(), device)
score
score(batch) -> torch.Tensor

Per-graph anomaly score: max-σ over (recon, recon_max, TAM affinity, RQ) in the calibrated z-norm space.

Source code in graphids/core/models/autoencoder/vgae.py
def score(self, batch) -> torch.Tensor:
    """Per-graph anomaly score: max-σ over (recon, recon_max, TAM affinity, RQ)
    in the calibrated z-norm space."""
    if not bool(self.score_norm_fitted):
        raise RuntimeError(
            "VGAE scoring requires on_test_setup() to have run. "
            "If loading an old ckpt without masker.mask_token, retrain under "
            "the mask-recon code or use the legacy scoring path."
        )
    recon, recon_max, affinity, rq, _mahal, _kl, _z = self._score(batch)
    eps = 1e-6
    zs = []
    for name, v in (
        ("recon", recon),
        ("recon_max", recon_max),
        ("affinity", affinity),
        ("rq", rq),
    ):
        mean = getattr(self, f"score_{name}_mean")
        std = getattr(self, f"score_{name}_std")
        zs.append((v - mean) / (std + eps))
    return torch.stack(zs, dim=0).amax(dim=0)

dgi

Deep Graph Infomax — collapsed arch + trainer-bridge.

Maximizes mutual information between node embeddings and a graph-level summary via a bilinear discriminator. Uses the same encoder backbone as VGAE (InputEncoder + conv stack) for fair ablation comparison.

Anomaly scoring at test time: OCGIN-style L2 distance between the pooled node embedding of a query graph and the centroid of training-normal pooled embeddings (Zhao & Akoglu 2021, arxiv:2103.04494).

Reference: Veličković et al., "Deep Graph Infomax" (ICLR 2019).

DGI
DGI(conv_type: str = 'gatv2', hidden_dims: list[int] | None = None, latent_dim: int | None = None, heads: int | None = None, embedding_dim: int | None = None, dropout: float = 0.15, edge_dim: int = 11, proj_dim: int = 0, gradient_checkpointing: bool = True, compile_model: bool = False, batch_norm: bool = True, id_encoder_cfg: IdEncodingCfg | None = None, id_encoder_class_path: str = 'graphids.core.models.id_encoding.LookupIdEncoder', id_encoder_kwargs: dict | None = None, lr: float = 0.001, weight_decay: float = 0.0001, scale: str = 'small', model_type: ModelType = 'dgi', dataset: str = '', seed: int = 42, num_ids: int = 0, in_channels: int = 0, num_classes: int = 2)

Bases: ScoreBasedDetectorMixin

Collapsed DGI — arch + trainer-bridge in one nn.Module.

No loss_fn kwarg: the contrastive MI loss is intrinsic to the architecture (built into the discriminator).

Source code in graphids/core/models/autoencoder/dgi.py
def __init__(
    self,
    # --- architecture (latent_dim/embedding_dim/heads=None → scale) ---
    conv_type: str = "gatv2",
    hidden_dims: list[int] | None = None,
    latent_dim: int | None = None,
    heads: int | None = None,
    embedding_dim: int | None = None,
    dropout: float = 0.15,
    edge_dim: int = 11,
    proj_dim: int = 0,
    gradient_checkpointing: bool = True,
    compile_model: bool = False,
    batch_norm: bool = True,
    id_encoder_cfg: IdEncodingCfg | None = None,
    id_encoder_class_path: str = "graphids.core.models.id_encoding.LookupIdEncoder",
    id_encoder_kwargs: dict | None = None,
    # --- training ---
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    # --- identity / dynamic ---
    scale: str = "small",
    model_type: ModelType = "dgi",
    dataset: str = "",
    seed: int = 42,
    num_ids: int = 0,
    in_channels: int = 0,
    num_classes: int = 2,
):
    s = self._SCALES.get(scale, {})
    if latent_dim is None:
        latent_dim = s.get("latent_dim", 48)
    if embedding_dim is None:
        embedding_dim = s.get("embedding_dim", 32)
    if heads is None:
        heads = s.get("heads", 4)
    super().__init__()
    # OCGIN scoring head: centroid of training-normal pooled embeddings.
    # Re-fit at test-start by ``Trainer.test`` via ``on_test_setup`` —
    # the centroid is a deterministic statistic of (encoder weights,
    # benign train data). Zero init means an uncalibrated forward pass
    # raises in ``score`` rather than returning bogus scores.
    self.register_buffer("svdd_center", torch.zeros(latent_dim))
    self._init_post(locals())
dgi_loss
dgi_loss(pos_z, neg_z, summary, batch_idx)

Contrastive MI loss: maximize real node–summary agreement.

Source code in graphids/core/models/autoencoder/dgi.py
def dgi_loss(self, pos_z, neg_z, summary, batch_idx):
    """Contrastive MI loss: maximize real node–summary agreement."""
    EPS = 1e-6
    pos_score = self.discriminate(pos_z, summary, batch_idx)
    neg_score = self.discriminate(neg_z, summary, batch_idx)
    return -torch.log(pos_score + EPS).mean() - torch.log(1 - neg_score + EPS).mean()
discriminate
discriminate(z, summary, batch)

Bilinear scoring: sigmoid(z^T W s) per node.

Source code in graphids/core/models/autoencoder/dgi.py
def discriminate(self, z, summary, batch):
    """Bilinear scoring: sigmoid(z^T W s) per node."""
    s = summary[batch]
    return torch.sigmoid((z @ self.discriminator_weight * s).sum(dim=1))
encode
encode(x, edge_index, edge_attr=None, batch=None, node_id=None)

Encode nodes to latent embeddings (same contract as VGAE minus KL).

Source code in graphids/core/models/autoencoder/dgi.py
def encode(self, x, edge_index, edge_attr=None, batch=None, node_id=None):
    """Encode nodes to latent embeddings (same contract as VGAE minus KL)."""
    x = self.input_encoder(x, node_id)
    ea = edge_attr if self._uses_edge_attr else None

    for i, conv in enumerate(self.encoder_layers):
        bn = self.encoder_bns[i] if self.batch_norm else None
        x = conv_forward(
            conv,
            x,
            edge_index,
            ea,
            bn=bn,
            batch=batch,
            dropout_p=self.dropout_rate,
            training=self.training,
            use_checkpointing=self.use_checkpointing,
        )
    return self.z_proj(x)
extract_features
extract_features(batch, device: device) -> dict[str, torch.Tensor]

Per-graph fusion features as named tensors (symmetric to VGAE/GAT).

  • pos_stats [N, 3] — anomaly, pos_mean, pos_spread (discriminator-derived)
  • conf [N, 1] — 1 / (1 + anomaly)
  • z_stats [N, 4] — z_mean, z_std, z_max, z_min (latent-pooled)
Source code in graphids/core/models/autoencoder/dgi.py
def extract_features(self, batch, device: torch.device) -> dict[str, torch.Tensor]:
    """Per-graph fusion features as named tensors (symmetric to VGAE/GAT).

    - ``pos_stats`` [N, 3] — anomaly, pos_mean, pos_spread (discriminator-derived)
    - ``conf``      [N, 1] — 1 / (1 + anomaly)
    - ``z_stats``   [N, 4] — z_mean, z_std, z_max, z_min (latent-pooled)
    """
    from torch_geometric.utils import scatter

    edge_attr = getattr(batch, "edge_attr", None)
    z = self.encode(
        batch.x,
        batch.edge_index,
        edge_attr,
        batch.batch,
        batch.node_id,
    )
    summary = self.summarize(z, batch.batch)
    pos_scores = self.discriminate(z, summary, batch.batch)

    b = batch.batch
    pos_mean = scatter(pos_scores, b, dim=0, reduce="mean")
    pos_sq_mean = scatter(pos_scores.pow(2), b, dim=0, reduce="mean")
    pos_spread = (pos_sq_mean - pos_mean.pow(2)).clamp(min=0).sqrt()
    anomaly = 1.0 - pos_mean

    z_mean = scatter(z.mean(1), b, dim=0, reduce="mean")
    z_std = scatter(z.std(1), b, dim=0, reduce="mean")
    z_max = scatter(z.max(1).values, b, dim=0, reduce="max")
    z_min = scatter(z.min(1).values, b, dim=0, reduce="min")
    return {
        "pos_stats": torch.stack([anomaly, pos_mean, pos_spread], dim=1),
        "conf": (1.0 / (1.0 + anomaly)).unsqueeze(-1),
        "z_stats": torch.stack([z_mean, z_std, z_max, z_min], dim=1),
    }
on_test_setup
on_test_setup(datamodule, device) -> None

Fit SVDD center from training-normal graphs at test-start. Always re-fits (no idempotence flag — center isn't persisted in state_dict; see Cardinal jid 8772115 for ckpt-ordering rationale).

Source code in graphids/core/models/autoencoder/dgi.py
def on_test_setup(self, datamodule, device) -> None:
    """Fit SVDD center from training-normal graphs at test-start.
    Always re-fits (no idempotence flag — center isn't persisted in
    state_dict; see Cardinal jid 8772115 for ckpt-ordering rationale)."""
    self._calibrate_svdd_center(datamodule.train_eval_dataloader(), device)
score
score(batch) -> torch.Tensor

OCGIN score: L2 distance from SVDD centroid in pooled-latent space.

Source code in graphids/core/models/autoencoder/dgi.py
def score(self, batch) -> torch.Tensor:
    """OCGIN score: L2 distance from SVDD centroid in pooled-latent space."""
    if not torch.any(self.svdd_center):
        raise RuntimeError(
            "DGI.svdd_center is zero. Call "
            "on_test_setup(datamodule, device) before scoring "
            "(Trainer.test does this automatically for the test phase)."
        )
    pooled = self._pooled_latent(batch)
    return (pooled - self.svdd_center).pow(2).sum(dim=1)
summarize
summarize(z, batch)

Graph-level summary: sigmoid(mean_pool(z)).

Source code in graphids/core/models/autoencoder/dgi.py
def summarize(self, z, batch):
    """Graph-level summary: sigmoid(mean_pool(z))."""
    return torch.sigmoid(global_mean_pool(z, batch))

vgae

Variational graph autoencoder — collapsed arch + trainer-bridge.

The single :class:VGAE class is both the architecture (encoder / decoder / aux heads / mask token / score-norm calibration buffers) and the trainer-bridge (training_step/validation_step/test_step, score primitives, fusion-feature extractor). No wrapper module — see ~/plans/graphids-collapse-model-modules.md Phase 1.

Encoder maps node features to q(z|x) = N(mu, σ²); decoder reconstructs continuous features from the reparameterized z. Mask-recon training (15% random node masking) commits the encoder to "predict v from neighborhood" rather than "echo v back".

VGAE
VGAE(*, loss_fn: Module | None = None, conv_type: str = 'gatv2', hidden_dims: list[int] | None = None, latent_dim: int | None = None, heads: int = 4, embedding_dim: int = 32, dropout: float = 0.1, edge_dim: int = 11, proj_dim: int = 0, gradient_checkpointing: bool = True, compile_model: bool = False, batch_norm: bool = True, mlp_hidden: int | None = None, id_encoder_cfg: IdEncodingCfg | None = None, id_encoder_class_path: str = 'graphids.core.models.id_encoding.LookupIdEncoder', id_encoder_kwargs: dict | None = None, lr: float = 0.003, weight_decay: float = 0.0001, mask_rate: float = 0.15, score_recon_weight: float = 1.0, score_mahal_weight: float = 1.0, score_kl_weight: float = 1.0, scale: str = 'small', model_type: ModelType = 'vgae', dataset: str = '', seed: int = 42, num_ids: int = 0, in_channels: int = 0, num_classes: int = 2)

Bases: ScoreBasedDetectorMixin

Collapsed VGAE — arch + trainer-bridge in one nn.Module.

Loss selection is decoupled: loss_fn is an nn.Module built from experiment config.

Anomaly score = max-σ over four components (masked recon mean, masked recon max, TAM affinity, Rayleigh quotient). Calibration buffers are filled by :meth:on_test_setup at test-start.

Source code in graphids/core/models/autoencoder/vgae.py
def __init__(
    self,
    *,
    loss_fn: nn.Module | None = None,
    # --- architecture (latent_dim/hidden_dims=None → resolve from scale) ---
    conv_type: str = "gatv2",
    hidden_dims: list[int] | None = None,
    latent_dim: int | None = None,
    heads: int = 4,
    embedding_dim: int = 32,
    dropout: float = 0.1,
    edge_dim: int = 11,
    proj_dim: int = 0,
    gradient_checkpointing: bool = True,
    compile_model: bool = False,
    batch_norm: bool = True,
    mlp_hidden: int | None = None,
    id_encoder_cfg: IdEncodingCfg | None = None,
    id_encoder_class_path: str = "graphids.core.models.id_encoding.LookupIdEncoder",
    id_encoder_kwargs: dict | None = None,
    # --- training ---
    lr: float = 0.003,
    weight_decay: float = 0.0001,
    mask_rate: float = 0.15,
    # --- anomaly scoring (config-schema stability; calibrated max-σ
    # path doesn't read these). ---
    score_recon_weight: float = 1.0,
    score_mahal_weight: float = 1.0,
    score_kl_weight: float = 1.0,
    # --- identity / dynamic ---
    scale: str = "small",
    model_type: ModelType = "vgae",
    dataset: str = "",
    seed: int = 42,
    num_ids: int = 0,
    in_channels: int = 0,
    num_classes: int = 2,
):
    s = self._SCALES.get(scale, {})
    if latent_dim is None:
        latent_dim = s.get("latent_dim", 48)
    if hidden_dims is None:
        hidden_dims = s.get("hidden_dims")
    super().__init__()
    self._register_score_norm_buffers(latent_dim)
    self._init_post(locals())
encode
encode(x, edge_index, edge_attr=None, batch=None, node_id=None)

Returns (z, kl_per_node, mu).

Source code in graphids/core/models/autoencoder/vgae.py
def encode(self, x, edge_index, edge_attr=None, batch=None, node_id=None):
    """Returns ``(z, kl_per_node, mu)``."""
    x = self.input_encoder(x, node_id)
    for i, conv in enumerate(self.encoder_layers):
        bn = self.encoder_bns[i] if self.batch_norm else None
        x = conv_forward(
            conv,
            x,
            edge_index,
            edge_attr,
            bn=bn,
            batch=batch,
            dropout_p=self.dropout_rate,
            training=self.training,
            use_checkpointing=self.use_checkpointing,
        )
    mu = self.z_mean(x)
    logvar = self.z_logvar(x).clamp(-10, 10)
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std
    kl_per_node = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean(dim=-1)
    return z, kl_per_node, mu
extract_features
extract_features(batch, device: device) -> dict[str, torch.Tensor]

Per-graph fusion features as named tensors.

  • errors [N, 3] — recon, mahal, kl (anomaly evidence)
  • conf [N, 1] — 1 / (1 + recon)
  • z_stats [N, 4] — z_mean, z_std, z_max, z_min
  • spike [N, 1] — recon_max (per-graph max masked-node MSE)
  • affinity [N, 1] — TAM per-graph mean affinity
  • rq [N, 1] — Rayleigh quotient (input-space spectral smoothness)
Source code in graphids/core/models/autoencoder/vgae.py
def extract_features(self, batch, device: torch.device) -> dict[str, torch.Tensor]:
    """Per-graph fusion features as named tensors.

    - ``errors``   [N, 3] — recon, mahal, kl (anomaly evidence)
    - ``conf``     [N, 1] — 1 / (1 + recon)
    - ``z_stats``  [N, 4] — z_mean, z_std, z_max, z_min
    - ``spike``    [N, 1] — recon_max (per-graph max masked-node MSE)
    - ``affinity`` [N, 1] — TAM per-graph mean affinity
    - ``rq``       [N, 1] — Rayleigh quotient (input-space spectral smoothness)
    """
    from torch_geometric.utils import scatter

    recon, recon_max, affinity, rq, mahal, kl, z = self._score(batch)
    b = batch.batch
    z_mean = scatter(z.mean(1), b, dim=0, reduce="mean")
    z_std = scatter(z.std(1), b, dim=0, reduce="mean")
    z_max = scatter(z.max(1).values, b, dim=0, reduce="max")
    z_min = scatter(z.min(1).values, b, dim=0, reduce="min")
    return {
        "errors": torch.stack([recon, mahal, kl], dim=1),
        "conf": (1.0 / (1.0 + recon)).unsqueeze(-1),
        "z_stats": torch.stack([z_mean, z_std, z_max, z_min], dim=1),
        "spike": recon_max.unsqueeze(-1),
        "affinity": affinity.unsqueeze(-1),
        "rq": rq.unsqueeze(-1),
    }
on_test_setup
on_test_setup(datamodule, device) -> None

Fit z-norm calibration buffers from benign val if not already populated. Idempotent: skips if a calibrated ckpt was reloaded.

Source code in graphids/core/models/autoencoder/vgae.py
def on_test_setup(self, datamodule, device) -> None:
    """Fit z-norm calibration buffers from benign val if not already
    populated. Idempotent: skips if a calibrated ckpt was reloaded."""
    if not bool(self.score_norm_fitted):
        self._fit_score_norm(datamodule.val_dataloader(), device)
score
score(batch) -> torch.Tensor

Per-graph anomaly score: max-σ over (recon, recon_max, TAM affinity, RQ) in the calibrated z-norm space.

Source code in graphids/core/models/autoencoder/vgae.py
def score(self, batch) -> torch.Tensor:
    """Per-graph anomaly score: max-σ over (recon, recon_max, TAM affinity, RQ)
    in the calibrated z-norm space."""
    if not bool(self.score_norm_fitted):
        raise RuntimeError(
            "VGAE scoring requires on_test_setup() to have run. "
            "If loading an old ckpt without masker.mask_token, retrain under "
            "the mask-recon code or use the legacy scoring path."
        )
    recon, recon_max, affinity, rq, _mahal, _kl, _z = self._score(batch)
    eps = 1e-6
    zs = []
    for name, v in (
        ("recon", recon),
        ("recon_max", recon_max),
        ("affinity", affinity),
        ("rq", rq),
    ):
        mean = getattr(self, f"score_{name}_mean")
        std = getattr(self, f"score_{name}_std")
        zs.append((v - mean) / (std + eps))
    return torch.stack(zs, dim=0).amax(dim=0)

base

Shared model infrastructure — base classes, utilities, contracts.

Graph family: - GraphModuleBase — base for VGAE, GAT, DGI - try_compile — safe torch.compile with conv-type gating - eval_mode — context manager that restores training state

Shared: - _ModelBase(pl.LightningModule) — mixin shared by GraphModuleBase + FusionModuleBase. Lightning provides self.device, self.log, self.log_dict, self.hparams, self.trainer, etc. - safe_load_checkpoint — checkpoint loading via class_path registry - strip_orig_mod_prefix — drop _orig_mod. keys from state_dicts produced under torch.compile

GraphModuleBase

Bases: _ModelBase

Shared base for VGAE, GAT, DGI — lazy setup, threshold metrics.

Subclasses must implement _build() using self.hparams.

compute_budget
compute_budget(train_dataset, dataset_name: str, min_steps: int | None = None) -> Any

Probe-once VRAM budget cached on the model.

Source code in graphids/core/models/base.py
def compute_budget(self, train_dataset, dataset_name: str, min_steps: int | None = None) -> Any:
    """Probe-once VRAM budget cached on the model."""
    if self._budget_cache is None:
        from graphids.core.budget import node_budget

        self._budget_cache = node_budget(
            dataset_name, model=self, train_dataset=train_dataset, min_steps=min_steps
        )
    return self._budget_cache
configure_optimizers
configure_optimizers()

Adam over all params using self.hparams.lr / weight_decay.

Source code in graphids/core/models/base.py
def configure_optimizers(self):
    """Adam over all params using ``self.hparams.lr`` / ``weight_decay``."""
    lr = getattr(self.hparams, "lr", 1e-3)
    wd = getattr(self.hparams, "weight_decay", 0.0)
    return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)
on_validation_epoch_end
on_validation_epoch_end() -> None

Override in subclasses to compute epoch-level val metrics.

Source code in graphids/core/models/base.py
def on_validation_epoch_end(self) -> None:
    """Override in subclasses to compute epoch-level val metrics."""
prepare_from_datamodule
prepare_from_datamodule(dm) -> None

Lazy-build with DM-supplied sizes, then capture test-set names.

Source code in graphids/core/models/base.py
def prepare_from_datamodule(self, dm) -> None:
    """Lazy-build with DM-supplied sizes, then capture test-set names."""
    already_built = getattr(self, "_built", False) or (
        getattr(self, "model", "_sentinel") not in (None, "_sentinel")
    )
    if not already_built:
        # Mirror onto self and into hparams so _build() sees DM values.
        for k in ("num_ids", "in_channels", "num_classes"):
            v = getattr(dm, k)
            setattr(self, k, v)
            self.hparams[k] = v
        self._build()
        self._built = True
    super().prepare_from_datamodule(dm)

ScoreBasedDetectorMixin

ScoreBasedDetectorMixin()

Bases: GraphModuleBase

Mix-in for graph models that emit per-graph anomaly scores.

Source code in graphids/core/models/base.py
def __init__(self) -> None:
    super().__init__()
    self._init_threshold_metrics()
    self.test_metrics = binary_test_metrics()
score
score(batch) -> torch.Tensor

Per-graph anomaly score, higher = more anomalous.

Source code in graphids/core/models/base.py
def score(self, batch) -> torch.Tensor:
    """Per-graph anomaly score, higher = more anomalous."""
    raise NotImplementedError

eval_mode

eval_mode(model)

Context manager: set model.eval(), restore original training state on exit.

Source code in graphids/core/models/base.py
@contextlib.contextmanager
def eval_mode(model):
    """Context manager: set model.eval(), restore original training state on exit."""
    was_training = model.training
    model.eval()
    try:
        yield
    finally:
        model.train(was_training)

safe_load_checkpoint

safe_load_checkpoint(model_type: str, ckpt_path, *, map_location='cpu')

Load a checkpoint, dispatching on the class_path saved at write time.

model_type is used only to know which loss_fn to rebuild for VGAE/GAT (loss is excluded from hyperparameters). Class lookup uses the self-describing class_path injected by _ModelBase.on_save_checkpoint.

Source code in graphids/core/models/base.py
def safe_load_checkpoint(model_type: str, ckpt_path, *, map_location="cpu"):
    """Load a checkpoint, dispatching on the ``class_path`` saved at write time.

    ``model_type`` is used only to know which loss_fn to rebuild for VGAE/GAT
    (loss is excluded from hyperparameters). Class lookup uses the
    self-describing ``class_path`` injected by ``_ModelBase.on_save_checkpoint``.
    """
    ckpt_path = Path(ckpt_path)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    from graphids._fs import atomic_load

    ckpt = atomic_load(ckpt_path, map_location=map_location, weights_only=True)
    dotted = ckpt.get("class_path")
    if not dotted:
        raise KeyError(
            f"Checkpoint {ckpt_path} missing 'class_path'. Re-train under the "
            "current LightningModule + on_save_checkpoint contract."
        )
    module_path, cls_name = dotted.rsplit(".", 1)
    cls = getattr(importlib.import_module(module_path), cls_name)

    # Lightning serializes self.hparams as an AttributeDict; coerce to plain
    # dict so the **-spread into init_kwargs is well-defined.
    hp = dict(ckpt.get("hyper_parameters", {}))

    # Per-class hook for rebuilding excluded init kwargs (e.g. ``loss_fn``,
    # which can't be pickled into ``hyper_parameters``). Each class that needs
    # something rebuilt declares ``_rebuild_excluded_kwargs(hp) -> dict`` as a
    # classmethod or staticmethod. Default: nothing extra.
    rebuild = getattr(cls, "_rebuild_excluded_kwargs", None)
    extra_kwargs: dict = rebuild(hp) if rebuild is not None else {}

    init_kwargs = {**hp, **extra_kwargs}
    module = cls(**init_kwargs)
    state_dict = strip_orig_mod_prefix(ckpt["state_dict"])
    # Strip training-only submodules (e.g. loss_fn) that aren't present on the
    # inference model but may have been saved into the checkpoint.
    state_dict = {k: v for k, v in state_dict.items() if not k.startswith("loss_fn.")}
    module.load_state_dict(state_dict)
    module.to(map_location)

    if hasattr(module, "on_load_checkpoint"):
        module.on_load_checkpoint(ckpt)

    return module

strip_orig_mod_prefix

strip_orig_mod_prefix(state: dict[str, Any]) -> dict[str, Any]

Drop _orig_mod. prefix injected by torch.compile's OptimizedModule.

_orig_mod. can appear mid-key (e.g. model._orig_mod.encoder.weight) when compile wraps an inner submodule; replace handles every position.

Source code in graphids/core/models/base.py
def strip_orig_mod_prefix(state: dict[str, Any]) -> dict[str, Any]:
    """Drop ``_orig_mod.`` prefix injected by ``torch.compile``'s OptimizedModule.

    ``_orig_mod.`` can appear mid-key (e.g. ``model._orig_mod.encoder.weight``)
    when compile wraps an inner submodule; ``replace`` handles every position.
    """
    return {k.replace("_orig_mod.", ""): v for k, v in state.items()}

try_compile

try_compile(model: Module, *, conv_type: str | None = None, **kwargs) -> nn.Module

Attempt torch.compile; fall back to eager on inductor failure.

Skips compile entirely for conv types that use to_dense_batch() (e.g. GPS) — the Tensor.item() call causes graph breaks, repeated recompilation, and eventual CUDA illegal memory access.

Source code in graphids/core/models/base.py
def try_compile(model: nn.Module, *, conv_type: str | None = None, **kwargs) -> nn.Module:
    """Attempt ``torch.compile``; fall back to eager on inductor failure.

    Skips compile entirely for conv types that use ``to_dense_batch()``
    (e.g. GPS) — the ``Tensor.item()`` call causes graph breaks, repeated
    recompilation, and eventual CUDA illegal memory access.
    """
    _INCOMPATIBLE_CONV_TYPES = frozenset({"gps"})
    if conv_type in _INCOMPATIBLE_CONV_TYPES:
        _log.warning(
            "compile_skipped",
            conv_type=conv_type,
            reason="to_dense_batch Tensor.item() causes graph breaks and CUDA crash",
        )
        return model
    # torch.compile is lazy — errors surface at first forward, not here.
    # A broad except at wrap time masked zero real failures and swallowed
    # config bugs. Let exceptions propagate.
    return torch.compile(model, **kwargs)

fusion

Fusion policy model family exports.

BanditFusionModule

BanditFusionModule(state_dim: int = 18, alpha_steps: int = 21, ucb_alpha: float = 1.0, lambda_reg: float = 1.0, hidden_dim: int = 128, num_layers: int = 3, backbone_lr: float = 0.001, backbone_retrain_freq: int = 50, backbone_epochs: int = 5, buffer_size: int = 100000, batch_size: int = 128, decision_threshold: float = 0.5, reward_kwargs: dict | None = None)

Bases: RLFusionBase

Neural-LinUCB: backbone + per-arm ridge with Sherman-Morrison online updates, and a frequency-gated backbone refit.

Source code in graphids/core/models/fusion/bandit.py
def __init__(
    self,
    state_dim: int = 18,
    alpha_steps: int = 21,
    ucb_alpha: float = 1.0,
    lambda_reg: float = 1.0,
    hidden_dim: int = 128,
    num_layers: int = 3,
    backbone_lr: float = 1e-3,
    backbone_retrain_freq: int = 50,
    backbone_epochs: int = 5,
    buffer_size: int = 100_000,
    batch_size: int = 128,
    decision_threshold: float = 0.5,
    reward_kwargs: dict | None = None,
):
    super().__init__(
        buffer_size=buffer_size,
        batch_size=batch_size,
        state_dim=state_dim,
        alpha_steps=alpha_steps,
        decision_threshold=decision_threshold,
        reward_kwargs=reward_kwargs,
    )
    self._store_init_kwargs(locals())

    self.backbone = _Backbone(state_dim, hidden_dim, num_layers)
    d = self.backbone.out_dim

    self.register_buffer(
        "A_inv", torch.eye(d).unsqueeze(0).repeat(alpha_steps, 1, 1) / lambda_reg
    )
    self.register_buffer("b", torch.zeros(alpha_steps, d))
    self.register_buffer("theta", torch.zeros(alpha_steps, d))

    # gpu_training_steps drives RLFusionBase._learn_step's inner loop.
    self.gpu_training_steps = backbone_epochs

    self._optimizer = optim.AdamW(self.backbone.parameters(), lr=backbone_lr)
    self._episode = 0
    self._ucb_widths: list[float] = []

MLPFusionModule

MLPFusionModule(state_dim: int = 18, hidden_dims: tuple[int, ...] = (64, 32), lr: float = 0.001, decision_threshold: float = 0.5)

Bases: FusionModuleBase

Same features as DQN, trained with BCE instead of RL.

Source code in graphids/core/models/fusion/mlp.py
def __init__(
    self,
    state_dim: int = 18,
    hidden_dims: tuple[int, ...] = (64, 32),
    lr: float = 1e-3,
    decision_threshold: float = 0.5,
):
    super().__init__(state_dim=state_dim, decision_threshold=decision_threshold)
    self._store_init_kwargs(locals())

    layers: list[nn.Module] = []
    in_dim = state_dim
    for h in hidden_dims:
        layers.extend([nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(0.2)])
        in_dim = h
    layers.append(nn.Linear(in_dim, 1))
    self.model = nn.Sequential(*layers)

MoEFusionModule

MoEFusionModule(state_dim: int = 18, num_experts: int = 3, expert_hidden: tuple[int, ...] = (64, 32), gate_hidden: tuple[int, ...] = (32,), lr: float = 0.001, decision_threshold: float = 0.5, aux_weight: float = 0.01)

Bases: FusionModuleBase

Dense soft-gated mixture of K identical experts over the flat feature vector.

Specialization is emergent: experts share architecture and input; only the gate's softmax over per-sample logits selects how their outputs combine. If gate entropy stays at log(K) (uniform) on a fitted run, the features carry no routable signal — see diagnostics + escalation table in the design doc.

Source code in graphids/core/models/fusion/moe.py
def __init__(
    self,
    state_dim: int = 18,
    num_experts: int = 3,
    expert_hidden: tuple[int, ...] = (64, 32),
    gate_hidden: tuple[int, ...] = (32,),
    lr: float = 1e-3,
    decision_threshold: float = 0.5,
    aux_weight: float = 0.01,
):
    super().__init__(state_dim=state_dim, decision_threshold=decision_threshold)
    self._store_init_kwargs(locals())

    self.experts = nn.ModuleList(
        [_build_head(state_dim, expert_hidden, out_dim=1) for _ in range(num_experts)]
    )
    self.gate = _build_head(state_dim, gate_hidden, out_dim=num_experts)

    # Last-batch routing diagnostics; set by forward_scores, read by
    # training_step / validation_step. Never read across batches.
    self._last_gate_weights: torch.Tensor | None = None
    self._last_expert_scores: torch.Tensor | None = None

RLFusionBase

RLFusionBase(*, buffer_size: int, batch_size: int, **kw)

Bases: FusionModuleBase

torchrl replay buffer + unified act/learn flow.

Subclass implements: - _compute_loss(sample) -> Tensor — scalar loss from a buffer sample. DQN delegates to a torchrl DQNLoss; Bandit computes MSE inline. The optimizer scope (self._optimizer) is whatever params the subclass actually trains — it does NOT have to match a single loss_module.

Subclass sets in __init__: - self._optimizer — optimizer over the trainable params.

Hooks: - _score_actions(td, training) — write td['action']. - _after_act(actions, obs, rewards) — online update. - _should_learn() — gate the optim step (default: every step). - _after_optim_step() — post-step (DQN target sync). - _after_learn() — post-batch (Bandit ridge reset). - _extra_metrics() — extra log fields.

Source code in graphids/core/models/fusion/base.py
def __init__(self, *, buffer_size: int, batch_size: int, **kw):
    super().__init__(batch_size=batch_size, **kw)
    self._rb = TensorDictReplayBuffer(
        storage=LazyTensorStorage(max_size=buffer_size, device=torch.device("cpu")),
        sampler=RandomSampler(),
        batch_size=batch_size,
    )
select_action_batch
select_action_batch(features_td: TensorDict, training: bool = True)

Returns (actions[N], alphas[N], normalized_features_td[N]).

Source code in graphids/core/models/fusion/base.py
def select_action_batch(self, features_td: TensorDict, training: bool = True):
    """Returns ``(actions[N], alphas[N], normalized_features_td[N])``."""
    td_norm = self.reward_calc.normalize(features_td).to(self.device)
    obs = flatten_features(td_norm)
    inner = TensorDict({"observation": obs}, batch_size=[obs.size(0)], device=self.device)
    with torch.no_grad():
        self._score_actions(inner, training=training)
    actions = inner["action"].detach().cpu()
    return actions, self.alpha_values[actions], td_norm.cpu()

WeightedAvgModule

WeightedAvgModule(lr: float = 0.01, decision_threshold: float = 0.5, state_dim: int = 18)

Bases: FusionModuleBase

alpha = sigmoid(w); score = (1-alpha)·vgae_conf + alpha·gat_conf.

Source code in graphids/core/models/fusion/weighted_avg.py
def __init__(self, lr: float = 1e-2, decision_threshold: float = 0.5, state_dim: int = 18):
    super().__init__(state_dim=state_dim, decision_threshold=decision_threshold)
    self._store_init_kwargs(locals())
    self.weight = nn.Parameter(torch.zeros(1))

flatten_features

flatten_features(td: TensorDict) -> torch.Tensor

Concatenate every leaf tensor along the last dim. Stable order: sorted nested-key path so the same TD always yields the same layout.

Only tuple-keyed (model-namespaced) leaves are concatenated. Top-level str leaves are reserved for metadata (labels, attack_type); they pass through the TD untouched and reach test_step via td.get(...) instead of being treated as features.

Source code in graphids/core/models/fusion/base.py
def flatten_features(td: TensorDict) -> torch.Tensor:
    """Concatenate every leaf tensor along the last dim. Stable order:
    sorted nested-key path so the same TD always yields the same layout.

    Only tuple-keyed (model-namespaced) leaves are concatenated. Top-level
    str leaves are reserved for metadata (``labels``, ``attack_type``);
    they pass through the TD untouched and reach ``test_step`` via
    ``td.get(...)`` instead of being treated as features.
    """
    leaves = sorted(
        k for k in td.keys(include_nested=True, leaves_only=True) if isinstance(k, tuple)
    )
    return torch.cat([td[k] for k in leaves], dim=-1)

bandit

Neural-LinUCB contextual bandit (Xu et al., ICLR 2022).

Backbone is gradient-trained (MSE between θ_a·z(s) and stored reward); the per-arm θ is updated analytically via Sherman-Morrison ridge. No torchrl LossModule — would be a vestigial wrapper here since θ is not gradient-trained and there's no target net.

BanditFusionModule
BanditFusionModule(state_dim: int = 18, alpha_steps: int = 21, ucb_alpha: float = 1.0, lambda_reg: float = 1.0, hidden_dim: int = 128, num_layers: int = 3, backbone_lr: float = 0.001, backbone_retrain_freq: int = 50, backbone_epochs: int = 5, buffer_size: int = 100000, batch_size: int = 128, decision_threshold: float = 0.5, reward_kwargs: dict | None = None)

Bases: RLFusionBase

Neural-LinUCB: backbone + per-arm ridge with Sherman-Morrison online updates, and a frequency-gated backbone refit.

Source code in graphids/core/models/fusion/bandit.py
def __init__(
    self,
    state_dim: int = 18,
    alpha_steps: int = 21,
    ucb_alpha: float = 1.0,
    lambda_reg: float = 1.0,
    hidden_dim: int = 128,
    num_layers: int = 3,
    backbone_lr: float = 1e-3,
    backbone_retrain_freq: int = 50,
    backbone_epochs: int = 5,
    buffer_size: int = 100_000,
    batch_size: int = 128,
    decision_threshold: float = 0.5,
    reward_kwargs: dict | None = None,
):
    super().__init__(
        buffer_size=buffer_size,
        batch_size=batch_size,
        state_dim=state_dim,
        alpha_steps=alpha_steps,
        decision_threshold=decision_threshold,
        reward_kwargs=reward_kwargs,
    )
    self._store_init_kwargs(locals())

    self.backbone = _Backbone(state_dim, hidden_dim, num_layers)
    d = self.backbone.out_dim

    self.register_buffer(
        "A_inv", torch.eye(d).unsqueeze(0).repeat(alpha_steps, 1, 1) / lambda_reg
    )
    self.register_buffer("b", torch.zeros(alpha_steps, d))
    self.register_buffer("theta", torch.zeros(alpha_steps, d))

    # gpu_training_steps drives RLFusionBase._learn_step's inner loop.
    self.gpu_training_steps = backbone_epochs

    self._optimizer = optim.AdamW(self.backbone.parameters(), lr=backbone_lr)
    self._episode = 0
    self._ucb_widths: list[float] = []

base

Fusion model bases.

All fusion modules consume a feature TensorDict from the new extraction pipeline, not a flat state vector. Modules that need a flat input (Q-network for DQN/Bandit, MLP) call flatten_features(td) to concatenate every leaf tensor along the feature dim.

  • FusionModuleBase — predict / training_step / validation_step / test_step. Branches on automatic_optimization: supervised path (MLP, WeightedAvg) implements forward_scores(td) -> probs; RL path comes from RLFusionBase.

  • RLFusionBase — torchrl replay buffer + act → reward → push → learn. Subclasses provide a torchrl LossModule plus three hooks.

RLFusionBase
RLFusionBase(*, buffer_size: int, batch_size: int, **kw)

Bases: FusionModuleBase

torchrl replay buffer + unified act/learn flow.

Subclass implements: - _compute_loss(sample) -> Tensor — scalar loss from a buffer sample. DQN delegates to a torchrl DQNLoss; Bandit computes MSE inline. The optimizer scope (self._optimizer) is whatever params the subclass actually trains — it does NOT have to match a single loss_module.

Subclass sets in __init__: - self._optimizer — optimizer over the trainable params.

Hooks: - _score_actions(td, training) — write td['action']. - _after_act(actions, obs, rewards) — online update. - _should_learn() — gate the optim step (default: every step). - _after_optim_step() — post-step (DQN target sync). - _after_learn() — post-batch (Bandit ridge reset). - _extra_metrics() — extra log fields.

Source code in graphids/core/models/fusion/base.py
def __init__(self, *, buffer_size: int, batch_size: int, **kw):
    super().__init__(batch_size=batch_size, **kw)
    self._rb = TensorDictReplayBuffer(
        storage=LazyTensorStorage(max_size=buffer_size, device=torch.device("cpu")),
        sampler=RandomSampler(),
        batch_size=batch_size,
    )
select_action_batch
select_action_batch(features_td: TensorDict, training: bool = True)

Returns (actions[N], alphas[N], normalized_features_td[N]).

Source code in graphids/core/models/fusion/base.py
def select_action_batch(self, features_td: TensorDict, training: bool = True):
    """Returns ``(actions[N], alphas[N], normalized_features_td[N])``."""
    td_norm = self.reward_calc.normalize(features_td).to(self.device)
    obs = flatten_features(td_norm)
    inner = TensorDict({"observation": obs}, batch_size=[obs.size(0)], device=self.device)
    with torch.no_grad():
        self._score_actions(inner, training=training)
    actions = inner["action"].detach().cpu()
    return actions, self.alpha_values[actions], td_norm.cpu()
build_mlp_body
build_mlp_body(state_dim: int, hidden_dim: int, num_layers: int) -> nn.Sequential

[Linear → LayerNorm → ReLU → Dropout(0.2)] x N.

Source code in graphids/core/models/fusion/base.py
def build_mlp_body(state_dim: int, hidden_dim: int, num_layers: int) -> nn.Sequential:
    """[Linear → LayerNorm → ReLU → Dropout(0.2)] x N."""
    layers: list[nn.Module] = []
    in_dim = state_dim
    for _ in range(num_layers):
        layers.extend(
            [nn.Linear(in_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(0.2)]
        )
        in_dim = hidden_dim
    return nn.Sequential(*layers)
flatten_features
flatten_features(td: TensorDict) -> torch.Tensor

Concatenate every leaf tensor along the last dim. Stable order: sorted nested-key path so the same TD always yields the same layout.

Only tuple-keyed (model-namespaced) leaves are concatenated. Top-level str leaves are reserved for metadata (labels, attack_type); they pass through the TD untouched and reach test_step via td.get(...) instead of being treated as features.

Source code in graphids/core/models/fusion/base.py
def flatten_features(td: TensorDict) -> torch.Tensor:
    """Concatenate every leaf tensor along the last dim. Stable order:
    sorted nested-key path so the same TD always yields the same layout.

    Only tuple-keyed (model-namespaced) leaves are concatenated. Top-level
    str leaves are reserved for metadata (``labels``, ``attack_type``);
    they pass through the TD untouched and reach ``test_step`` via
    ``td.get(...)`` instead of being treated as features.
    """
    leaves = sorted(
        k for k in td.keys(include_nested=True, leaves_only=True) if isinstance(k, tuple)
    )
    return torch.cat([td[k] for k in leaves], dim=-1)

dqn

DQN fusion: torchrl DQNLoss + EGreedyModule over QValueActor.

Subclasses RLFusionBase and contributes only the DQN-specific math: the Q-actor + epsilon-greedy explorer, the DQNLoss (with double_dqn toggle and delay_value target net), and SoftUpdate Polyak sync. gamma=0 because each graph is an independent context.

mlp

Supervised MLP baseline: binary classification from flattened fusion features.

MLPFusionModule
MLPFusionModule(state_dim: int = 18, hidden_dims: tuple[int, ...] = (64, 32), lr: float = 0.001, decision_threshold: float = 0.5)

Bases: FusionModuleBase

Same features as DQN, trained with BCE instead of RL.

Source code in graphids/core/models/fusion/mlp.py
def __init__(
    self,
    state_dim: int = 18,
    hidden_dims: tuple[int, ...] = (64, 32),
    lr: float = 1e-3,
    decision_threshold: float = 0.5,
):
    super().__init__(state_dim=state_dim, decision_threshold=decision_threshold)
    self._store_init_kwargs(locals())

    layers: list[nn.Module] = []
    in_dim = state_dim
    for h in hidden_dims:
        layers.extend([nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(0.2)])
        in_dim = h
    layers.append(nn.Linear(in_dim, 1))
    self.model = nn.Sequential(*layers)

moe

MoE+BCE per-sample gated fusion: K experts with softmax router, dense soft-gated.

Implements the canonical Jacobs & Jordan (1991) "Adaptive Mixtures of Local Experts" formulation: every sample passes through every expert; the gate emits per-sample weights w(x) ∈ Δ^{K-1}; final prediction is the convex combination Σᵢ wᵢ(x) · sigmoid(hᵢ(x)). Trained end-to-end with BCE on the mixed score — no per-expert supervision, no auxiliary losses in v0.

Why dense soft-gated and not sparse top-k: sparse routing's value is conditional compute at scale (Switch Transformer, Mixtral). At K=3 with 18-dim features the FLOPs argument is moot, and soft blending is the hypothesis we want to test. Design rationale, variant survey, and escalation paths: docs/drafts/moe-fusion-design.md.

MoEFusionModule
MoEFusionModule(state_dim: int = 18, num_experts: int = 3, expert_hidden: tuple[int, ...] = (64, 32), gate_hidden: tuple[int, ...] = (32,), lr: float = 0.001, decision_threshold: float = 0.5, aux_weight: float = 0.01)

Bases: FusionModuleBase

Dense soft-gated mixture of K identical experts over the flat feature vector.

Specialization is emergent: experts share architecture and input; only the gate's softmax over per-sample logits selects how their outputs combine. If gate entropy stays at log(K) (uniform) on a fitted run, the features carry no routable signal — see diagnostics + escalation table in the design doc.

Source code in graphids/core/models/fusion/moe.py
def __init__(
    self,
    state_dim: int = 18,
    num_experts: int = 3,
    expert_hidden: tuple[int, ...] = (64, 32),
    gate_hidden: tuple[int, ...] = (32,),
    lr: float = 1e-3,
    decision_threshold: float = 0.5,
    aux_weight: float = 0.01,
):
    super().__init__(state_dim=state_dim, decision_threshold=decision_threshold)
    self._store_init_kwargs(locals())

    self.experts = nn.ModuleList(
        [_build_head(state_dim, expert_hidden, out_dim=1) for _ in range(num_experts)]
    )
    self.gate = _build_head(state_dim, gate_hidden, out_dim=num_experts)

    # Last-batch routing diagnostics; set by forward_scores, read by
    # training_step / validation_step. Never read across batches.
    self._last_gate_weights: torch.Tensor | None = None
    self._last_expert_scores: torch.Tensor | None = None

reward

Fusion reward calculator.

This module now exposes one reward primitive. The old mode switch and alternate reward class were deleted to keep the model layer honest: one calculator, one contract.

FusionRewardCalculator
FusionRewardCalculator(*, vgae_weights: list[float] | tuple[float, ...], correct: float, incorrect: float, confidence_weight: float, combined_conf_weight: float, disagreement_penalty: float, overconf_penalty: float, balance_weight: float)

Bases: Module

Vectorized fusion reward over a feature TensorDict.

Required nested keys
  • ("vgae", "errors") [N, 3] — recon, mahal, kl
  • ("vgae", "conf") [N, 1]
  • ("gat", "probs") [N, 2]
  • ("gat", "conf") [N, 1]

Other keys (z_stats, emb_stats, …) are ignored — they're consumed by the supervised/Q-network paths after flattening, not by the reward.

Source code in graphids/core/models/fusion/reward.py
def __init__(
    self,
    *,
    vgae_weights: list[float] | tuple[float, ...],
    correct: float,
    incorrect: float,
    confidence_weight: float,
    combined_conf_weight: float,
    disagreement_penalty: float,
    overconf_penalty: float,
    balance_weight: float,
) -> None:
    super().__init__()
    self.register_buffer("_vgae_weights", torch.tensor(vgae_weights, dtype=torch.float32))
    self._reward_correct = correct
    self._reward_incorrect = incorrect
    self._confidence_weight = confidence_weight
    self._combined_conf_weight = combined_conf_weight
    self._disagreement_penalty = disagreement_penalty
    self._overconf_penalty = overconf_penalty
    self._balance_weight = balance_weight
compute
compute(td: TensorDict, preds: Tensor, labels: Tensor, alphas: Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]

Vectorized reward. Returns (total[N], components).

components is a per-term breakdown of what each shaping term contributed to total per graph (mutually-exclusive correct/wrong terms zero out on the inactive branch). sum(components.values()) == total by construction. Used by callers to log per-component means per epoch — diagnostic for which term the policy is exploiting.

Source code in graphids/core/models/fusion/reward.py
def compute(
    self,
    td: TensorDict,
    preds: torch.Tensor,
    labels: torch.Tensor,
    alphas: torch.Tensor,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
    """Vectorized reward. Returns ``(total[N], components)``.

    ``components`` is a per-term breakdown of what each shaping term
    contributed to ``total`` per graph (mutually-exclusive correct/wrong
    terms zero out on the inactive branch). ``sum(components.values()) ==
    total`` by construction. Used by callers to log per-component means
    per epoch — diagnostic for which term the policy is exploiting.
    """
    anomaly, gat_prob = self.derive_scores(td)
    vgae_conf = td["vgae", "conf"].squeeze(-1)
    gat_conf = td["gat", "conf"].squeeze(-1)
    combined_conf = torch.max(vgae_conf, gat_conf)

    correct = preds == labels
    base = torch.where(correct, self._reward_correct, self._reward_incorrect)
    agreement = 1.0 - (anomaly - gat_prob).abs()

    max_score = torch.max(anomaly, gat_prob)
    confidence = torch.where(labels == 1, max_score, 1.0 - max_score)
    bonus = self._confidence_weight * confidence + self._combined_conf_weight * combined_conf

    disagreement = self._disagreement_penalty * (1.0 - agreement)
    fused = alphas * gat_prob + (1 - alphas) * anomaly
    overconf = torch.where(
        preds == 1,
        self._overconf_penalty * fused,
        self._overconf_penalty * (1.0 - fused),
    )
    balance = self._balance_weight * (1.0 - (alphas - 0.5).abs() * 2)

    zero = torch.zeros_like(base)
    components = {
        "r_classification": base,
        "r_agreement": torch.where(correct, agreement, zero),
        "r_confidence": torch.where(correct, bonus, zero),
        "r_disagreement_penalty": torch.where(correct, zero, disagreement),
        "r_overconfidence_penalty": torch.where(correct, zero, overconf),
        "r_balance": balance,
    }
    total = sum(components.values())
    return total, components
derive_scores
derive_scores(td: TensorDict) -> tuple[torch.Tensor, torch.Tensor]

Return (anomaly_scores[N], gat_probs_pos[N]).

anomaly was previously (errors @ weights).clamp(0, 1) — saturated to 1.0 for nearly every sample because typical weighted error magnitudes are O(1)–O(10), well above the clamp ceiling. That broke the RL fusion path: at α≈0.5 the blended score α·gat_prob + (1−α)·anomaly was ≥0.5 everywhere → predict-attack on every sample → MCC≈0 even though AUROC was perfect (the ranking was right but the threshold was uniformly above benigns). Replace the clamp with the Möbius transform x/(1+x): bounded [0, 1) on non-negative errors (recon, mahal, kl are all non-negative), strictly monotonic, parameter-free, preserves rank ordering. This matches the sigmoidal compression weighted_avg already uses on recon_mean.

Source code in graphids/core/models/fusion/reward.py
def derive_scores(self, td: TensorDict) -> tuple[torch.Tensor, torch.Tensor]:
    """Return (anomaly_scores[N], gat_probs_pos[N]).

    ``anomaly`` was previously ``(errors @ weights).clamp(0, 1)`` —
    saturated to 1.0 for nearly every sample because typical weighted
    error magnitudes are O(1)–O(10), well above the clamp ceiling. That
    broke the RL fusion path: at α≈0.5 the blended score
    ``α·gat_prob + (1−α)·anomaly`` was ≥0.5 everywhere → predict-attack
    on every sample → MCC≈0 even though AUROC was perfect (the ranking
    was right but the threshold was uniformly above benigns). Replace
    the clamp with the Möbius transform ``x/(1+x)``: bounded [0, 1) on
    non-negative errors (recon, mahal, kl are all non-negative), strictly
    monotonic, parameter-free, preserves rank ordering. This matches the
    sigmoidal compression ``weighted_avg`` already uses on `recon_mean`.
    """
    weighted = (td["vgae", "errors"] * self._vgae_weights).sum(dim=1)
    anomaly = weighted / (1.0 + weighted)
    gat_prob = td["gat", "probs"][:, 1]
    return anomaly, gat_prob
normalize
normalize(td: TensorDict) -> TensorDict

Clamp confidence keys to [0, 1]. Returns a shallow-cloned TD.

Source code in graphids/core/models/fusion/reward.py
def normalize(self, td: TensorDict) -> TensorDict:
    """Clamp confidence keys to [0, 1]. Returns a shallow-cloned TD."""
    out = td.clone(recurse=False)
    out["vgae"] = td["vgae"].clone(recurse=False)
    out["gat"] = td["gat"].clone(recurse=False)
    out["vgae", "conf"] = td["vgae", "conf"].clamp(0.0, 1.0)
    out["gat", "conf"] = td["gat", "conf"].clamp(0.0, 1.0)
    return out

weighted_avg

Simplest fusion baseline: learns a single scalar alpha blending vgae anomaly + gat attack prob.

score = (1-alpha) * vgae_anom + alpha * gat_attack where vgae_anom = 1 - vgae_conf = recon_mean/(1+recon_mean) (high = anomalous) gat_attack = gat/probs[:,1] (high = attack)

If this matches DQN's F1, the RL approach is unjustified.

WeightedAvgModule
WeightedAvgModule(lr: float = 0.01, decision_threshold: float = 0.5, state_dim: int = 18)

Bases: FusionModuleBase

alpha = sigmoid(w); score = (1-alpha)·vgae_conf + alpha·gat_conf.

Source code in graphids/core/models/fusion/weighted_avg.py
def __init__(self, lr: float = 1e-2, decision_threshold: float = 0.5, state_dim: int = 18):
    super().__init__(state_dim=state_dim, decision_threshold=decision_threshold)
    self._store_init_kwargs(locals())
    self.weight = nn.Parameter(torch.zeros(1))

id_encoding

Pluggable identity-encoding strategies for graph nodes.

An IdEncoder maps a node_id LongTensor to per-node embedding vectors. Subclasses implement different strategies (lookup table, k-probe hash, ...) behind a uniform interface so VGAE / GAT / DGI do not know which strategy is in use.

Research basis: ~/plans/oov-embedding-handling.md.

HashIdEncoder

HashIdEncoder(num_buckets: int, embedding_dim: int, *, k: int = 2, seed: int = 42)

Bases: IdEncoder

Source code in graphids/core/models/id_encoding/hash_embedding.py
def __init__(
    self,
    num_buckets: int,
    embedding_dim: int,
    *,
    k: int = 2,
    seed: int = 42,
):
    super().__init__()
    if num_buckets < 2:
        raise ValueError(f"num_buckets must be >= 2, got {num_buckets}")
    if k < 1:
        raise ValueError(f"k must be >= 1, got {k}")
    self.embedding = nn.Embedding(num_buckets, embedding_dim)
    self.out_dim = embedding_dim
    self.num_buckets = num_buckets
    self.k = k
    # ``k`` decorrelated hash offsets, deterministic in ``seed``.
    # Spread across int64 so per-probe bucket distributions are
    # well-separated for small vocabs. Registered as a buffer so
    # checkpoint round-trip is exact.
    offsets = torch.tensor(
        [seed + i * (1 << 30) for i in range(k)],
        dtype=torch.int64,
    )
    self.register_buffer("_hash_offsets", offsets)
from_vocab_size classmethod
from_vocab_size(num_ids: int, *, embedding_dim: int, k: int = 2, seed: int = 42, num_buckets_factor: int = 4, num_buckets: int | None = None) -> HashIdEncoder

Build from a datamodule-injected num_ids.

Default bucket count: next_pow2(num_buckets_factor · num_ids), minimum 8. Per plan: Yan 2021 / Coleman 2023 use 2–4× vocab size as a sweet spot between collision rate and parameter count. num_buckets can be passed explicitly to override.

Source code in graphids/core/models/id_encoding/hash_embedding.py
@classmethod
def from_vocab_size(
    cls,
    num_ids: int,
    *,
    embedding_dim: int,
    k: int = 2,
    seed: int = 42,
    num_buckets_factor: int = 4,
    num_buckets: int | None = None,
) -> HashIdEncoder:
    """Build from a datamodule-injected ``num_ids``.

    Default bucket count: ``next_pow2(num_buckets_factor · num_ids)``,
    minimum 8. Per plan: Yan 2021 / Coleman 2023 use 2–4× vocab size
    as a sweet spot between collision rate and parameter count.
    ``num_buckets`` can be passed explicitly to override.
    """
    if num_buckets is None:
        target = max(8, num_buckets_factor * max(1, num_ids))
        num_buckets = 1 << (target - 1).bit_length()
    return cls(num_buckets=num_buckets, embedding_dim=embedding_dim, k=k, seed=seed)

IdEncoder

Bases: Module

Maps per-node identities to per-node embedding vectors.

Planned subclasses: - LookupIdEncoder — dense nn.Embedding over a shared vocab, with optional stochastic UNK-drop (Stage 3 ablation). - HashIdEncoder (Stage 2 primary, not yet implemented) — k-probe hash embedding per Yan et al. 2021 (CIKM).

build_encoder

build_encoder(class_path: str, num_ids: int, embedding_dim: int, **kwargs: Any) -> IdEncoder

Resolve a dotted class_path and call from_vocab_size.

num_ids is data-dependent (populated by datamodule.setup), so encoder construction stays at model-build time.

Source code in graphids/core/models/id_encoding/base.py
def build_encoder(class_path: str, num_ids: int, embedding_dim: int, **kwargs: Any) -> IdEncoder:
    """Resolve a dotted ``class_path`` and call ``from_vocab_size``.

    ``num_ids`` is data-dependent (populated by ``datamodule.setup``), so
    encoder construction stays at model-build time.
    """
    mod, _, cls_name = class_path.rpartition(".")
    cls = getattr(importlib.import_module(mod), cls_name)
    return cls.from_vocab_size(num_ids=num_ids, embedding_dim=embedding_dim, **kwargs)

base

Base class for pluggable identity encoders.

Contract (duck-typed, matching the rest of the codebase):

  • forward(node_id: LongTensor) -> Tensor of shape (N, out_dim).
  • out_dim: int attribute set in __init__.
  • All stateful policy (vocab size, hash seeds, UNK-drop rate) lives on the encoder instance — InputEncoder holds one and does not branch on its type.
IdEncoder

Bases: Module

Maps per-node identities to per-node embedding vectors.

Planned subclasses: - LookupIdEncoder — dense nn.Embedding over a shared vocab, with optional stochastic UNK-drop (Stage 3 ablation). - HashIdEncoder (Stage 2 primary, not yet implemented) — k-probe hash embedding per Yan et al. 2021 (CIKM).

build_encoder
build_encoder(class_path: str, num_ids: int, embedding_dim: int, **kwargs: Any) -> IdEncoder

Resolve a dotted class_path and call from_vocab_size.

num_ids is data-dependent (populated by datamodule.setup), so encoder construction stays at model-build time.

Source code in graphids/core/models/id_encoding/base.py
def build_encoder(class_path: str, num_ids: int, embedding_dim: int, **kwargs: Any) -> IdEncoder:
    """Resolve a dotted ``class_path`` and call ``from_vocab_size``.

    ``num_ids`` is data-dependent (populated by ``datamodule.setup``), so
    encoder construction stays at model-build time.
    """
    mod, _, cls_name = class_path.rpartition(".")
    cls = getattr(importlib.import_module(mod), cls_name)
    return cls.from_vocab_size(num_ids=num_ids, embedding_dim=embedding_dim, **kwargs)

config

Explicit ID-encoding configs and factories.

hash_embedding

k-probe hash embedding — primary Stage-2 treatment.

Every id (seen or unseen) deterministically maps to k rows of a bucketed embedding table by k decorrelated hash functions; the per-probe vectors are summed. Because any id hits trained buckets by construction, no special OOV slot is needed.

Shape follows Coleman et al. 2023 Unified Embedding (NeurIPS Spotlight): one shared table, k probes, sum combiner — minimum parameters, clean theoretical analysis. Yan et al. 2021 Binary Code Hash Embedding (CIKM) uses the same k-probe idea with separate tables per hash; at CAN scale (~100 ids, B=512) the shared table has the same expressive power at half the parameters.

Hash: bucket_i(id) = (id * KNUTH + offset_i) mod num_buckets, where KNUTH = 2654435761 (golden-ratio-derived Knuth multiplier) and the k offsets are deterministic functions of the seed constructor arg. The multiplier is coprime to any num_buckets >= 2 that isn't a specific pathological case, and Knuth's value is well-studied for integer-id hashing at tiny scale.

Research basis: ~/plans/oov-embedding-handling.md (Stage 2).

HashIdEncoder
HashIdEncoder(num_buckets: int, embedding_dim: int, *, k: int = 2, seed: int = 42)

Bases: IdEncoder

Source code in graphids/core/models/id_encoding/hash_embedding.py
def __init__(
    self,
    num_buckets: int,
    embedding_dim: int,
    *,
    k: int = 2,
    seed: int = 42,
):
    super().__init__()
    if num_buckets < 2:
        raise ValueError(f"num_buckets must be >= 2, got {num_buckets}")
    if k < 1:
        raise ValueError(f"k must be >= 1, got {k}")
    self.embedding = nn.Embedding(num_buckets, embedding_dim)
    self.out_dim = embedding_dim
    self.num_buckets = num_buckets
    self.k = k
    # ``k`` decorrelated hash offsets, deterministic in ``seed``.
    # Spread across int64 so per-probe bucket distributions are
    # well-separated for small vocabs. Registered as a buffer so
    # checkpoint round-trip is exact.
    offsets = torch.tensor(
        [seed + i * (1 << 30) for i in range(k)],
        dtype=torch.int64,
    )
    self.register_buffer("_hash_offsets", offsets)
from_vocab_size classmethod
from_vocab_size(num_ids: int, *, embedding_dim: int, k: int = 2, seed: int = 42, num_buckets_factor: int = 4, num_buckets: int | None = None) -> HashIdEncoder

Build from a datamodule-injected num_ids.

Default bucket count: next_pow2(num_buckets_factor · num_ids), minimum 8. Per plan: Yan 2021 / Coleman 2023 use 2–4× vocab size as a sweet spot between collision rate and parameter count. num_buckets can be passed explicitly to override.

Source code in graphids/core/models/id_encoding/hash_embedding.py
@classmethod
def from_vocab_size(
    cls,
    num_ids: int,
    *,
    embedding_dim: int,
    k: int = 2,
    seed: int = 42,
    num_buckets_factor: int = 4,
    num_buckets: int | None = None,
) -> HashIdEncoder:
    """Build from a datamodule-injected ``num_ids``.

    Default bucket count: ``next_pow2(num_buckets_factor · num_ids)``,
    minimum 8. Per plan: Yan 2021 / Coleman 2023 use 2–4× vocab size
    as a sweet spot between collision rate and parameter count.
    ``num_buckets`` can be passed explicitly to override.
    """
    if num_buckets is None:
        target = max(8, num_buckets_factor * max(1, num_ids))
        num_buckets = 1 << (target - 1).bit_length()
    return cls(num_buckets=num_buckets, embedding_dim=embedding_dim, k=k, seed=seed)

lookup

Dense lookup embedding with optional stochastic UNK-drop.

Default (p_unk_drop=0.0) reproduces the pre-refactor nn.Embedding behavior byte-for-byte so existing single-vocab runs are a no-op change.

p_unk_drop > 0.0 implements the Stage 3 ablation arm from ~/plans/oov-embedding-handling.md: during training, each node_id is remapped to UNK_INDEX with probability p, so the OOV row receives gradient and attack-introduced IDs at inference land in a trained slot instead of init noise.

supervised

Supervised graph model family exports.

GAT

GAT(*, loss_fn: Module | None = None, hidden: int | None = None, layers: int | None = None, heads: int | None = None, dropout: float = 0.2, fc_layers: int = 3, embedding_dim: int = 16, conv_type: str = 'gatv2', edge_dim: int = 11, pool_aggrs: list[str] | None = None, sequence_pool: Literal['auto', 'flat', 'mean', 'attention', 'gru'] = 'auto', proj_dim: int = 0, gradient_checkpointing: bool = True, compile_model: bool = False, id_encoder_cfg: IdEncodingCfg | None = None, id_encoder_class_path: str = 'graphids.core.models.id_encoding.LookupIdEncoder', id_encoder_kwargs: dict | None = None, lr: float = 0.001, weight_decay: float = 0.0001, scale: str = 'small', model_type: ModelType = 'gat', dataset: str = '', seed: int = 42, variational: bool = True, num_ids: int = 0, in_channels: int = 0, num_classes: int = 2)

Bases: GraphModuleBase

Collapsed GAT — arch + trainer-bridge in one nn.Module.

Loss selection is decoupled: loss_fn is an nn.Module built from the experiment config. When the block resolves to a :class:~graphids.core.losses.distillation.SoftLabelDistillation, training automatically becomes a KD run — no branching here.

scale selects per-axis hyperparam presets from :attr:_SCALES; explicit hidden / layers / heads kwargs (non-None) override the preset.

Source code in graphids/core/models/supervised/gat.py
def __init__(
    self,
    *,
    loss_fn: nn.Module | None = None,
    # --- architecture (None → resolve from ``scale`` preset) ---
    hidden: int | None = None,
    layers: int | None = None,
    heads: int | None = None,
    dropout: float = 0.2,
    fc_layers: int = 3,
    embedding_dim: int = 16,
    conv_type: str = "gatv2",
    edge_dim: int = 11,
    pool_aggrs: list[str] | None = None,
    sequence_pool: Literal["auto", "flat", "mean", "attention", "gru"] = "auto",
    proj_dim: int = 0,
    gradient_checkpointing: bool = True,
    compile_model: bool = False,
    id_encoder_cfg: IdEncodingCfg | None = None,
    id_encoder_class_path: str = "graphids.core.models.id_encoding.LookupIdEncoder",
    id_encoder_kwargs: dict | None = None,
    # --- training ---
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    # --- identity / dynamic ---
    scale: str = "small",
    model_type: ModelType = "gat",
    dataset: str = "",
    seed: int = 42,
    variational: bool = True,  # upstream VGAE type — identity key for supervised
    num_ids: int = 0,
    in_channels: int = 0,
    num_classes: int = 2,
):
    s = self._SCALES.get(scale, {})
    if hidden is None:
        hidden = s.get("hidden", 48)
    if layers is None:
        layers = s.get("layers", 3)
    if heads is None:
        heads = s.get("heads", 8)
    super().__init__()
    self.test_metrics = classification_test_metrics(num_classes)
    self._val_probs: list[torch.Tensor] = []
    self._val_labels: list[torch.Tensor] = []
    self._init_post(locals())
extract_features
extract_features(batch, device: device) -> dict[str, torch.Tensor]

Per-graph fusion features as named tensors.

  • probs [N, 2] — prob_0, prob_1
  • conf [N, 1] — 1 - entropy / log(2)
  • emb_stats [N, 4] — emb_mean, emb_std, emb_max, emb_min
Source code in graphids/core/models/supervised/gat.py
def extract_features(self, batch, device: torch.device) -> dict[str, torch.Tensor]:
    """Per-graph fusion features as named tensors.

    - ``probs``     [N, 2] — prob_0, prob_1
    - ``conf``      [N, 1] — 1 - entropy / log(2)
    - ``emb_stats`` [N, 4] — emb_mean, emb_std, emb_max, emb_min
    """
    logits, emb = self(batch, return_embedding=True)
    probs = F.softmax(logits, dim=1)
    entropy = -(probs * (probs + 1e-8).log()).sum(dim=1)
    conf = (1.0 - entropy / math.log(2)).clamp(0.0, 1.0)
    return {
        "probs": probs,
        "conf": conf.unsqueeze(-1),
        "emb_stats": torch.cat(
            [
                emb.mean(1, keepdim=True),
                emb.std(1, keepdim=True),
                emb.max(1).values.unsqueeze(1),
                emb.min(1).values.unsqueeze(1),
            ],
            dim=1,
        ),
    }

gat

GAT supervised classifier — collapsed arch + trainer-bridge.

The single :class:GAT class is both the architecture (InputEncoder + conv stack + JK + pool + FC head) and the trainer-bridge (training_step/validation_step/test_step, fusion-feature extractor). No wrapper module — see ~/plans/graphids-collapse-model-modules.md Phase 3.

Supports GATConv (default), GATv2Conv, and TransformerConv via conv_type. TransformerConv natively uses edge_attr, enabling the 11-D edge features (frequency, temporal intervals, bidirectionality, degree products) that GATConv ignores.

GAT
GAT(*, loss_fn: Module | None = None, hidden: int | None = None, layers: int | None = None, heads: int | None = None, dropout: float = 0.2, fc_layers: int = 3, embedding_dim: int = 16, conv_type: str = 'gatv2', edge_dim: int = 11, pool_aggrs: list[str] | None = None, sequence_pool: Literal['auto', 'flat', 'mean', 'attention', 'gru'] = 'auto', proj_dim: int = 0, gradient_checkpointing: bool = True, compile_model: bool = False, id_encoder_cfg: IdEncodingCfg | None = None, id_encoder_class_path: str = 'graphids.core.models.id_encoding.LookupIdEncoder', id_encoder_kwargs: dict | None = None, lr: float = 0.001, weight_decay: float = 0.0001, scale: str = 'small', model_type: ModelType = 'gat', dataset: str = '', seed: int = 42, variational: bool = True, num_ids: int = 0, in_channels: int = 0, num_classes: int = 2)

Bases: GraphModuleBase

Collapsed GAT — arch + trainer-bridge in one nn.Module.

Loss selection is decoupled: loss_fn is an nn.Module built from the experiment config. When the block resolves to a :class:~graphids.core.losses.distillation.SoftLabelDistillation, training automatically becomes a KD run — no branching here.

scale selects per-axis hyperparam presets from :attr:_SCALES; explicit hidden / layers / heads kwargs (non-None) override the preset.

Source code in graphids/core/models/supervised/gat.py
def __init__(
    self,
    *,
    loss_fn: nn.Module | None = None,
    # --- architecture (None → resolve from ``scale`` preset) ---
    hidden: int | None = None,
    layers: int | None = None,
    heads: int | None = None,
    dropout: float = 0.2,
    fc_layers: int = 3,
    embedding_dim: int = 16,
    conv_type: str = "gatv2",
    edge_dim: int = 11,
    pool_aggrs: list[str] | None = None,
    sequence_pool: Literal["auto", "flat", "mean", "attention", "gru"] = "auto",
    proj_dim: int = 0,
    gradient_checkpointing: bool = True,
    compile_model: bool = False,
    id_encoder_cfg: IdEncodingCfg | None = None,
    id_encoder_class_path: str = "graphids.core.models.id_encoding.LookupIdEncoder",
    id_encoder_kwargs: dict | None = None,
    # --- training ---
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    # --- identity / dynamic ---
    scale: str = "small",
    model_type: ModelType = "gat",
    dataset: str = "",
    seed: int = 42,
    variational: bool = True,  # upstream VGAE type — identity key for supervised
    num_ids: int = 0,
    in_channels: int = 0,
    num_classes: int = 2,
):
    s = self._SCALES.get(scale, {})
    if hidden is None:
        hidden = s.get("hidden", 48)
    if layers is None:
        layers = s.get("layers", 3)
    if heads is None:
        heads = s.get("heads", 8)
    super().__init__()
    self.test_metrics = classification_test_metrics(num_classes)
    self._val_probs: list[torch.Tensor] = []
    self._val_labels: list[torch.Tensor] = []
    self._init_post(locals())
extract_features
extract_features(batch, device: device) -> dict[str, torch.Tensor]

Per-graph fusion features as named tensors.

  • probs [N, 2] — prob_0, prob_1
  • conf [N, 1] — 1 - entropy / log(2)
  • emb_stats [N, 4] — emb_mean, emb_std, emb_max, emb_min
Source code in graphids/core/models/supervised/gat.py
def extract_features(self, batch, device: torch.device) -> dict[str, torch.Tensor]:
    """Per-graph fusion features as named tensors.

    - ``probs``     [N, 2] — prob_0, prob_1
    - ``conf``      [N, 1] — 1 - entropy / log(2)
    - ``emb_stats`` [N, 4] — emb_mean, emb_std, emb_max, emb_min
    """
    logits, emb = self(batch, return_embedding=True)
    probs = F.softmax(logits, dim=1)
    entropy = -(probs * (probs + 1e-8).log()).sum(dim=1)
    conf = (1.0 - entropy / math.log(2)).clamp(0.0, 1.0)
    return {
        "probs": probs,
        "conf": conf.unsqueeze(-1),
        "emb_stats": torch.cat(
            [
                emb.mean(1, keepdim=True),
                emb.std(1, keepdim=True),
                emb.max(1).values.unsqueeze(1),
                emb.min(1).values.unsqueeze(1),
            ],
            dim=1,
        ),
    }