Skip to content

Commit

Permalink
refactor(mat_test): break out data loading & include MAT as git submo…
Browse files Browse the repository at this point in the history
…dule
  • Loading branch information
mmjb committed Aug 23, 2021
1 parent b4d7f06 commit b83beca
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 85 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/MAT"]
path = third_party/MAT
url = https://github.com/ardigen/MAT
91 changes: 91 additions & 0 deletions fs_mol/data/mat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Tuple

import numpy as np
import torch
from dpu_utils.utils import RichPath

from fs_mol.data import (
FSMolBatcher,
FSMolTask,
MoleculeDatapoint,
default_reader_fn,
)

# Assumes that MAT is in the python lib path:
from featurization.data_utils import construct_dataset, load_data_from_smiles, mol_collate_func


@dataclass(frozen=True)
class FSMolMATBatch:
node_features: torch.Tensor
adjacency_matrix: torch.Tensor
distance_matrix: torch.Tensor


@dataclass(frozen=True)
class MATMoleculeDatapoint(MoleculeDatapoint):
mat_features: np.ndarray


def mat_process_samples(samples: List[MoleculeDatapoint]) -> List[MATMoleculeDatapoint]:
# Set `one_hot_formal_charge` for compatibilitiy with pretrained weights (see README.md in MAT).
all_features, _ = load_data_from_smiles(
x_smiles=[sample.smiles for sample in samples],
labels=[sample.bool_label for sample in samples],
one_hot_formal_charge=True,
)

# MAT can internally decide that there is something wrong with a sample and reject it. Our
# dataset is clean, so this shouldn't happen (or at least shouldn't happen silently!).
if len(all_features) < len(samples):
raise ValueError("MAT rejected some samples; can't continue, as that may skew results.")

# Note that `sample.__dict__` is almost like `dataclasses.asdict(sample)`, but shallow, i.e. it
# doesn't dict-ify the inner dataclass describing molecular graph.
return [
MATMoleculeDatapoint(mat_features=features, **sample.__dict__)
for sample, features in zip(samples, all_features)
]


def mat_batcher_init_fn(batch_data: Dict[str, Any]):
batch_data["mat_features"] = []


def mat_batcher_add_sample_fn(
batch_data: Dict[str, Any], sample_id: int, sample: MATMoleculeDatapoint
):
batch_data["mat_features"].append(sample.mat_features)


def mat_batcher_finalizer_fn(batch_data: Dict[str, Any]) -> Tuple[FSMolMATBatch, np.ndarray]:
adjacency_matrix, node_features, distance_matrix, labels = mol_collate_func(
construct_dataset(
batch_data["mat_features"], [[label] for label in batch_data["bool_labels"]]
)
)

batch = FSMolMATBatch(
node_features=node_features,
adjacency_matrix=adjacency_matrix,
distance_matrix=distance_matrix,
)

return batch, labels.squeeze(dim=-1).cpu().detach().numpy()


def mat_task_reader_fn(paths: List[RichPath], idx: int) -> List[FSMolTask]:
[task] = default_reader_fn(paths, idx)
return [FSMolTask(name=task.name, samples=mat_process_samples(task.samples))]


def get_mat_batcher(max_num_graphs: int):
return FSMolBatcher(
max_num_graphs=max_num_graphs,
init_callback=mat_batcher_init_fn,
per_datapoint_callback=mat_batcher_add_sample_fn,
finalizer_callback=mat_batcher_finalizer_fn,
)
93 changes: 8 additions & 85 deletions fs_mol/mat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,33 @@

import argparse
import logging
import os
import sys
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Optional

import numpy as np
import torch
from dpu_utils.utils import RichPath
from rdkit import RDLogger
from pyprojroot import here as project_root

sys.path.insert(0, "./MAT/src")
sys.path.insert(0, str(project_root()))
sys.path.insert(0, os.path.join(str(project_root()), "third_party", "MAT", "src"))

from fs_mol.data import (
FSMolBatcher,
FSMolTask,
FSMolTaskSample,
MoleculeDatapoint,
default_reader_fn,
)
from fs_mol.data import FSMolTaskSample
from fs_mol.data.mat import FSMolMATBatch, get_mat_batcher, mat_task_reader_fn
from fs_mol.models.interface import AbstractTorchModel
from fs_mol.multitask_train import eval_model_by_finetuning_on_task
from fs_mol.utils.metrics import BinaryEvalMetrics
from fs_mol.utils.multitask_utils import resolve_starting_model_file
from fs_mol.utils.test_utils import add_eval_cli_args, eval_model, set_up_test_run

