Skip to content

Commit

Permalink
improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
akensert committed Apr 15, 2024
1 parent 155bed1 commit cd50b19
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ def __init__(
self.node_encoder = MolecularNodeEncoder(atom_featurizers)
self.edge_encoder = MolecularEdgeEncoder(bond_featurizers, self_loops=self_loops)

def __call__(self, molecules: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray:
molecular_graphs = []
def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray:
residue_graphs = []
residue_sizes = []
for molecule in molecules:
molecule = chem_ops.get_molecule(molecule)
molecular_graph = {
**self.node_encoder(molecule),
**self.edge_encoder(molecule)
for residue in residues:
residue = chem_ops.get_molecule(residue)
residue_graph = {
**self.node_encoder(residue),
**self.edge_encoder(residue)
}
molecular_graphs.append(molecular_graph)
residue_sizes.append(molecule.GetNumAtoms())
graph = self._merge_molecular_graphs(molecular_graphs)
graph["residue_size"] = np.array(residue_sizes)
return graph
residue_graphs.append(residue_graph)
residue_sizes.append(residue.GetNumAtoms())
disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs)
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)
return disjoint_peptide_graph

@staticmethod
def _collate_fn(
Expand All @@ -43,16 +43,18 @@ def _collate_fn(
Merges list of graphs into a single disjoint graph.
"""

x, y = list(zip(*data))
disjoint_peptide_graphs, y = list(zip(*data))

disjoint_graph = PeptideGraphEncoder._merge_molecular_graphs(x)
disjoint_graph["peptide_size"] = np.concatenate([
graph["residue_size"].shape[:1] for graph in x
disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
)
disjoint_peptide_batch_graph["peptide_size"] = np.concatenate([
g["residue_size"].shape[:1] for g in disjoint_peptide_graphs
]).astype("int32")
disjoint_graph["residue_size"] = np.concatenate([
graph["residue_size"] for graph in x
disjoint_peptide_batch_graph["residue_size"] = np.concatenate([
g["residue_size"] for g in disjoint_peptide_graphs
]).astype("int32")
return disjoint_graph, np.stack(y)
return disjoint_peptide_batch_graph, np.stack(y)

@staticmethod
def _merge_molecular_graphs(
Expand Down

0 comments on commit cd50b19

Please sign in to comment.