diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6237716e..ecec1b3e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,6 +14,17 @@ jobs: build: runs-on: ubuntu-latest name: Python${{ matrix.python-version }}/Django${{ matrix.django-version }} + + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: django_rq_test_db + ports: + - 5432:5432 + strategy: matrix: python-version: ["3.10", "3.11", "3.12"] @@ -34,8 +45,8 @@ jobs: run: | python -m pip install --upgrade pip pip install django==${{ matrix.django-version }} - pip install redis django-redis rq sentry-sdk rq-scheduler + pip install redis django-redis rq sentry-sdk rq-scheduler psycopg2-binary - name: Run Test run: | - `which django-admin` test django_rq --settings=django_rq.tests.settings --pythonpath=. + `which django-admin` test django_rq --settings=django_rq.tests.settings --pythonpath=. --noinput diff --git a/django_rq/management/commands/rqworker-pool.py b/django_rq/management/commands/rqworker-pool.py index a7329801..c8b2a4a4 100644 --- a/django_rq/management/commands/rqworker-pool.py +++ b/django_rq/management/commands/rqworker-pool.py @@ -1,16 +1,17 @@ +import multiprocessing as mp import os import sys -from rq.serializers import resolve_serializer -from rq.worker_pool import WorkerPool from rq.logutils import setup_loghandlers +from rq.serializers import resolve_serializer from django.core.management.base import BaseCommand from ...jobs import get_job_class -from ...utils import configure_sentry +from ...utils import configure_sentry, reset_db_connections from ...queues import get_queues from ...workers import get_worker_class +from ...worker_pool import DjangoWorkerPool class Command(BaseCommand): @@ -89,7 +90,7 @@ def handle(self, *args, **options): worker_class = get_worker_class(options.get('worker_class', None)) serializer = resolve_serializer(options['serializer']) - pool = WorkerPool( + pool = DjangoWorkerPool( queues=queues, connection=queues[0].connection, num_workers=options['num_workers'], @@ -97,4 +98,7 @@ def handle(self, *args, **options): worker_class=worker_class, job_class=job_class, ) + # Close any opened DB connection before any fork + reset_db_connections() + mp.set_start_method('fork', force=True) pool.start(burst=options.get('burst', False), logging_level=logging_level) diff --git a/django_rq/tests/settings.py b/django_rq/tests/settings.py index c54273f0..430f64a7 100644 --- a/django_rq/tests/settings.py +++ b/django_rq/tests/settings.py @@ -39,8 +39,15 @@ DATABASES = { 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': ':memory:', + 'ENGINE': 'django.db.backends.postgresql_psycopg2', + 'NAME': 'django_rq_db', + 'USER': 'postgres', + 'PASSWORD': 'postgres', + 'HOST': 'localhost', + 'PORT': '5432', + 'TEST': { + 'NAME': 'django_rq_test_db', + } }, } diff --git a/django_rq/tests/tests.py b/django_rq/tests/tests.py index 50f2733d..164079ed 100644 --- a/django_rq/tests/tests.py +++ b/django_rq/tests/tests.py @@ -1,5 +1,6 @@ -import sys import datetime +import multiprocessing +import sys import time from unittest import skipIf, mock from unittest.mock import patch, PropertyMock, MagicMock @@ -37,6 +38,8 @@ from django_rq.utils import get_jobs, get_statistics, get_scheduler_pid from django_rq.workers import get_worker, get_worker_class +from .utils import query_user + try: from rq_scheduler import Scheduler from ..queues import get_scheduler @@ -303,6 +306,18 @@ def test_pass_queue_via_commandline_args(self): self.assertTrue(job['job'].is_finished) self.assertIn(job['job'].id, job['finished_job_registry'].get_job_ids()) + def test_rqworker_pool_process_start_method(self) -> None: + for start_method in ['spawn', 'fork']: + with mock.patch.object(multiprocessing, 'get_start_method', return_value=start_method): + queue_name = 'django_rq_test' + queue = get_queue(queue_name) + job = queue.enqueue(query_user) + finished_job_registry = FinishedJobRegistry(queue.name, queue.connection) + call_command('rqworker-pool', queue_name, burst=True) + + self.assertTrue(job.is_finished) + self.assertIn(job.id, finished_job_registry.get_job_ids()) + def test_configure_sentry(self): rqworker.configure_sentry('https://1@sentry.io/1') self.mock_sdk.init.assert_called_once_with( diff --git a/django_rq/tests/utils.py b/django_rq/tests/utils.py index afe4df2a..f2091f58 100644 --- a/django_rq/tests/utils.py +++ b/django_rq/tests/utils.py @@ -1,7 +1,11 @@ +from typing import Optional + from django_rq.queues import get_connection, get_queue_by_index +from django.contrib.auth.models import User + -def get_queue_index(name='default'): +def get_queue_index(name: str = 'default') -> int: """ Returns the position of Queue for the named queue in QUEUES_LIST """ @@ -17,3 +21,11 @@ def get_queue_index(name='default'): queue_index = i break return queue_index + + +def query_user() -> Optional[User]: + try: + return User.objects.first() + except Exception as e: + print('Exception caught when querying user: ', e) + raise e diff --git a/django_rq/worker_pool.py b/django_rq/worker_pool.py new file mode 100644 index 00000000..68876cb4 --- /dev/null +++ b/django_rq/worker_pool.py @@ -0,0 +1,38 @@ +import django +from multiprocessing import Process, get_start_method +from typing import Any + +from rq.worker_pool import WorkerPool, run_worker + + +class DjangoWorkerPool(WorkerPool): + def get_worker_process( + self, + name: str, + burst: bool, + _sleep: float = 0, + logging_level: str = "INFO", + ) -> Process: + """Returns the worker process""" + return Process( + target=run_django_worker, + args=(name, self._queue_names, self._connection_class, self._pool_class, self._pool_kwargs), + kwargs={ + '_sleep': _sleep, + 'burst': burst, + 'logging_level': logging_level, + 'worker_class': self.worker_class, + 'job_class': self.job_class, + 'serializer': self.serializer, + }, + name=f'Worker {name} (WorkerPool {self.name})', + ) + + +def run_django_worker(*args: Any, **kwargs: Any) -> None: + # multiprocessing library default process start method may be + # `spawn` or `fork` depending on the host OS + if get_start_method() == 'spawn': + django.setup() + + run_worker(*args, **kwargs)