-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhf_similarity_wic.py
166 lines (134 loc) · 6.38 KB
/
hf_similarity_wic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from transformer_infrastructure.hf_utils import parse_fasta, get_hidden_states, build_index
from transformer_infrastructure.hf_embed import get_embeddings, parse_fasta_for_embed
import pandas as pd
import time
from sentence_transformers import util
from Bio import SeqIO
import pickle
import argparse
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import logging
import faiss
#fasta = '/scratch/gpfs/cmcwhite/quantest2/QuanTest2/Test/zf-CCHH.vie'
def get_seqsim_args():
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", dest = "model_name", type = str, required = True,
help="Model directory Ex. /path/to/model_dir")
parser.add_argument("-st", "-sequence_table", dest = "sequence_path", type = str, required = False,
help="Path to table of sequences to evaluate in csv (id,sequence) no header. Output of utils.parse_fasta")
parser.add_argument("-f", "-fasta", dest = "fasta_path", type = str, required = False,
help="Path to fasta of sequences to evaluate")
parser.add_argument("-n", "--dont_add_spaces" , action = "store_true",
help="Flag if sequences already have spaces")
parser.add_argument("-o", "--outfile", dest = "outfile", type = str, required = True,
help="output csv for table of word attributions")
parser.add_argument("-k", dest = "k", type = int, required = False,
help="If present, limit to k closest sequences")
#parser.add_argument("-p", dest = "percent", type = float, required = False,
# help="If present, limit to top percent similar")
args = parser.parse_args()
return(args)
def get_sequence_similarity(layers, model_name, seqs, seqs_spaced, seq_names, outfile, logging, k):
# Use last ten layers by default
#layers = [-10, -9,-8,-7, -6, -5, -4, -3, -2, -1] if layers is None else layers
# Add more cpus if memory error. 4cpus/1000 sequences
logging.info("load tokenizer")
#tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.info("load model")
#model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
logging.info("model loaded")
#x = [l[i:i + n] for i in range(0, len(l), n)]
#logging.info(x)
#seqs_list = np.array_split(lst, 5)
#enc_hidden_states = embed_sequences(model_name, seqs, False, False)
seqlens = [len(x) for x in seqs]
padding = 5
embedding_dict = get_embeddings(seqs_spaced,
model_name,
seqlens = seqlens,
get_sequence_embeddings = True,
get_aa_embeddings = False,
layers = layers,
padding = padding)
enc_hidden_states = embedding_dict['sequence_embeddings']
#enc_hidden_states = embed_sequences(model_pname, seqs, False, False)
logging.info(enc_hidden_states.shape)
#n = 100
#seqs_batches = [seqs[i:i + n] for i in range(0, len(seqs), n)]
logging.info("start encoding")
# Encoding uses lots of memory
# Avoid by either increasing cpus or sequences in batches
# Or add to index in batches?
#for i in range(len(seqs_batches)):
# logging.info(i)
# hidden_states = get_hidden_states(seqs_batches[i], model, tokenizer, layers)
# Get cls embedding as proxy for whole sentence
# logging.info("pre hidden")
# enc_hidden_states_batch = hidden_states[:,0,:]
# if i == 0:
# enc_hidden_states = enc_hidden_states_batch
# else:
# enc_hidden_states = torch.cat([enc_hidden_states, enc_hidden_states_batch])
logging.info("post_hidden")
logging.info("Start comparison")
start = time.time()
index = build_index(enc_hidden_states)
#
k = len(seq_names)
print(k)
distance, index = index.search(enc_hidden_states, k)
end = time.time()
tottime = end - start
logging.info("compare complete in {} s".format(tottime))
#print(distance)
pairs = []
complete = []
with open(outfile, "w") as o:
for i in range(len(index)):
complete.append(seq_names[i])
row =index[i]
for j in range(len(row)):
if seq_names[row[j]] in complete:
continue
name1 = seq_names[i]
name2 = seq_names[row[j]]
D = 1 - round(distance[i,j], 2)
pairs.append([name1, name2, D])
#print(name1, name2,D)
#print(i, row[j], seq_names[i], seq_names[row[j]], distance[i,j])
o.write("{}\t{}\t{}\n".format(name1,name2,D))
return 1
if __name__ == '__main__':
# Embedding not good on short sequences without context Ex. HEIAI vs. HELAI, will select terminal I for middle I, instead of context match L
# Potentially maximize local score?
# Maximize # of matches
# How to get sequence info?
log_format = "%(asctime)s::%(levelname)s::%(name)s::"\
"%(filename)s::%(lineno)d::%(message)s"
logging.basicConfig(filename= "seqsim_logger.txt", level='DEBUG', format=log_format)
logging.info("is this running?")
args = get_seqsim_args()
logging.info("load sequences")
fasta_path = args.fasta_path
model_name = args.model_name
outfile = args.outfile
k = args.k
#if args.fasta_path:
# fasta_tbl = args.fasta_path + ".txt"
# sequence_lols = parse_fasta(args.fasta_path, fasta_tbl, args.dont_add_spaces)
#
# df = pd.DataFrame.from_records(sequence_lols, columns=['id', 'sequence', 'sequence_spaced'])
#logging.info("sequences loaded")
seq_names, seqs, seqs_spaced = parse_fasta_for_embed(fasta_path, extra_padding = True)
#seq_names = df['id'].tolist()
#seqs = df['sequence_spaced'].tolist()
max_length = 1024
seqs_spaced = [x[:2*max_length-2] for x in seqs_spaced]
layers = [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1]
#model_name = 'prot_bert_bfd'
#sqs = ['A A H K C Q T C G K A F N R S S T L N T H A R I H Y A G N P', 'Y K C K Q C G K A F A R S G G L Q K H K R T H']
#seqs = ['H E A L A I', 'H E A I A L', 'H E E L A H']
#seq_names = ['seq1','seq2', 'seq3']
get_sequence_similarity(layers, model_name, seqs, seqs_spaced, seq_names, outfile, logging, k)