From 9f45e12b7c07dd6ad94503804e88fbbb1a9c6b7e Mon Sep 17 00:00:00 2001 From: elilaird Date: Mon, 14 Oct 2024 13:24:51 -0500 Subject: [PATCH] updated unit conversions for y --- torch_geometric/datasets/qm40.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/torch_geometric/datasets/qm40.py b/torch_geometric/datasets/qm40.py index 1d3dbc07b8b9..2d6a7340b716 100644 --- a/torch_geometric/datasets/qm40.py +++ b/torch_geometric/datasets/qm40.py @@ -22,6 +22,27 @@ HAR2EV = 27.211386246 KCALMOL2EV = 0.04336414 +conversion = torch.tensor( + [ + HAR2EV, + HAR2EV, + HAR2EV, + HAR2EV, + 1.0, + 1.0, + 1.0, + KCALMOL2EV, + 1.0, + 1.0, + 1.0, + HAR2EV, + HAR2EV, + HAR2EV, + KCALMOL2EV, + KCALMOL2EV, + ] +) + class QM40(InMemoryDataset): raw_url = None @@ -118,7 +139,7 @@ def process(self) -> None: for mol_idx, row in tqdm(main_df.iterrows(), total=len(main_df)): ID = row['Zinc_id'] SMILES = row['smile'] - y = row.iloc[2:].values.astype(np.float32) + y = torch.tensor(row.iloc[2:].values.astype(np.float32), dtype=torch.float) mol_xyz = xyz_df_grouped.get_group(ID).reset_index(drop=True) mol_bonds = bond_df_grouped.get_group(ID).reset_index(drop=True) @@ -171,12 +192,12 @@ def process(self) -> None: # Create node features x1 = one_hot(torch.tensor(type_idx), num_classes=len(types)) - x2 = torch.tensor(np.array([atomic_numbers, aromatic, sp, sp2, sp3, num_hs]), dtype=torch.float).t().contiguous() + x2 = torch.tensor(np.array([aromatic, sp, sp2, sp3, num_hs]), dtype=torch.float).t().contiguous() x = torch.cat([x1, x2], dim=-1) data = Data( x=x, z=z, pos=pos, edge_index=edge_index, smiles=SMILES, - edge_attr=edge_attr, edge_attr2=edge_attr2, y=y, name=ID, idx=mol_idx, + edge_attr=edge_attr, edge_attr2=edge_attr2, y=y * conversion.view(1, -1), name=ID, idx=mol_idx, ) if self.pre_filter is not None and not self.pre_filter(data):