from featurization.data_utils import construct_dataset, load_data_from_smiles, mol_collate_func
# Assumes that MAT is in the python lib path:
from transformer import GraphTransformer, make_model


logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class FSMolMATBatch:
node_features: torch.Tensor
adjacency_matrix: torch.Tensor
distance_matrix: torch.Tensor


class MATModel(GraphTransformer, AbstractTorchModel[FSMolMATBatch]):
def forward(self, batch: FSMolMATBatch) -> Any:
mask = torch.sum(torch.abs(batch.node_features), dim=-1) != 0
Expand Down Expand Up @@ -111,58 +97,6 @@ def build_from_model_file(
return model


@dataclass(frozen=True)
class MATMoleculeDatapoint(MoleculeDatapoint):
mat_features: np.ndarray


def mat_process_samples(samples: List[MoleculeDatapoint]) -> List[MATMoleculeDatapoint]:
# Set `one_hot_formal_charge` for compatibilitiy with pretrained weights (see README.md in MAT).
all_features, _ = load_data_from_smiles(
x_smiles=[sample.smiles for sample in samples],
labels=[sample.bool_label for sample in samples],
one_hot_formal_charge=True,
)

# MAT can internally decide that there is something wrong with a sample and reject it. Our
# dataset is clean, so this shouldn't happen (or at least shouldn't happen silently!).
if len(all_features) < len(samples):
raise ValueError("MAT rejected some samples; can't continue, as that may skew results.")

# Note that `sample.__dict__` is almost like `dataclasses.asdict(sample)`, but shallow, i.e. it
# doesn't dict-ify the inner dataclass describing molecular graph.
return [
MATMoleculeDatapoint(mat_features=features, **sample.__dict__)
for sample, features in zip(samples, all_features)
]


def mat_batcher_init_fn(batch_data: Dict[str, Any]):
batch_data["mat_features"] = []


def mat_batcher_add_sample_fn(
batch_data: Dict[str, Any], sample_id: int, sample: MATMoleculeDatapoint
):
batch_data["mat_features"].append(sample.mat_features)


def mat_batcher_finalizer_fn(batch_data: Dict[str, Any]) -> Tuple[FSMolMATBatch, np.ndarray]:
adjacency_matrix, node_features, distance_matrix, labels = mol_collate_func(
construct_dataset(
batch_data["mat_features"], [[label] for label in batch_data["bool_labels"]]
)
)

batch = FSMolMATBatch(
node_features=node_features,
adjacency_matrix=adjacency_matrix,
distance_matrix=distance_matrix,
)

return batch, labels.squeeze(dim=-1).cpu().detach().numpy()


def turn_off_warnings():
# Ignore rdkit warnings.
RDLogger.DisableLog("rdApp.*")
Expand Down Expand Up @@ -224,26 +158,15 @@ def main():
device=device,
)

def task_reader_fn(paths: List[RichPath], idx: int) -> List[FSMolTask]:
[task] = default_reader_fn(paths, idx)
return [FSMolTask(name=task.name, samples=mat_process_samples(task.samples))]

def test_model_fn(
task_sample: FSMolTaskSample, temp_out_folder: str, seed: int
) -> BinaryEvalMetrics:
batcher = FSMolBatcher(
max_num_graphs=args.batch_size,
init_callback=mat_batcher_init_fn,
per_datapoint_callback=mat_batcher_add_sample_fn,
finalizer_callback=mat_batcher_finalizer_fn,
)

return eval_model_by_finetuning_on_task(
model_weights_file,
model_cls=MATModel,
task_sample=task_sample,
temp_out_folder=temp_out_folder,
batcher=batcher,
batcher=get_mat_batcher(args.batch_size),
learning_rate=args.learning_rate,
task_specific_learning_rate=args.task_specific_lr,
metric_to_use="avg_precision",
Expand All @@ -259,7 +182,7 @@ def test_model_fn(
out_dir=args.save_dir,
num_samples=args.num_runs,
valid_size_or_ratio=0.2,
task_reader_fn=task_reader_fn,
task_reader_fn=mat_task_reader_fn,
seed=args.seed,
)

Expand Down
1 change: 1 addition & 0 deletions third_party/MAT
Submodule MAT added at 1c7af0

0 comments on commit b83beca

Please sign in to comment.