diff --git a/openfe/setup/ligand_network_planning.py b/openfe/setup/ligand_network_planning.py index 8835c9735..8cd120314 100644 --- a/openfe/setup/ligand_network_planning.py +++ b/openfe/setup/ligand_network_planning.py @@ -313,7 +313,7 @@ def generate_minimal_redundant_network( def generate_network_from_names( ligands: list[SmallMoleculeComponent], - mappers: Union[AtomMapper, Iterable[AtomMapper]], + mapper: AtomMapper, names: list[tuple[str, str]], ) -> LigandNetwork: """ @@ -344,7 +344,7 @@ def generate_network_from_names( """ nodes = list(ligands) - network_planner = ExplicitNetworkGenerator(mappers=mappers, scorer=None) + network_planner = ExplicitNetworkGenerator(mappers=mapper, scorer=None) network = network_planner.generate_network_from_names( ligands=nodes, names=names @@ -355,7 +355,7 @@ def generate_network_from_names( def generate_network_from_indices( ligands: list[SmallMoleculeComponent], - mappers: Union[AtomMapper, Iterable[AtomMapper]], + mapper: AtomMapper, indices: list[tuple[int, int]], ) -> LigandNetwork: """ @@ -383,7 +383,7 @@ def generate_network_from_indices( """ nodes = list(ligands) - network_planner = ExplicitNetworkGenerator(mappers=mappers, scorer=None) + network_planner = ExplicitNetworkGenerator(mappers=mapper, scorer=None) network = network_planner.generate_network_from_indices( ligands=nodes, indices=indices ) @@ -392,7 +392,7 @@ def generate_network_from_indices( def load_orion_network( ligands: list[SmallMoleculeComponent], - mappers: Union[AtomMapper, Iterable[AtomMapper]], + mapper: AtomMapper, network_file: Union[str, Path], ) -> LigandNetwork: """Load a :class:`.LigandNetwork` from an Orion NES network file. @@ -431,7 +431,7 @@ def load_orion_network( names.append((entry[0], entry[2])) - network_planner = ExplicitNetworkGenerator(mappers=mappers, scorer=None) + network_planner = ExplicitNetworkGenerator(mappers=mapper, scorer=None) network = network_planner.generate_network_from_names( ligands=ligands, names=names ) @@ -441,7 +441,7 @@ def load_orion_network( def load_fepplus_network( ligands: list[SmallMoleculeComponent], - mappers: Union[AtomMapper, Iterable[AtomMapper]], + mapper: AtomMapper, network_file: Union[str, Path], ) -> LigandNetwork: """Load a :class:`.LigandNetwork` from an FEP+ edges network file. @@ -480,7 +480,7 @@ def load_fepplus_network( names.append((entry[2], entry[4])) - network_planner = ExplicitNetworkGenerator(mappers=mappers, scorer=None) + network_planner = ExplicitNetworkGenerator(mappers=mapper, scorer=None) network = network_planner.generate_network_from_names( ligands=ligands, names=names ) diff --git a/openfe/tests/setup/test_network_planning.py b/openfe/tests/setup/test_network_planning.py index feb456359..9992fc0c5 100644 --- a/openfe/tests/setup/test_network_planning.py +++ b/openfe/tests/setup/test_network_planning.py @@ -440,7 +440,7 @@ def test_network_from_names(atom_mapping_basic_test_files, lomap_old_mapper): network = openfe.setup.ligand_network_planning.generate_network_from_names( ligands=ligs, names=requested, - mappers=lomap_old_mapper, + mapper=lomap_old_mapper, ) assert len(network.nodes) == len(ligs) @@ -464,7 +464,7 @@ def test_network_from_names_bad_name( _ = openfe.setup.ligand_network_planning.generate_network_from_names( ligands=ligs, names=requested, - mappers=lomap_old_mapper, + mapper=lomap_old_mapper, ) @@ -483,7 +483,7 @@ def test_network_from_names_duplicate_name( _ = openfe.setup.ligand_network_planning.generate_network_from_names( ligands=ligs, names=requested, - mappers=lomap_old_mapper, + mapper=lomap_old_mapper, ) @@ -497,7 +497,7 @@ def test_network_from_indices( network = openfe.setup.ligand_network_planning.generate_network_from_indices( ligands=ligs, indices=requested, - mappers=lomap_old_mapper, + mapper=lomap_old_mapper, ) assert len(network.nodes) == len(ligs) @@ -522,7 +522,7 @@ def test_network_from_indices_indexerror( network = openfe.setup.ligand_network_planning.generate_network_from_indices( ligands=ligs, indices=requested, - mappers=lomap_old_mapper, + mapper=lomap_old_mapper, ) @@ -536,7 +536,7 @@ def test_network_from_indices_disconnected_warning( _ = openfe.setup.ligand_network_planning.generate_network_from_indices( ligands=ligs, indices=requested, - mappers=lomap_old_mapper, + mapper=lomap_old_mapper, ) @@ -553,7 +553,7 @@ def test_network_from_external(file_fixture, loader, request, network = loader( ligands=[l for l in benzene_modifications.values()], - mappers=openfe.LomapAtomMapper(), + mapper=openfe.LomapAtomMapper(), network_file=network_file, ) @@ -589,7 +589,7 @@ def test_network_from_external_unknown_edge(file_fixture, loader, request, with pytest.raises(KeyError, match="Invalid name"): network = loader( ligands=ligs, - mappers=openfe.LomapAtomMapper(), + mapper=openfe.LomapAtomMapper(), network_file=network_file, ) @@ -614,7 +614,7 @@ def test_bad_orion_network(benzene_modifications, tmpdir): with pytest.raises(KeyError, match="line does not match"): network = openfe.setup.ligand_network_planning.load_orion_network( ligands=[l for l in benzene_modifications.values()], - mappers=openfe.LomapAtomMapper(), + mapper=openfe.LomapAtomMapper(), network_file='bad_orion_net.dat', ) @@ -637,6 +637,6 @@ def test_bad_edges_network(benzene_modifications, tmpdir): with pytest.raises(KeyError, match="line does not match"): network = openfe.setup.ligand_network_planning.load_fepplus_network( ligands=[l for l in benzene_modifications.values()], - mappers=openfe.LomapAtomMapper(), + mapper=openfe.LomapAtomMapper(), network_file='bad_edges.edges', )