-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
42 lines (30 loc) · 1.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import argparse
from text_processing import preprocessing
from DepGAT import Dependency_GAT
def tk2onehot(_tk_list):
tk_dim = len(_tk_list)
tk2onehot = []
for idx,_ in enumerate(_tk_list):
temp = torch.zeros(tk_dim)
temp[idx] = 1
tk2onehot.append(temp)
return tk2onehot
def main(args):
sample_text = "My dog likes eating sausage"
input_tk_list, input_dep_list = preprocessing(sample_text, args.nlp_pipeline)
# Simple One-hot encoding is applied. This can be replaced based on the choice of embedding language model.
input_rep = tk2onehot(input_tk_list)
in_dim = len(input_tk_list)
out_dim = len(input_rep[0])
# input_dim: word embedding dimension
model = Dependency_GAT(in_dim=in_dim, out_dim=out_dim, alpha=args.alpha, num_layers=args.num_layers)
output = model(input_rep, input_dep_list)
print(output)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--nlp_pipeline", default="stanza", type=str, help="NLP preprocessing pipeline.")
parser.add_argument("--num_layers", default=1, type=int, help="The number of hidden layers of GCN.")
parser.add_argument("--alpha", default=0.01, type=float, help="Negative slope that controls the angle of the negative slope of LeakyReLU")
args = parser.parse_args()
main(args)