From 1362cda809281ebb4fa2999d6522657154e70e4b Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Wed, 25 Sep 2024 15:54:39 +0200 Subject: [PATCH] Suppress pickle security issues in pytorch materializer --- .../pytorch/materializers/base_pytorch_materializer.py | 10 ++++++++-- .../materializers/pytorch_module_materializer.py | 5 ++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py index bad8ecc56f0..d1e786edadf 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py @@ -41,7 +41,10 @@ def load(self, data_type: Type[Any]) -> Any: The loaded PyTorch object. """ with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f: - return torch.load(f) + # NOTE (security): The `torch.load` function uses `pickle` as + # the default unpickler, which is NOT secure. This materializer + # is intended for use with trusted data sources. + return torch.load(f) # nosec def save(self, obj: Any) -> None: """Uses `torch.save` to save a PyTorch object. @@ -50,7 +53,10 @@ def save(self, obj: Any) -> None: obj: The PyTorch object to save. """ with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f: - torch.save(obj, f, pickle_module=cloudpickle) + # NOTE (security): The `torch.save` function uses `cloudpickle` as + # the default unpickler, which is NOT secure. This materializer + # is intended for use with trusted data sources. + torch.save(obj, f, pickle_module=cloudpickle) # nosec # Alias for the BasePyTorchMaterializer class, allowing users that have already used diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py index 38ffea18de6..86e36572c8d 100644 --- a/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py @@ -61,7 +61,10 @@ def save(self, model: Module) -> None: with fileio.open( os.path.join(self.uri, CHECKPOINT_FILENAME), "wb" ) as f: - torch.save(model.state_dict(), f, pickle_module=cloudpickle) + # NOTE (security): The `torch.save` function uses `cloudpickle` as + # the default unpickler, which is NOT secure. This materializer + # is intended for use with trusted data sources. + torch.save(model.state_dict(), f, pickle_module=cloudpickle) # nosec def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]: """Extract metadata from the given `Model` object.