Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for GraphSaint dataset ('yelp', et.c.) #19

Open
hengruizhang98 opened this issue Oct 17, 2023 · 0 comments
Open

Support for GraphSaint dataset ('yelp', et.c.) #19

hengruizhang98 opened this issue Oct 17, 2023 · 0 comments

Comments

@hengruizhang98
Copy link
Contributor

We have to add support for graph datasets of multiple labels (tasks), e.g., yelp dataset from GraphSaint paper.

Below is an instruction about how to download the raw dataset and process them into the DGL format:

pip install gdown
mkdir data
cd data

# download datasets from https://drive.google.com/drive/folders/1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz

# The following commmands fail due to permission issues

# gdown --id "1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz" --folder
# gdown --id "1apP2Qn8r6G0jQXykZHyNT6Lz2pgzcQyL" --folder

# You can download the required datasets manuallly and place them in the data folder

# use the following command to process the dataset in the the DGL format
python process_graphsaint_data.py 
import scipy.sparse as sp
import numpy as np
import networkx as nx
import sys
import json
import os
import dgl
import torch
from networkx.readwrite import json_graph

def to_bidirected(graph):
    num_nodes = graph.num_nodes()
    
    graph = graph.remove_self_loop()
    src, dst = graph.edges()
    
    new_src = torch.cat([src, dst])
    new_dst = torch.cat([dst, src])
    
    new_graph = dgl.graph((new_src, new_dst), num_nodes = num_nodes)
    
    return new_graph

def load_graphsaint_dataset(name, root):
    
    dataset_dir = f'{root}/{name}/'
    save_dir = f'{root}/{name}/processed'

    if os.path.exists(f'{save_dir}/graph.bin'):
        print('loading saved graph')
        graph = dgl.load_graphs(f'{save_dir}/graph.bin')[0][0]
        
    else:
        os.makedirs(save_dir)
        adj_full=sp.load_npz(dataset_dir+'adj_full.npz')
        G=nx.from_scipy_sparse_matrix(adj_full)
        print('nx: finish load graph')
        data=json_graph.node_link_data(G)
        role=json.load(open(dataset_dir+'role.json','r'))
        te=set(role['te'])
        va=set(role['va'])
        
        num_nodes = G.number_of_nodes()
        
        train_mask = torch.ones(num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(num_nodes, dtype=torch.bool)
        
        train_mask[list(te)] = False
        train_mask[list(va)] = False
        val_mask[list(va)] = True
        test_mask[list(te)] = True
        
        feat = torch.tensor(np.load(dataset_dir+'feats.npy').astype(np.float32))

        class_map=json.load(open(dataset_dir+'class_map.json','r'))
        label = []
        for i in range(num_nodes):
            label.append(class_map[str(i)])
        label = torch.tensor(label).long()

        edges = list(G.edges)
        src, dst = zip(*edges)

        graph = dgl.graph((src, dst), num_nodes = num_nodes)
        
        graph.ndata['feat'] = feat
        graph.ndata['label'] = label
        
        graph.ndata['train_mask'] = train_mask
        graph.ndata['val_mask'] = val_mask
        graph.ndata['test_mask'] = test_mask
        
        dgl.save_graphs(f"{save_dir}/graph.bin", graph)    
        
    feat = graph.ndata.pop('feat')
    label = graph.ndata.pop('label')
    
    train_idx = torch.nonzero(graph.ndata['train_mask'], as_tuple=True)[0]    
    val_idx = torch.nonzero(graph.ndata['val_mask'], as_tuple=True)[0]
    test_idx = torch.nonzero(graph.ndata['test_mask'], as_tuple=True)[0]
    
    print(f'Total nodes: {graph.num_nodes()}, Train nodes: {len(train_idx)}, Val nodes: {len(val_idx)}, Test nodes: {len(test_idx)}')
    
    num_class = label.max() + 1
    
    graph = to_bidirected(graph)
    
    return (graph, feat, label, num_class, train_idx, val_idx, test_idx)

if __name__ == '__main__':

    root = 'data'
    for name in ['yelp']:
        load_graphsaint_dataset(name, root)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant