Skip to content

Commit

Permalink
fixing tests and lints
Browse files Browse the repository at this point in the history
  • Loading branch information
ethantang-db committed Oct 28, 2024
1 parent 571b056 commit 7736389
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion streaming/base/storage/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def download(self, remote: Optional[str], local: str, timeout: float = 60.0) ->
remote = pathlib.PureWindowsPath(remote).as_posix()
local = pathlib.PureWindowsPath(local).as_posix()

local_dir = os.path.dirname(local)
os.makedirs(local_dir, exist_ok=True)

self._validate_remote_path(remote)
self._download_impl(remote, local, timeout)

Expand Down Expand Up @@ -150,6 +153,7 @@ class S3Downloader(CloudDownloader):
"""Download files from AWS S3 to local filesystem."""

def __init__(self):
"""Initialize the S3 downloader."""
super().__init__()

self._s3_client: Optional[Any] = None # Hard to tell exactly what the typing of this is
Expand Down Expand Up @@ -245,7 +249,7 @@ class SFTPDownloader(CloudDownloader):
"""Download files from SFTP to local filesystem."""

def __init__(self):
"""Initialize the downloader."""
"""Initialize the SFTP downloader."""
super().__init__()

from urllib.parse import SplitResult
Expand Down Expand Up @@ -327,8 +331,10 @@ def _create_ssh_client(self, url: urllib.parse.SplitResult) -> None:


class GCSDownloader(CloudDownloader):
"""Download files from Google Cloud Storage to local filesystem."""

def __init__(self):
"""Initialize the GCS downloader."""
super().__init__()

from google.cloud.storage import Client
Expand All @@ -349,6 +355,7 @@ def clean_up(self) -> None:
self._gcs_client = None

def _download_impl(self, remote: str, local: str, timeout: float) -> None:
"""Implementation of the download function."""
from google.cloud.storage import Client

if self._gcs_client is None:
Expand Down Expand Up @@ -377,6 +384,7 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None:
raise e

def _create_gcs_client(self) -> None:
"""Create a GCS client."""
if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
from boto3.session import Session

Expand All @@ -398,8 +406,10 @@ def _create_gcs_client(self) -> None:


class OCIDownloader(CloudDownloader):
"""Download files from Oracle Cloud Infrastructure to local filesystem."""

def __init__(self):
"""Initialize the OCI downloader."""
super().__init__()

import oci
Expand All @@ -419,6 +429,7 @@ def clean_up(self) -> None:
self._oci_client = None

def _download_impl(self, remote: str, local: str, timeout: float) -> None:
"""Implementation of the download function."""
if self._oci_client is None:
self._create_oci_client()
assert self._oci_client is not None
Expand All @@ -435,6 +446,7 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None:
os.rename(local_tmp, local)

def _create_oci_client(self) -> None:
"""Create an OCI client."""
import oci

config = oci.config.from_file()
Expand All @@ -443,8 +455,10 @@ def _create_oci_client(self) -> None:


class HFDownloader(CloudDownloader):
"""Download files from Hugging Face to local filesystem."""

def __init__(self):
"""Initialize the Hugging Face downloader."""
super().__init__()

@staticmethod
Expand All @@ -461,6 +475,7 @@ def clean_up(self) -> None:
pass

def _download_impl(self, remote: str, local: str, timeout: float) -> None:
"""Implementation of the download function."""
from huggingface_hub import hf_hub_download

_, _, _, repo_org, repo_name, path = remote.split('/', 5)
Expand All @@ -475,8 +490,10 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None:


class AzureDownloader(CloudDownloader):
"""Download files from Azure to local filesystem."""

def __init__(self):
"""Initialize the Azure downloader."""
super().__init__()

from azure.storage.blob import BlobServiceClient
Expand All @@ -497,6 +514,7 @@ def clean_up(self) -> None:
self._azure_client = None

def _download_impl(self, remote: str, local: str, timeout: float) -> None:
"""Implementation of the download function."""
if self._azure_client is None:
self._create_azure_client()
assert self._azure_client is not None
Expand All @@ -513,6 +531,7 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None:
os.rename(local_tmp, local)

