Skip to content

Commit

Permalink
send with binary type
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Sep 24, 2024
1 parent ccba830 commit d295b32
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 51 deletions.
49 changes: 29 additions & 20 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,18 +503,8 @@ def transformWithStateUDF(

statefulProcessorApiClient.set_implicit_key(key)

if timeMode != "none":
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp()
else:
batch_timestamp = -1
watermark_timestamp = -1
# process with invalid expiry timer info and emit data rows
data_iter = statefulProcessor.handleInputRows(
key, inputRows, TimerValues(batch_timestamp, watermark_timestamp), ExpiredTimerInfo(False))
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp()

if timeMode == "processingtime":
expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator(batch_timestamp)
Expand All @@ -523,17 +513,36 @@ def transformWithStateUDF(
else:
expiry_list_iter = []

# process with invalid expiry timer info and emit data rows
data_iter = statefulProcessor.handleInputRows(
key, inputRows, TimerValues(batch_timestamp, watermark_timestamp), ExpiredTimerInfo(False))
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)

result_iter_list = [data_iter]
if len(expiry_list_iter) > 0:
raise Exception(f"i wonder key equals to row, key is: {key}, "
f"key type: {type(key)}"
f"row is: {expiry_list_iter[0][0]}, "
f"equals: {key[0] == expiry_list_iter[0][0]}")
# process with valid expiry time info and with empty input rows,
# only timer related rows will be emitted
for expiry_list in expiry_list_iter:
for key_obj, expiry_timestamp in expiry_list:
if (timeMode == "processingtime" and expiry_timestamp < batch_timestamp) or\
(timeMode == "eventtime" and expiry_timestamp < watermark_timestamp):
result_iter_list.append(statefulProcessor.handleInputRows(
(key_obj,), iter([]),
TimerValues(batch_timestamp, watermark_timestamp),
ExpiredTimerInfo(True, expiry_timestamp)))
"""
if expiry_list_iter is not None:
for expiry_list in expiry_list_iter:
for key_obj, expiry_timestamp in expiry_list:
if timeMode == "processingtime" and expiry_timestamp < batch_timestamp:
result_iter_list.append(statefulProcessor.handleInputRows(
(key_obj,), iter([]),
TimerValues(batch_timestamp, watermark_timestamp),
ExpiredTimerInfo(True, expiry_timestamp)))
elif timeMode == "eventtime" and expiry_timestamp < watermark_timestamp:
result_iter_list.append(statefulProcessor.handleInputRows(
(key_obj,), iter([]),
TimerValues(batch_timestamp, watermark_timestamp),
ExpiredTimerInfo(True, expiry_timestamp)))
"""

# TODO(SPARK-49603) set the handle state in the lazily initialized iterator

Expand Down
53 changes: 28 additions & 25 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,33 +185,36 @@ def list_timers(self) -> Iterator[list[int]]:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}")

def get_expiry_timers_iterator(self, expiry_timestamp: int) -> Iterator[list[Any, int]]:
def get_expiry_timers_iterator(self, expiry_timestamp: int) -> list[Any, int]:
import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
while True:
expiry_timer_call = stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
timer_request = stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
message = stateMessage.StateRequest(timerRequest=timer_request)
# while True:
expiry_timer_call = stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
timer_request = stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
message = stateMessage.StateRequest(timerRequest=timer_request)

