diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py index bad8ecc56f0..d1e786edadf 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py @@ -41,7 +41,10 @@ def load(self, data_type: Type[Any]) -> Any: The loaded PyTorch object. """ with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f: - return torch.load(f) + # NOTE (security): The `torch.load` function uses `pickle` as + # the default unpickler, which is NOT secure. This materializer + # is intended for use with trusted data sources. + return torch.load(f) # nosec def save(self, obj: Any) -> None: """Uses `torch.save` to save a PyTorch object. @@ -50,7 +53,10 @@ def save(self, obj: Any) -> None: obj: The PyTorch object to save. """ with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f: - torch.save(obj, f, pickle_module=cloudpickle) + # NOTE (security): The `torch.save` function uses `cloudpickle` as + # the default unpickler, which is NOT secure. This materializer + # is intended for use with trusted data sources. + torch.save(obj, f, pickle_module=cloudpickle) # nosec # Alias for the BasePyTorchMaterializer class, allowing users that have already used diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py index 38ffea18de6..86e36572c8d 100644 --- a/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py @@ -61,7 +61,10 @@ def save(self, model: Module) -> None: with fileio.open( os.path.join(self.uri, CHECKPOINT_FILENAME), "wb" ) as f: - torch.save(model.state_dict(), f, pickle_module=cloudpickle) + # NOTE (security): The `torch.save` function uses `cloudpickle` as + # the default unpickler, which is NOT secure. This materializer + # is intended for use with trusted data sources. + torch.save(model.state_dict(), f, pickle_module=cloudpickle) # nosec def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]: """Extract metadata from the given `Model` object.