Skip to content

Commit

Permalink
[ PECO-2065 ] Create the async execution flow for the PySQL Connector (
Browse files Browse the repository at this point in the history
…#463)

* Built the basic flow for the async pipeline - testing is remaining

* Implemented the flow for the get_execution_result, but the problem of invalid operation handle still persists

* Missed adding some files in previous commit

* Working prototype of execute_async, get_query_state and get_execution_result

* Added integration tests for execute_async

* add docs for functions

* Refractored the async code

* Fixed java doc

* Reformatted
  • Loading branch information
jprakash-db authored Nov 26, 2024
1 parent 43fa964 commit 328aeb5
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 1 deletion.
105 changes: 105 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
Expand Down Expand Up @@ -47,6 +48,7 @@

from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
TOperationState,
)


Expand Down Expand Up @@ -430,6 +432,8 @@ def __init__(
self.escaper = ParamEscaper()
self.lastrowid = None

self.ASYNC_DEFAULT_POLLING_INTERVAL = 2

# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
def __enter__(self) -> "Cursor":
return self
Expand Down Expand Up @@ -796,6 +800,7 @@ def execute(
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=False,
)
self.active_result_set = ResultSet(
self.connection,
Expand All @@ -812,6 +817,106 @@ def execute(

return self

def execute_async(
self,
operation: str,
parameters: Optional[TParameterCollection] = None,
) -> "Cursor":
"""
Execute a query and do not wait for it to complete and just move ahead
:param operation:
:param parameters:
:return:
"""
param_approach = self._determine_parameter_approach(parameters)
if param_approach == ParameterApproach.NONE:
prepared_params = NO_NATIVE_PARAMS
prepared_operation = operation

elif param_approach == ParameterApproach.INLINE:
prepared_operation, prepared_params = self._prepare_inline_parameters(
operation, parameters
)
elif param_approach == ParameterApproach.NATIVE:
normalized_parameters = self._normalize_tparametercollection(parameters)
param_structure = self._determine_parameter_structure(normalized_parameters)
transformed_operation = transform_paramstyle(
operation, normalized_parameters, param_structure
)
prepared_operation, prepared_params = self._prepare_native_parameters(
transformed_operation, normalized_parameters, param_structure
)

self._check_not_closed()
self._close_and_clear_active_result_set()
self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=True,
)

return self

def get_query_state(self) -> "TOperationState":
"""
Get the state of the async executing query or basically poll the status of the query
:return:
"""
self._check_not_closed()
return self.thrift_backend.get_query_state(self.active_op_handle)

def get_async_execution_result(self):
"""
Checks for the status of the async executing query and fetches the result if the query is finished
Otherwise it will keep polling the status of the query till there is a Not pending state
:return:
"""
self._check_not_closed()

def is_executing(operation_state) -> "bool":
return not operation_state or operation_state in [
ttypes.TOperationState.RUNNING_STATE,
ttypes.TOperationState.PENDING_STATE,
]

while is_executing(self.get_query_state()):
# Poll after some default time
time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL)

operation_state = self.get_query_state()
if operation_state == ttypes.TOperationState.FINISHED_STATE:
execute_response = self.thrift_backend.get_execution_result(
self.active_op_handle, self
)
self.active_result_set = ResultSet(
self.connection,
execute_response,
self.thrift_backend,
self.buffer_size_bytes,
self.arraysize,
)

if execute_response.is_staging_operation:
self._handle_staging_operation(
staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path
)

return self
else:
raise Error(
f"get_execution_result failed with Operation status {operation_state}"
)

def executemany(self, operation, seq_of_parameters):
"""
Execute the operation once for every set of passed in parameters.
Expand Down
76 changes: 75 additions & 1 deletion src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import threading
from typing import List, Union

from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState

try:
import pyarrow
except ImportError:
Expand Down Expand Up @@ -769,6 +771,63 @@ def _results_message_to_execute_response(self, resp, operation_state):
arrow_schema_bytes=schema_bytes,
)

def get_execution_result(self, op_handle, cursor):

assert op_handle is not None

req = ttypes.TFetchResultsReq(
operationHandle=ttypes.TOperationHandle(
op_handle.operationId,
op_handle.operationType,
False,
op_handle.modifiedRowCount,
),
maxRows=cursor.arraysize,
maxBytes=cursor.buffer_size_bytes,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
includeResultSetMetadata=True,
)

resp = self.make_request(self._client.FetchResults, req)

t_result_set_metadata_resp = resp.resultSetMetadata

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
has_more_rows = resp.hasMoreRows
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)

schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)

queue = ResultSetQueueFactory.build_queue(
row_set_type=resp.resultSetMetadata.resultFormat,
t_row_set=resp.results,
arrow_schema_bytes=schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_options=self._ssl_options,
)

return ExecuteResponse(
arrow_queue=queue,
status=resp.status,
has_been_closed_server_side=False,
has_more_rows=has_more_rows,
lz4_compressed=lz4_compressed,
is_staging_operation=is_staging_operation,
command_handle=op_handle,
description=description,
arrow_schema_bytes=schema_bytes,
)

def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
if initial_operation_status_resp:
self._check_command_not_in_error_or_closed_state(
Expand All @@ -787,6 +846,12 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
return operation_state

def get_query_state(self, op_handle) -> "TOperationState":
poll_resp = self._poll_for_status(op_handle)
operation_state = poll_resp.operationState
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
return operation_state

@staticmethod
def _check_direct_results_for_error(t_spark_direct_results):
if t_spark_direct_results:
Expand Down Expand Up @@ -817,6 +882,7 @@ def execute_command(
cursor,
use_cloud_fetch=True,
parameters=[],
async_op=False,
):
assert session_handle is not None

Expand Down Expand Up @@ -846,7 +912,11 @@ def execute_command(
parameters=parameters,
)
resp = self.make_request(self._client.ExecuteStatement, req)
return self._handle_execute_response(resp, cursor)

if async_op:
self._handle_execute_response_async(resp, cursor)
else:
return self._handle_execute_response(resp, cursor)

def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
assert session_handle is not None
Expand Down Expand Up @@ -945,6 +1015,10 @@ def _handle_execute_response(self, resp, cursor):

return self._results_message_to_execute_response(resp, final_operation_state)

def _handle_execute_response_async(self, resp, cursor):
cursor.active_op_handle = resp.operationHandle
self._check_direct_results_for_error(resp.directResults)

def fetch_results(
self,
op_handle,
Expand Down
23 changes: 23 additions & 0 deletions tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
compare_dbr_versions,
is_thrift_v5_plus,
)
from databricks.sql.thrift_api.TCLIService import ttypes
from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin
from tests.e2e.common.large_queries_mixin import LargeQueriesMixin
from tests.e2e.common.timestamp_tests import TimestampTestsMixin
Expand Down Expand Up @@ -78,6 +79,7 @@ class PySQLPytestTestCase:
}
arraysize = 1000
buffer_size_bytes = 104857600
POLLING_INTERVAL = 2

@pytest.fixture(autouse=True)
def get_details(self, connection_details):
Expand Down Expand Up @@ -175,6 +177,27 @@ def test_cloud_fetch(self):
for i in range(len(cf_result)):
assert cf_result[i] == noop_result[i]

def test_execute_async(self):
def isExecuting(operation_state):
return not operation_state or operation_state in [
ttypes.TOperationState.RUNNING_STATE,
ttypes.TOperationState.PENDING_STATE,
]

long_running_query = "SELECT COUNT(*) FROM RANGE(10000 * 16) x JOIN RANGE(10000) y ON FROM_UNIXTIME(x.id * y.id, 'yyyy-MM-dd') LIKE '%not%a%date%'"
with self.cursor() as cursor:
cursor.execute_async(long_running_query)

## Polling after every POLLING_INTERVAL seconds
while isExecuting(cursor.get_query_state()):
time.sleep(self.POLLING_INTERVAL)
log.info("Polling the status in test_execute_async")

cursor.get_async_execution_result()
result = cursor.fetchall()

assert result[0].asDict() == {"count(1)": 0}


# Exclude Retry tests because they require specific setups, and LargeQueries too slow for core
# tests
Expand Down

0 comments on commit 328aeb5

Please sign in to comment.