diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 0c7ffe6e2c..2b8d148781 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -692,6 +692,7 @@ def fetch_DT( formatted_delta_table_name = format_tablename(delta_table_name) import grpc + import pyspark.errors.exceptions.connect as spark_errors try: fetch( method, @@ -702,8 +703,16 @@ def fetch_DT( sparkSession, dbsql, ) - except grpc.RpcError as e: - if e.code( + except (grpc.RpcError, spark_errors.SparkConnectGrpcException) as e: + if isinstance( + e, + spark_errors.SparkConnectGrpcException, + ) and 'Cannot start cluster' in str(e): + raise FaultyDataPrepCluster( + message= + f'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive. {e}', + ) from e + if isinstance(e, grpc.RpcError) and e.code( ) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details( ): raise FaultyDataPrepCluster( diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 3b8a874a9d..5292f86e6d 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -9,6 +9,7 @@ import grpc from pyspark.errors import AnalysisException +from pyspark.errors.exceptions.connect import SparkConnectGrpcException from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( FaultyDataPrepCluster, @@ -584,6 +585,58 @@ def test_fetch_DT_grpc_error_handling( # Verify that fetch was called mock_fetch.assert_called_once() + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info', + ) + def test_fetch_DT_catches_cluster_failed_to_start( + self, + mock_validate_cluster_info: MagicMock, + mock_fetch: MagicMock, + ): + # Arrange + # Mock the validate_and_get_cluster_info to return test values + mock_validate_cluster_info.return_value = ('dbconnect', None, None) + + # Create a SparkConnectGrpcException indicating that the cluster failed to start + + grpc_error = SparkConnectGrpcException( + message='Cannot start cluster etc...', + ) + + # Configure the fetch function to raise the SparkConnectGrpcException + mock_fetch.side_effect = grpc_error + + # Test inputs + delta_table_name = 'test_table' + json_output_folder = '/tmp/to/jsonl' + http_path = None + cluster_id = None + use_serverless = False + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'test-token' + + # Act & Assert + with self.assertRaises(FaultyDataPrepCluster) as context: + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + use_serverless=use_serverless, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + + # Verify that the FaultyDataPrepCluster contains the expected message + self.assertIn( + 'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.', + str(context.exception), + ) + + # Verify that fetch was called + mock_fetch.assert_called_once() + @patch( 'llmfoundry.command_utils.data_prep.convert_delta_to_json.get_total_rows', )