From 77363892d4d40935e4f51233b57c8b73eb13fbc3 Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Mon, 28 Oct 2024 16:51:10 -0700 Subject: [PATCH] fixing tests and lints --- streaming/base/storage/download.py | 39 +++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index e1bfa0763..607fece94 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -101,6 +101,9 @@ def download(self, remote: Optional[str], local: str, timeout: float = 60.0) -> remote = pathlib.PureWindowsPath(remote).as_posix() local = pathlib.PureWindowsPath(local).as_posix() + local_dir = os.path.dirname(local) + os.makedirs(local_dir, exist_ok=True) + self._validate_remote_path(remote) self._download_impl(remote, local, timeout) @@ -150,6 +153,7 @@ class S3Downloader(CloudDownloader): """Download files from AWS S3 to local filesystem.""" def __init__(self): + """Initialize the S3 downloader.""" super().__init__() self._s3_client: Optional[Any] = None # Hard to tell exactly what the typing of this is @@ -245,7 +249,7 @@ class SFTPDownloader(CloudDownloader): """Download files from SFTP to local filesystem.""" def __init__(self): - """Initialize the downloader.""" + """Initialize the SFTP downloader.""" super().__init__() from urllib.parse import SplitResult @@ -327,8 +331,10 @@ def _create_ssh_client(self, url: urllib.parse.SplitResult) -> None: class GCSDownloader(CloudDownloader): + """Download files from Google Cloud Storage to local filesystem.""" def __init__(self): + """Initialize the GCS downloader.""" super().__init__() from google.cloud.storage import Client @@ -349,6 +355,7 @@ def clean_up(self) -> None: self._gcs_client = None def _download_impl(self, remote: str, local: str, timeout: float) -> None: + """Implementation of the download function.""" from google.cloud.storage import Client if self._gcs_client is None: @@ -377,6 +384,7 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None: raise e def _create_gcs_client(self) -> None: + """Create a GCS client.""" if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ: from boto3.session import Session @@ -398,8 +406,10 @@ def _create_gcs_client(self) -> None: class OCIDownloader(CloudDownloader): + """Download files from Oracle Cloud Infrastructure to local filesystem.""" def __init__(self): + """Initialize the OCI downloader.""" super().__init__() import oci @@ -419,6 +429,7 @@ def clean_up(self) -> None: self._oci_client = None def _download_impl(self, remote: str, local: str, timeout: float) -> None: + """Implementation of the download function.""" if self._oci_client is None: self._create_oci_client() assert self._oci_client is not None @@ -435,6 +446,7 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None: os.rename(local_tmp, local) def _create_oci_client(self) -> None: + """Create an OCI client.""" import oci config = oci.config.from_file() @@ -443,8 +455,10 @@ def _create_oci_client(self) -> None: class HFDownloader(CloudDownloader): + """Download files from Hugging Face to local filesystem.""" def __init__(self): + """Initialize the Hugging Face downloader.""" super().__init__() @staticmethod @@ -461,6 +475,7 @@ def clean_up(self) -> None: pass def _download_impl(self, remote: str, local: str, timeout: float) -> None: + """Implementation of the download function.""" from huggingface_hub import hf_hub_download _, _, _, repo_org, repo_name, path = remote.split('/', 5) @@ -475,8 +490,10 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None: class AzureDownloader(CloudDownloader): + """Download files from Azure to local filesystem.""" def __init__(self): + """Initialize the Azure downloader.""" super().__init__() from azure.storage.blob import BlobServiceClient @@ -497,6 +514,7 @@ def clean_up(self) -> None: self._azure_client = None def _download_impl(self, remote: str, local: str, timeout: float) -> None: + """Implementation of the download function.""" if self._azure_client is None: self._create_azure_client() assert self._azure_client is not None @@ -513,6 +531,7 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None: os.rename(local_tmp, local) def _create_azure_client(self) -> None: + """Create an Azure client.""" from azure.storage.blob import BlobServiceClient self._azure_client = BlobServiceClient( @@ -521,8 +540,10 @@ def _create_azure_client(self) -> None: class AzureDataLakeDownloader(CloudDownloader): + """Download files from Azure Data Lake to local filesystem.""" def __init__(self): + """Initialize the Azure Data Lake downloader.""" super().__init__() from azure.storage.filedatalake import DataLakeServiceClient @@ -543,6 +564,7 @@ def clean_up(self) -> None: self._azure_dl_client = None def _download_impl(self, remote: str, local: str, timeout: float) -> None: + """Implementation of the download function.""" from azure.core.exceptions import ResourceNotFoundError if self._azure_dl_client is None: @@ -564,6 +586,7 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None: raise e def _create_azure_dl_client(self) -> None: + """Create an Azure Data Lake client.""" from azure.storage.filedatalake import DataLakeServiceClient self._azure_dl_client = DataLakeServiceClient( @@ -573,8 +596,10 @@ def _create_azure_dl_client(self) -> None: class DatabricksUnityCatalogDownloader(CloudDownloader): + """Download files from Databricks Unity Catalog to local filesystem.""" def __init__(self): + """Initialize the Databricks Unity Catalog downloader.""" super().__init__() try: @@ -599,6 +624,7 @@ def clean_up(self) -> None: self._db_uc_client = None def _validate_remote_path(self, remote: str): + """Validates the remote path for Databricks Unity Catalog client.""" path = pathlib.Path(remote) provider_prefix = os.path.join(path.parts[0], path.parts[1]) if provider_prefix != 'dbfs:/Volumes': @@ -607,6 +633,7 @@ def _validate_remote_path(self, remote: str): f'Catalog, instead, got {provider_prefix} for remote={remote}.') def _download_impl(self, remote: str, local: str, timeout: float) -> None: + """Implementation of the download function.""" from databricks.sdk.core import DatabricksError if self._db_uc_client is None: @@ -639,13 +666,16 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None: os.rename(local_tmp, local) def _create_db_uc_client(self) -> None: + """Create a Databricks Unity Catalog client.""" from databricks.sdk import WorkspaceClient self._db_uc_client = WorkspaceClient() class DBFSDownloader(CloudDownloader): + """Download files from Databricks File System to local filesystem.""" def __init__(self): + """Initialize the Databricks File System downloader.""" super().__init__() try: @@ -670,6 +700,7 @@ def clean_up(self) -> None: self._dbfs_client = None def _download_impl(self, remote: str, local: str, timeout: float) -> None: + """Implementation of the download function.""" from databricks.sdk.core import DatabricksError if self._dbfs_client is None: @@ -700,13 +731,16 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None: os.rename(local_tmp, local) def _create_dbfs_client(self) -> None: + """Create a Databricks File System client.""" from databricks.sdk import WorkspaceClient self._dbfs_client = WorkspaceClient() class AlipanDownloader(CloudDownloader): + """Download files from Alipan to local filesystem.""" def __init__(self): + """Initialize the Alipan downloader.""" super().__init__() @staticmethod @@ -723,6 +757,7 @@ def clean_up(self) -> None: pass def _download_impl(self, remote: str, local: str, timeout: float) -> None: + """Implementation of the download function.""" from alipcs_py.alipcs import AliPCSApiMix from alipcs_py.commands.download import download_file @@ -762,8 +797,10 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None: class LocalDownloader(CloudDownloader): + """Download files from local filesystem to local filesystem.""" def __init__(self): + """Initialize the Local file system downloader.""" super().__init__() @staticmethod