def _create_azure_client(self) -> None:
"""Create an Azure client."""
from azure.storage.blob import BlobServiceClient

self._azure_client = BlobServiceClient(
Expand All @@ -521,8 +540,10 @@ def _create_azure_client(self) -> None:


class AzureDataLakeDownloader(CloudDownloader):
"""Download files from Azure Data Lake to local filesystem."""

def __init__(self):
"""Initialize the Azure Data Lake downloader."""
super().__init__()

from azure.storage.filedatalake import DataLakeServiceClient
Expand All @@ -543,6 +564,7 @@ def clean_up(self) -> None:
self._azure_dl_client = None

def _download_impl(self, remote: str, local: str, timeout: float) -> None:
"""Implementation of the download function."""
from azure.core.exceptions import ResourceNotFoundError

if self._azure_dl_client is None:
Expand All @@ -564,6 +586,7 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None:
raise e

def _create_azure_dl_client(self) -> None:
"""Create an Azure Data Lake client."""
from azure.storage.filedatalake import DataLakeServiceClient

self._azure_dl_client = DataLakeServiceClient(
Expand All @@ -573,8 +596,10 @@ def _create_azure_dl_client(self) -> None:


class DatabricksUnityCatalogDownloader(CloudDownloader):
"""Download files from Databricks Unity Catalog to local filesystem."""

def __init__(self):
"""Initialize the Databricks Unity Catalog downloader."""
super().__init__()

try:
Expand All @@ -599,6 +624,7 @@ def clean_up(self) -> None:
self._db_uc_client = None

def _validate_remote_path(self, remote: str):
"""Validates the remote path for Databricks Unity Catalog client."""
path = pathlib.Path(remote)
provider_prefix = os.path.join(path.parts[0], path.parts[1])
if provider_prefix != 'dbfs:/Volumes':
Expand All @@ -607,6 +633,7 @@ def _validate_remote_path(self, remote: str):
f'Catalog, instead, got {provider_prefix} for remote={remote}.')

def _download_impl(self, remote: str, local: str, timeout: float) -> None:
"""Implementation of the download function."""
from databricks.sdk.core import DatabricksError

if self._db_uc_client is None:
Expand Down Expand Up @@ -639,13 +666,16 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None:
os.rename(local_tmp, local)

def _create_db_uc_client(self) -> None:
"""Create a Databricks Unity Catalog client."""
from databricks.sdk import WorkspaceClient
self._db_uc_client = WorkspaceClient()


class DBFSDownloader(CloudDownloader):
"""Download files from Databricks File System to local filesystem."""

def __init__(self):
"""Initialize the Databricks File System downloader."""
super().__init__()

try:
Expand All @@ -670,6 +700,7 @@ def clean_up(self) -> None:
self._dbfs_client = None

def _download_impl(self, remote: str, local: str, timeout: float) -> None:
"""Implementation of the download function."""
from databricks.sdk.core import DatabricksError

if self._dbfs_client is None:
Expand Down Expand Up @@ -700,13 +731,16 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None:
os.rename(local_tmp, local)

def _create_dbfs_client(self) -> None:
"""Create a Databricks File System client."""
from databricks.sdk import WorkspaceClient
self._dbfs_client = WorkspaceClient()


class AlipanDownloader(CloudDownloader):
"""Download files from Alipan to local filesystem."""

def __init__(self):
"""Initialize the Alipan downloader."""
super().__init__()

@staticmethod
Expand All @@ -723,6 +757,7 @@ def clean_up(self) -> None:
pass

def _download_impl(self, remote: str, local: str, timeout: float) -> None:
"""Implementation of the download function."""
from alipcs_py.alipcs import AliPCSApiMix
from alipcs_py.commands.download import download_file

Expand Down Expand Up @@ -762,8 +797,10 @@ def _download_impl(self, remote: str, local: str, timeout: float) -> None:


class LocalDownloader(CloudDownloader):
"""Download files from local filesystem to local filesystem."""

def __init__(self):
"""Initialize the Local file system downloader."""
super().__init__()

@staticmethod
Expand Down

0 comments on commit 7736389

Please sign in to comment.