From d6ca865f4c620f4c5100c4c27ef482288a1b4cb1 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Fri, 10 Jan 2025 14:14:45 -0800 Subject: [PATCH] Fix SelectConsumer reconnection with unit test coverage --- .../consumers/select_consumer.py | 19 ++++--- tests/test_consumers.py | 49 +++++++++++++++++-- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/neon_mq_connector/consumers/select_consumer.py b/neon_mq_connector/consumers/select_consumer.py index c18f782..bd77e46 100644 --- a/neon_mq_connector/consumers/select_consumer.py +++ b/neon_mq_connector/consumers/select_consumer.py @@ -30,7 +30,7 @@ import threading import time -from asyncio import Event, run +from asyncio import Event, Lock, run from typing import Optional import pika.exceptions @@ -96,6 +96,7 @@ def __init__(self, self.exchange_reset = exchange_reset self.connection: Optional[pika.SelectConnection] = None + self._connection_lock = Lock() self.connection_failed_attempts = 0 self.max_connection_failed_attempts = 3 @@ -185,10 +186,14 @@ def on_message(self, channel, method, properties, body): self.error_func(self, e) def on_close(self, _, e): + self._consumer_started.clear() if isinstance(e, pika.exceptions.ConnectionClosed): LOG.info(f"Connection closed normally: {e}") - if not self._stopping: + else: LOG.error(f"Closing MQ connection due to exception: {e}") + if not self._stopping: + # Connection was gracefully closed by the server. Try to re-connect + LOG.info(f"Trying to reconnect after server closed connection") self.reconnect() @property @@ -200,10 +205,9 @@ def is_consuming(self) -> bool: return self._consumer_started.is_set() def run(self): - """Starting connnection io loop """ + """Starting connection io loop """ if not self.is_consuming: try: - super(SelectConsumerThread, self).run() self.connection: pika.SelectConnection = self.create_connection() self.connection.ioloop.start() except (pika.exceptions.ChannelClosed, @@ -217,6 +221,8 @@ def run(self): LOG.error(f"Failed to start io loop on consumer thread {self.name!r}: {e}") self._close_connection() self.error_func(self, e) + else: + LOG.warning("Consumer already running!") def _close_connection(self, mark_consumer_as_dead: bool = True): try: @@ -230,7 +236,7 @@ def _close_connection(self, mark_consumer_as_dead: bool = True): LOG.info(f"Channel closed") if self.connection: self.connection.ioloop.stop() - self.connection = None + # self.connection = None except Exception as e: LOG.error(f"Failed to close connection for Consumer {self.name!r}: {e}") self._is_consuming = False @@ -240,8 +246,9 @@ def _close_connection(self, mark_consumer_as_dead: bool = True): else: self._stopping = False - def reconnect(self, wait_interval: int = 1): + def reconnect(self, wait_interval: int = 5): self._close_connection(mark_consumer_as_dead=False) + # TODO: Find a better way to wait for shutdown/server restart time.sleep(wait_interval) self.run() diff --git a/tests/test_consumers.py b/tests/test_consumers.py index c034e47..8f38eb8 100644 --- a/tests/test_consumers.py +++ b/tests/test_consumers.py @@ -25,14 +25,12 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import asyncio -from time import sleep -from unittest.mock import Mock import pytest +from time import sleep +from unittest.mock import Mock from unittest import TestCase - from pika.connection import ConnectionParameters from pika.credentials import PlainCredentials from pika.exchange_type import ExchangeType @@ -171,3 +169,46 @@ def test_select_consumer_thread(self): self.assertFalse(test_thread.is_consuming) self.assertFalse(test_thread.is_consumer_alive) test_thread.on_close.assert_not_called() + + def test_handle_reconnection(self): + from neon_mq_connector.consumers.select_consumer import SelectConsumerThread + connection_params = ConnectionParameters(host='localhost', + port=self.rmq_instance.port, + virtual_host="/neon_testing", + credentials=PlainCredentials( + "test_user", + "test_password")) + queue = "test_q" + callback = Mock() + error = Mock() + + # Valid thread + test_thread = SelectConsumerThread(connection_params, queue, callback, + error) + test_thread.on_connected = Mock(side_effect=test_thread.on_connected) + test_thread.on_channel_open = Mock(side_effect=test_thread.on_channel_open) + test_thread.on_close = Mock(side_effect=test_thread.on_close) + + test_thread.start() + while not test_thread.is_consuming: + sleep(0.1) + + test_thread.on_connected.assert_called_once() + test_thread.on_channel_open.assert_called_once() + test_thread.on_close.assert_not_called() + + self.rmq_instance.stop() + test_thread.on_close.assert_called_once() + self.assertFalse(test_thread.is_consuming) + self.assertTrue(test_thread.is_consumer_alive) + + self.rmq_instance.start() + # TODO: Wait for re-connection + while not test_thread.is_consuming: + sleep(0.1) + self.assertTrue(test_thread.is_consuming) + self.assertTrue(test_thread.is_consumer_alive) + + test_thread.join(30) + self.assertFalse(test_thread.is_consuming) + self.assertFalse(test_thread.is_consumer_alive) \ No newline at end of file