Streaming Link Prediction with Vector Symbolic Architectures

The ability to accurately predict future connections in large-scale, dynamic networks is a fundamental challenge with profound implications, impacting domains from social media recommendations and financial fraud detection to modeling biological interactions. As real-world graphs increasingly manifest as high-velocity data streams, the task of streaming link prediction becomes paramount. In this setting, models must not only achieve high accuracy but also operate under strict computational constraints, adapting to the continuous influx of new edges without requiring costly, full-scale retraining.

Current state-of-the-art approaches, predominantly based on Graph Neural Networks (GNNs), have demonstrated remarkable success in static link prediction. However, their application to the streaming context is fraught with challenges. The iterative, multi-layer message-passing mechanism inherent to GNNs, while powerful, is computationally intensive. Adapting these models to new graph structures typically necessitates retraining the entire model or engaging in complex incremental learning schemes, both of which are often too slow and resource-heavy for real-time applications where networks evolve rapidly. This computational bottleneck renders many GNN-based solutions impractical for genuine streaming scenarios.

In this article I explore a novel method that that leverages the efficiency and scalability of Vector-Symbolic Architectures (VSA), a brain-inspired computing framework. VSAs encode entities and their relationships into high-dimensional vectors (hypervectors) using a fixed set of algebraic operations. A key advantage of this framework is its capacity for rapid, one-shot learning; the entire graph structure can be encoded into node representations in a single pass over the edge stream through efficient vector operations. This completely bypasses the need for iterative training of the core embedding model, making it exceptionally well-suited for streaming data.

In this work, I introduce a highly efficient hybrid method that synergistically combines the speed of VSA with the predictive power of a lightweight, learnable scoring function. Our approach unfolds in two stages. First, we construct a VSA "sketch" of the graph in a single pass over the historical edge stream, generating a fixed set of node hypervectors that encode rich, multi-hop structural information. These powerful representations are then frozen. Second, a simple yet effective linear scorer is trained on top of these static hypervectors to discern the subtle patterns that indicate a likely link. By decoupling the computationally intensive representation learning from the final prediction task, our model achieves state-of-the-art predictive performance while remaining uniquely adapted to the demands of high-velocity graph streams.

For a concise introduction to Vector Symbolic Architectures, please read my previous article An Intuitive Introduction to Hyperdimensional Computing .

2. Proposed Method: VSA Sketching with a Learned Scorer

Our proposed method addresses the challenge of streaming link prediction by decoupling the process into two distinct stages: (1) rapid, single-pass construction of node representations using Vector-Symbolic Architectures (VSA), and (2) training a lightweight, learnable scorer on top of these fixed representations. This hybrid approach is designed to harness the speed and scalability of VSA for encoding the evolving graph structure while leveraging the discriminative power of a focused machine learning model to achieve high predictive accuracy.

2.1. High-Level Intuition

The core intuition behind our approach is to separate the concerns of representation learning and task-specific prediction. Instead of training a complex, end-to-end model that learns both simultaneously, we first build a powerful, general-purpose "sketch" of the graph. This sketch, composed of high-dimensional vector embeddings for each node, captures rich, multi-hop neighborhood information. The key insight is that this encoding can be performed in a single pass over the stream of edges using efficient, parallelizable vector operations, completely avoiding iterative, gradient-based training for the representation learning phase.

Once these node vectors are built and frozen, the problem of link prediction is reframed as a much simpler supervised learning task: given the vector representations for two nodes, and , what is the likelihood of an edge existing? We solve this by training a very lightweight linear model—our scorer—that learns a task-specific similarity metric over pairs of node vectors. This decoupling means that as new edges arrive, we only need to perform fast updates to the VSA sketch, without ever having to retrain the scorer, making the entire architecture natively suited for high-velocity streaming environments.

2.2 Stage 1: Streaming VSA Graph Sketch

