diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..795e8247bffd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722)) - Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) +- Added `PPRGo` implementation and example ([#9847](https://github.com/pyg-team/pytorch_geometric/pull/9847)) +- Allow top-k sparsification in `utils.get_ppr` and `transforms.GDC` ([#9847](https://github.com/pyg-team/pytorch_geometric/pull/9847)) ### Changed diff --git a/examples/pprgo.py b/examples/pprgo.py new file mode 100644 index 000000000000..d9fdad1413a8 --- /dev/null +++ b/examples/pprgo.py @@ -0,0 +1,112 @@ +# Example usage of PPRGo (from the paper "Scaling Graph Neural Networks +# with Approximate PageRank). This trains on a small subset of the Reddit +# graph dataset and predicts on the rest. + +import argparse +import os.path as osp +import time + +import torch +import torch.nn.functional as F +from sklearn.metrics import accuracy_score, f1_score + +from torch_geometric.data import Data +from torch_geometric.datasets import Reddit +from torch_geometric.nn.models import PPRGo, pprgo_prune_features +from torch_geometric.transforms import GDC +from torch_geometric.utils import subgraph + +# Parameters +parser = argparse.ArgumentParser() +parser.add_argument('--alpha', type=float, default=0.5) +parser.add_argument('--eps', type=float, default=1e-3) +parser.add_argument('--topk', type=int, default=32) +parser.add_argument('--n_train', type=int, default=1000) +parser.add_argument('--hidden_size', type=int, default=64) +parser.add_argument('--n_layers', type=int, default=2) +parser.add_argument('--dropout', type=float, default=0.1) +parser.add_argument('--lr', type=float, default=0.005) +parser.add_argument('--n_epochs', type=int, default=50) +parser.add_argument('--n_power_iters', type=int, default=2) +parser.add_argument('--frac_predict', type=float, default=0.2) +args = parser.parse_args() + +# Load Reddit dataset +s = time.time() +name = 'Reddit' +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) +data = Reddit(path)[0] +print(f'Loaded Reddit dataset in {round(time.time() - s, 2)}s') + +# Get random split of train/test nodes (|train| << |test|) +s = time.time() +ind = torch.randperm(len(data.x), dtype=torch.long) +train_idx = ind[:args.n_train] +test_idx = ind[args.n_train:] + +train_edge_index, _ = subgraph(train_idx, data.edge_index, relabel_nodes=True) +train = Data(x=data.x[train_idx], edge_index=train_edge_index, + y=data.y[train_idx]) +print(f'Split data into {len(train_idx)} train and {len(test_idx)} test nodes ' + f'in {round(time.time() - s, 2)}s') + +# Set up ppr transform via gdc +s = time.time() +ppr = GDC( + exact=False, normalization_in='sym', normalization_out=None, + diffusion_kwargs=dict( + method='ppr', + alpha=args.alpha, + eps=args.eps, + ), sparsification_kwargs=dict(method='topk', topk=args.topk)) +train = ppr(train) +print(f'Ran PPR on {args.n_train} train nodes in {round(time.time() - s, 2)}s') + +# Set up model and optimizer +model = PPRGo(num_features=data.x.shape[1], num_classes=data.y.max() + 1, + hidden_size=args.hidden_size, n_layers=args.n_layers, + dropout=args.dropout) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3) + +# Run training, single-batch full graph +s = time.time() +print() +print('=' * 40) +model.train() +train = pprgo_prune_features(train) +for i in range(1, args.n_epochs + 1): + optimizer.zero_grad() + pred = model(train.x, train.edge_index, train.edge_attr) + + loss = F.cross_entropy(pred, train.y) + top1 = torch.argmax(pred, dim=1) + acc = (torch.sum(top1 == train.y) / len(train.y)).item() + + loss.backward() + optimizer.step() + if i % 10 == 0: + print(f'Epoch {i}: loss = {round(loss.item(), 3)}, ' + f'acc = {round(acc, 3)}') + +print('=' * 40) +print(f'Finished training in {round(time.time() - s, 2)}s') + +# Do inference on the test graph and report performance +s = time.time() +preds = model.predict_power_iter(data.x, data.edge_index, args.n_power_iters, + args.alpha, args.frac_predict) +preds = preds.argmax(1) +labels = data.y + +acc_train = 100 * accuracy_score(labels[train_idx], preds[train_idx]) +acc_test = 100 * accuracy_score(labels[test_idx], preds[test_idx]) + +f1_train = f1_score(labels[train_idx], preds[train_idx], average='macro') +f1_test = f1_score(labels[test_idx], preds[test_idx], average='macro') + +print() +print('=' * 40) +print(f'Accuracy ... train: {acc_train:.1f}%, test: {acc_test:.1f}%') +print(f'F1 score ... train: {f1_train:.3f}, test: {f1_test:.3f}') +print('=' * 40) +print(f'Ran inference in {round(time.time() - s, 2)}s') diff --git a/test/nn/models/test_pprgo.py b/test/nn/models/test_pprgo.py new file mode 100644 index 000000000000..013de2209aeb --- /dev/null +++ b/test/nn/models/test_pprgo.py @@ -0,0 +1,69 @@ +import pytest +import torch + +from torch_geometric.datasets import KarateClub +from torch_geometric.nn.models.pprgo import PPRGo, pprgo_prune_features +from torch_geometric.testing import withPackage + + +@pytest.mark.parametrize('n_layers', [1, 4]) +@pytest.mark.parametrize('dropout', [0.0, 0.2]) +def test_pprgo_forward(n_layers, dropout): + num_nodes = 100 + num_edges = 500 + num_features = 64 + num_classes = 16 + hidden_size = 64 + + # Need to ensure edge_index is sorted + full so dim size checks are right + # edge_index should contain all num_node unique nodes (ie every node is + # connected to at least one destination, since we truncate by topk) + edge_index = torch.stack([ + torch.arange(0, num_nodes).repeat(num_edges // num_nodes), + torch.randint(0, num_nodes, [num_edges]) + ], dim=0) + + # Mimic the behavior of pprgo_prune_features manually + # i.e., we expect node_embeds to be |V| x d + node_embeds = torch.rand((num_nodes, num_features)) + node_embeds = node_embeds[edge_index[1], :] + + edge_weight = torch.rand(num_edges) + + model = PPRGo(num_features, num_classes, hidden_size, n_layers, dropout) + pred = model(node_embeds, edge_index, edge_weight) + assert pred.size() == (num_nodes, num_classes) + + +def test_pprgo_karate(): + data = KarateClub()[0] + num_nodes = data.num_nodes + + data = pprgo_prune_features(data) + data.edge_weight = torch.ones((data.edge_index.shape[1], )) + + assert data.x.shape[0] == data.edge_index.shape[1] + num_classes = 16 + model = PPRGo(num_nodes, num_classes, hidden_size=64, n_layers=3, + dropout=0.0) + pred = model(data.x, data.edge_index, data.edge_weight) + assert pred.shape == (num_nodes, num_classes) + + +@pytest.mark.parametrize('n_power_iters', [1, 3]) +@pytest.mark.parametrize('frac_predict', [1.0, 0.5]) +@pytest.mark.parametrize('batch_size', [1, 9999]) +@withPackage('torch_sparse') +def test_pprgo_inference(n_power_iters, frac_predict, batch_size): + data = KarateClub()[0] + num_nodes = data.num_nodes + + data = pprgo_prune_features(data) + data.edge_weight = torch.rand(data.edge_index.shape[1]) + + num_classes = 16 + model = PPRGo(num_nodes, num_classes, 64, 3, 0.0) + logits = model.predict_power_iter(data.x, data.edge_index, n_power_iters, + frac_predict, batch_size=batch_size) + assert torch.all(~torch.isnan(logits)) + assert logits.shape == (data.x.shape[0], num_classes) diff --git a/test/utils/test_ppr.py b/test/utils/test_ppr.py index ca1a268d7735..1262df4649da 100644 --- a/test/utils/test_ppr.py +++ b/test/utils/test_ppr.py @@ -8,15 +8,12 @@ @withPackage('numba') @pytest.mark.parametrize('target', [None, torch.tensor([0, 4, 5, 6])]) -def test_get_ppr(target): +@pytest.mark.parametrize('topk', [None, 1, 3, 7]) +def test_get_ppr(target, topk): data = KarateClub()[0] - edge_index, edge_weight = get_ppr( - data.edge_index, - alpha=0.1, - eps=1e-5, - target=target, - ) + edge_index, edge_weight = get_ppr(data.edge_index, alpha=0.1, eps=1e-5, + target=target, topk=topk) assert edge_index.size(0) == 2 assert edge_index.size(1) == edge_weight.numel() @@ -26,3 +23,4 @@ def test_get_ppr(target): assert edge_index[0].min() == min_row and edge_index[0].max() == max_row assert edge_index[1].min() >= 0 and edge_index[1].max() < data.num_nodes assert edge_weight.min() >= 0.0 and edge_weight.max() <= 1.0 + assert len(edge_weight) % (topk or 1) == 0 diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 9ade58cebc05..94220623696a 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -32,54 +32,20 @@ from .git_mol import GITMol from .molecule_gpt import MoleculeGPT from .glem import GLEM +from .pprgo import PPRGo, pprgo_prune_features # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) __all__ = classes = [ - 'MLP', - 'GCN', - 'GraphSAGE', - 'GIN', - 'GAT', - 'PNA', - 'EdgeCNN', - 'JumpingKnowledge', - 'HeteroJumpingKnowledge', - 'MetaLayer', - 'Node2Vec', - 'DeepGraphInfomax', - 'InnerProductDecoder', - 'GAE', - 'VGAE', - 'ARGA', - 'ARGVA', - 'SignedGCN', - 'RENet', - 'GraphUNet', - 'SchNet', - 'DimeNet', - 'DimeNetPlusPlus', - 'to_captum_model', - 'to_captum_input', - 'captum_output_to_dicts', - 'MetaPath2Vec', - 'DeepGCNLayer', - 'TGNMemory', - 'LabelPropagation', - 'CorrectAndSmooth', - 'AttentiveFP', - 'RECT_L', - 'LINKX', - 'LightGCN', - 'MaskLabel', - 'GroupAddRev', - 'GNNFF', - 'PMLP', - 'NeuralFingerprint', - 'ViSNet', - 'GRetriever', - 'GITMol', - 'MoleculeGPT', - 'GLEM', + 'MLP', 'GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA', 'EdgeCNN', + 'JumpingKnowledge', 'HeteroJumpingKnowledge', 'MetaLayer', 'Node2Vec', + 'DeepGraphInfomax', 'InnerProductDecoder', 'GAE', 'VGAE', 'ARGA', 'ARGVA', + 'SignedGCN', 'RENet', 'GraphUNet', 'SchNet', 'DimeNet', 'DimeNetPlusPlus', + 'to_captum_model', 'to_captum_input', 'captum_output_to_dicts', + 'MetaPath2Vec', 'DeepGCNLayer', 'TGNMemory', 'LabelPropagation', + 'CorrectAndSmooth', 'AttentiveFP', 'RECT_L', 'LINKX', 'LightGCN', + 'MaskLabel', 'GroupAddRev', 'GNNFF', 'PMLP', 'NeuralFingerprint', 'ViSNet', + 'GRetriever', 'GITMol', 'MoleculeGPT', 'GLEM', 'PPRGo', + 'pprgo_prune_features' ] diff --git a/torch_geometric/nn/models/pprgo.py b/torch_geometric/nn/models/pprgo.py new file mode 100644 index 000000000000..052ae1898d29 --- /dev/null +++ b/torch_geometric/nn/models/pprgo.py @@ -0,0 +1,225 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from torch_geometric.data import Data +from torch_geometric.typing import Adj, Tensor +from torch_geometric.utils import is_sparse, scatter + + +def pprgo_prune_features(data: Data) -> Data: + r"""Prunes the node features in :obj:`data` so that the only vectors + loaded into memory correspond to target nodes, as prescribed by + :obj:`data.edge_index`. Useful for saving memory during PPRGo training. + + Args: + data (Data): Graph object. + + :rtype: :class:`Data` + """ + data.x = data.x[data.edge_index[1], :] + return data + + +class PPRGo(nn.Module): + r"""The PPRGo model, based on efficient propagation of approximate + personalized vectors from the `"Scaling Graph Neural Networks with + Approximate PageRank" `_ paper. + + Propagates pointwise predictions on node embeddings according to truncated + sparse PageRank vectors. Because this model considers all :math:`K`-hop + neighborhoods simultaneously, it is only one layer and fast. + + Prior to training, a sparse approximation of PPR vectors should be + efficiently precomputed via :class:`torch_geometric.transforms.gdc.GDC`. + This information is expected as :obj:`edge_index` and :obj:`edge_attr` + during the forward pass. + + Args: + num_features (int): Number of dimensions in node features. + num_classes (int): Number of output classes. + hidden_size (int): Number of hidden dimensions. + n_layers (int): Number of linear layers for pointwise node projections. + Minimum 1. + dropout (float, optional): Node dropout probability for diffused graph. + (default: :obj:`0.0`) + **kwargs (optional): Additional arguments to + :class:`torch.nn.Module`. + """ + def __init__(self, num_features: int, num_classes: int, hidden_size: int, + n_layers: int, dropout: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.num_features = num_features + self.num_classes = num_classes + self.hidden_size = hidden_size + self.n_layers = n_layers + self.dropout = dropout + + # Initialize MLP for feature computation, our only parameters + assert n_layers >= 1, "Must have at least 1 layer to ensure dims work" + if n_layers == 1: + fcs = [nn.Linear(num_features, num_classes, bias=False)] + else: + fcs = [nn.Linear(num_features, hidden_size, bias=False)] + for i in range(n_layers - 2): + fcs.append(nn.Linear(hidden_size, hidden_size, bias=False)) + fcs.append(nn.Linear(hidden_size, num_classes, bias=False)) + + self.fcs = nn.ModuleList(fcs) + self.drop = nn.Dropout(dropout) + + def _features(self, x: Tensor) -> Tensor: + r"""Compute MLP features on :obj:`x`. + + :rtype: :class:`Tensor` + """ + x = self.fcs[0](self.drop(x)) + for fc in self.fcs[1:]: + x = fc(self.drop(F.relu(x))) + return x + + def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor: + r"""Computes the forward pass through PPRGo, assuming PPR edges + and scores are precomputed and accessible via :obj:`edge_index` + and :obj:`edge_attr`. Note this expects a truncated node feature + matrix, which can be produced by calling + :obj:`torch_geometric.nn.models.pprgo.pprgo_prune_features` on the + train graph. + + Args: + x (Tensor): Truncated node feature matrix with shape + :obj:`[|V|, d]`. :obj:`|V|` is the number of edges. + Each row corresponds to an embedding of a destination node. + edge_index (Adj): :obj:`[2, |V|]` dense matrix of PPR edges. + edge_attr (Tensor): :obj:`[|V|,]` vector of PPR scores. + + :rtype: :class:`Tensor` + """ + if is_sparse(edge_index): # pragma: no cover + # We expect only the [2, |E|] dense format for now + raise ValueError("Sparse tensors not supported yet") + + # First perform node feature computation (compute logits) + x = self._features(x) + + # Next manually scatter along the correct edges via PPR + # It would be nice to use the MessagePassing base class here, + # but because of the safety checks and the fact that x is + # pre-indexed during data loading, this is not possible + weighted = x * edge_attr[:, None] + src_idx = edge_index[0] + dim_size = src_idx[-1] + 1 + + # We expect a src_idx with every node as a source node (ordered) + # since topk threshold will leave k>0 outgoing edges per node + return scatter(weighted, src_idx, dim=0, dim_size=dim_size, + reduce='sum') + + @torch.no_grad() + def predict_power_iter(self, x: Tensor, edge_index: Adj, + n_power_iters: int = 1, alpha: float = 0.15, + frac_predict: float = 1.0, + normalization: str = "sym", + batch_size: int = 8192) -> Tensor: + r"""Forward pass through PPRGo with power iteration instead of + computing all the sparse PPR vectors. Useful for large graphs. + + During inference, we only need to compute a small set of node + predictions. These labels are then propagated across the rest of + the nodes, assuming graph homophily. This propagation is expressed as + + .. math:: + \mathbf{Q}^{(0)} &= \mathbf{H} + + \mathbf{Q}^{(p+1)} &= (1-\alpha) \mathbf{D}^{-1} \mathbf{AQ}^{(p)} + + \alpha \mathbf{H} + + where :math:`\mathbf{H}` are the predictions on the reduced set. + + Since we might need inference on a large graph, we batch prediction + computations via :obj:`batch_size`. The returned logits are shape + :obj:`[|V|, num_classes]`. + + Args: + x (Tensor): Node features. + edge_index (Tensor): Adjacency matrix with shape :obj:`[2, |E|]`. + n_power_iters (int, optional): Number of power iterations. + (default: :obj:`1`) + alpha (float, optional): Teleportation probability. + (default: :obj:`0.15`) + frac_predict (float, optional): Fraction of nodes to run feature + computation for. All other features will be diffused during + message propagation but initially set to 0. + (default: :obj:`1.0`) + normalization (str, optional): Determines normalization of + :math:`\mathbf{A}` during power iteration. Should match the + :obj:`in_normalization` in + :obj:`torch_geometric.transforms.gdc.GDC`. + For now, only :obj:`'sym'` normalization is supported. + (default: :obj:`'sym'`) + batch_size (int, optional): Batch size for computing predictions + before label propagation. + (default: :obj:`8192`) + + :rtype: :class:`Tensor` + """ + assert n_power_iters >= 1, "Number of iterations must be positive int" + + # First, sample node embeddings along edges according to frac_predict + n_nodes = x.shape[0] + if frac_predict != 1.0: + ind = torch.randperm(n_nodes)[:int(frac_predict * n_nodes)] + ind, _ = torch.sort(ind) + x = x[ind] + else: + ind = torch.arange(0, n_nodes) + + # Then, compute logits on the selected nodes (on gpu if possible) + # propagating to non-selected nodes as well (assumes graph homophily) + # Since even the number of selected nodes may be large, batch inference + device = next(self.fcs.parameters()).device + train = self.fcs.training + if train: + self.fcs.eval() + + sele_logits = [] + for j in range(0, n_nodes, batch_size): + x_batch = x[j:j + batch_size].to(device) + preds = self._features(x_batch).cpu() + sele_logits.append(preds) + + sele_logits = torch.vstack(sele_logits) + if train: + self.fcs.train() + + # Set all other logits to zero in the graph, they will get + # filled in when we propagate the selected nodes + logits_init = torch.zeros((n_nodes, sele_logits.shape[1]), + dtype=torch.float32) + logits_init[ind] = sele_logits.to(torch.float32) + + # Finally, run power iteration (differ based on normalization) + try: + from torch_sparse import SparseTensor + except ImportError: # pragma: no cover + raise ValueError( + "Cannot find torch_sparse package, needed for inference") + + logits = logits_init.clone() + adj = SparseTensor(row=edge_index[0], col=edge_index[1], + sparse_sizes=(n_nodes, n_nodes)) + + if normalization == 'sym': + # Assume undirected (symmetric) adjacency matrix + # (In practice, topk sparsification usually leads to some rounding + # errors which slightly violate this symmetry) + denom = torch.maximum(adj.sum(1).flatten(), torch.Tensor([1e-12])) + deg_sqrt_inv = torch.unsqueeze(1. / torch.sqrt(denom), dim=1) + for j in range(n_power_iters): + deg_adj_logits = adj @ (deg_sqrt_inv * logits) + logits = ((1 - alpha) * deg_sqrt_inv * deg_adj_logits + + alpha * logits_init) + else: # pragma: no cover + raise NotImplementedError(normalization + " norm not implemented") + + return logits diff --git a/torch_geometric/transforms/gdc.py b/torch_geometric/transforms/gdc.py index 0b0ecddf0260..cd347b081d5a 100644 --- a/torch_geometric/transforms/gdc.py +++ b/torch_geometric/transforms/gdc.py @@ -306,11 +306,9 @@ def diffusion_matrix_approx( # noqa: D417 deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum') edge_index, edge_weight = get_ppr( - edge_index, - alpha=kwargs['alpha'], - eps=kwargs['eps'], + edge_index, alpha=kwargs['alpha'], eps=kwargs['eps'], num_nodes=num_nodes, - ) + topk=self.sparsification_kwargs.get('topk')) if normalization == 'col': edge_index, edge_weight = sort_edge_index( @@ -456,8 +454,8 @@ def sparsify_sparse( # noqa: D417 edge_index = edge_index[:, remaining_edge_idx] edge_weight = edge_weight[remaining_edge_idx] elif method == 'topk': - raise NotImplementedError( - 'Sparse topk sparsification not implemented') + # Handled directly in ppr computation + pass else: raise ValueError(f"GDC sparsification '{method}' unknown") diff --git a/torch_geometric/utils/ppr.py b/torch_geometric/utils/ppr.py index e6b4043dc303..f91afb7cad13 100644 --- a/torch_geometric/utils/ppr.py +++ b/torch_geometric/utils/ppr.py @@ -21,6 +21,7 @@ def _get_ppr( # pragma: no cover alpha: float, eps: float, target: Optional[np.ndarray] = None, + topk: Optional[int] = None, ) -> Tuple[List[List[int]], List[List[float]]]: num_nodes = len(rowptr) - 1 if target is None else len(target) @@ -65,8 +66,15 @@ def _get_ppr( # pragma: no cover if vnode not in q: q.append(vnode) - js[inode_uint] = list(p.keys()) - vals[inode_uint] = list(p.values()) + p_keys = np.array(list(p.keys())) + p_vals = np.array(list(p.values())) + if topk is not None and len(p_keys) > topk: + topk_ind = np.argsort(p_vals)[-topk:] + p_keys = p_keys[topk_ind] + p_vals = p_vals[topk_ind] + + js[inode_uint] = [k for k in p_keys] + vals[inode_uint] = [v for v in p_vals] return js, vals @@ -74,13 +82,9 @@ def _get_ppr( # pragma: no cover _get_ppr_numba: Optional[Callable] = None -def get_ppr( - edge_index: Tensor, - alpha: float = 0.2, - eps: float = 1e-5, - target: Optional[Tensor] = None, - num_nodes: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: +def get_ppr(edge_index: Tensor, alpha: float = 0.2, eps: float = 1e-5, + target: Optional[Tensor] = None, num_nodes: Optional[int] = None, + topk: Optional[int] = None) -> Tuple[Tensor, Tensor]: r"""Calculates the personalized PageRank (PPR) vector for all or a subset of nodes using a variant of the `Andersen algorithm `_. @@ -95,6 +99,8 @@ def get_ppr( If not given, calculates PPR vectors for all nodes. (default: :obj:`None`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) + topk (int, optional): If not None, store the :obj:`k` largest entries + per PPR vector. (default: :obj:`None`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`) """ @@ -112,11 +118,8 @@ def get_ppr( cols, weights = _get_ppr_numba( rowptr.cpu().numpy(), - col.cpu().numpy(), - alpha, - eps, - None if target is None else target.cpu().numpy(), - ) + col.cpu().numpy(), alpha, eps, + None if target is None else target.cpu().numpy(), topk) device = edge_index.device col = torch.tensor(list(chain.from_iterable(cols)), device=device)