self._send_proto_message(message.SerializeToString())
response_message = self._receive_proto_message()
status = response_message[0]
if status == 1:
break
elif status == 0:
iterator = self._read_arrow_state()
batch = next(iterator)
result_list = []
key_fields = [field.name for field in self.key_schema.fields]
# TODO any better way to restore a grouping object from a batch?
batch_df = batch.to_pandas()
for i in range(batch.num_rows):
key = batch_df.at[i, 'key'].get(key_fields[0])
timestamp = batch_df.at[i, 'timestamp'].item()
result_list.append((key, timestamp))
yield result_list
else:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}")
self._send_proto_message(message.SerializeToString())
response_message = self._receive_proto_message()
status = response_message[0]
if status == 1:
# break
return []
elif status == 0:
iterator = self._read_arrow_state()
batch = next(iterator)
result_list = []
key_fields = [field.name for field in self.key_schema.fields]
# TODO any better way to restore a grouping object from a batch?
batch_df = batch.to_pandas()
for i in range(batch.num_rows):
d_k = self.pickleSer.loads(batch_df.at[i, 'key'])
# raise Exception(f"I am in expiry timestamp list, {d_k}")
timestamp = batch_df.at[i, 'timestamp'].item()
result_list.append((d_k, timestamp))
# yield result_list
return result_list
else:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}")

def get_batch_timestamp(self) -> int:
import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _build_test_df(self, input_path):
)
return df_final

"""
def _test_transform_with_state_in_pandas_basic(
self, stateful_processor, check_results, single_batch=False, timeMode="None"
):
Expand Down Expand Up @@ -309,6 +310,7 @@ def check_results(batch_df, batch_id):
finally:
input_dir.cleanup()
"""
def _test_transform_with_state_in_pandas_proc_timer(
self, stateful_processor, check_results):
input_path = tempfile.mkdtemp()
Expand Down Expand Up @@ -366,6 +368,7 @@ def check_timestamp(batch_df):

def check_results(batch_df, batch_id):
if batch_id == 0:
print(f"batch_df here: {batch_df.show()}\n")
assert set(batch_df.sort("id").select("id", "countAsString").collect()) == {
Row(id="0", countAsString="1"),
Row(id="1", countAsString="1"),
Expand Down Expand Up @@ -396,7 +399,8 @@ def check_results(batch_df, batch_id):
assert(current_batch_expired_timestamp > self.first_expired_timestamp)

self._test_transform_with_state_in_pandas_proc_timer(ProcTimeStatefulProcessor(), check_results)

"""
def _test_transform_with_state_in_pandas_event_time(self, stateful_processor, check_results):
import pyspark.sql.functions as f
Expand Down Expand Up @@ -478,7 +482,7 @@ def check_results(batch_df, batch_id):
}
self._test_transform_with_state_in_pandas_event_time(EventTimeStatefulProcessor(), check_results)

"""

# A stateful processor that output the max event time it has seen. Register timer for
# current watermark. Clear max state if timer expires.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState}
import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateResponseWithLongTypeVal, StateVariableRequest, TimerRequest, TimerStateCallCommand, TimerValueRequest, ValueStateCall}
import org.apache.spark.sql.streaming.{TTLConfig, ValueState}
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.types.{BinaryType, LongType, StructField, StructType}
import org.apache.spark.sql.util.ArrowUtils

/**
Expand Down Expand Up @@ -160,7 +160,7 @@ class TransformWithStateInPandasStateServer(
val expiryTimestamp = expiryRequest.getExpiryTimestampMs
if (!expiryTimestampIter.isDefined) {
expiryTimestampIter =
Option(statefulProcessorHandle.getExpiredTimersWithKeyRow(expiryTimestamp))
Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp))
}
// expiryTimestampIter could be None in the TWSPandasServerSuite
if (!expiryTimestampIter.isDefined || !expiryTimestampIter.get.hasNext) {
Expand All @@ -169,10 +169,10 @@ class TransformWithStateInPandasStateServer(
} else {
sendResponse(0)
val outputSchema = new StructType()
.add("key", groupingKeySchema)
.add("key", BinaryType)
.add(StructField("timestamp", LongType))
sendIteratorAsArrowBatches(expiryTimestampIter.get, outputSchema) { data =>
InternalRow(data._1, data._2)
InternalRow(PythonSQLUtils.toPyRow(data._1.asInstanceOf[Row]), data._2)
}
}

Expand Down

0 comments on commit d295b32

Please sign in to comment.