Skip to content

Commit

Permalink
Fix SelectConsumer reconnection with unit test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Jan 10, 2025
1 parent 989429a commit d6ca865
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
19 changes: 13 additions & 6 deletions neon_mq_connector/consumers/select_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import threading
import time

from asyncio import Event, run
from asyncio import Event, Lock, run
from typing import Optional

import pika.exceptions
Expand Down Expand Up @@ -96,6 +96,7 @@ def __init__(self,
self.exchange_reset = exchange_reset

self.connection: Optional[pika.SelectConnection] = None
self._connection_lock = Lock()
self.connection_failed_attempts = 0
self.max_connection_failed_attempts = 3

Expand Down Expand Up @@ -185,10 +186,14 @@ def on_message(self, channel, method, properties, body):
self.error_func(self, e)

def on_close(self, _, e):
self._consumer_started.clear()
if isinstance(e, pika.exceptions.ConnectionClosed):
LOG.info(f"Connection closed normally: {e}")
if not self._stopping:
else:
LOG.error(f"Closing MQ connection due to exception: {e}")
if not self._stopping:
# Connection was gracefully closed by the server. Try to re-connect
LOG.info(f"Trying to reconnect after server closed connection")
self.reconnect()

@property
Expand All @@ -200,10 +205,9 @@ def is_consuming(self) -> bool:
return self._consumer_started.is_set()

def run(self):
"""Starting connnection io loop """
"""Starting connection io loop """
if not self.is_consuming:
try:
super(SelectConsumerThread, self).run()
self.connection: pika.SelectConnection = self.create_connection()
self.connection.ioloop.start()
except (pika.exceptions.ChannelClosed,
Expand All @@ -217,6 +221,8 @@ def run(self):
LOG.error(f"Failed to start io loop on consumer thread {self.name!r}: {e}")
self._close_connection()
self.error_func(self, e)
else:
LOG.warning("Consumer already running!")

def _close_connection(self, mark_consumer_as_dead: bool = True):
try:
Expand All @@ -230,7 +236,7 @@ def _close_connection(self, mark_consumer_as_dead: bool = True):
LOG.info(f"Channel closed")
if self.connection:
self.connection.ioloop.stop()
self.connection = None
# self.connection = None
except Exception as e:
LOG.error(f"Failed to close connection for Consumer {self.name!r}: {e}")
self._is_consuming = False
Expand All @@ -240,8 +246,9 @@ def _close_connection(self, mark_consumer_as_dead: bool = True):
else:
self._stopping = False

def reconnect(self, wait_interval: int = 1):
def reconnect(self, wait_interval: int = 5):
self._close_connection(mark_consumer_as_dead=False)
# TODO: Find a better way to wait for shutdown/server restart
time.sleep(wait_interval)
self.run()

Expand Down
49 changes: 45 additions & 4 deletions tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio
from time import sleep
from unittest.mock import Mock

import pytest

from time import sleep
from unittest.mock import Mock
from unittest import TestCase

from pika.connection import ConnectionParameters
from pika.credentials import PlainCredentials
from pika.exchange_type import ExchangeType
Expand Down Expand Up @@ -171,3 +169,46 @@ def test_select_consumer_thread(self):
self.assertFalse(test_thread.is_consuming)
self.assertFalse(test_thread.is_consumer_alive)
test_thread.on_close.assert_not_called()

def test_handle_reconnection(self):
from neon_mq_connector.consumers.select_consumer import SelectConsumerThread
connection_params = ConnectionParameters(host='localhost',
port=self.rmq_instance.port,
virtual_host="/neon_testing",
credentials=PlainCredentials(
"test_user",
"test_password"))
queue = "test_q"
callback = Mock()
error = Mock()

# Valid thread
test_thread = SelectConsumerThread(connection_params, queue, callback,
error)
test_thread.on_connected = Mock(side_effect=test_thread.on_connected)
test_thread.on_channel_open = Mock(side_effect=test_thread.on_channel_open)
test_thread.on_close = Mock(side_effect=test_thread.on_close)

test_thread.start()
while not test_thread.is_consuming:
sleep(0.1)

test_thread.on_connected.assert_called_once()
test_thread.on_channel_open.assert_called_once()
test_thread.on_close.assert_not_called()

self.rmq_instance.stop()
test_thread.on_close.assert_called_once()
self.assertFalse(test_thread.is_consuming)
self.assertTrue(test_thread.is_consumer_alive)

self.rmq_instance.start()
# TODO: Wait for re-connection
while not test_thread.is_consuming:
sleep(0.1)
self.assertTrue(test_thread.is_consuming)
self.assertTrue(test_thread.is_consumer_alive)

test_thread.join(30)
self.assertFalse(test_thread.is_consuming)
self.assertFalse(test_thread.is_consumer_alive)

0 comments on commit d6ca865

Please sign in to comment.