Skip to content

Data: Sampler

Dual-budget bin-packing sampler — closes a batch when adding a graph would exceed either the node budget or the edge budget. Single-axis node-only budgets allowed edge-heavy batches to OOM; see .claude/rules/critical-constraints.md.

Two paths:

  • NodeBudgetBatchSampler — live sampler, bucket-shuffled, fresh each epoch. Used when shuffle=True.
  • pack_offline — first-fit-decreasing packing used by the prebatch path at setup. ~10-20% tighter than sequential; no epoch-to-epoch randomness to preserve.

graphids.core.data.sampler

sampler

Node-budget batch sampler for variable-size graphs.

Bin-packing sampler that yields index batches honoring a node budget, and optionally an edge budget as a dual constraint. The dual constraint matters when per-batch memory is dominated by message-passing activations (∝ edges) rather than node features (∝ nodes) — the edge budget prevents rare dense-edge graphs from OOMing even when the node budget would admit them.

NodeBudgetBatchSampler

NodeBudgetBatchSampler(sizes: Tensor, max_num: int, *, edge_sizes: Tensor | None = None, max_edges: int | None = None, shuffle: bool = True, num_buckets: int = 20, indices: Tensor | list[int] | None = None)

Bases: Sampler[list[int]]

Bin-packing sampler with optional dual node/edge budget.

  • sizes / max_num: per-graph node counts, max nodes per batch.
  • edge_sizes / max_edges (optional): per-graph edge counts, max edges per batch. A batch closes when adding a graph would exceed EITHER budget. A graph exceeding either budget on its own is skipped (with a one-line per-epoch summary warning).

Bucket-shuffle keeps batch-to-batch size variance low. indices maps local positions to dataset-global indices (for curriculum subsets).

Source code in graphids/core/data/sampler.py
def __init__(
    self,
    sizes: torch.Tensor,
    max_num: int,
    *,
    edge_sizes: torch.Tensor | None = None,
    max_edges: int | None = None,
    shuffle: bool = True,
    num_buckets: int = 20,
    indices: torch.Tensor | list[int] | None = None,
):
    if max_num <= 0:
        raise ValueError(f"max_num must be positive, got {max_num}")
    self.sizes = sizes.to(torch.long)
    self.max_num = int(max_num)

    if edge_sizes is not None:
        if len(edge_sizes) != len(self.sizes):
            raise ValueError(
                f"edge_sizes length ({len(edge_sizes)}) != sizes length ({len(self.sizes)})"
            )
        if max_edges is None or max_edges <= 0:
            raise ValueError("max_edges must be a positive int when edge_sizes is given")
        self.edge_sizes: torch.Tensor | None = edge_sizes.to(torch.long)
        self.max_edges: int | None = int(max_edges)
    else:
        self.edge_sizes = None
        self.max_edges = None

    self.shuffle = shuffle
    self.num_buckets = max(1, int(num_buckets))
    if indices is not None:
        idx = torch.as_tensor(indices, dtype=torch.long)
        if len(idx) != len(self.sizes):
            raise ValueError(f"indices length ({len(idx)}) != sizes length ({len(self.sizes)})")
        self._index_map: list[int] | None = idx.tolist()
    else:
        self._index_map = None
    # Populated by __iter__ after a full pass; read by __len__ so the
    # DataLoader progress-bar probe doesn't trigger a separate full pack.
    self._cached_len: int | None = None

pack_offline

pack_offline(sizes: Tensor, max_num: int, *, edge_sizes: Tensor | None = None, max_edges: int | None = None) -> list[list[int]]

First-fit-decreasing packing for the prebatch path.

The sampler's live packing walks indices sequentially (or bucket-shuffled) and closes a batch greedily — ~11/9 × OPT at best, and significantly worse when dataset order isn't size-sorted. FFD sorts graphs by size descending, then places each into the first batch it fits. For variable- size graphs this gives ~10-20% better node-budget utilization than sequential packing with no epoch-to-epoch randomness to preserve.

Returns a list of batch index lists (dataset-global indices; no shuffle). Used by GraphDataModule._prebatch — the class sampler is still used for live training where shuffle=True re-buckets per epoch.

Source code in graphids/core/data/sampler.py
def pack_offline(
    sizes: torch.Tensor,
    max_num: int,
    *,
    edge_sizes: torch.Tensor | None = None,
    max_edges: int | None = None,
) -> list[list[int]]:
    """First-fit-decreasing packing for the prebatch path.

    The sampler's live packing walks indices sequentially (or bucket-shuffled)
    and closes a batch greedily — ~11/9 × OPT at best, and significantly
    worse when dataset order isn't size-sorted. FFD sorts graphs by size
    descending, then places each into the first batch it fits. For variable-
    size graphs this gives ~10-20% better node-budget utilization than
    sequential packing with no epoch-to-epoch randomness to preserve.

    Returns a list of batch index lists (dataset-global indices; no
    shuffle). Used by ``GraphDataModule._prebatch`` — the class sampler
    is still used for live training where ``shuffle=True`` re-buckets
    per epoch.
    """
    if max_num <= 0:
        raise ValueError(f"max_num must be positive, got {max_num}")
    has_edges = edge_sizes is not None
    if has_edges:
        if len(edge_sizes) != len(sizes):
            raise ValueError(
                f"edge_sizes length ({len(edge_sizes)}) != sizes length ({len(sizes)})"
            )
        if max_edges is None or max_edges <= 0:
            raise ValueError("max_edges must be a positive int when edge_sizes is given")

    sizes_l = sizes.to(torch.long)
    edges_l = edge_sizes.to(torch.long) if has_edges else None
    order = torch.argsort(sizes_l, descending=True).tolist()

    bins: list[_Bin] = []
    skipped = 0
    for i in order:
        n_i = int(sizes_l[i].item())
        e_i = int(edges_l[i].item()) if has_edges else 0
        oversize = n_i > max_num or (max_edges is not None and e_i > max_edges)
        if oversize:
            skipped += 1
            continue
        placed = False
        for b in bins:
            if b.n_sum + n_i <= max_num and (max_edges is None or b.e_sum + e_i <= max_edges):
                b.indices.append(i)
                b.n_sum += n_i
                b.e_sum += e_i
                placed = True
                break
        if not placed:
            bins.append(_Bin(indices=[i], n_sum=n_i, e_sum=e_i))

    if skipped:
        log.warning(
            "sampler_skipped_oversize",
            n_skipped=skipped,
            n_total=len(sizes_l),
            max_nodes=max_num,
            max_edges=max_edges,
        )
    return [b.indices for b in bins]