-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
134 lines (108 loc) · 5.03 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
'''
adapted from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/io/tu.html#read_tu_data
'''
import os
import os.path as osp
import glob
import torch
import torch.nn.functional as F
import numpy as np
from torch_sparse import coalesce
from torch_geometric.io import read_txt_array
from torch_geometric.utils import remove_self_loops
from torch_geometric.data import Data
NAMES = ['A', 'graph_indicator', 'edge_attributes', 'graph_labels', 'graph_attributes','node_labels']
# ,
# 'edge_labels', ,'node_attributes',
# ]
# edge labels, edge attr, node_attributes are OPTIONAL
def parse_node_txt_array(src, sep=None, start=0, end=None, dtype=None, device=None):
'''
For Node labels,provided we have a value like "[ 80 80 111]",
remove the brackets and split on the space
'''
with open(src, 'r') as f:
src = f.read().split('\n')[:-1]
src = [[float(x) for x in line[1:-1].split(sep)[start:end] if x.isdigit()] for line in src]
src = torch.tensor(src, dtype=dtype).squeeze()
return src
def read_cell_data(folder, prefix, names=NAMES):
files = glob.glob(osp.join(folder, '{}_*.txt'.format(prefix)))
print('Loading Edge Index Data...')
edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1
batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1
node_attributes = node_labels = None
if 'node_attributes' in names:
print('Loading Node Attribute Data...')
node_attributes = read_file(folder, prefix, 'node_attributes')
if 'node_labels' in names:
print('Loading Node Labels Data...')
node_labels = read_node_label_file(folder, prefix, 'node_labels', torch.long)
if node_labels.dim() == 1:
node_labels = node_labels.unsqueeze(-1)
node_labels = node_labels - node_labels.min(dim=0)[0]
node_labels = node_labels.unbind(dim=-1)
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
x = cat([node_attributes, node_labels])
edge_attributes, edge_labels = None, None
if 'edge_attributes' in names:
print('Loading Edge Attribute Data...')
edge_attributes = read_file(folder, prefix, 'edge_attributes')
if 'edge_labels' in names:
print('Loading Edge Label Data...')
edge_labels = read_file(folder, prefix, 'edge_labels', torch.long)
if edge_labels.dim() == 1:
edge_labels = edge_labels.unsqueeze(-1)
edge_labels = edge_labels - edge_labels.min(dim=0)[0]
edge_labels = edge_labels.unbind(dim=-1)
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)
edge_attr = cat([edge_attributes, edge_labels])
y = None
if 'graph_labels' in names: # Classification problem.
print('Loading Graph Label Data...')
y = read_file(folder, prefix, 'graph_labels', torch.long)
_, y = y.unique(sorted=True, return_inverse=True)
elif 'graph_attributes' in names: # Regression problem.
print('Loading Graph Attribute Data...')
y = read_file(folder, prefix, 'graph_attributes')
num_nodes = edge_index.max().item() + 1 if x is None else x.size(0)
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
num_nodes)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data, slices = split(data, batch)
return data, slices
def read_file(folder, prefix, name, dtype=None):
path = osp.join(folder, '{}_{}.txt'.format(prefix, name))
return read_txt_array(path, sep=',', dtype=dtype)
def read_node_label_file(folder,prefix,name,dtype=None):
path = osp.join(folder, '{}_{}.txt'.format(prefix, name))
final = parse_node_txt_array(path, dtype=dtype)
# print(final)
return final
def cat(seq):
seq = [item for item in seq if item is not None]
seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
return torch.cat(seq, dim=-1) if len(seq) > 0 else None
def split(data, batch):
node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0)
node_slice = torch.cat([torch.tensor([0]), node_slice])
row, _ = data.edge_index
edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0)
edge_slice = torch.cat([torch.tensor([0]), edge_slice])
# Edge indices should start at zero for every graph.
data.edge_index -= node_slice[batch[row]].unsqueeze(0)
data.__num_nodes__ = torch.bincount(batch).tolist()
slices = {'edge_index': edge_slice}
if data.x is not None:
slices['x'] = node_slice
if data.edge_attr is not None:
slices['edge_attr'] = edge_slice
if data.y is not None:
if data.y.size(0) == batch.size(0):
slices['y'] = node_slice
else:
slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long)
return data, slices