From d295b327281daad9154a45e8b8c8943f9bf2455e Mon Sep 17 00:00:00 2001 From: jingz-db Date: Tue, 24 Sep 2024 13:23:23 -0700 Subject: [PATCH] send with binary type --- python/pyspark/sql/pandas/group_ops.py | 49 ++++++++++------- .../stateful_processor_api_client.py | 53 ++++++++++--------- .../test_pandas_transform_with_state.py | 8 ++- ...ransformWithStateInPandasStateServer.scala | 8 +-- 4 files changed, 67 insertions(+), 51 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index f71943a802ccf..9cf44d85434bd 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -503,18 +503,8 @@ def transformWithStateUDF( statefulProcessorApiClient.set_implicit_key(key) - if timeMode != "none": - batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() - watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() - else: - batch_timestamp = -1 - watermark_timestamp = -1 - # process with invalid expiry timer info and emit data rows - data_iter = statefulProcessor.handleInputRows( - key, inputRows, TimerValues(batch_timestamp, watermark_timestamp), ExpiredTimerInfo(False)) - statefulProcessorApiClient.set_handle_state( - StatefulProcessorHandleState.DATA_PROCESSED - ) + batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() + watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() if timeMode == "processingtime": expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator(batch_timestamp) @@ -523,17 +513,36 @@ def transformWithStateUDF( else: expiry_list_iter = [] + # process with invalid expiry timer info and emit data rows + data_iter = statefulProcessor.handleInputRows( + key, inputRows, TimerValues(batch_timestamp, watermark_timestamp), ExpiredTimerInfo(False)) + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.DATA_PROCESSED + ) + result_iter_list = [data_iter] + if len(expiry_list_iter) > 0: + raise Exception(f"i wonder key equals to row, key is: {key}, " + f"key type: {type(key)}" + f"row is: {expiry_list_iter[0][0]}, " + f"equals: {key[0] == expiry_list_iter[0][0]}") # process with valid expiry time info and with empty input rows, # only timer related rows will be emitted - for expiry_list in expiry_list_iter: - for key_obj, expiry_timestamp in expiry_list: - if (timeMode == "processingtime" and expiry_timestamp < batch_timestamp) or\ - (timeMode == "eventtime" and expiry_timestamp < watermark_timestamp): - result_iter_list.append(statefulProcessor.handleInputRows( - (key_obj,), iter([]), - TimerValues(batch_timestamp, watermark_timestamp), - ExpiredTimerInfo(True, expiry_timestamp))) + """ + if expiry_list_iter is not None: + for expiry_list in expiry_list_iter: + for key_obj, expiry_timestamp in expiry_list: + if timeMode == "processingtime" and expiry_timestamp < batch_timestamp: + result_iter_list.append(statefulProcessor.handleInputRows( + (key_obj,), iter([]), + TimerValues(batch_timestamp, watermark_timestamp), + ExpiredTimerInfo(True, expiry_timestamp))) + elif timeMode == "eventtime" and expiry_timestamp < watermark_timestamp: + result_iter_list.append(statefulProcessor.handleInputRows( + (key_obj,), iter([]), + TimerValues(batch_timestamp, watermark_timestamp), + ExpiredTimerInfo(True, expiry_timestamp))) + """ # TODO(SPARK-49603) set the handle state in the lazily initialized iterator diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 9e1ca6490eac6..223972412040b 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -185,33 +185,36 @@ def list_timers(self) -> Iterator[list[int]]: # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}") - def get_expiry_timers_iterator(self, expiry_timestamp: int) -> Iterator[list[Any, int]]: + def get_expiry_timers_iterator(self, expiry_timestamp: int) -> list[Any, int]: import pyspark.sql.streaming.StateMessage_pb2 as stateMessage - while True: - expiry_timer_call = stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp) - timer_request = stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call) - message = stateMessage.StateRequest(timerRequest=timer_request) + # while True: + expiry_timer_call = stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp) + timer_request = stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call) + message = stateMessage.StateRequest(timerRequest=timer_request) - self._send_proto_message(message.SerializeToString()) - response_message = self._receive_proto_message() - status = response_message[0] - if status == 1: - break - elif status == 0: - iterator = self._read_arrow_state() - batch = next(iterator) - result_list = [] - key_fields = [field.name for field in self.key_schema.fields] - # TODO any better way to restore a grouping object from a batch? - batch_df = batch.to_pandas() - for i in range(batch.num_rows): - key = batch_df.at[i, 'key'].get(key_fields[0]) - timestamp = batch_df.at[i, 'timestamp'].item() - result_list.append((key, timestamp)) - yield result_list - else: - # TODO(SPARK-49233): Classify user facing errors. - raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}") + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status == 1: + # break + return [] + elif status == 0: + iterator = self._read_arrow_state() + batch = next(iterator) + result_list = [] + key_fields = [field.name for field in self.key_schema.fields] + # TODO any better way to restore a grouping object from a batch? + batch_df = batch.to_pandas() + for i in range(batch.num_rows): + d_k = self.pickleSer.loads(batch_df.at[i, 'key']) + # raise Exception(f"I am in expiry timestamp list, {d_k}") + timestamp = batch_df.at[i, 'timestamp'].item() + result_list.append((d_k, timestamp)) + # yield result_list + return result_list + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}") def get_batch_timestamp(self) -> int: import pyspark.sql.streaming.StateMessage_pb2 as stateMessage diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 77751fb2942d7..bc7adec0ef283 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -88,6 +88,7 @@ def _build_test_df(self, input_path): ) return df_final + """ def _test_transform_with_state_in_pandas_basic( self, stateful_processor, check_results, single_batch=False, timeMode="None" ): @@ -309,6 +310,7 @@ def check_results(batch_df, batch_id): finally: input_dir.cleanup() + """ def _test_transform_with_state_in_pandas_proc_timer( self, stateful_processor, check_results): input_path = tempfile.mkdtemp() @@ -366,6 +368,7 @@ def check_timestamp(batch_df): def check_results(batch_df, batch_id): if batch_id == 0: + print(f"batch_df here: {batch_df.show()}\n") assert set(batch_df.sort("id").select("id", "countAsString").collect()) == { Row(id="0", countAsString="1"), Row(id="1", countAsString="1"), @@ -396,7 +399,8 @@ def check_results(batch_df, batch_id): assert(current_batch_expired_timestamp > self.first_expired_timestamp) self._test_transform_with_state_in_pandas_proc_timer(ProcTimeStatefulProcessor(), check_results) - + """ + def _test_transform_with_state_in_pandas_event_time(self, stateful_processor, check_results): import pyspark.sql.functions as f @@ -478,7 +482,7 @@ def check_results(batch_df, batch_id): } self._test_transform_with_state_in_pandas_event_time(EventTimeStatefulProcessor(), check_results) - + """ # A stateful processor that output the max event time it has seen. Register timer for # current watermark. Clear max state if timer expires. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 63a4efccb0272..385a5769606f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateResponseWithLongTypeVal, StateVariableRequest, TimerRequest, TimerStateCallCommand, TimerValueRequest, ValueStateCall} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} -import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, LongType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils /** @@ -160,7 +160,7 @@ class TransformWithStateInPandasStateServer( val expiryTimestamp = expiryRequest.getExpiryTimestampMs if (!expiryTimestampIter.isDefined) { expiryTimestampIter = - Option(statefulProcessorHandle.getExpiredTimersWithKeyRow(expiryTimestamp)) + Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp)) } // expiryTimestampIter could be None in the TWSPandasServerSuite if (!expiryTimestampIter.isDefined || !expiryTimestampIter.get.hasNext) { @@ -169,10 +169,10 @@ class TransformWithStateInPandasStateServer( } else { sendResponse(0) val outputSchema = new StructType() - .add("key", groupingKeySchema) + .add("key", BinaryType) .add(StructField("timestamp", LongType)) sendIteratorAsArrowBatches(expiryTimestampIter.get, outputSchema) { data => - InternalRow(data._1, data._2) + InternalRow(PythonSQLUtils.toPyRow(data._1.asInstanceOf[Row]), data._2) } }