-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add orb support... tricky #303
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -217,6 +217,36 @@ def choose_calculator( | |||||||||||||||||||||||||||||||||||||||||||||
kwargs.setdefault("sevennet_config", None) | ||||||||||||||||||||||||||||||||||||||||||||||
calculator = SevenNetCalculator(model=model, device=device, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
elif arch == "orb": | ||||||||||||||||||||||||||||||||||||||||||||||
__version__ = "0.3" | ||||||||||||||||||||||||||||||||||||||||||||||
from orb_models.forcefield.calculator import ORBCalculator | ||||||||||||||||||||||||||||||||||||||||||||||
from orb_models.forcefield.graph_regressor import GraphRegressor | ||||||||||||||||||||||||||||||||||||||||||||||
import orb_models.forcefield.pretrained as orb_ff | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(model_path, str): | ||||||||||||||||||||||||||||||||||||||||||||||
match model_path: | ||||||||||||||||||||||||||||||||||||||||||||||
case "orb-v1": | ||||||||||||||||||||||||||||||||||||||||||||||
model = orb_ff.orb_v1() | ||||||||||||||||||||||||||||||||||||||||||||||
case "orb-mptraj-only-v1": | ||||||||||||||||||||||||||||||||||||||||||||||
model = orb_ff.orb_v1_mptraj_only() | ||||||||||||||||||||||||||||||||||||||||||||||
case "orb-d3-v1": | ||||||||||||||||||||||||||||||||||||||||||||||
model = orb_ff.orb_d3_v1() | ||||||||||||||||||||||||||||||||||||||||||||||
case "orb-d3-xs-v1": | ||||||||||||||||||||||||||||||||||||||||||||||
model = orb_ff.orb_d3_xs_v1() | ||||||||||||||||||||||||||||||||||||||||||||||
case "orb-d3-sm-v1": | ||||||||||||||||||||||||||||||||||||||||||||||
model = orb_ff.orb_d3_sm_v1() | ||||||||||||||||||||||||||||||||||||||||||||||
case _: | ||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
"Please specify `model_path`, as there is no " | ||||||||||||||||||||||||||||||||||||||||||||||
f"default model for {arch}" | ||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+227
to
+242
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be split into multiple lines, but this is the gist
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice but orb-mptraj does not match the pattern. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we choose what "model_path" is? You could also easily special case that. Here, it's not clear that that is different at first glance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no really I user the canonical names from their website... which i suspect is what people will expect to call them |
||||||||||||||||||||||||||||||||||||||||||||||
elif isinstance(model_path, GraphRegressor): | ||||||||||||||||||||||||||||||||||||||||||||||
model = model_path | ||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||
model = orb_ff.orb_v1_mptraj_only() | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
calculator = ORBCalculator(model=model, device=device, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
f"Unrecognized {arch=}. Suported architectures " | ||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,16 +46,18 @@ chgnet = {version = "0.3.8", optional = true} | |
dgl = { version = "2.1.0", optional = true } # Pin due to matgl installation issues | ||
matgl = { version = "1.1.3", optional = true} | ||
sevenn = { version = "0.9.3", optional = true } | ||
orb_models = { version ="0.3.0", optional = true } | ||
torchdata = {version = "0.7.1", optional = true} # Pin due to dgl issue | ||
torch_geometric = { version = "^2.5.3", optional = true } | ||
ruff = "^0.5.7" | ||
|
||
[tool.poetry.extras] | ||
all = ["alignn", "chgnet", "matgl", "dgl", "torchdata", "sevenn", "torch_geometric"] | ||
all = ["alignn", "chgnet", "matgl", "dgl", "torchdata", "sevenn", "torch_geometric", "orb"] | ||
alignn = ["alignn"] | ||
chgnet = ["chgnet"] | ||
m3gnet = ["matgl", "dgl", "torchdata"] | ||
sevennet = ["sevenn", "torch_geometric"] | ||
orb = ["orb_models"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we not need to add Pynanoflann too? |
||
|
||
[tool.poetry.group.dev.dependencies] | ||
coverage = {extras = ["toml"], version = "^7.4.1"} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can use match case with 3.9?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it does not but it breaks my heart to write if elif elif elif like we are in fotran 77 world... is not pyhton3.9 eol this month?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we shall catch the guilty one and then deal with it separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think actually even 3.8 isn't eol until next month, so a little way to go until 3.9.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my good intentions destroyed...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even then, I would probably recommend building a dict, then, with the keys and pull the function and call it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in fairness they have a dict already we can use that.