-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add integration with HF Hub #7833
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,3 +151,4 @@ | |
get_numpy_dtype_from_string, | ||
get_torch_dtype_from_string, | ||
) | ||
from .hub_mixin import MonaiHubMixin |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,37 @@ | ||||||||||||||||
from monai.utils import OptionalImportError | ||||||||||||||||
|
||||||||||||||||
__all__ = ["MonaiHubMixin"] | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class DummyPyTorchModelHubMixin: | ||||||||||||||||
|
||||||||||||||||
error_message = "To use `{}` method please required packages: `pip install huggingface_hub safetensors`." | ||||||||||||||||
|
||||||||||||||||
def __init_subclass__(cls, *args, **kwargs): | ||||||||||||||||
super().__init_subclass__() | ||||||||||||||||
|
||||||||||||||||
@classmethod | ||||||||||||||||
def from_pretrained(cls, *args, **kwargs): | ||||||||||||||||
raise OptionalImportError(cls.error_message.format("from_pretrained")) | ||||||||||||||||
|
||||||||||||||||
def save_pretrained(self, *args, **kwargs): | ||||||||||||||||
raise OptionalImportError(self.error_message.format("save_pretrained")) | ||||||||||||||||
|
||||||||||||||||
def push_to_hub(self, *args, **kwargs): | ||||||||||||||||
raise OptionalImportError(self.error_message.format("push_to_hub")) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
try: | ||||||||||||||||
from huggingface_hub import PyTorchModelHubMixin | ||||||||||||||||
except ImportError: | ||||||||||||||||
PyTorchModelHubMixin = DummyPyTorchModelHubMixin | ||||||||||||||||
Comment on lines
+24
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @qubvel Thanks for the contribution. We have our own optional_import function to do something similar.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, great, I will update it! |
||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class MonaiHubMixin( | ||||||||||||||||
PyTorchModelHubMixin, | ||||||||||||||||
library_name="monai", | ||||||||||||||||
repo_url="https://github.com/Project-MONAI/MONAI", | ||||||||||||||||
docs_url="https://docs.monai.io/en/", | ||||||||||||||||
tags=["monai"], | ||||||||||||||||
): | ||||||||||||||||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flake8 is correct in that this file is missing the license header that's in all our other source files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it 👍