-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #32 from timbernat/openff-cleanup
OpenFF overhaul
- Loading branch information
Showing
27 changed files
with
678 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,3 +107,6 @@ ENV/ | |
|
||
# In-tree generated files | ||
*/_version.py | ||
|
||
# Espaloma junk output | ||
**/.model.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,10 @@ | |
__author__ = 'Timotej Bernat' | ||
__email__ = '[email protected]' | ||
|
||
from typing import Callable, Iterable, Optional, Type, Union | ||
from typing import Callable, Concatenate, Iterable, Iterator, Optional, ParamSpec, TypeVar, Union | ||
|
||
T = TypeVar('T') | ||
Params = ParamSpec('Params') | ||
|
||
from inspect import signature, Parameter | ||
from functools import wraps, partial | ||
|
@@ -13,20 +16,18 @@ | |
|
||
from .meta import extend_to_methods | ||
from . import signatures | ||
from ..typetools.parametric import T, Args, KWArgs | ||
from ..typetools.categorical import ListLike | ||
from ..fileutils.pathutils import aspath, asstrpath | ||
|
||
|
||
@extend_to_methods | ||
def optional_in_place(funct : Callable[[object, Args, KWArgs], None]) -> Callable[[object, Args, bool, KWArgs], Optional[object]]: | ||
def optional_in_place(funct : Callable[[Concatenate[object, Params]], None]) -> Callable[[Concatenate[object, Params]], Optional[object]]: | ||
'''Decorator function for allowing in-place (writeable) functions which modify object attributes | ||
to be not performed in-place (i.e. read-only), specified by a boolean flag''' | ||
# TODO : add assertion that the wrapped function has at least one arg AND that the first arg is of the desired (limited) type | ||
old_sig = signature(funct) | ||
|
||
@wraps(funct) # for preserving docstring and type annotations / signatures | ||
def in_place_wrapper(obj : object, *args : Args, in_place : bool=False, **kwargs : KWArgs) -> Optional[object]: # read-only by default | ||
def in_place_wrapper(obj : object, *args : Params.args, in_place : bool=False, **kwargs : Params.kwargs) -> Optional[object]: # read-only by default | ||
'''If not in-place, create a clone on which the method is executed''' # NOTE : old_sig.bind screws up arg passing | ||
if in_place: | ||
funct(obj, *args, **kwargs) # default call to writeable method - implicitly returns None | ||
|
@@ -54,9 +55,9 @@ def in_place_wrapper(obj : object, *args : Args, in_place : bool=False, **kwargs | |
return in_place_wrapper | ||
|
||
# TODO : implement support for extend_to_methods (current mechanism is broken by additional deocrator parameters) | ||
def flexible_listlike_input(funct : Callable[[ListLike], T]=None, CastType : Type[ListLike]=list, valid_member_types : Union[Type, tuple[Type]]=object) -> Callable[[Iterable], T]: | ||
def flexible_listlike_input(funct : Callable[[Iterator], T]=None, CastType : type[Iterator]=list, valid_member_types : Union[type, tuple[type]]=object) -> Callable[[Iterable], T]: | ||
'''Wrapper which allows a function which expects a single list-initializable, Container-like object to accept any Iterable (or even star-unpacked arguments)''' | ||
if not issubclass(CastType, ListLike): | ||
if not issubclass(CastType, Iterator): | ||
raise TypeError(f'Cannot wrap listlike input with non-listlike type "{CastType.__name__}"') | ||
|
||
@wraps(funct) | ||
|
@@ -79,13 +80,13 @@ def wrapper(*args) -> T: # wrapper which accepts an arbitrary number of non-keyw | |
return wrapper | ||
|
||
@extend_to_methods | ||
def allow_string_paths(funct : Callable[[Path, Args, KWArgs], T]) -> Callable[[Union[Path, str], Args, KWArgs], T]: | ||
def allow_string_paths(funct : Callable[[Concatenate[Path, Params]], T]) -> Callable[[Concatenate[Union[Path, str], Params]], T]: | ||
'''Modifies a function which expects a Path as its first argument to also accept string-paths''' | ||
# TODO : add assertion that the wrapped function has at least one arg AND that the first arg is of the desired (limited) type | ||
old_sig = signature(funct) # lookup old type signature | ||
|
||
@wraps(funct) # for preserving docstring and type annotations / signatures | ||
def str_path_wrapper(flex_path : Union[str, Path], *args : Args, **kwargs : KWArgs) -> T: | ||
def str_path_wrapper(flex_path : Union[str, Path], *args : Params.args, **kwargs : Params.kwargs) -> T: | ||
'''First converts stringy paths into normal Paths, then executes the original function''' | ||
return funct(aspath(flex_path), *args, **kwargs) | ||
|
||
|
@@ -99,13 +100,13 @@ def str_path_wrapper(flex_path : Union[str, Path], *args : Args, **kwargs : KWAr | |
return str_path_wrapper | ||
|
||
@extend_to_methods | ||
def allow_pathlib_paths(funct : Callable[[str, Args, KWArgs], T]) -> Callable[[Union[Path, str], Args, KWArgs], T]: | ||
def allow_pathlib_paths(funct : Callable[[Concatenate[str, Params]], T]) -> Callable[[Concatenate[Union[Path, str], Params]], T]: | ||
'''Modifies a function which expects a string path as its first argument to also accept canonical pathlib Paths''' | ||
# TODO : add assertion that the wrapped function has at least one arg AND that the first arg is of the desired (limited) type | ||
old_sig = signature(funct) # lookup old type signature | ||
|
||
@wraps(funct) # for preserving docstring and type annotations / signatures | ||
def str_path_wrapper(flex_path : Union[str, Path], *args : Args, **kwargs : KWArgs) -> T: | ||
def str_path_wrapper(flex_path : Union[str, Path], *args : Params.args, **kwargs : Params.kwargs) -> T: | ||
'''First converts normal Paths into stringy paths, then executes the original function''' | ||
return funct(asstrpath(flex_path), *args, **kwargs) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
'''Utilities for checking and enforcing module dependencies within code''' | ||
|
||
__author__ = 'Timotej Bernat' | ||
__email__ = '[email protected]' | ||
|
||
from typing import Callable, ParamSpec, TypeVar | ||
|
||
Params = ParamSpec('Params') | ||
ReturnType = TypeVar('ReturnType') | ||
TCall = Callable[Params, ReturnType] # generic function of callable class | ||
|
||
# from importlib import import_module | ||
from importlib.util import find_spec | ||
from functools import wraps | ||
|
||
|
||
def module_installed(module_name : str) -> bool: | ||
''' | ||
Check whether a module of the given name is present on the system | ||
Parameters | ||
---------- | ||
module_name : str | ||
The name of the module, as it would occur in an import statement | ||
Do not support direct passing of module objects to avoid circularity | ||
(i.e. no reason to check if a module is present if one has already imported it elsewhere) | ||
Returns | ||
------- | ||
module_found : bool | ||
Whether or not the module was found to be installed in the current working environment | ||
''' | ||
# try: | ||
# package = import_module(module_name) | ||
# except ModuleNotFoundError: | ||
# return False | ||
# else: | ||
# return True | ||
|
||
try: # NOTE: opted for this implementation, as it never actually imports the package in question (faster and fewer side-effects) | ||
return find_spec(module_name) is not None | ||
except (ValueError, AttributeError, ModuleNotFoundError): # these could all be raised by | ||
return False | ||
|
||
def modules_installed(*module_names : list[str]) -> bool: | ||
''' | ||
Check whether one or more modules are all present | ||
Will only return true if ALL specified modules are found | ||
Parameters | ||
---------- | ||
module_names : *str | ||
Any number of module names, passed as a comma-separated sequence of strings | ||
Returns | ||
------- | ||
all_modules_found : bool | ||
Whether or not all modules were found to be installed in the current working environment | ||
''' | ||
return all(module_installed(module_name) for module_name in module_names) | ||
|
||
def requires_modules( | ||
*required_module_names : list[str], | ||
missing_module_error : type[Exception]=ImportError, | ||
) -> Callable[[TCall[..., ReturnType]], TCall[..., ReturnType]]: | ||
''' | ||
Decorator which enforces optional module dependencies prior to function execution | ||
Parameters | ||
---------- | ||
module_names : *str | ||
Any number of module names, passed as a comma-separated sequence of strings | ||
missing_module_error : type[Exception], default ImportError | ||
The type of Exception to raise if a module is not found installed | ||
Defaults to ImportError | ||
Raises | ||
------ | ||
ImportError : Exception | ||
Raised if any of the specified packages is not found to be installed | ||
Exception message will indicate the name of the specific package found missing | ||
''' | ||
def decorator(func) -> TCall[..., ReturnType]: | ||
@wraps(func) | ||
def req_wrapper(*args : Params.args, **kwargs : Params.kwargs) -> ReturnType: | ||
for module_name in required_module_names: | ||
if not module_installed(module_name): | ||
raise missing_module_error(f'No installation found for module "{module_name}"') | ||
else: | ||
return func(*args, **kwargs) | ||
|
||
return req_wrapper | ||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,34 @@ | ||
'''Tools for manipulating and extending OpenFF objects, and for interfacing with other tools and formats''' | ||
'''Extensions, interfaces, and convenience methods built around the functionality in the OpenFF software stack''' | ||
|
||
__author__ = 'Timotej Bernat' | ||
__email__ = '[email protected]' | ||
|
||
from typing import Any | ||
from pathlib import Path | ||
|
||
import openforcefields | ||
from openff.toolkit import ToolkitRegistry | ||
from openff.toolkit import GLOBAL_TOOLKIT_REGISTRY as GTR | ||
from openff.toolkit.utils.base_wrapper import ToolkitWrapper | ||
from openff.toolkit.utils.utils import all_subclasses | ||
from openff.toolkit.utils.exceptions import LicenseError, ToolkitUnavailableException | ||
from openff.toolkit.typing.engines.smirnoff.forcefield import _get_installed_offxml_dir_paths | ||
|
||
from openff.toolkit.utils.openeye_wrapper import OpenEyeToolkitWrapper | ||
from espaloma_charge.openff_wrapper import EspalomaChargeToolkitWrapper | ||
from openff.nagl.toolkits import NAGLRDKitToolkitWrapper, NAGLOpenEyeToolkitWrapper | ||
|
||
|
||
# FORCE FIELD AND ToolkitWrapper REFERENCE | ||
FFDIR = Path(openforcefields.get_forcefield_dirs_paths()[0]) # Locate path where OpenFF forcefields are installed | ||
FF_DIR_REGISTRY : dict[Path, Path] = {} | ||
FF_PATH_REGISTRY : dict[Path, Path] = {} | ||
for ffdir_str in _get_installed_offxml_dir_paths(): | ||
ffdir = Path(ffdir_str) | ||
ffdir_name = ffdir.parent.stem | ||
|
||
FF_DIR_REGISTRY[ ffdir_name] = ffdir | ||
FF_PATH_REGISTRY[ffdir_name] = [path for path in ffdir.glob('*.offxml')] | ||
|
||
# CHECKING FOR OpenEye | ||
ALL_IMPORTABLE_TKWRAPPERS = all_subclasses(ToolkitWrapper) # References to every registered ToolkitWrapper and ToolkitRegistry | ||
try: | ||
_ = OpenEyeToolkitWrapper() | ||
_OE_TKWRAPPER_IS_AVAILABLE = True | ||
OEUnavailableException = None | ||
except (LicenseError, ToolkitUnavailableException) as error: | ||
_OE_TKWRAPPER_IS_AVAILABLE = False | ||
OEUnavailableException = error # catch and record relevant error message for use (rather than trying to replicate it elsewhere) | ||
|
||
# Register OpenFF-compatible GNN ToolkitWrappers | ||
GTR.register_toolkit(EspalomaChargeToolkitWrapper) | ||
GTR.register_toolkit(NAGLRDKitToolkitWrapper) | ||
if _OE_TKWRAPPER_IS_AVAILABLE: | ||
GTR.register_toolkit(NAGLOpenEyeToolkitWrapper) | ||
|
||
|
||
# GENERATE LOOKUP DICTS FOR EVERY REGISTERED ToolkitWrappers and ToolkitRegistry | ||
REGISTERED_TKWRAPPER_TYPES = [type(tkwrapper) for tkwrapper in GTR.registered_toolkits] | ||
TKWRAPPERS = { # NOTE : this must be done AFTER any new registrations to thr GTR (e.g. after registering GNN ToolkitWrappers) | ||
tk_wrap.toolkit_name : tk_wrap | ||
for tk_wrap in GTR.registered_toolkits | ||
} | ||
TKREGS = {} # individually register toolkit wrappers for cases where a registry must be passed | ||
for tk_name, tk_wrap in TKWRAPPERS.items(): | ||
tk_reg = ToolkitRegistry() | ||
tk_reg.register_toolkit(tk_wrap) | ||
TKREGS[tk_name] = tk_reg | ||
# Subpackage-wide precheck to see if OpenFF is even usable in the first place | ||
from ...genutils.importutils.dependencies import modules_installed | ||
if not modules_installed('openff', 'openff.toolkit'): | ||
raise ModuleNotFoundError( | ||
f''' | ||
OpenFF packages which are required to utilitize {__name__} not found in current environment | ||
Please follow installation instructions at https://docs.openforcefield.org/projects/toolkit/en/stable/installation.html, then retry import | ||
''' | ||
) | ||
|
||
# Import of toplevel OpenFF object registries | ||
from ._forcefields import ( | ||
FFDIR, | ||
FF_DIR_REGISTRY, | ||
FF_PATH_REGISTRY, | ||
) | ||
from ._toolkits import ( | ||
## toolkit registries | ||
GLOBAL_TOOLKIT_REGISTRY, GTR, | ||
POLYMERIST_TOOLKIT_REGISTRY, | ||
## catalogues of available toolkit wrappers | ||
ALL_IMPORTABLE_TKWRAPPERS, | ||
ALL_AVAILABLE_TKWRAPPERS, | ||
TKWRAPPERS, | ||
TKWRAPPER_TYPES, | ||
## registry of partial charge methods by | ||
CHARGE_METHODS_BY_TOOLKIT, | ||
TOOLKITS_BY_CHARGE_METHOD, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
'''For dynamically determining and cataloging which SMIRNOFF-copatible force fields are installed (and accompanying functionality) are available''' | ||
|
||
__author__ = 'Timotej Bernat' | ||
__email__ = '[email protected]' | ||
|
||
from typing import Optional | ||
from pathlib import Path | ||
|
||
from ...genutils.importutils.dependencies import modules_installed | ||
|
||
|
||
# Force field and ToolkitWrapper reference | ||
FFDIR : Optional[Path] = None | ||
if modules_installed('openff.toolkit'): | ||
from openforcefields import get_forcefield_dirs_paths | ||
|
||
FFDIR = Path(get_forcefield_dirs_paths()[0]) # Locate path where OpenFF forcefields are installed | ||
|
||
FF_DIR_REGISTRY : dict[Path, Path] = {} | ||
FF_PATH_REGISTRY : dict[Path, Path] = {} | ||
if modules_installed('openforcefields'): | ||
from openff.toolkit.typing.engines.smirnoff.forcefield import _get_installed_offxml_dir_paths | ||
|
||
for ffdir_str in _get_installed_offxml_dir_paths(): | ||
ffdir = Path(ffdir_str) | ||
ffdir_name = ffdir.parent.stem | ||
|
||
FF_DIR_REGISTRY[ ffdir_name] = ffdir | ||
FF_PATH_REGISTRY[ffdir_name] = [path for path in ffdir.glob('*.offxml')] |
Oops, something went wrong.