From 96d0f5d7e1879d852f8a0369626b991cfd8dcf7c Mon Sep 17 00:00:00 2001 From: Ethan Tang Date: Sat, 2 Nov 2024 09:00:36 -0700 Subject: [PATCH] added more tests and fixes wrt comments --- streaming/base/storage/download.py | 2 +- streaming/base/util.py | 2 +- tests/test_download.py | 63 ++++++++++++++++++++++++------ 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index e572f0768..d92df6a1e 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -60,7 +60,7 @@ def get(cls, remote_dir: Optional[str] = None) -> 'CloudDownloader': if remote_dir is None: return _LOCAL_DOWNLOADER() - logger.info('Acquiring downloader client for remote directory %s', remote_dir) + logger.debug('Acquiring downloader client for remote directory %s', remote_dir) prefix = urllib.parse.urlparse(remote_dir).scheme if prefix == 'dbfs' and remote_dir.startswith('dbfs:/Volumes'): diff --git a/streaming/base/util.py b/streaming/base/util.py index 16913cae0..e4c4095ad 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -490,7 +490,7 @@ def retry( # type: ignore num_tries = 0 def clean_up(): - print("cleaning up") + # Do clean up stuff here @retry(RuntimeError, clean_up_fn=clean_up, num_attempts=3, initial_backoff=0.1) def flaky_function(): diff --git a/tests/test_download.py b/tests/test_download.py index b48ec711c..6fa55532d 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -19,13 +19,14 @@ from tests.conftest import GCS_URL, MY_BUCKET, R2_URL MY_PREFIX = 'train' +TEST_FILE = 'file.txt' @pytest.fixture(scope='function') def remote_local_file() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" - def _method(cloud_prefix: str = '', filename: str = 'file.txt') -> tuple[str, str]: + def _method(cloud_prefix: str = '', filename: str = TEST_FILE) -> tuple[str, str]: try: mock_local_dir = tempfile.TemporaryDirectory() mock_local_filepath = os.path.join(mock_local_dir.name, filename) @@ -114,15 +115,14 @@ class TestGCSClient: @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') def test_download_from_gcs(self, remote_local_file: Any): with tempfile.TemporaryDirectory() as tmp_dir: - file_name = 'file.txt' - tmp = os.path.join(tmp_dir, file_name) - mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://', filename=file_name) + tmp = os.path.join(tmp_dir, TEST_FILE) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://', filename=TEST_FILE) client = boto3.client('s3', region_name='us-east-1', endpoint_url=GCS_URL, aws_access_key_id=os.environ['GCS_KEY'], aws_secret_access_key=os.environ['GCS_SECRET']) - client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') + client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, TEST_FILE), Body='') downloader = GCSDownloader() downloader.download(mock_remote_filepath, tmp) assert os.path.isfile(tmp) @@ -176,20 +176,58 @@ class TestDatabricksUnityCatalog: @pytest.mark.parametrize('cloud_prefix', ['dbfs:/Volumess', 'dbfs:/Content']) def test_invalid_prefix_from_db_uc(self, remote_local_file: Any, cloud_prefix: str): with tempfile.TemporaryDirectory() as tmp_dir: - file_name = os.path.join(tmp_dir, 'file.txt') - mock_remote_filepath, _ = remote_local_file(cloud_prefix=cloud_prefix, - filename=file_name) + file_name = os.path.join(tmp_dir, TEST_FILE) + mock_remote_filepath, _ = remote_local_file(cloud_prefix=cloud_prefix) with pytest.raises(Exception, match='Expected path prefix to be `dbfs:/Volumes`.*'): downloader = DatabricksUnityCatalogDownloader() downloader.download(mock_remote_filepath, file_name) + @patch('databricks.sdk.WorkspaceClient', autospec=True) + def test_databricks_error_file_not_found(self, workspace_client_mock: Mock, + remote_local_file: Any): + from databricks.sdk.core import DatabricksError + workspace_client_mock_instance = workspace_client_mock.return_value + workspace_client_mock_instance.files = Mock() + workspace_client_mock_instance.files.download = Mock() + download_return_val = workspace_client_mock_instance.files.download.return_value + download_return_val.contents = Mock() + download_return_val.contents.__enter__ = Mock( + side_effect=DatabricksError('Error', error_code='NOT_FOUND')) + download_return_val.contents.__exit__ = Mock() + + with tempfile.TemporaryDirectory() as tmp_dir: + file_name = os.path.join(tmp_dir, TEST_FILE) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='dbfs:/Volumes') + with pytest.raises(FileNotFoundError): + downloader = DatabricksUnityCatalogDownloader() + downloader.download(mock_remote_filepath, file_name) + + @patch('databricks.sdk.WorkspaceClient', autospec=True) + def test_databricks_error(self, workspace_client_mock: Mock, remote_local_file: Any): + from databricks.sdk.core import DatabricksError + workspace_client_mock_instance = workspace_client_mock.return_value + workspace_client_mock_instance.files = Mock() + workspace_client_mock_instance.files.download = Mock() + download_return_val = workspace_client_mock_instance.files.download.return_value + download_return_val.contents = Mock() + download_return_val.contents.__enter__ = Mock( + side_effect=DatabricksError('Error', error_code='REQUEST_LIMIT_EXCEEDED')) + download_return_val.contents.__exit__ = Mock() + + with tempfile.TemporaryDirectory() as tmp_dir: + file_name = os.path.join(tmp_dir, TEST_FILE) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='dbfs:/Volumes') + with pytest.raises(DatabricksError): + downloader = DatabricksUnityCatalogDownloader() + downloader.download(mock_remote_filepath, file_name) + class TestDatabricksFileSystem: def test_invalid_prefix_from_dbfs(self, remote_local_file: Any): with tempfile.TemporaryDirectory() as tmp_dir: - file_name = os.path.join(tmp_dir, 'file.txt') - mock_remote_filepath, _ = remote_local_file(cloud_prefix='dbfsx:/', filename=file_name) + file_name = os.path.join(tmp_dir, TEST_FILE) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='dbfsx:/') with pytest.raises(Exception, match='Expected remote path to start with.*'): downloader = DBFSDownloader() downloader.download(mock_remote_filepath, file_name) @@ -198,9 +236,8 @@ def test_invalid_prefix_from_dbfs(self, remote_local_file: Any): def test_download_from_local(): mock_remote_dir = tempfile.TemporaryDirectory() mock_local_dir = tempfile.TemporaryDirectory() - file_name = 'file.txt' - mock_remote_file = os.path.join(mock_remote_dir.name, file_name) - mock_local_file = os.path.join(mock_local_dir.name, file_name) + mock_remote_file = os.path.join(mock_remote_dir.name, TEST_FILE) + mock_local_file = os.path.join(mock_local_dir.name, TEST_FILE) # Creates a new empty file with open(mock_remote_file, 'w') as _: pass