Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration with HF Hub #7833

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion monai/networks/nets/ahnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from monai.networks.blocks.fcn import FCN
from monai.networks.layers.factories import Act, Conv, Norm, Pool
from monai.utils import MonaiHubMixin


__all__ = ["AHnet", "Ahnet", "AHNet"]

Expand Down Expand Up @@ -300,7 +302,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class AHNet(nn.Module):
class AHNet(nn.Module, MonaiHubMixin):
"""
AHNet based on `Anisotropic Hybrid Network <https://arxiv.org/pdf/1711.08580.pdf>`_.
Adapted from `lsqshr's official code <https://github.com/lsqshr/AH-Net/blob/master/net3d.py>`_.
Expand Down
5 changes: 3 additions & 2 deletions monai/networks/nets/basic_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from monai.networks.blocks import Convolution, UpSample
from monai.networks.layers.factories import Conv, Pool
from monai.utils import ensure_tuple_rep
from monai.utils import ensure_tuple_rep, MonaiHubMixin


__all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"]

Expand Down Expand Up @@ -175,7 +176,7 @@ def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
return x


class BasicUNet(nn.Module):
class BasicUNet(nn.Module, MonaiHubMixin):

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,4 @@
get_numpy_dtype_from_string,
get_torch_dtype_from_string,
)
from .hub_mixin import MonaiHubMixin
37 changes: 37 additions & 0 deletions monai/utils/hub_mixin.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flake8 is correct in that this file is missing the license header that's in all our other source files.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it 👍

Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from monai.utils import OptionalImportError

__all__ = ["MonaiHubMixin"]


class DummyPyTorchModelHubMixin:

error_message = "To use `{}` method please required packages: `pip install huggingface_hub safetensors`."

def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__()

@classmethod
def from_pretrained(cls, *args, **kwargs):
raise OptionalImportError(cls.error_message.format("from_pretrained"))

def save_pretrained(self, *args, **kwargs):
raise OptionalImportError(self.error_message.format("save_pretrained"))

def push_to_hub(self, *args, **kwargs):
raise OptionalImportError(self.error_message.format("push_to_hub"))


try:
from huggingface_hub import PyTorchModelHubMixin
except ImportError:
PyTorchModelHubMixin = DummyPyTorchModelHubMixin
Comment on lines +24 to +27
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @qubvel Thanks for the contribution. We have our own optional_import function to do something similar.

Suggested change
try:
from huggingface_hub import PyTorchModelHubMixin
except ImportError:
PyTorchModelHubMixin = DummyPyTorchModelHubMixin
PyTorchModelHubMixin, has_hg_hub = optional_import("huggingface_hub", name="PyTorchModelHubMixin")
if not has_hg_hub:
PyTorchModelHubMixin = DummyPyTorchModelHubMixin # could put DummyPyTorchModelHubMixin definition here instead

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional_import provides its own dummy class which raises an exception whenever a member is requested, but this doesn't have __init_subclass__ so your version is needed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, great, I will update it!



class MonaiHubMixin(
PyTorchModelHubMixin,
library_name="monai",
repo_url="https://github.com/Project-MONAI/MONAI",
docs_url="https://docs.monai.io/en/",
tags=["monai"],
):
pass
Loading