Skip to content

Commit

Permalink
fix: DIA-1556: Not getting predictions for all tasks (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hakan Erol authored Oct 30, 2024
1 parent c34fd15 commit 3665a80
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
7 changes: 5 additions & 2 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ async def initialize(self):
self.kafka_input_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
auto_offset_reset="earliest",
group_id="adala-consumer-group", # TODO: make it configurable based on the environment
group_id=self.kafka_input_topic, # ensuring unique group_id to not mix up offsets between topics
)
await self.consumer.start()

self.producer = AIOKafkaProducer(
bootstrap_servers=self.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
acks='all' # waits for all replicas to respond that they have written the message
)
await self.producer.start()

Expand Down Expand Up @@ -94,7 +96,7 @@ async def message_sender(
record_no = 0
try:
for record in data:
await producer.send_and_wait(topic, value=record)
await producer.send(topic, value=record)
record_no += 1
# print_text(f"Sent message: {record} to {topic=}")
logger.info(
Expand All @@ -108,6 +110,7 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
batch = await self.consumer.getmany(
timeout_ms=self.timeout_ms, max_records=batch_size
)
await self.consumer.commit()

if len(batch) == 0:
batch_data = []
Expand Down
1 change: 1 addition & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ async def submit_batch(batch: BatchData):
producer = AIOKafkaProducer(
bootstrap_servers=settings.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
acks='all' # waits for all replicas to respond that they have written the message
)
await producer.start()

Expand Down
5 changes: 4 additions & 1 deletion server/tasks/stream_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def run_streaming(
task_time_limit=settings.task_time_limit_sec,
)
def streaming_parent_task(
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 10
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 50
):
"""
This task is used to launch the two tasks that are doing the real work, so that
Expand Down Expand Up @@ -140,7 +140,9 @@ async def async_process_streaming_output(
output_topic_name,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
auto_offset_reset="earliest",
group_id=output_topic_name, # ensuring unique group_id to not mix up offsets between topics
)
await consumer.start()
logger.info(f"consumer started {output_topic_name=}")
Expand All @@ -156,6 +158,7 @@ async def async_process_streaming_output(
try:
while not input_done.is_set():
data = await consumer.getmany(timeout_ms=timeout_ms, max_records=batch_size)
await consumer.commit()
for topic_partition, messages in data.items():
topic = topic_partition.topic
if messages:
Expand Down
12 changes: 10 additions & 2 deletions tests/test_stream_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def mock_kafka_consumer_input():
mock_consumer = MockConsumer.return_value
mock_consumer.start = AsyncMock()
mock_consumer.stop = AsyncMock()
mock_consumer.commit = AsyncMock()
mock_consumer.getmany = AsyncMock(
side_effect=[
# first call return batch
Expand Down Expand Up @@ -96,6 +97,7 @@ async def getmany_side_effect(*args, **kwargs):
mock_consumer = MockConsumer.return_value
mock_consumer.start = AsyncMock()
mock_consumer.stop = AsyncMock()
mock_consumer.commit = AsyncMock()
mock_consumer.getmany = AsyncMock(side_effect=getmany_side_effect)
yield mock_consumer

Expand All @@ -111,7 +113,13 @@ async def send_and_wait_side_effect(*args, **kwargs):
PRODUCER_SENT_DATA.set()
return AsyncMock()

async def send_side_effect(*args, **kwargs):
PRODUCER_SENT_DATA.set()
return AsyncMock()

mock_producer.send_and_wait = AsyncMock(side_effect=send_and_wait_side_effect)
mock_producer.send = AsyncMock(side_effect=send_side_effect)

yield mock_producer


Expand Down Expand Up @@ -156,6 +164,6 @@ async def test_run_streaming(
)

# Verify that producer is called with the correct amount of send_and_wait calls and data
assert mock_kafka_producer.send_and_wait.call_count == 1
assert mock_kafka_producer.send.call_count == 1
for row in TEST_OUTPUT_DATA:
mock_kafka_producer.send_and_wait.assert_any_call("output_topic", value=row)
mock_kafka_producer.send.assert_any_call("output_topic", value=row)

0 comments on commit 3665a80

Please sign in to comment.