Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DIA-1584: Use send_and_wait + batches for output topic #245

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ async def initialize(self):
self.kafka_input_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
#enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
auto_offset_reset="earliest",
group_id=self.kafka_input_topic, # ensuring unique group_id to not mix up offsets between topics
#group_id=self.kafka_input_topic, # ensuring unique group_id to not mix up offsets between topics
)
await self.consumer.start()

Expand Down Expand Up @@ -95,12 +95,9 @@ async def message_sender(
):
record_no = 0
try:
for record in data:
await producer.send(topic, value=record)
record_no += 1
# print_text(f"Sent message: {record} to {topic=}")
await producer.send_and_wait(topic, value=data)
logger.info(
f"The number of records sent to topic:{topic}, record_no:{record_no}"
f"The number of records sent to topic:{topic}, record_no:{len(data)}"
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
)
finally:
pass
Expand All @@ -110,7 +107,7 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
batch = await self.consumer.getmany(
timeout_ms=self.timeout_ms, max_records=batch_size
)
await self.consumer.commit()
#await self.consumer.commit()

if len(batch) == 0:
batch_data = []
Expand All @@ -129,7 +126,7 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
return InternalDataFrame(batch_data)

async def set_predictions(self, predictions: InternalDataFrame):
predictions_iter = (r.to_dict() for _, r in predictions.iterrows())
predictions = [r.to_dict() for _, r in predictions.iterrows()]
await self.message_sender(
self.producer, predictions_iter, self.kafka_output_topic
self.producer, predictions, self.kafka_output_topic
)
71 changes: 36 additions & 35 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ gspread = "^5.12.3"
datasets = "^2.16.1"
aiohttp = "^3.9.3"
boto3 = "^1.34.38"
aiokafka = "^0.10.0"
aiokafka = "^0.11.0"
# these are for the server
# they would be installed as `extras` if poetry supported version strings for extras, but it doesn't
# https://github.com/python-poetry/poetry/issues/834
Expand Down
20 changes: 12 additions & 8 deletions server/tasks/stream_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def run_streaming(
task_time_limit=settings.task_time_limit_sec,
)
def streaming_parent_task(
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 10
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 1
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
pakelley marked this conversation as resolved.
Show resolved Hide resolved
):
"""
This task is used to launch the two tasks that are doing the real work, so that
Expand Down Expand Up @@ -140,9 +140,9 @@ async def async_process_streaming_output(
output_topic_name,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
#enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
auto_offset_reset="earliest",
group_id=output_topic_name, # ensuring unique group_id to not mix up offsets between topics
#group_id=output_topic_name, # ensuring unique group_id to not mix up offsets between topics
)
await consumer.start()
logger.info(f"consumer started {output_topic_name=}")
Expand All @@ -158,14 +158,18 @@ async def async_process_streaming_output(
try:
while not input_done.is_set():
data = await consumer.getmany(timeout_ms=timeout_ms, max_records=batch_size)
await consumer.commit()
#await consumer.commit()
for topic_partition, messages in data.items():
topic = topic_partition.topic
# messages is a list of ConsumerRecord
if messages:
logger.info(f"Processing messages in output job {topic=} number of messages: {len(messages)}")
data = [msg.value for msg in messages]
result_handler(data)
logger.info(f"Processed messages in output job {topic=} number of messages: {len(messages)}")
# batches is a list of lists
batches = [msg.value for msg in messages]
# records is a list of records to send to LSE
for records in batches:
logger.info(f"Processing messages in output job {topic=} number of messages: {len(records)}")
result_handler(records)
logger.info(f"Processed messages in output job {topic=} number of messages: {len(records)}")
else:
logger.info(f"Consumer pulled data, but no messages in {topic=}")

Expand Down
9 changes: 4 additions & 5 deletions tests/test_stream_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def getmany_side_effect(*args, **kwargs):
await PRODUCER_SENT_DATA.wait()
return {
AsyncMock(topic="output_topic_partition"): [
AsyncMock(value=row) for row in TEST_OUTPUT_DATA
AsyncMock(value=TEST_OUTPUT_DATA)
]
}

Expand Down Expand Up @@ -159,11 +159,10 @@ async def test_run_streaming(
await run_streaming(
agent=agent,
result_handler=result_handler,
batch_size=10,
batch_size=1,
output_topic_name="output_topic",
)

# Verify that producer is called with the correct amount of send_and_wait calls and data
assert mock_kafka_producer.send.call_count == 1
for row in TEST_OUTPUT_DATA:
mock_kafka_producer.send.assert_any_call("output_topic", value=row)
assert mock_kafka_producer.send_and_wait.call_count == 1
mock_kafka_producer.send_and_wait.assert_any_call("output_topic", value=TEST_OUTPUT_DATA)
Loading