diff --git a/torch_geometric/datasets/wikipedia_network.py b/torch_geometric/datasets/wikipedia_network.py index 726f389eed96..9b0a4804fa05 100644 --- a/torch_geometric/datasets/wikipedia_network.py +++ b/torch_geometric/datasets/wikipedia_network.py @@ -27,6 +27,9 @@ class WikipediaNetwork(InMemoryDataset): into five categories to predict. If set to :obj:`True`, the dataset :obj:`"crocodile"` is not available. + If set to :obj:`True`, train/validation/test splits will be + available as masks for multiple splits with shape + :obj:`[num_nodes, num_splits]`. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. @@ -42,9 +45,14 @@ class WikipediaNetwork(InMemoryDataset): processed_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/' 'geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f') - def __init__(self, root: str, name: str, geom_gcn_preprocess: bool = True, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): + def __init__( + self, + root: str, + name: str, + geom_gcn_preprocess: bool = True, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + ): self.name = name.lower() self.geom_gcn_preprocess = geom_gcn_preprocess assert self.name in ['chameleon', 'crocodile', 'squirrel']