Skip to content

Commit

Permalink
change default masking rate values, and update example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
akensert committed Jun 21, 2024
1 parent 7b010c3 commit 784e169
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 2 additions & 2 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def collate_fn(
@staticmethod
def masked_collate_fn(
data: list[types.MolecularGraph],
node_masking_rate: float = 0.25,
edge_masking_rate: float = 0.25,
node_masking_rate: float = 0.0,
edge_masking_rate: float = 0.0,
) -> tuple[types.MolecularGraph, np.ndarray]:
"""
Merge list of graphs into a single disjoint graph.
Expand Down
10 changes: 9 additions & 1 deletion notebooks/examples-masking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"import os\n",
"os.environ[\"KERAS_BACKEND\"] = \"torch\" # Comment out for tensorflow backend\n",
"\n",
"from functools import partial\n",
"\n",
"from molexpress import layers\n",
"from molexpress.datasets import featurizers\n",
"from molexpress.datasets import encoders\n",
Expand Down Expand Up @@ -117,8 +119,14 @@
" \n",
"torch_dataset = Dataset(x_dummy)\n",
"\n",
"# We should implement the collate functions differently (and probably not as a staticmethod \n",
"# of the peptide graph encoder); but for now, we keep it as it is, and we can do a partial \n",
"# for additional arguments:\n",
"partial_collate_fn = partial(\n",
" peptide_graph_encoder.masked_collate_fn, node_masking_rate=0.25, edge_masking_rate=0.25)\n",
"\n",
"dataset = torch.utils.data.DataLoader(\n",
" torch_dataset, batch_size=2, collate_fn=peptide_graph_encoder.masked_collate_fn)\n"
" torch_dataset, batch_size=2, collate_fn=partial_collate_fn)\n"
]
},
{
Expand Down

0 comments on commit 784e169

Please sign in to comment.