diff --git a/connectors/protocol/connectors.py b/connectors/protocol/connectors.py index e1e753ccb..c69cd9dea 100644 --- a/connectors/protocol/connectors.py +++ b/connectors/protocol/connectors.py @@ -819,6 +819,10 @@ async def error(self, error): } await self.index.update(doc_id=self.id, doc=doc) + async def connected(self): + doc = {"status": Status.CONNECTED.value, "error": None} + await self.index.update(doc_id=self.id, doc=doc) + async def sync_done(self, job, cursor=None): job_status = JobStatus.ERROR if job is None else job.status job_error = JOB_NOT_FOUND_ERROR if job is None else job.error diff --git a/connectors/services/job_scheduling.py b/connectors/services/job_scheduling.py index c4c135ea7..250144871 100644 --- a/connectors/services/job_scheduling.py +++ b/connectors/services/job_scheduling.py @@ -106,6 +106,11 @@ async def _schedule(self, connector): if connector.features.sync_rules_enabled(): await connector.validate_filtering(validator=data_source) + + self.logger.info( + "Connector is configured correctly and can reach the data source" + ) + await connector.connected() except Exception as e: connector.log_error(e, exc_info=True) await connector.error(e) diff --git a/tests/commons.py b/tests/commons.py index 18629b508..b3f17e085 100644 --- a/tests/commons.py +++ b/tests/commons.py @@ -15,18 +15,28 @@ class AsyncIterator: Async documents generator fake class, which records the args and kwargs it was called with. """ - def __init__(self, items): + def __init__(self, items, reusable=False): + """ + AsyncIterator is a test-only abstraction to mock async iterables. + By default it's usable only once: once iterated over, he iterator will not + iterate again any more. + If reusable is True, then iterator can be re-used, but only if it's used by a single coroutine. + If AsyncIterator is used in several coroutines, it'll not work correctly at all + """ self.items = items self.call_args = [] self.call_kwargs = [] self.i = 0 self.call_count = 0 + self.reusable = reusable def __aiter__(self): return self async def __anext__(self): if self.i >= len(self.items): + if self.reusable: + self.i = 0 raise StopAsyncIteration item = self.items[self.i] diff --git a/tests/services/test_job_scheduling.py b/tests/services/test_job_scheduling.py index 5cd9e5586..1d2a66f90 100644 --- a/tests/services/test_job_scheduling.py +++ b/tests/services/test_job_scheduling.py @@ -93,6 +93,7 @@ def mock_connector( connector.heartbeat = AsyncMock() connector.reload = AsyncMock() connector.error = AsyncMock() + connector.connected = AsyncMock() connector.update_last_sync_scheduled_at_by_job_type = AsyncMock() return connector @@ -377,6 +378,45 @@ def _source_klass(config): connector.error.assert_awaited_with(actual_error) +@pytest.mark.asyncio +@patch("connectors.services.job_scheduling.get_source_klass") +async def test_run_when_connector_failed_validation_then_succeeded( + get_source_klass_mock, connector_index_mock, set_env +): + error_message = "Something invalid is in config!" + actual_error = Exception(error_message) + + data_source_mock = Mock() + + def _source_klass(config): + return data_source_mock + + def _error_once(): + data_source_mock.validate_config.reset_mock(side_effect=True) + raise actual_error + + get_source_klass_mock.return_value = _source_klass + + data_source_mock.validate_config_fields = Mock() + data_source_mock.validate_config = AsyncMock(side_effect=_error_once) + data_source_mock.ping = AsyncMock() + data_source_mock.close = AsyncMock() + + connector = mock_connector(next_sync=datetime.now(timezone.utc)) + connector_index_mock.supported_connectors.return_value = AsyncIterator( + [connector], reusable=True + ) + await create_and_run_service(JobSchedulingService, stop_after=0.15) + + data_source_mock.validate_config_fields.assert_called() + data_source_mock.validate_config.assert_awaited() + data_source_mock.ping.assert_awaited() + data_source_mock.close.assert_awaited() + + connector.error.assert_awaited_with(actual_error) + connector.connected.assert_awaited() + + @pytest.mark.asyncio @patch("connectors.services.job_scheduling.get_source_klass") async def test_run_when_connector_ping_fails(