diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..4452eb05 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/MAT"] + path = third_party/MAT + url = https://github.com/ardigen/MAT diff --git a/fs_mol/data/mat.py b/fs_mol/data/mat.py new file mode 100644 index 00000000..a547b3f2 --- /dev/null +++ b/fs_mol/data/mat.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import numpy as np +import torch +from dpu_utils.utils import RichPath + +from fs_mol.data import ( + FSMolBatcher, + FSMolTask, + MoleculeDatapoint, + default_reader_fn, +) + +# Assumes that MAT is in the python lib path: +from featurization.data_utils import construct_dataset, load_data_from_smiles, mol_collate_func + + +@dataclass(frozen=True) +class FSMolMATBatch: + node_features: torch.Tensor + adjacency_matrix: torch.Tensor + distance_matrix: torch.Tensor + + +@dataclass(frozen=True) +class MATMoleculeDatapoint(MoleculeDatapoint): + mat_features: np.ndarray + + +def mat_process_samples(samples: List[MoleculeDatapoint]) -> List[MATMoleculeDatapoint]: + # Set `one_hot_formal_charge` for compatibilitiy with pretrained weights (see README.md in MAT). + all_features, _ = load_data_from_smiles( + x_smiles=[sample.smiles for sample in samples], + labels=[sample.bool_label for sample in samples], + one_hot_formal_charge=True, + ) + + # MAT can internally decide that there is something wrong with a sample and reject it. Our + # dataset is clean, so this shouldn't happen (or at least shouldn't happen silently!). + if len(all_features) < len(samples): + raise ValueError("MAT rejected some samples; can't continue, as that may skew results.") + + # Note that `sample.__dict__` is almost like `dataclasses.asdict(sample)`, but shallow, i.e. it + # doesn't dict-ify the inner dataclass describing molecular graph. + return [ + MATMoleculeDatapoint(mat_features=features, **sample.__dict__) + for sample, features in zip(samples, all_features) + ] + + +def mat_batcher_init_fn(batch_data: Dict[str, Any]): + batch_data["mat_features"] = [] + + +def mat_batcher_add_sample_fn( + batch_data: Dict[str, Any], sample_id: int, sample: MATMoleculeDatapoint +): + batch_data["mat_features"].append(sample.mat_features) + + +def mat_batcher_finalizer_fn(batch_data: Dict[str, Any]) -> Tuple[FSMolMATBatch, np.ndarray]: + adjacency_matrix, node_features, distance_matrix, labels = mol_collate_func( + construct_dataset( + batch_data["mat_features"], [[label] for label in batch_data["bool_labels"]] + ) + ) + + batch = FSMolMATBatch( + node_features=node_features, + adjacency_matrix=adjacency_matrix, + distance_matrix=distance_matrix, + ) + + return batch, labels.squeeze(dim=-1).cpu().detach().numpy() + + +def mat_task_reader_fn(paths: List[RichPath], idx: int) -> List[FSMolTask]: + [task] = default_reader_fn(paths, idx) + return [FSMolTask(name=task.name, samples=mat_process_samples(task.samples))] + + +def get_mat_batcher(max_num_graphs: int): + return FSMolBatcher( + max_num_graphs=max_num_graphs, + init_callback=mat_batcher_init_fn, + per_datapoint_callback=mat_batcher_add_sample_fn, + finalizer_callback=mat_batcher_finalizer_fn, + ) diff --git a/fs_mol/mat_test.py b/fs_mol/mat_test.py index 43bb7557..fcf44cb8 100644 --- a/fs_mol/mat_test.py +++ b/fs_mol/mat_test.py @@ -2,47 +2,33 @@ import argparse import logging +import os import sys import warnings -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional -import numpy as np import torch -from dpu_utils.utils import RichPath from rdkit import RDLogger from pyprojroot import here as project_root -sys.path.insert(0, "./MAT/src") sys.path.insert(0, str(project_root())) +sys.path.insert(0, os.path.join(str(project_root()), "third_party", "MAT", "src")) -from fs_mol.data import ( - FSMolBatcher, - FSMolTask, - FSMolTaskSample, - MoleculeDatapoint, - default_reader_fn, -) +from fs_mol.data import FSMolTaskSample +from fs_mol.data.mat import FSMolMATBatch, get_mat_batcher, mat_task_reader_fn from fs_mol.models.interface import AbstractTorchModel from fs_mol.multitask_train import eval_model_by_finetuning_on_task from fs_mol.utils.metrics import BinaryEvalMetrics from fs_mol.utils.multitask_utils import resolve_starting_model_file from fs_mol.utils.test_utils import add_eval_cli_args, eval_model, set_up_test_run -from featurization.data_utils import construct_dataset, load_data_from_smiles, mol_collate_func +# Assumes that MAT is in the python lib path: from transformer import GraphTransformer, make_model logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class FSMolMATBatch: - node_features: torch.Tensor - adjacency_matrix: torch.Tensor - distance_matrix: torch.Tensor - - class MATModel(GraphTransformer, AbstractTorchModel[FSMolMATBatch]): def forward(self, batch: FSMolMATBatch) -> Any: mask = torch.sum(torch.abs(batch.node_features), dim=-1) != 0 @@ -111,58 +97,6 @@ def build_from_model_file( return model -@dataclass(frozen=True) -class MATMoleculeDatapoint(MoleculeDatapoint): - mat_features: np.ndarray - - -def mat_process_samples(samples: List[MoleculeDatapoint]) -> List[MATMoleculeDatapoint]: - # Set `one_hot_formal_charge` for compatibilitiy with pretrained weights (see README.md in MAT). - all_features, _ = load_data_from_smiles( - x_smiles=[sample.smiles for sample in samples], - labels=[sample.bool_label for sample in samples], - one_hot_formal_charge=True, - ) - - # MAT can internally decide that there is something wrong with a sample and reject it. Our - # dataset is clean, so this shouldn't happen (or at least shouldn't happen silently!). - if len(all_features) < len(samples): - raise ValueError("MAT rejected some samples; can't continue, as that may skew results.") - - # Note that `sample.__dict__` is almost like `dataclasses.asdict(sample)`, but shallow, i.e. it - # doesn't dict-ify the inner dataclass describing molecular graph. - return [ - MATMoleculeDatapoint(mat_features=features, **sample.__dict__) - for sample, features in zip(samples, all_features) - ] - - -def mat_batcher_init_fn(batch_data: Dict[str, Any]): - batch_data["mat_features"] = [] - - -def mat_batcher_add_sample_fn( - batch_data: Dict[str, Any], sample_id: int, sample: MATMoleculeDatapoint -): - batch_data["mat_features"].append(sample.mat_features) - - -def mat_batcher_finalizer_fn(batch_data: Dict[str, Any]) -> Tuple[FSMolMATBatch, np.ndarray]: - adjacency_matrix, node_features, distance_matrix, labels = mol_collate_func( - construct_dataset( - batch_data["mat_features"], [[label] for label in batch_data["bool_labels"]] - ) - ) - - batch = FSMolMATBatch( - node_features=node_features, - adjacency_matrix=adjacency_matrix, - distance_matrix=distance_matrix, - ) - - return batch, labels.squeeze(dim=-1).cpu().detach().numpy() - - def turn_off_warnings(): # Ignore rdkit warnings. RDLogger.DisableLog("rdApp.*") @@ -224,26 +158,15 @@ def main(): device=device, ) - def task_reader_fn(paths: List[RichPath], idx: int) -> List[FSMolTask]: - [task] = default_reader_fn(paths, idx) - return [FSMolTask(name=task.name, samples=mat_process_samples(task.samples))] - def test_model_fn( task_sample: FSMolTaskSample, temp_out_folder: str, seed: int ) -> BinaryEvalMetrics: - batcher = FSMolBatcher( - max_num_graphs=args.batch_size, - init_callback=mat_batcher_init_fn, - per_datapoint_callback=mat_batcher_add_sample_fn, - finalizer_callback=mat_batcher_finalizer_fn, - ) - return eval_model_by_finetuning_on_task( model_weights_file, model_cls=MATModel, task_sample=task_sample, temp_out_folder=temp_out_folder, - batcher=batcher, + batcher=get_mat_batcher(args.batch_size), learning_rate=args.learning_rate, task_specific_learning_rate=args.task_specific_lr, metric_to_use="avg_precision", @@ -259,7 +182,7 @@ def test_model_fn( out_dir=args.save_dir, num_samples=args.num_runs, valid_size_or_ratio=0.2, - task_reader_fn=task_reader_fn, + task_reader_fn=mat_task_reader_fn, seed=args.seed, ) diff --git a/third_party/MAT b/third_party/MAT new file mode 160000 index 00000000..1c7af06e --- /dev/null +++ b/third_party/MAT @@ -0,0 +1 @@ +Subproject commit 1c7af06e440d967064763dc765847413b5770f05