-
Notifications
You must be signed in to change notification settings - Fork 3
/
semanticSimFunctions.py
157 lines (113 loc) · 6.51 KB
/
semanticSimFunctions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# https://huggingface.co/michiyasunaga/BioLinkBERT-large
# https://medium.com/@adriensieg/text-similarities-da019229c894
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def getSentenceEmbedding(sentence, tokenizer, model):
# Tokenize sentences
encoded_input = tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling. In this case, mean pooling.
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
return sentence_embedding
def getSentenceSimilarity(sentence1, sentence2, tokenizer, model, simMetric):
sentence1_embedding = getSentenceEmbedding(sentence1, tokenizer, model)
sentence2_embedding = getSentenceEmbedding(sentence2, tokenizer, model)
if simMetric == "cosine_similarity":
sentenceSim = cosine_similarity(sentence1_embedding, sentence2_embedding)[0][0]
# ToDo: add other simMetrics
#elif simMetric == "cosine_similarity_primitive": # use primitive operations
# sentenceSim = np.dot(sentence1_embedding, sentence2_embedding)/(norm(sentence1_embedding)*norm(sentence2_embedding))
return sentenceSim, sentence1_embedding, sentence2_embedding
def getNameSimilarities_noExpertName(names_DF, LLM_name_col, GO_name_col, tokenizer, model, simMetric, epsilon= 0.05):
"""
names_DF: data frame with columns containing the names from various sources (each row is a different gene set)
*_name_col: strings of column names """
## Initialize columns
names_DF['LLM_name_GO_term_sim'] = None;
nSystems = names_DF.shape[0]
for systemInd in range(nSystems):
print(systemInd)
systemRow = names_DF.iloc[systemInd]
LLM_name = systemRow[LLM_name_col]
GO_term = systemRow[GO_name_col]
LLM_name_GO_term_sim, LLM_name_embedding, GO_term_embedding = getSentenceSimilarity(LLM_name, GO_term,
tokenizer, model,
simMetric)
names_DF.loc[systemInd, 'LLM_name_GO_term_sim'] = LLM_name_GO_term_sim
return names_DF
def getNameSimilarities_no_repeat(names_DF, LLM_name_col, GO_name_col, tokenizer, model, llm_name_embedding_dict = {},
go_term_embedding_dict = {}, simMetric = 'cosine_similarity', epsilon= 0.05):
"""
names_DF: data frame with columns containing the names from various sources (each row is a different gene set)
*_name_col: strings of column names """
## Initialize columns
names_DF['LLM_name_GO_term_sim'] = None
# reset df index
names_DF = names_DF.reset_index(drop = True)
nSystems = names_DF.shape[0]
for systemInd in range(nSystems):
print(systemInd)
systemRow = names_DF.iloc[systemInd]
# get the llm and go names from their respective columns
LLM_name = systemRow[LLM_name_col]
GO_term = systemRow[GO_name_col]
# get sentence embeddings from dict if they exist, otherwise compute and add to dict
if LLM_name in llm_name_embedding_dict:
LLM_name_embedding = llm_name_embedding_dict[LLM_name]
else:
LLM_name_embedding = getSentenceEmbedding(LLM_name, tokenizer, model)
llm_name_embedding_dict[LLM_name] = LLM_name_embedding
# same with GO term name
if GO_term in go_term_embedding_dict:
GO_term_embedding = go_term_embedding_dict[GO_term]
else:
GO_term_embedding = getSentenceEmbedding(GO_term, tokenizer, model)
go_term_embedding_dict[GO_term] = GO_term_embedding
LLM_name_GO_term_sim = cosine_similarity(LLM_name_embedding, GO_term_embedding)[0][0]
# print(LLM_name_GO_term_sim)
# write the similarity value to the dataframe
names_DF.loc[systemInd, 'LLM_name_GO_term_sim'] = LLM_name_GO_term_sim
return names_DF, llm_name_embedding_dict, go_term_embedding_dict
def getNameSimilarities(names_DF, LLM_name_col, GO_name_col, human_name_col, tokenizer, model, simMetric, epsilon= 0.05):
"""
names_DF: data frame with columns containing the names from various sources (each row is a different gene set)
*_name_col: strings of column names """
## Initialize columns
names_DF['LLM_name_human_name_sim'] = None;
names_DF['GO_term_human_name_sim'] = None;
names_DF['winner'] = None;
nSystems = names_DF.shape[0]
for systemInd in range(nSystems):
print(systemInd)
systemRow = names_DF.iloc[systemInd]
LLM_name = systemRow[LLM_name_col]
human_name = systemRow[human_name_col]
GO_term = systemRow[GO_name_col]
LLM_name_human_name_sim, LLM_name_embedding, human_name_embedding = getSentenceSimilarity(LLM_name, human_name,
tokenizer, model,
simMetric)
GO_term_human_name_sim, GO_term_embedding, human_name_embedding = getSentenceSimilarity(GO_term, human_name,
tokenizer, model,
simMetric)
names_DF.loc[systemInd, 'LLM_name_human_name_sim'] = LLM_name_human_name_sim
names_DF.loc[systemInd, 'GO_term_human_name_sim'] = GO_term_human_name_sim
if (GO_term_human_name_sim < 0.4) and (LLM_name_human_name_sim < 0.4):
names_DF.loc[systemInd, 'winner'] = "Neither"
elif abs(GO_term_human_name_sim - LLM_name_human_name_sim) <= epsilon:
names_DF.loc[systemInd, 'winner'] = "Tied"
elif LLM_name_human_name_sim > GO_term_human_name_sim:
names_DF.loc[systemInd, 'winner'] = "LLM"
elif GO_term_human_name_sim > LLM_name_human_name_sim - epsilon:
names_DF.loc[systemInd, 'winner'] = "GO"
else:
print("Impossible!")
# print((LLM_name_human_name_sim, GO_term_human_name_sim))
return names_DF