-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpca_viz.py
62 lines (52 loc) · 1.71 KB
/
pca_viz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import plotly.express as px
from tdc import Oracle
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import Draw
GSK3B_scorer = Oracle(name = 'GSK3B')
SA_scorer = Oracle(name = 'SA')
DRD2_scorer = Oracle(name = 'DRD2')
JNK3_scorer = Oracle(name = 'JNK3')
def check_SA(gen_smiles):
score = SA_scorer(Chem.MolToSmiles(gen_smiles))
return score
def check_DRD2(gen_smiles):
score = DRD2_scorer(Chem.MolToSmiles(gen_smiles))
return score
def check_JNK3(gen_smiles):
score = JNK3_scorer(Chem.MolToSmiles(gen_smiles))
return score
def check_GSK3B(gen_smiles):
score = GSK3B_scorer(Chem.MolToSmiles(gen_smiles))
return score
def cache_prop_pred():
prop_pred = {}
for prop_name, function in Descriptors.descList:
prop_pred[prop_name] = function
prop_pred['sa'] = check_SA
prop_pred['drd2'] = check_DRD2
prop_pred['jnk3'] = check_JNK3
prop_pred['gsk3b'] = check_GSK3B
return prop_pred
prop_pred = cache_prop_pred()
prop_name = np.array(list(prop_pred.keys()))
dataset = 'zinc250k'
latent = np.load(f'./saved_latent/{dataset}_z.npy')
prop = np.load(f'./saved_latent/{dataset}_props.npy')
# prop_name = np.load(f'./saved_latent/{dataset}_prop_name.npy')
print (latent.shape, prop.shape, prop_name.shape)
n_components = 2
pca = PCA(n_components = n_components)
components = pca.fit_transform(latent)
print (components.shape)
for i in range(prop.shape[0]):
fig = px.scatter(
components,
color=prop[i],
x=0, y=1,
title=f'{dataset} {prop_name[i]}',
)
fig.write_image(f'./{dataset}_pca_prop/{dataset}_{prop_name[i]}_pca.png')