diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index 27c6c556..0f8c878b 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -115,7 +115,7 @@ class CorrelationKwargs(TypedDict, total=True): # Janus specific Architectures = Literal[ - "mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet" + "mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet", "orb" ] Devices = Literal["cpu", "cuda", "mps", "xpu"] Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"] diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index 537dc027..a1310c2c 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -225,6 +225,36 @@ def choose_calculator( kwargs.setdefault("sevennet_config", None) calculator = SevenNetCalculator(model=model_path, device=device, **kwargs) + elif arch == "orb": + from orb_models import __version__ + from orb_models.forcefield.calculator import ORBCalculator + from orb_models.forcefield.graph_regressor import GraphRegressor + import orb_models.forcefield.pretrained as orb_ff + + if isinstance(model_path, str): + match model_path: + case "orb-v1": + model = orb_ff.orb_v1() + case "orb-mptraj-only-v1": + model = orb_ff.orb_v1_mptraj_only() + case "orb-d3-v1": + model = orb_ff.orb_d3_v1() + case "orb-d3-xs-v1": + model = orb_ff.orb_d3_xs_v1() + case "orb-d3-sm-v1": + model = orb_ff.orb_d3_sm_v1() + case _: + raise ValueError( + "Please specify `model_path`, as there is no " + f"default model for {arch}" + ) + elif isinstance(model_path, GraphRegressor): + model = model_path + else: + model = orb_ff.orb_v1_mptraj_only() + + calculator = ORBCalculator(model=model, device=device, **kwargs) + else: raise ValueError( f"Unrecognized {arch=}. Suported architectures " diff --git a/pyproject.toml b/pyproject.toml index a481a7d9..2097bb14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,10 @@ m3gnet = [ "matgl == 1.1.3", "dgl == 2.1.0", ] +orb = [ + "orb-models == 0.4.1", + "pynanoflann", +] sevennet = [ "sevenn == 0.10.0", ] @@ -55,6 +59,7 @@ all = [ "janus-core[alignn]", "janus-core[chgnet]", "janus-core[m3gnet]", + "janus-core[orb]", "janus-core[sevennet]", ] @@ -164,3 +169,6 @@ default-groups = [ "docs", "pre-commit", ] + +[tool.uv.sources] +pynanoflann = { git = "https://github.com/dwastberg/pynanoflann", rev = "af434039ae14bedcbb838a7808924d6689274168" }