Skip to content

Commit

Permalink
Catch delta table not found error (#1625)
Browse files Browse the repository at this point in the history
🚢
  • Loading branch information
milocress authored Oct 30, 2024
1 parent e7dddd2 commit 30ab45d
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 20 deletions.
24 changes: 24 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from llmfoundry.utils.exceptions import (
ClusterDoesNotExistError,
ClusterInvalidAccessMode,
DeltaTableNotFoundError,
FailedToConnectToDatabricksError,
FailedToCreateSQLConnectionError,
FaultyDataPrepCluster,
Expand Down Expand Up @@ -503,6 +504,29 @@ def fetch(
raise InsufficientPermissionsError(str(e)) from e
elif 'UC_NOT_ENABLED' in str(e):
raise UCNotEnabledError() from e
elif 'DELTA_TABLE_NOT_FOUND' in str(e):
err_str = str(e)
# Error string should be in this format:
# ---
# Error processing `catalog`.`volume_name`.`table_name`:
# [DELTA_TABLE_NOT_FOUND] Delta table `volume_name`.`table_name`
# doesn't exist.
# ---
parts = err_str.split('`')
if len(parts) < 7:
# Failed to parse error, our codebase is brittle
# with respect to the string representations of
# errors in the spark library.
catalog_name, volume_name, table_name = ['unknown'] * 3
else:
catalog_name = parts[1]
volume_name = parts[3]
table_name = parts[5]
raise DeltaTableNotFoundError(
catalog_name,
volume_name,
table_name,
) from e

if isinstance(e, InsufficientPermissionsError):
raise
Expand Down
22 changes: 21 additions & 1 deletion llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
'MisconfiguredHfDatasetError',
'DatasetTooSmallError',
'RunTimeoutError',
'UCNotEnabledError',
'DeltaTableNotFoundError',
]

ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
Expand Down Expand Up @@ -530,5 +532,23 @@ class UCNotEnabledError(UserError):
"""Error thrown when user does not have UC enabled on their cluster."""

def __init__(self) -> None:
message = f'Unity Catalog is not enabled on your cluster.'
message = 'Unity Catalog is not enabled on your cluster.'
super().__init__(message)


class DeltaTableNotFoundError(UserError):
"""Error thrown when the delta table passed in training doesn't exist."""

def __init__(
self,
catalog_name: str,
volume_name: str,
table_name: str,
) -> None:
message = f'Your data path {catalog_name}.{volume_name}.{table_name} does not exist. Please double check your delta table name'
super().__init__(
message=message,
catalog_name=catalog_name,
volume_name=volume_name,
table_name=table_name,
)
76 changes: 57 additions & 19 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import MagicMock, mock_open, patch

import grpc
from pyspark.errors import AnalysisException

from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
FaultyDataPrepCluster,
Expand All @@ -19,6 +20,7 @@
iterative_combine_jsons,
run_query,
)
from llmfoundry.utils.exceptions import DeltaTableNotFoundError


class TestConvertDeltaToJsonl(unittest.TestCase):
Expand Down Expand Up @@ -139,25 +141,24 @@ def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any):

mock_listdir.assert_called_once_with(json_directory)
mock_file.assert_called()
"""
Diagnostic print
for call_args in mock_file().write.call_args_list:
print(call_args)
--------------------
call('{')
call('"key"')
call(': ')
call('"value"')
call('}')
call('\n')
call('{')
call('"key"')
call(': ')
call('"value"')
call('}')
call('\n')
--------------------
"""
# Diagnostic print
# for call_args in mock_file().write.call_args_list:
# print(call_args)
# --------------------
# call('{')
# call('"key"')
# call(': ')
# call('"value"')
# call('}')
# call('\n')
# call('{')
# call('"key"')
# call(': ')
# call('"value"')
# call('}')
# call('\n')
# --------------------

self.assertEqual(mock_file().write.call_count, 2)

@patch(
Expand Down Expand Up @@ -582,3 +583,40 @@ def test_fetch_DT_grpc_error_handling(

# Verify that fetch was called
mock_fetch.assert_called_once()

@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.get_total_rows',
)
def test_fetch_nonexistent_table_error(
self,
mock_gtr: MagicMock,
):
# Create a spark.AnalysisException with specific details
analysis_exception = AnalysisException(
message='[DELTA_TABLE_NOT_FOUND] yada yada',
)

# Configure the fetch function to raise the AnalysisException
mock_gtr.side_effect = analysis_exception

# Test inputs
method = 'dbsql'
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'

# Act & Assert
with self.assertRaises(DeltaTableNotFoundError) as context:
fetch(
method=method,
tablename=delta_table_name,
json_output_folder=json_output_folder,
)

# Verify that the DeltaTableNotFoundError contains the expected message
self.assertIn(
'Please double check your delta table name',
str(context.exception),
)

# Verify that get_total_rows was called
mock_gtr.assert_called_once()

0 comments on commit 30ab45d

Please sign in to comment.