diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 5e280d7f24..daae4418e2 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -21,6 +21,8 @@ from monai.networks.blocks.fcn import FCN from monai.networks.layers.factories import Act, Conv, Norm, Pool +from monai.utils import MonaiHubMixin + __all__ = ["AHnet", "Ahnet", "AHNet"] @@ -300,7 +302,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class AHNet(nn.Module): +class AHNet(nn.Module, MonaiHubMixin): """ AHNet based on `Anisotropic Hybrid Network `_. Adapted from `lsqshr's official code `_. diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index b9970d4113..77f00e8d3f 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -19,7 +19,8 @@ from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.factories import Conv, Pool -from monai.utils import ensure_tuple_rep +from monai.utils import ensure_tuple_rep, MonaiHubMixin + __all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"] @@ -175,7 +176,7 @@ def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): return x -class BasicUNet(nn.Module): +class BasicUNet(nn.Module, MonaiHubMixin): def __init__( self, diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 2c32eb2cf4..1167ea2b93 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -151,3 +151,4 @@ get_numpy_dtype_from_string, get_torch_dtype_from_string, ) +from .hub_mixin import MonaiHubMixin diff --git a/monai/utils/hub_mixin.py b/monai/utils/hub_mixin.py new file mode 100644 index 0000000000..2a3bdbbdf7 --- /dev/null +++ b/monai/utils/hub_mixin.py @@ -0,0 +1,37 @@ +from monai.utils import OptionalImportError + +__all__ = ["MonaiHubMixin"] + + +class DummyPyTorchModelHubMixin: + + error_message = "To use `{}` method please required packages: `pip install huggingface_hub safetensors`." + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + raise OptionalImportError(cls.error_message.format("from_pretrained")) + + def save_pretrained(self, *args, **kwargs): + raise OptionalImportError(self.error_message.format("save_pretrained")) + + def push_to_hub(self, *args, **kwargs): + raise OptionalImportError(self.error_message.format("push_to_hub")) + + +try: + from huggingface_hub import PyTorchModelHubMixin +except ImportError: + PyTorchModelHubMixin = DummyPyTorchModelHubMixin + + +class MonaiHubMixin( + PyTorchModelHubMixin, + library_name="monai", + repo_url="https://github.com/Project-MONAI/MONAI", + docs_url="https://docs.monai.io/en/", + tags=["monai"], +): + pass