From 3c13f1ce206907dcd2a1a24edd4733473396530b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Casaj=C3=BAs?= Date: Thu, 17 Oct 2024 11:16:33 +0200 Subject: [PATCH] Have the dead letter also take events to avoid race conditions (#2267) * Have the dead letter also take events to avoid race conditions * Ensure we take the event * Tests for event taken * Rename --- app/models.py | 15 +++-- events/event_source.py | 32 +++++----- oneshot/send_plan_change_events.py | 63 +++++++++++++++++++ tests/events/test_dead_letter_event_source.py | 54 ++++++++++++++++ tests/events/test_sent_events.py | 31 ++++++++- 5 files changed, 172 insertions(+), 23 deletions(-) create mode 100644 oneshot/send_plan_change_events.py create mode 100644 tests/events/test_dead_letter_event_source.py diff --git a/app/models.py b/app/models.py index 5e3c55080..536f99042 100644 --- a/app/models.py +++ b/app/models.py @@ -3771,15 +3771,14 @@ class SyncEvent(Base, ModelMixin): sa.Index("ix_sync_event_taken_time", "taken_time"), ) - def mark_as_taken(self) -> bool: - sql = """ - UPDATE sync_event - SET taken_time = :taken_time - WHERE id = :sync_event_id - AND taken_time IS NULL - """ + def mark_as_taken(self, allow_taken_older_than: Optional[Arrow] = None) -> bool: + taken_condition = ["taken_time IS NULL"] args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id} - + if allow_taken_older_than: + taken_condition.append("taken_time < :taken_older_than") + args["taken_older_than"] = allow_taken_older_than.datetime + sql_taken_condition = "({})".format(" OR ".join(taken_condition)) + sql = f"UPDATE sync_event SET taken_time = :taken_time WHERE id = :sync_event_id AND {sql_taken_condition}" res = Session.execute(sql, args) Session.commit() diff --git a/events/event_source.py b/events/event_source.py index 3633ec742..2b5920c3f 100644 --- a/events/event_source.py +++ b/events/event_source.py @@ -85,24 +85,28 @@ class DeadLetterEventSource(EventSource): def __init__(self, max_retries: int): self.__max_retries = max_retries + def execute_loop( + self, on_event: Callable[[SyncEvent], NoReturn] + ) -> list[SyncEvent]: + threshold = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES) + events = SyncEvent.get_dead_letter( + older_than=threshold, max_retries=self.__max_retries + ) + if events: + LOG.info(f"Got {len(events)} dead letter events") + newrelic.agent.record_custom_metric( + "Custom/dead_letter_events_to_process", len(events) + ) + for event in events: + if event.mark_as_taken(allow_taken_older_than=threshold): + on_event(event) + return events + @newrelic.agent.background_task() def run(self, on_event: Callable[[SyncEvent], NoReturn]): while True: try: - threshold = arrow.utcnow().shift( - minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - ) - events = SyncEvent.get_dead_letter( - older_than=threshold, max_retries=self.__max_retries - ) - if events: - LOG.info(f"Got {len(events)} dead letter events") - if events: - newrelic.agent.record_custom_metric( - "Custom/dead_letter_events_to_process", len(events) - ) - for event in events: - on_event(event) + events = self.execute_loop(on_event) Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock if not events: LOG.debug("No dead letter events") diff --git a/oneshot/send_plan_change_events.py b/oneshot/send_plan_change_events.py new file mode 100644 index 000000000..4fcaad2dd --- /dev/null +++ b/oneshot/send_plan_change_events.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +import argparse +import time + +from sqlalchemy import func + +from app.events.event_dispatcher import EventDispatcher +from app.events.generated.event_pb2 import UserPlanChanged, EventContent +from app.models import PartnerUser +from app.db import Session + +parser = argparse.ArgumentParser( + prog="Backfill alias", description="Update alias notes and backfill flag" +) +parser.add_argument( + "-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id" +) +parser.add_argument( + "-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id" +) + +args = parser.parse_args() +pu_id_start = args.start_pu_id +max_pu_id = args.end_pu_id +if max_pu_id == 0: + max_pu_id = Session.query(func.max(PartnerUser.id)).scalar() + +print(f"Checking partner user {pu_id_start} to {max_pu_id}") +step = 100 +updated = 0 +start_time = time.time() +with_premium = 0 +for batch_start in range(pu_id_start, max_pu_id, step): + partner_users = ( + Session.query(PartnerUser).filter( + PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step + ) + ).all() + for partner_user in partner_users: + subscription_end = partner_user.user.get_active_subscription_end( + include_partner_subscription=False + ) + end_timestamp = None + if subscription_end: + with_premium += 1 + end_timestamp = subscription_end.timestamp + event = UserPlanChanged(plan_end_time=end_timestamp) + EventDispatcher.send_event( + partner_user.user, EventContent(user_plan_change=event) + ) + Session.flush() + updated += 1 + Session.commit() + elapsed = time.time() - start_time + last_batch_id = batch_start + step + time_per_alias = elapsed / (last_batch_id) + remaining = max_pu_id - last_batch_id + time_remaining = remaining / time_per_alias + hours_remaining = time_remaining / 60.0 + print( + f"\PartnerUser {batch_start}/{max_pu_id} {updated} {hours_remaining:.2f} mins remaining" + ) +print(f"With SL premium {with_premium}") diff --git a/tests/events/test_dead_letter_event_source.py b/tests/events/test_dead_letter_event_source.py new file mode 100644 index 000000000..1186d59af --- /dev/null +++ b/tests/events/test_dead_letter_event_source.py @@ -0,0 +1,54 @@ +import arrow + +from app.db import Session +from app.models import SyncEvent +from events.event_source import DeadLetterEventSource, _DEAD_LETTER_THRESHOLD_MINUTES + + +class EventCounter: + def __init__(self): + self.processed_events = 0 + + def on_event(self, event: SyncEvent): + self.processed_events += 1 + + +def setup_function(func): + Session.query(SyncEvent).delete() + + +def test_dead_letter_does_not_take_untaken_events(): + source = DeadLetterEventSource(1) + counter = EventCounter() + threshold_time = arrow.utcnow().shift(minutes=-(_DEAD_LETTER_THRESHOLD_MINUTES) + 1) + SyncEvent.create( + content="test".encode("utf-8"), created_at=threshold_time, flush=True + ) + SyncEvent.create( + content="test".encode("utf-8"), taken_time=threshold_time, flush=True + ) + events_processed = source.execute_loop(on_event=counter.on_event) + assert len(events_processed) == 0 + assert counter.processed_events == 0 + + +def test_dead_letter_takes_untaken_events_created_older_than_threshold(): + source = DeadLetterEventSource(1) + counter = EventCounter() + old_create = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1) + SyncEvent.create(content="test".encode("utf-8"), created_at=old_create, flush=True) + events_processed = source.execute_loop(on_event=counter.on_event) + assert len(events_processed) == 1 + assert events_processed[0].taken_time > old_create + assert counter.processed_events == 1 + + +def test_dead_letter_takes_taken_events_created_older_than_threshold(): + source = DeadLetterEventSource(1) + counter = EventCounter() + old_taken = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1) + SyncEvent.create(content="test".encode("utf-8"), taken_time=old_taken, flush=True) + events_processed = source.execute_loop(on_event=counter.on_event) + assert len(events_processed) == 1 + assert events_processed[0].taken_time > old_taken + assert counter.processed_events == 1 diff --git a/tests/events/test_sent_events.py b/tests/events/test_sent_events.py index 805abdf91..367cdf32c 100644 --- a/tests/events/test_sent_events.py +++ b/tests/events/test_sent_events.py @@ -1,7 +1,9 @@ +import arrow + from app import config, alias_utils from app.db import Session from app.events.event_dispatcher import GlobalDispatcher -from app.models import Alias +from app.models import Alias, SyncEvent from tests.utils import random_token from .event_test_utils import ( OnMemoryDispatcher, @@ -26,6 +28,33 @@ def setup_function(func): on_memory_dispatcher.clear() +def test_event_taken_updates(): + event = SyncEvent.create(content="test".encode("utf-8"), flush=True) + assert event.taken_time is None + assert event.mark_as_taken() + assert event.taken_time is not None + + +def test_event_mark_as_taken_does_nothing_for_taken_events(): + now = arrow.utcnow() + event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True) + assert not event.mark_as_taken() + + +def test_event_mark_as_taken_does_nothing_for_not_before_events(): + now = arrow.utcnow() + event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True) + older_than = now.shift(minutes=-1) + assert not event.mark_as_taken(allow_taken_older_than=older_than) + + +def test_event_mark_as_taken_works_for_before_events(): + now = arrow.utcnow() + event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True) + older_than = now.shift(minutes=+1) + assert event.mark_as_taken(allow_taken_older_than=older_than) + + def test_fire_event_on_alias_creation(): (user, pu) = _create_linked_user() alias = Alias.create_new_random(user)