Vector-Symbolic Architectures are a brain-inspired computational framework that represents and manipulates information using high-dimensional vectors, or hypervectors. The framework relies on three fundamental algebraic operations:

  1. Bundling (Addition): . This operation superimposes hypervectors and is used to represent sets or aggregates. The resulting vector is similar to its components.

  2. Binding (Multiplication): . This operation (typically element-wise multiplication) associates two hypervectors, creating a new vector dissimilar to its inputs. It is used to link information, such as binding a node to a role.

  3. Permutation: This operation shuffles the elements of a hypervector in a fixed way. It is used to represent ordered sequences and is critically non-commutative (.

Our graph sketching process leverages these operations to encode the k-hop neighborhood of every node into a single hypervector .

Initialization: We begin by assigning every node in the graph a unique, dense, bipolar random hypervector of dimension . These base vectors serve as the unique identifiers for each node.

Single-Pass Encoding: The sketch is built by processing each edge in the training stream once. For an edge , we interpret the update as a message passing event where information from is aggregated into 's representation (and vice-versa, as the graph is undirected). To capture multi-hop information, we perform rounds of message passing in parallel for each edge. For hop , the message from node iGs transformed by binding it with a unique, random "role" vector and applying a permutation :

The role vector contextualizes the message as originating from a k-hop neighbor, while the permutation ensures that the order of relationships is preserved. All incoming messages to a node are then bundled (summed) to update its hypervector . This process effectively compresses the structural information from a node's entire k-hop neighborhood into its representative hypervector. The use of random base vectors and fixed operations ensures that this entire encoding phase is computationally efficient, requires no backpropagation, and can be applied incrementally as new edges arrive.

2.3. Stage 2: Training the Lightweight Scorer

While the VSA sketch provides potent structural embeddings, the default similarity metric in VSA (cosine similarity) may not be optimal for the nuanced task of link prediction. Therefore, we freeze the generated hypervectors and train a lightweight scorer to learn a more effective, task-specific similarity function.

For a candidate edge , we compute the element-wise (Hadamard) product of their hypervectors, . This interaction vector captures the patterns of agreement and disagreement between the two node representations. A single, learnable weight vector of dimension is then used to project this interaction vector down to a scalar score:

This linear scorer is exceptionally efficient. We train using a margin ranking loss, which pushes the scores of true edges (positives) to be higher than those of non-existent edges (negatives). Because the computationally expensive graph encoding is already complete, this training phase is very fast, focusing only on learning the parameters of the scorer. This allows us to train on the entire set of positive and negative training examples, resulting in a highly accurate prediction model without incurring the costs associated with traditional GNNs.

2.4. Mathematical Formulation and Theoretical Justification

This section provides a formal description of our method, grounding the previously described intuitions in mathematical principles.

VSA Sketch Construction. Let be a graph with nodes. We initialize the VSA sketch by assigning each node i∈V a unique, random base hypervector , where is the dimensionality. The node's initial representation is set to its normalized base vector, Hi(0)​=Bi​/∣∣Bi​∣∣2​. We also define K random role hypervectors and fixed permutation operators for .

The sketch is updated in a single pass over the stream of edges. For a batch of undirected edges ​, the message passed from a source node s to a target node for hop is defined as:

where denotes element-wise multiplication. The total aggregated information for each node is the weighted sum of all incoming messages within the batch:

Here, is a scalar edge weight incorporating factors like temporal recency and degree normalization, and is a discount factor that down-weights the influence of longer paths. The matrix of node hypervectors is then updated via an exponential moving average:

where is the matrix of aggregated messages ​ and is a decay factor. Finally, each node's hypervector is L2-normalized, .

Theoretical Justification. The efficacy of this method stems from the properties of high-dimensional spaces. Random vectors in high dimensions are quasi-orthogonal with high probability. This allows the bundling (addition) of many messages into a single hypervector ​ while preserving the ability to query for constituent information. The binding operation ​ creates a new vector that is dissimilar to both inputs, effectively creating a key-value pair. This bound product can be approximately inverted by binding with the role again: ​. The permutation operator ​ makes the binding non-commutative, allowing the model to distinguish the order of relations and encode directed path information. This update rule can be interpreted as a parallelized, non-parametric analogue of a K-layer GNN, where information from K-hop neighborhoods is compressed into the node representations in a single, efficient pass.

Learnable Scorer. After freezing the hypervectors H, we train a linear scorer on the link prediction task. The score for a candidate edge is given by:

where is a learnable weight vector. This model learns a weighted similarity metric, identifying the dimensions of the interaction vector ​ that are most predictive of a link. The model is trained to minimize the margin ranking loss L over a set of positive edges and negative edges .

This objective function encourages the scores of positive examples to be greater than the scores of negative examples by at least a margin of 1.

3. Experimental Results

We use the OGB Link Property Prediction (ogbl-collab) dataset, a widely recognized benchmark for link prediction. This dataset is a large, undirected co-authorship graph derived from the Microsoft Academic Graph (MAG). Nodes represent authors, and an edge between two authors indicates that they have co-authored a paper. The dataset is temporally split, making it ideal for evaluating models in a dynamic setting:

  • Training Set: Edges from papers published up to and including 2017.

  • Validation Set: Edges from papers published in 2018.

  • Test Set: Edges from papers published in 2019.

The graph consists of 235,868 nodes and approximately 1,280,000 training edges. The task is to predict the collaborative links for the validation and test years. Performance is measured using the Hits@K metric, which evaluates how often a true positive edge is ranked among the top K predicted candidates.

All experiments were conducted on a single machine equipped with an NVIDIA RTX 3090 GPU with 24GB of VRAM. The implementation leverages PyTorch for all vector operations and model training. Following the methodology described in Section 2, we first build the VSA sketch in a single pass over the training edges (≤ 2017) and then train the lightweight scorer, using the validation set (2018) for monitoring and the test set (2019) for final evaluation.

This method achieves HIT@50 of 0.665 and HIT@100 0.701. The SoA methods approach HITS@50 of 0.7129. While the SoA approaches get higher accuracies, the key distinction lies in the computational model. While GNN-based methods require multiple passes (epochs) over the data with expensive backpropagation to learn node embeddings, our method constructs its embeddings in a single, non-iterative pass. This makes our approach orders of magnitude faster and immediately applicable to genuine streaming scenarios where retraining an entire GNN is computationally infeasible. For instance, it takes less than 3 mins to run the algorithm end to end on the OGBL-Collab dataset. Our results strongly suggest that the proposed VSA-Sketch paradigm presents a new and highly effective trade-off between predictive accuracy and computational efficiency for link prediction on large, evolving graphs.

4. Conclusion

This article presented a novel and highly efficient method for link prediction on streaming graphs. By decoupling the process into a single-pass VSA sketch construction phase and a fast, lightweight scorer training phase, the approach overcomes the primary scalability limitations of traditional GNN-based models. Our experimental results on the ogbl-collab dataset confirm that this method achieves state-of-the-art accuracy without the need for expensive, iterative training epochs. This work paves the way for practical, high-performance link prediction systems in real-time, resource-constrained environments. Future work could explore the application of this paradigm to other dynamic graph tasks and investigate hardware-specific optimizations for VSA operations.

5. Code

Below is the code for end to end pipeline of this proposed method:

#!/usr/bin/env python3
"""
vsa_ogbl_collab.py  —  VSA-based link prediction with a LEARNABLE SCORER.

Key ideas:
- Streaming VSA sketch H built from train edges (<=2017)
- H vectors are then FROZEN.
- A lightweight scorer is trained on the FULL TRAINING EDGES (with negative samples)
  while monitoring performance on the validation set (2018).
- Final evaluation on the test set (2019).

CLI example:
    python vsa_ogbl_collab.py --device cuda --d 4096 --K 1 --gamma 0.6 \
        --iters 4 --decay 0.05 --edge-batch 40000 --degree-norm --fp16 \
        --fold-valid --scorer-lr 1e-3 --scorer-epochs 5 --beta-deg 0.0
"""

import argparse, time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

# --------------------------- Utils ---------------------------


def bipolar(shape, generator, device, dtype=torch.float32):
    # generator MUST match `device`
    t = torch.rand(shape, generator=generator, device=device, dtype=torch.float32)
    return (t > 0.5).to(dtype).mul_(2).sub_(1)


def make_undirected_edge_index(edge_uv: torch.Tensor, num_nodes: int) -> torch.Tensor:
    # edge_uv: [E, 2]
    u, v = edge_uv[:, 0], edge_uv[:, 1]
    uu = torch.cat([u, v], dim=0)
    vv = torch.cat([v, u], dim=0)
    return torch.stack([uu, vv], dim=0)  # [2, 2E]


# --------------------------- Learnable Scorer ---------------------------


class VSAScorer(torch.nn.Module):
    """Learns a linear projection over the Hadamard product of two vectors."""

    def __init__(self, d, dtype=torch.float32):
        super().__init__()
        self.d = d
        # ALWAYS use float32 for the learnable weight for stability
        self.W = torch.nn.Parameter(torch.randn(d, dtype=torch.float32) * 0.1)

    def forward(self, Hu, Hv):
        # Hu, Hv: [batch, d]
        hadamard_product = Hu * Hv
        return torch.einsum(
            "bd,d->b", hadamard_product.to(torch.float32), self.W.to(torch.float32)
        )


# --------------------------- VSA Sketch ---------------------------


class VSASketch(torch.nn.Module):
    def __init__(
        self,
        num_nodes,
        d=4096,
        K=2,
        gamma=0.7,
        decay=0.985,
        device="cuda",
        seed=42,
        fp16=False,
        degree_norm=False,
        inv_sqrt_deg=None,
    ):
        super().__init__()
        self.n = num_nodes
        self.d = d
        self.K = K
        self.gamma = gamma
        self.decay = decay
        self.device = torch.device(
            device if (device == "cpu" or torch.cuda.is_available()) else "cpu"
        )
        self.dtype = torch.float16 if fp16 else torch.float32

        # generator tied to device
        g = torch.Generator(device=self.device.type).manual_seed(seed)

        # Bases + binding vectors
        self.B = bipolar((self.n, d), g, self.device, torch.float32).to(self.dtype)
        self.Rk = torch.stack(
            [bipolar((d,), g, self.device, torch.float32) for _ in range(K)], dim=0
        ).to(self.dtype)

        # Permutations
        base_perm = torch.randperm(d, generator=g, device=self.device)
        self.perm_k = []
        idx = base_perm
        for k in range(1, K + 1):
            if k > 1:
                idx = idx[base_perm]
            self.perm_k.append(idx.clone())

        # Node signatures
        self.H = torch.nn.Parameter(
            F.normalize(self.B.clone(), p=2, dim=1), requires_grad=False
        )
        self.scorer = VSAScorer(d, self.dtype)

        # Degree norm
        self.degree_norm = degree_norm
        if degree_norm and inv_sqrt_deg is not None:
            self.register_buffer(
                "inv_sqrt_deg", inv_sqrt_deg.to(self.device).to(self.dtype)
            )
        else:
            self.inv_sqrt_deg = None

    @torch.no_grad()
    def update_from_edges_timed(
        self,
        edge_uv: torch.Tensor,  # [E, 2] with u,v indices (train edges <=2017)
        year: torch.Tensor,  # [E] years for these edges
        weight: torch.Tensor,  # [E] counts/weights for these edges
        t_ref: int,  # 2018 for valid build; 2019 when folding valid for test
        lam_time: float = 0.25,
        iterations: int = 1,
        edge_batch: int = 300_000,
    ):
        dev = self.device
        dtype = self.dtype
        E = edge_uv.size(0)

        # Undirected augmentation for message passing
        ei = make_undirected_edge_index(edge_uv, self.n).to(dev)  # [2, 2E]
        src, dst = ei[0], ei[1]

        # Per-edge weights: temporal decay * sqrt(count)
        year = year.to(dev).to(torch.float32)
        weight = weight.to(dev).to(torch.float32)
        time_w = torch.exp(-lam_time * (t_ref - year).clamp(min=0.0))  # newer ~1.0
        cnt_w = torch.sqrt(weight.clamp(min=1.0))
        base_w = (time_w * cnt_w).to(dtype)  # [E]
        base_w = torch.cat([base_w, base_w], dim=0).view(-1, 1)  # [2E, 1]

        # Iterate
        for _ in range(iterations):
            for start in tqdm(range(0, src.numel(), edge_batch), desc="Building H"):
                end = min(start + edge_batch, src.numel())
                s = src[start:end]
                t = dst[start:end]
                w = base_w[start:end]  # [e,1]
                acc = torch.zeros((self.n, self.d), device=dev, dtype=dtype)
                Hprev = self.H  # [n,d]
                for hop in range(self.K):
                    msg = Hprev[s] * self.Rk[hop]
                    msg = msg.index_select(dim=-1, index=self.perm_k[hop])
                    if self.degree_norm and (self.inv_sqrt_deg is not None):
                        scale = (self.inv_sqrt_deg[s] * self.inv_sqrt_deg[t]).view(
                            -1, 1
                        )
                        msg = msg * scale
                    wk = self.gamma ** (hop + 1)
                    acc.index_add_(0, t, wk * w * msg)
                self.H.mul_(self.decay)
                self.H.add_(acc, alpha=(1.0 - self.decay))
                self.H.copy_(F.normalize(self.H, p=2, dim=1))

    def pair_score(
        self,
        u: torch.Tensor,
        v: torch.Tensor,
        deg: torch.Tensor | None = None,
        beta_deg: float = 0.0,
    ):
        Hu = self.H[u.to(self.device)]
        Hv = self.H[v.to(self.device)]
        sim = self.scorer(Hu, Hv)
        if (deg is not None) and (beta_deg != 0.0):
            deg_dev = deg.to(self.device)
            sim = sim - beta_deg * torch.log1p(deg_dev[v].to(sim.dtype))
        return sim


# --------------------------- Scorer Training ---------------------------


def train_scorer(
    sketch: VSASketch,
    pos_train_edges: torch.Tensor,
    neg_train_edges: torch.Tensor,
    pos_val_edges: torch.Tensor,
    neg_val_edges: torch.Tensor,
    evaluator: Evaluator,
    epochs: int = 5,
    lr: float = 1e-3,
    batch_size: int = 32768,
):
    print("\n--- Training Scorer on Training Set (monitoring with Validation Set) ---")
    optimizer = torch.optim.Adam(sketch.scorer.parameters(), lr=lr, weight_decay=1e-5)
    sketch.scorer.to(sketch.device)
    sketch.scorer.train()

    pos_train_edges = pos_train_edges.to(sketch.device)
    neg_train_edges = neg_train_edges.to(sketch.device)

    for epoch in range(1, epochs + 1):
        total_loss = 0
        num_batches = 0

        # Shuffle data for each epoch
        perm = torch.randperm(pos_train_edges.size(0), device=sketch.device)
        pos_train_edges_shuffled = pos_train_edges[perm]

        perm = torch.randperm(neg_train_edges.size(0), device=sketch.device)
        neg_train_edges_shuffled = neg_train_edges[perm]

        pbar = tqdm(
            range(0, pos_train_edges.size(0), batch_size),
            desc=f"Epoch {epoch}/{epochs} Training",
        )
        for st in pbar:
            optimizer.zero_grad()

            # Positive batch
            ed = min(st + batch_size, pos_train_edges.size(0))
            pos_batch = pos_train_edges_shuffled[st:ed]
            Hu_pos = sketch.H[pos_batch[:, 0]].to(torch.float32)
            Hv_pos = sketch.H[pos_batch[:, 1]].to(torch.float32)
            pos_scores = sketch.scorer(Hu_pos, Hv_pos)

            # Negative batch (match size)
            neg_batch = neg_train_edges_shuffled[st:ed]
            Hu_neg = sketch.H[neg_batch[:, 0]].to(torch.float32)
            Hv_neg = sketch.H[neg_batch[:, 1]].to(torch.float32)
            neg_scores = sketch.scorer(Hu_neg, Hv_neg)

            # Margin ranking loss
            target = torch.ones_like(pos_scores)
            loss = F.margin_ranking_loss(pos_scores, neg_scores, target, margin=1.0)

            if torch.isnan(loss):
                print("\n[FATAL] Loss became NaN. Stopping training.")
                return

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
            pbar.set_postfix({"loss": loss.item()})

        avg_loss = total_loss / num_batches if num_batches > 0 else 0

        # --- Validation monitoring at the end of each epoch ---
        sketch.scorer.eval()  # Set to evaluation mode
        print(f"\n--- Running validation for Epoch {epoch} ---")
        y_pos_v, y_neg_v = score_pos_neg_flat(
            sketch, pos_val_edges, neg_val_edges, deg=None, batch=1_000_000
        )
        evaluator.K = 50  # Monitor with a single K value for speed
        hits50 = evaluator.eval({"y_pred_pos": y_pos_v, "y_pred_neg": y_neg_v})[
            f"hits@50"
        ]
        print(
            f"Epoch {epoch}/{epochs} | Avg Train Loss: {avg_loss:.4f} | Valid Hits@50: {hits50:.4f}"
        )
        sketch.scorer.train()  # Set back to train mode for the next epoch


# --------------------------- Evaluation (official shape) ---------------------------


@torch.inference_mode()
def score_pos_neg_flat(
    sketch: VSASketch,
    pos_edges: torch.Tensor,
    neg_edges: torch.Tensor,
    deg: torch.Tensor | None,
    beta_deg: float = 0.0,
    batch: int = 1_000_000,
):
    sketch.scorer.eval()  # Set scorer to evaluation mode
    pos_edges = pos_edges.to(sketch.device)
    neg_edges = neg_edges.to(sketch.device)

    # Positives
    y_pos = []
    for st in tqdm(range(0, pos_edges.size(0), batch), desc="Scoring Positives"):
        ed = min(st + batch, pos_edges.size(0))
        u, v = pos_edges[st:ed, 0], pos_edges[st:ed, 1]
        y = sketch.pair_score(u, v, deg=deg, beta_deg=beta_deg)
        y_pos.append(y.detach().cpu())
    y_pred_pos = torch.cat(y_pos, dim=0)

    # Negatives
    y_neg = []
    for st in tqdm(range(0, neg_edges.size(0), batch), desc="Scoring Negatives"):
        ed = min(st + batch, neg_edges.size(0))
        u, v = neg_edges[st:ed, 0], neg_edges[st:ed, 1]
        y = sketch.pair_score(u, v, deg=deg, beta_deg=beta_deg)
        y_neg.append(y.detach().cpu())
    y_pred_neg = torch.cat(y_neg, dim=0)

    return y_pred_pos, y_pred_neg


# --------------------------- Main ---------------------------


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
    ap.add_argument("--d", type=int, default=8192)
    ap.add_argument("--K", type=int, default=3)
    ap.add_argument("--gamma", type=float, default=0.6)
    ap.add_argument(
        "--decay",
        type=float,
        default=0.0,
        help="Update decay for H vectors. 0.0 is recommended for single-pass.",
    )
    ap.add_argument("--iters", type=int, default=1)
    ap.add_argument("--edge-batch", type=int, default=400_000)
    ap.add_argument("--fp16", action="store_true")
    ap.add_argument("--degree-norm", action="store_true")
    ap.add_argument("--lam-time", type=float, default=0.25)
    ap.add_argument(
        "--beta-deg", type=float, default=0.05, help="degree bias correction at scoring"
    )
    ap.add_argument(
        "--fold-valid",
        action="store_true",
        help="fold 2018 edges before scoring 2019 (test)",
    )
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument(
        "--scorer-lr", type=float, default=1e-3, help="Learning rate for the scorer."
    )
    ap.add_argument(
        "--scorer-epochs",
        type=int,
        default=5,
        help="Number of epochs to train the scorer.",
    )
    ap.add_argument(
        "--scorer-batch-size",
        type=int,
        default=32768,
        help="Batch size for scorer training.",
    )
    args = ap.parse_args()

    torch.manual_seed(args.seed)
    if args.device == "cuda" and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    t0 = time.time()
    dataset = PygLinkPropPredDataset(name="ogbl-collab", root="./ogbdata")
    split_edge = dataset.get_edge_split()
    data = dataset[0]
    n = int(data.num_nodes)

    # Train edges (<=2017)
    train_edge = split_edge["train"]["edge"].to(torch.long)
    train_year = split_edge["train"].get(
        "year", torch.full((train_edge.size(0),), 2017, dtype=torch.long)
    )
    train_wt = split_edge["train"].get(
        "weight", torch.ones(train_edge.size(0), dtype=torch.long)
    )

    deg = torch.bincount(
        torch.cat([train_edge[:, 0], train_edge[:, 1]], dim=0), minlength=n
    ).to(torch.long)
    inv_sqrt_deg = deg.clamp(min=1).to(torch.float32).pow(-0.5)

    device = torch.device(
        args.device if (args.device == "cpu" or torch.cuda.is_available()) else "cpu"
    )
    print(
        f"[info] ogbl-collab loaded: n={n}, E_train={train_edge.size(0)} (build in {time.time() - t0:.2f}s)"
    )

    sketch = VSASketch(
        num_nodes=n,
        d=args.d,
        K=args.K,
        gamma=args.gamma,
        decay=args.decay,
        device=device.type,
        seed=args.seed,
        fp16=args.fp16,
        degree_norm=args.degree_norm,
        inv_sqrt_deg=inv_sqrt_deg,
    ).to(device)

    # 1. Build VSA sketch on training data (2017 and older)
    t1 = time.time()
    sketch.update_from_edges_timed(
        edge_uv=train_edge,
        year=train_year,
        weight=train_wt,
        t_ref=2018,
        lam_time=args.lam_time,
        iterations=args.iters,
        edge_batch=args.edge_batch,
    )
    print(f"[info] VSA build: {time.time() - t1:.2f}s (d={args.d}, K={args.K})")

    evaluator = Evaluator(name="ogbl-collab")
    pos_val = split_edge["valid"]["edge"].to(torch.long)
    neg_val = split_edge["valid"]["edge_neg"].to(torch.long)

    # --- Generate negative samples for the training set ---
    # The training split only provides positive edges. To train the scorer
    # with a ranking loss, we need negative examples. We generate them by
    # randomly sampling pairs of nodes.
    print("[info] Generating negative training samples...")
    num_neg_samples = train_edge.size(0)
    neg_train_edge = torch.randint(0, n, (num_neg_samples, 2), dtype=torch.long)
    print(f"[info] Generated {neg_train_edge.size(0)} negative samples.")

    # 2. Train scorer on TRAINING set, monitor with VALIDATION set
    # H vectors are fixed. We only learn scoring weights.
    train_scorer(
        sketch,
        train_edge,
        neg_train_edge,  # Training data
        pos_val,
        neg_val,  # Monitoring data
        evaluator,  # Evaluator for monitoring
        epochs=args.scorer_epochs,
        lr=args.scorer_lr,
        batch_size=args.scorer_batch_size,
    )

    # 3. Evaluate FINAL model on VALID (2018) with the trained scorer
    y_pos_v, y_neg_v = score_pos_neg_flat(
        sketch, pos_val, neg_val, deg=deg, beta_deg=args.beta_deg, batch=1_000_000
    )

    print("\n=== FINAL VALID (2018) PERFORMANCE")
    for K in [10, 50, 100]:
        evaluator.K = K
        hits = evaluator.eval({"y_pred_pos": y_pos_v, "y_pred_neg": y_neg_v})[
            f"hits@{K}"
        ]
        print(f"Hits@{K}: {hits:.4f}")

    # 4. Optionally fold 2018 edges into H before scoring TEST (2019)
    if args.fold_valid:
        vt1 = time.time()
        val_year = split_edge["valid"].get(
            "year", torch.full((pos_val.size(0),), 2018, dtype=torch.long)
        )
        val_wt = split_edge["valid"].get(
            "weight", torch.ones(pos_val.size(0), dtype=torch.long)
        )
        sketch.update_from_edges_timed(
            edge_uv=pos_val,
            year=val_year,
            weight=val_wt,
            t_ref=2019,
            lam_time=args.lam_time,
            iterations=1,
            edge_batch=args.edge_batch,
        )
        print(f"[info] Folded 2018 edges before TEST in {time.time() - vt1:.2f}s")

    # 5. Evaluate FINAL model on TEST (2019)
    pos_test = split_edge["test"]["edge"].to(torch.long)
    neg_test = split_edge["test"]["edge_neg"].to(torch.long)
    y_pos_t, y_neg_t = score_pos_neg_flat(
        sketch, pos_test, neg_test, deg=deg, beta_deg=args.beta_deg, batch=1_000_000
    )

    print("\n=== FINAL TEST (2019) PERFORMANCE")
    for K in [10, 50, 100]:
        evaluator.K = K
        hits = evaluator.eval({"y_pred_pos": y_pos_t, "y_pred_neg": y_neg_t})[
            f"hits@{K}"
        ]
        print(f"Hits@{K}: {hits:.4f}")

    total_mb = n * args.d * (2 if args.fp16 else 4) / (1024 * 1024)
    print(
        f"\n[info] Approx signature memory: ~{total_mb:.1f} MB on device={device.type}"
    )


if __name__ == "__main__":
    main()