From ce61f0539b95c268eb94914783198d6652f5a8d0 Mon Sep 17 00:00:00 2001 From: Jedr Blaszyk Date: Mon, 7 Oct 2024 09:39:58 +0200 Subject: [PATCH] Create connector record in agent component (#2861) --- connectors/agent/component.py | 4 +- connectors/agent/config.py | 86 ++++---- connectors/agent/connector_record_manager.py | 104 ++++++++++ connectors/agent/protocol.py | 100 +++++++-- connectors/es/index.py | 13 ++ connectors/protocol/connectors.py | 10 + tests/agent/test_agent_config.py | 205 ++++++++++++++++--- tests/agent/test_connector_record_manager.py | 153 ++++++++++++++ tests/agent/test_protocol.py | 107 +++++++--- 9 files changed, 681 insertions(+), 101 deletions(-) create mode 100644 connectors/agent/connector_record_manager.py create mode 100644 tests/agent/test_connector_record_manager.py diff --git a/connectors/agent/component.py b/connectors/agent/component.py index 92356f27e..8c65fb0d2 100644 --- a/connectors/agent/component.py +++ b/connectors/agent/component.py @@ -59,7 +59,9 @@ async def run(self): action_handler = ConnectorActionHandler() self.connector_service_manager = ConnectorServiceManager(self.config_wrapper) checkin_handler = ConnectorCheckinHandler( - client, self.config_wrapper, self.connector_service_manager + client, + self.config_wrapper, + self.connector_service_manager, ) self.multi_service = MultiService( diff --git a/connectors/agent/config.py b/connectors/agent/config.py index 80bccf946..33d4eb7a7 100644 --- a/connectors/agent/config.py +++ b/connectors/agent/config.py @@ -24,57 +24,30 @@ class ConnectorsAgentConfigurationWrapper: def __init__(self): """Inits the class. - There's default config that allows us to run connectors natively (see _force_allow_native flag), - when final configuration is reported these defaults will be merged with defaults from Connectors - Service config and specific config coming from Agent. + There's default config that allows us to run connectors service. When final + configuration is reported these defaults will be merged with defaults from + Connectors Service config and specific config coming from Agent. """ self._default_config = { "service": { "log_level": "INFO", - "_use_native_connector_api_keys": False, }, - "_force_allow_native": True, - "native_service_types": [ - "azure_blob_storage", - "box", - "confluence", - "dropbox", - "github", - "gmail", - "google_cloud_storage", - "google_drive", - "jira", - "mongodb", - "mssql", - "mysql", - "notion", - "onedrive", - "oracle", - "outlook", - "network_drive", - "postgresql", - "s3", - "salesforce", - "servicenow", - "sharepoint_online", - "slack", - "microsoft_teams", - "zoom", - ], + "connectors": [], } self.specific_config = {} - def try_update(self, unit): + def try_update(self, connector_id, service_type, output_unit): """Try update the configuration and see if it changed. - This method takes the check-in event coming from Agent and checks if config needs an update. + This method takes the check-in event data (connector_id, service_type and output) coming + from Agent and checks if config needs an update. If update is needed, configuration is updated and method returns True. If no update is needed the method returns False. """ - source = unit.config.source + source = output_unit.config.source # TODO: find a good link to what this object is. has_hosts = source.fields.get("hosts") @@ -83,9 +56,17 @@ def try_update(self, unit): assumed_configuration = {} + # Connector-related + assumed_configuration["connectors"] = [ + { + "connector_id": connector_id, + "service_type": service_type, + } + ] + # Log-related assumed_configuration["service"] = {} - assumed_configuration["service"]["log_level"] = unit.log_level + assumed_configuration["service"]["log_level"] = output_unit.log_level # Auth-related if has_hosts and (has_api_key or has_basic_auth): @@ -154,6 +135,32 @@ def _elasticsearch_config_changed(): "elasticsearch" ) + def _connectors_config_changes(): + current_connectors = current_config.get("connectors", []) + new_connectors = new_config.get("connectors", []) + + if len(current_connectors) != len(new_connectors): + return True + + current_connectors_dict = { + connector["connector_id"]: connector for connector in current_connectors + } + new_connectors_dict = { + connector["connector_id"]: connector for connector in new_connectors + } + + if set(current_connectors_dict.keys()) != set(new_connectors_dict.keys()): + return True + + for connector_id in current_connectors_dict: + current_connector = current_connectors_dict[connector_id] + new_connector = new_connectors_dict[connector_id] + + if current_connector != new_connector: + return True + + return False + if _log_level_changed(): logger.debug("log_level changed") return True @@ -162,6 +169,10 @@ def _elasticsearch_config_changed(): logger.debug("elasticsearch changed") return True + if _connectors_config_changes(): + logger.debug("connectors changed") + return True + return False def get(self): @@ -182,3 +193,6 @@ def get(self): configuration = dict(add_defaults(config)) return configuration + + def get_specific_config(self): + return self.specific_config diff --git a/connectors/agent/connector_record_manager.py b/connectors/agent/connector_record_manager.py new file mode 100644 index 000000000..039784e62 --- /dev/null +++ b/connectors/agent/connector_record_manager.py @@ -0,0 +1,104 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +import secrets +import string + +from connectors.agent.logger import get_logger +from connectors.es.index import DocumentNotFoundError +from connectors.protocol import ConnectorIndex + +logger = get_logger("agent_connector_record_manager") + + +class ConnectorRecordManager: + """ + Manages connector records in Elasticsearch, ensuring that connectors tied to agent components + exist in the connector index. It creates the connector record if necessary. + """ + + def __init__(self): + self.connector_index = None + + async def ensure_connector_records_exist(self, agent_config, connector_name=None): + """ + Ensure that connector records exist for all connectors specified in the agent configuration. + + If the connector record with a given ID doesn't exist, create a new one. + """ + + if not self._agent_config_ready(agent_config): + return + + # Initialize the ES client if it's not already initialized + if not self.connector_index: + self.connector_index = ConnectorIndex(agent_config.get("elasticsearch")) + + for connector_config in self._get_connectors(agent_config): + connector_id, service_type = ( + connector_config["connector_id"], + connector_config["service_type"], + ) + + if not connector_name: + random_connector_name_id = self._generate_random_connector_name_id( + length=4 + ) + connector_name = f"[Elastic-managed] {service_type} connector {random_connector_name_id}" + + if not await self._connector_exists(connector_id): + try: + await self.connector_index.connector_put( + connector_id=connector_id, + service_type=service_type, + connector_name=connector_name, + ) + logger.info(f"Created connector record for {connector_id}") + except Exception as e: + logger.error( + f"Failed to create connector record for {connector_id}: {e}" + ) + raise e + + def _agent_config_ready(self, agent_config): + """ + Validates the agent configuration to check if all info is present to create a connector record. + """ + connectors = agent_config.get("connectors") + if connectors is None or len(connectors) == 0: + return False + + for connector in connectors: + if "connector_id" not in connector or "service_type" not in connector: + return False + + elasticsearch_config = agent_config.get("elasticsearch") + if not elasticsearch_config: + return False + + if "host" not in elasticsearch_config or "api_key" not in elasticsearch_config: + return False + + return True + + async def _connector_exists(self, connector_id): + try: + doc = await self.connector_index.fetch_by_id(connector_id) + return doc is not None + except DocumentNotFoundError: + return False + except Exception as e: + logger.error( + f"Error while checking existence of connector '{connector_id}': {e}" + ) + raise e + + def _get_connectors(self, agent_config): + return agent_config.get("connectors") + + def _generate_random_connector_name_id(self, length=4): + return "".join( + secrets.choice(string.ascii_letters + string.digits) for _ in range(length) + ) diff --git a/connectors/agent/protocol.py b/connectors/agent/protocol.py index 1fc9b2424..726f1c73a 100644 --- a/connectors/agent/protocol.py +++ b/connectors/agent/protocol.py @@ -8,11 +8,16 @@ from elastic_agent_client.handler.action import BaseActionHandler from elastic_agent_client.handler.checkin import BaseCheckinHandler +from connectors.agent.connector_record_manager import ConnectorRecordManager from connectors.agent.logger import get_logger logger = get_logger("protocol") +CONNECTORS_INPUT_TYPE = "connectors-py" +ELASTICSEARCH_OUTPUT_TYPE = "elasticsearch" + + class ConnectorActionHandler(BaseActionHandler): """Class handling Agent actions. @@ -38,9 +43,16 @@ class ConnectorCheckinHandler(BaseCheckinHandler): This class reads the events, sees if there's a reported change to connector-specific settings, tries to update the configuration and, if the configuration is updated, restarts the Connectors Service. + + If the connector document with given ID doesn't exist, it creates a new one. """ - def __init__(self, client, agent_connectors_config_wrapper, service_manager): + def __init__( + self, + client, + agent_connectors_config_wrapper, + service_manager, + ): """Inits the class. Initing this class should not produce side-effects. @@ -48,6 +60,7 @@ def __init__(self, client, agent_connectors_config_wrapper, service_manager): super().__init__(client) self.agent_connectors_config_wrapper = agent_connectors_config_wrapper self.service_manager = service_manager + self.connector_record_manager = ConnectorRecordManager() async def apply_from_client(self): """Implementation of BaseCheckinHandler.apply_from_client @@ -73,26 +86,83 @@ async def apply_from_client(self): # Filter Elasticsearch outputs from the available outputs elasticsearch_outputs = [ - output - for output in outputs - if output.config and output.config.type == "elasticsearch" + output_unit + for output_unit in outputs + if output_unit.config + and output_unit.config.type == ELASTICSEARCH_OUTPUT_TYPE ] - if elasticsearch_outputs: - if len(elasticsearch_outputs) > 1: + inputs = [ + unit + for unit in self.client.units + if unit.unit_type == proto.UnitType.INPUT + ] + + # Ensure only the single valid connector input is selected from the inputs + connector_inputs = [ + input_unit + for input_unit in inputs + if input_unit.config.type == CONNECTORS_INPUT_TYPE + ] + + if connector_inputs: + if len(connector_inputs) > 1: logger.warning( - "Multiple Elasticsearch outputs detected. The first ES output defined in the agent policy will be used." + "Multiple connector inputs detected. The first connector input defined in the agent policy will be used." ) - logger.debug("Elasticsearch outputs found.") + logger.debug("Connector input found.") + + connector_input = connector_inputs[0] + + def _extract_unit_config_value(unit, field_name): + field_value = unit.config.source.fields.get(field_name) + return field_value.string_value if field_value else None - configuration_changed = self.agent_connectors_config_wrapper.try_update( - elasticsearch_outputs[0] + service_type = _extract_unit_config_value( + connector_input, "service_type" ) - if configuration_changed: - logger.info( - "Connector service manager config updated. Restarting service manager." + connector_name = _extract_unit_config_value( + connector_input, "connector_name" + ) + connector_id = _extract_unit_config_value(connector_input, "id") + + logger.info( + f"Connector input found. Service type: {service_type}, Connector ID: {connector_id}, Connector Name: {connector_name}" + ) + + if elasticsearch_outputs: + if len(elasticsearch_outputs) > 1: + logger.warning( + "Multiple Elasticsearch outputs detected. The first ES output defined in the agent policy will be used." + ) + + logger.debug("Elasticsearch outputs found.") + + elasticsearch_output = elasticsearch_outputs[0] + + configuration_changed = ( + self.agent_connectors_config_wrapper.try_update( + connector_id=connector_id, + service_type=service_type, + output_unit=elasticsearch_output, + ) ) - self.service_manager.restart() + + # After updating the configuration, ensure all connector records exist in the connector index + await self.connector_record_manager.ensure_connector_records_exist( + agent_config=self.agent_connectors_config_wrapper.get_specific_config(), + connector_name=connector_name, + ) + + if configuration_changed: + logger.info( + "Connector service manager config updated. Restarting service manager." + ) + self.service_manager.restart() + else: + logger.debug("No changes to connectors config") else: - logger.debug("No changes to connectors config") + logger.warning("No Elasticsearch output found") + else: + logger.warning("No connector integration input found") diff --git a/connectors/es/index.py b/connectors/es/index.py index e27858df9..610fbaf89 100644 --- a/connectors/es/index.py +++ b/connectors/es/index.py @@ -98,6 +98,19 @@ async def connector_check_in(self, connector_id): partial(self._api_wrapper.connector_check_in, connector_id) ) + async def connector_put( + self, connector_id, service_type, connector_name, index_name + ): + return await self._retrier.execute_with_retry( + partial( + self.client.connector.put, + connector_id=connector_id, + service_type=service_type, + name=connector_name, + index_name=index_name, + ) + ) + async def connector_update_filtering_draft_validation( self, connector_id, validation_result ): diff --git a/connectors/protocol/connectors.py b/connectors/protocol/connectors.py index 5a05bf385..fb36c1e65 100644 --- a/connectors/protocol/connectors.py +++ b/connectors/protocol/connectors.py @@ -162,6 +162,16 @@ async def heartbeat(self, doc_id): else: await self.update(doc_id=doc_id, doc={"last_seen": iso_utc()}) + async def connector_put( + self, connector_id, service_type, connector_name=None, index_name=None + ): + await self.api.connector_put( + connector_id=connector_id, + service_type=service_type, + connector_name=connector_name, + index_name=index_name, + ) + async def supported_connectors(self, native_service_types=None, connector_ids=None): if native_service_types is None: native_service_types = [] diff --git a/tests/agent/test_agent_config.py b/tests/agent/test_agent_config.py index 3da5a4530..dc79bca25 100644 --- a/tests/agent/test_agent_config.py +++ b/tests/agent/test_agent_config.py @@ -7,6 +7,9 @@ from connectors.agent.config import ConnectorsAgentConfigurationWrapper +CONNECTOR_ID = "test-connector" +SERVICE_TYPE = "test-service-type" + def prepare_unit_mock(fields, log_level): if not fields: @@ -22,22 +25,48 @@ def prepare_unit_mock(fields, log_level): return unit_mock -def test_try_update_without_auth_data(): +def prepare_config_wrapper(): + # populate with connectors list, so that we can test for changes in other config properties config_wrapper = ConnectorsAgentConfigurationWrapper() + initial_config_unit = prepare_unit_mock({}, None) + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=initial_config_unit, + ) + return config_wrapper + + +def test_try_update_without_auth_data(): + config_wrapper = prepare_config_wrapper() unit_mock = prepare_unit_mock({}, None) - assert config_wrapper.try_update(unit_mock) is False + assert ( + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=unit_mock, + ) + is False + ) def test_try_update_with_api_key_auth_data(): hosts = ["https://localhost:9200"] api_key = "lemme_in" - config_wrapper = ConnectorsAgentConfigurationWrapper() + config_wrapper = prepare_config_wrapper() unit_mock = prepare_unit_mock({"hosts": hosts, "api_key": api_key}, None) - assert config_wrapper.try_update(unit_mock) is True + assert ( + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=unit_mock, + ) + is True + ) assert config_wrapper.get()["elasticsearch"]["host"] == hosts[0] assert config_wrapper.get()["elasticsearch"]["api_key"] == api_key @@ -47,10 +76,17 @@ def test_try_update_with_non_encoded_api_key_auth_data(): api_key = "something:else" encoded = "c29tZXRoaW5nOmVsc2U=" - config_wrapper = ConnectorsAgentConfigurationWrapper() + config_wrapper = prepare_config_wrapper() source_mock = prepare_unit_mock({"hosts": hosts, "api_key": api_key}, None) - assert config_wrapper.try_update(source_mock) is True + assert ( + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=source_mock, + ) + is True + ) assert config_wrapper.get()["elasticsearch"]["host"] == hosts[0] assert config_wrapper.get()["elasticsearch"]["api_key"] == encoded @@ -60,12 +96,19 @@ def test_try_update_with_basic_auth_auth_data(): username = "elastic" password = "hold the door" - config_wrapper = ConnectorsAgentConfigurationWrapper() + config_wrapper = prepare_config_wrapper() unit_mock = prepare_unit_mock( {"hosts": hosts, "username": username, "password": password}, None ) - assert config_wrapper.try_update(unit_mock) is True + assert ( + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=unit_mock, + ) + is True + ) assert config_wrapper.get()["elasticsearch"]["host"] == hosts[0] assert config_wrapper.get()["elasticsearch"]["username"] == username assert config_wrapper.get()["elasticsearch"]["password"] == password @@ -76,7 +119,7 @@ def test_try_update_multiple_times_does_not_reset_config_values(): api_key = "lemme_in" log_level = "DEBUG" - config_wrapper = ConnectorsAgentConfigurationWrapper() + config_wrapper = prepare_config_wrapper() # First unit comes with elasticsearch data first_unit_mock = prepare_unit_mock({"hosts": hosts, "api_key": api_key}, None) @@ -84,8 +127,22 @@ def test_try_update_multiple_times_does_not_reset_config_values(): # Second unit comes only with a log_level second_unit_mock = prepare_unit_mock({}, log_level) - assert config_wrapper.try_update(first_unit_mock) is True - assert config_wrapper.try_update(second_unit_mock) is True + assert ( + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=first_unit_mock, + ) + is True + ) + assert ( + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=second_unit_mock, + ) + is True + ) assert config_wrapper.get()["elasticsearch"]["host"] == hosts[0] assert config_wrapper.get()["elasticsearch"]["api_key"] == api_key @@ -96,9 +153,12 @@ def test_config_changed_when_new_variables_are_passed(): hosts = ["https://localhost:9200"] api_key = "lemme_in_lalala" - new_config = {"elasticsearch": {"hosts": hosts, "api_key": api_key}} + new_config = { + "connectors": [{"connector_id": CONNECTOR_ID, "service_type": SERVICE_TYPE}], + "elasticsearch": {"hosts": hosts, "api_key": api_key}, + } - config_wrapper = ConnectorsAgentConfigurationWrapper() + config_wrapper = prepare_config_wrapper() assert config_wrapper.config_changed(new_config) is True @@ -114,10 +174,17 @@ def test_config_changed_when_elasticsearch_config_changed(): "password": "hey-im-a-password", } } - new_config = {"elasticsearch": {"hosts": hosts, "api_key": api_key}} + new_config = { + "connectors": [{"connector_id": CONNECTOR_ID, "service_type": SERVICE_TYPE}], + "elasticsearch": {"hosts": hosts, "api_key": api_key}, + } - config_wrapper = ConnectorsAgentConfigurationWrapper() - config_wrapper.try_update(prepare_unit_mock(starting_config, None)) + config_wrapper = prepare_config_wrapper() + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=prepare_unit_mock(starting_config, None), + ) assert config_wrapper.config_changed(new_config) is True @@ -126,27 +193,117 @@ def test_config_changed_when_elasticsearch_config_did_not_change(): hosts = ["https://localhost:9200"] api_key = "lemme_in_lalala" - new_config = {"elasticsearch": {"hosts": hosts, "api_key": api_key}} + new_config = { + "connectors": [{"connector_id": CONNECTOR_ID, "service_type": SERVICE_TYPE}], + "elasticsearch": {"hosts": hosts, "api_key": api_key}, + } - config_wrapper = ConnectorsAgentConfigurationWrapper() - config_wrapper.try_update(prepare_unit_mock(new_config, None)) + config_wrapper = prepare_config_wrapper() + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=prepare_unit_mock(new_config, None), + ) assert config_wrapper.config_changed(new_config) is True def test_config_changed_when_log_level_config_changed(): - config_wrapper = ConnectorsAgentConfigurationWrapper() - config_wrapper.try_update(prepare_unit_mock({}, "INFO")) + config_wrapper = prepare_config_wrapper() + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=prepare_unit_mock({}, "INFO"), + ) - new_config = {"service": {"log_level": "DEBUG"}} + new_config = { + "connectors": [{"connector_id": CONNECTOR_ID, "service_type": SERVICE_TYPE}], + "service": {"log_level": "DEBUG"}, + } assert config_wrapper.config_changed(new_config) is True def test_config_changed_when_log_level_config_did_not_change(): + config_wrapper = prepare_config_wrapper() + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=prepare_unit_mock({}, "INFO"), + ) + + new_config = { + "connectors": [{"connector_id": CONNECTOR_ID, "service_type": SERVICE_TYPE}], + "service": {"log_level": "INFO"}, + } + + assert config_wrapper.config_changed(new_config) is False + + +def test_config_changed_when_connectors_changed(): + config_wrapper = ConnectorsAgentConfigurationWrapper() + + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=prepare_unit_mock({}, None), + ) + + new_config = { + "connectors": [ + {"connector_id": "test-connector-2", "service_type": "test-service-type-2"} + ], + } + + assert config_wrapper.config_changed(new_config) is True + + +def test_config_changed_when_connectors_list_is_cleared(): + config_wrapper = ConnectorsAgentConfigurationWrapper() + + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=prepare_unit_mock({}, None), + ) + + new_config = { + "connectors": [], + } + + assert config_wrapper.config_changed(new_config) is True + + +def test_config_changed_when_connectors_list_is_extended(): + config_wrapper = ConnectorsAgentConfigurationWrapper() + + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=prepare_unit_mock({}, None), + ) + + new_config = { + "connectors": [ + {"connector_id": CONNECTOR_ID, "service_type": SERVICE_TYPE}, + {"connector_id": "test-connector-2", "service_type": "test-service-type-2"}, + ], + } + + assert config_wrapper.config_changed(new_config) is True + + +def test_config_changed_when_connectors_did_not_change(): config_wrapper = ConnectorsAgentConfigurationWrapper() - config_wrapper.try_update(prepare_unit_mock({}, "INFO")) - new_config = {"service": {"log_level": "INFO"}} + config_wrapper.try_update( + connector_id=CONNECTOR_ID, + service_type=SERVICE_TYPE, + output_unit=prepare_unit_mock({}, None), + ) + + new_config = { + "connectors": [{"connector_id": CONNECTOR_ID, "service_type": SERVICE_TYPE}], + } assert config_wrapper.config_changed(new_config) is False diff --git a/tests/agent/test_connector_record_manager.py b/tests/agent/test_connector_record_manager.py new file mode 100644 index 000000000..46ae9f87c --- /dev/null +++ b/tests/agent/test_connector_record_manager.py @@ -0,0 +1,153 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from connectors.agent.connector_record_manager import ( + ConnectorRecordManager, +) +from connectors.es.index import DocumentNotFoundError +from connectors.protocol import ConnectorIndex + + +@pytest.fixture +def mock_connector_index(): + return AsyncMock(ConnectorIndex) + + +@pytest.fixture +def mock_agent_config(): + return { + "elasticsearch": {"host": "http://localhost:9200", "api_key": "dummy_key"}, + "connectors": [{"connector_id": "1", "service_type": "service1"}], + } + + +@pytest.fixture +def connector_record_manager(mock_connector_index): + manager = ConnectorRecordManager() + manager.connector_index = mock_connector_index + return manager + + +@pytest.mark.asyncio +@patch("connectors.protocol.ConnectorIndex", new_callable=AsyncMock) +async def test_ensure_connector_records_exist_creates_connectors_if_not_exist( + mock_connector_index, mock_agent_config +): + manager = ConnectorRecordManager() + manager.connector_index = mock_connector_index + mock_connector_index.fetch_by_id.side_effect = DocumentNotFoundError + mock_connector_index.connector_put = AsyncMock() + connector_ui_id = "1234" + manager._generate_random_connector_name_id = Mock(return_value=connector_ui_id) + + await manager.ensure_connector_records_exist(mock_agent_config) + assert mock_connector_index.connector_put.call_count == 1 + mock_connector_index.connector_put.assert_any_await( + connector_id="1", + service_type="service1", + connector_name=f"[Elastic-managed] service1 connector {connector_ui_id}", + ) + + +@pytest.mark.asyncio +async def test_ensure_connector_records_exist_connector_already_exists( + connector_record_manager, mock_agent_config +): + connector_record_manager._connector_exists = AsyncMock(return_value=True) + await connector_record_manager.ensure_connector_records_exist(mock_agent_config) + assert connector_record_manager.connector_index.connector_put.call_count == 0 + + +@pytest.mark.asyncio +@patch("connectors.protocol.ConnectorIndex", new_callable=AsyncMock) +async def test_ensure_connector_records_raises_on_non_404_error( + mock_connector_index, mock_agent_config +): + manager = ConnectorRecordManager() + manager.connector_index = mock_connector_index + mock_connector_index.fetch_by_id.side_effect = Exception("Unexpected error") + mock_connector_index.connector_put = AsyncMock() + + with pytest.raises(Exception, match="Unexpected error"): + await manager.ensure_connector_records_exist(mock_agent_config) + + assert mock_connector_index.connector_put.call_count == 0 + + +@pytest.mark.asyncio +async def test_ensure_connector_records_exist_agent_config_not_ready( + connector_record_manager, +): + invalid_config = {"connectors": []} + await connector_record_manager.ensure_connector_records_exist(invalid_config) + assert connector_record_manager.connector_index.connector_put.call_count == 0 + + +@pytest.mark.asyncio +async def test_ensure_connector_records_exist_exception_on_create( + connector_record_manager, mock_agent_config +): + connector_record_manager._connector_exists = AsyncMock(return_value=False) + connector_record_manager.connector_index.connector_put = AsyncMock( + side_effect=Exception("Failed to create") + ) + with pytest.raises(Exception, match="Failed to create"): + await connector_record_manager.ensure_connector_records_exist(mock_agent_config) + + +@pytest.mark.asyncio +async def test_connector_exists_returns_true_when_found(connector_record_manager): + connector_record_manager.connector_index.fetch_by_id = AsyncMock( + return_value={"id": "1"} + ) + exists = await connector_record_manager._connector_exists("1") + assert exists is True + + +@pytest.mark.asyncio +async def test_connector_exists_returns_false_when_not_found(connector_record_manager): + connector_record_manager.connector_index.fetch_by_id = AsyncMock( + side_effect=DocumentNotFoundError + ) + exists = await connector_record_manager._connector_exists("1") + assert exists is False + + +@pytest.mark.asyncio +async def test_connector_exists_raises_non_404_exception(connector_record_manager): + connector_record_manager.connector_index.fetch_by_id = AsyncMock( + side_effect=Exception("Fetch error") + ) + with pytest.raises(Exception, match="Fetch error"): + await connector_record_manager._connector_exists("1") + + +def test_agent_config_ready_with_valid_config( + connector_record_manager, mock_agent_config +): + ready = connector_record_manager._agent_config_ready(mock_agent_config) + assert ready is True + + +def test_agent_config_ready_with_invalid_config_missing_connectors( + connector_record_manager, +): + invalid_config = { + "elasticsearch": {"host": "http://localhost:9200", "api_key": "dummy_key"} + } + ready = connector_record_manager._agent_config_ready(invalid_config) + assert ready is False + + +def test_agent_config_ready_with_invalid_config_missing_elasticsearch( + connector_record_manager, +): + invalid_config = {"connectors": [{"connector_id": "1", "service_type": "service1"}]} + ready = connector_record_manager._agent_config_ready(invalid_config) + assert ready is False diff --git a/tests/agent/test_protocol.py b/tests/agent/test_protocol.py index f9e917b4e..ac78eb833 100644 --- a/tests/agent/test_protocol.py +++ b/tests/agent/test_protocol.py @@ -3,7 +3,7 @@ # or more contributor license agreements. Licensed under the Elastic License 2.0; # you may not use this file except in compliance with the Elastic License 2.0. # -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest from elastic_agent_client.client import Unit @@ -14,6 +14,34 @@ from connectors.agent.protocol import ConnectorActionHandler, ConnectorCheckinHandler +@pytest.fixture(autouse=True) +def input_mock(): + unit_mock = Mock() + unit_mock.unit_type = proto.UnitType.INPUT + + def _string_config_field_mock(value): + field = Mock() + field.string_value = value + return field + + unit_mock.config.source.fields = { + "service_type": _string_config_field_mock("test-service"), + "connector_name": _string_config_field_mock("test-connector"), + "id": _string_config_field_mock("test-id"), + } + unit_mock.config.type = "connectors-py" + return unit_mock + + +@pytest.fixture(autouse=True) +def connector_record_manager_mock(): + connector_record_manager_mock = Mock() + connector_record_manager_mock.ensure_connector_records_exist = AsyncMock( + return_value=True + ) + return connector_record_manager_mock + + class TestConnectorActionHandler: @pytest.mark.asyncio async def test_handle_action(self): @@ -25,7 +53,9 @@ async def test_handle_action(self): class TestConnectorCheckingHandler: @pytest.mark.asyncio - async def test_apply_from_client_when_no_units_received(self): + async def test_apply_from_client_when_no_units_received( + self, connector_record_manager_mock, input_mock + ): client_mock = Mock() config_wrapper_mock = Mock() service_manager_mock = Mock() @@ -33,8 +63,11 @@ async def test_apply_from_client_when_no_units_received(self): client_mock.units = [] checkin_handler = ConnectorCheckinHandler( - client_mock, config_wrapper_mock, service_manager_mock + client_mock, + config_wrapper_mock, + service_manager_mock, ) + checkin_handler.connector_record_manager = connector_record_manager_mock await checkin_handler.apply_from_client() @@ -42,18 +75,23 @@ async def test_apply_from_client_when_no_units_received(self): assert not service_manager_mock.restart.called @pytest.mark.asyncio - async def test_apply_from_client_when_units_with_no_output(self): + async def test_apply_from_client_when_units_with_no_output( + self, connector_record_manager_mock, input_mock + ): client_mock = Mock() config_wrapper_mock = Mock() service_manager_mock = Mock() unit_mock = Mock() unit_mock.unit_type = "Something else" - client_mock.units = [unit_mock] + client_mock.units = [unit_mock, input_mock] checkin_handler = ConnectorCheckinHandler( - client_mock, config_wrapper_mock, service_manager_mock + client_mock, + config_wrapper_mock, + service_manager_mock, ) + checkin_handler.connector_record_manager = connector_record_manager_mock await checkin_handler.apply_from_client() @@ -62,7 +100,7 @@ async def test_apply_from_client_when_units_with_no_output(self): @pytest.mark.asyncio async def test_apply_from_client_when_units_with_output_and_non_updating_config( - self, + self, connector_record_manager_mock, input_mock ): client_mock = Mock() config_wrapper_mock = Mock() @@ -74,11 +112,14 @@ async def test_apply_from_client_when_units_with_output_and_non_updating_config( unit_mock.unit_type = proto.UnitType.OUTPUT unit_mock.config.source = {"elasticsearch": {"api_key": 123}} - client_mock.units = [unit_mock] + client_mock.units = [unit_mock, input_mock] checkin_handler = ConnectorCheckinHandler( - client_mock, config_wrapper_mock, service_manager_mock + client_mock, + config_wrapper_mock, + service_manager_mock, ) + checkin_handler.connector_record_manager = connector_record_manager_mock await checkin_handler.apply_from_client() @@ -86,7 +127,9 @@ async def test_apply_from_client_when_units_with_output_and_non_updating_config( assert not service_manager_mock.restart.called @pytest.mark.asyncio - async def test_apply_from_client_when_units_with_output_and_updating_config(self): + async def test_apply_from_client_when_units_with_output_and_updating_config( + self, connector_record_manager_mock, input_mock + ): client_mock = Mock() config_wrapper_mock = Mock() @@ -98,11 +141,14 @@ async def test_apply_from_client_when_units_with_output_and_updating_config(self unit_mock.config.source = {"elasticsearch": {"api_key": 123}} unit_mock.config.type = "elasticsearch" - client_mock.units = [unit_mock] + client_mock.units = [unit_mock, input_mock] checkin_handler = ConnectorCheckinHandler( - client_mock, config_wrapper_mock, service_manager_mock + client_mock, + config_wrapper_mock, + service_manager_mock, ) + checkin_handler.connector_record_manager = connector_record_manager_mock await checkin_handler.apply_from_client() @@ -111,7 +157,7 @@ async def test_apply_from_client_when_units_with_output_and_updating_config(self @pytest.mark.asyncio async def test_apply_from_client_when_units_with_multiple_outputs_and_updating_config( - self, + self, connector_record_manager_mock, input_mock ): client_mock = Mock() config_wrapper_mock = Mock() @@ -130,24 +176,28 @@ async def test_apply_from_client_when_units_with_multiple_outputs_and_updating_c unit_kafka.config.type = "kafka" unit_kafka.config.id = "config-kafka" - client_mock.units = [unit_kafka, unit_es] + client_mock.units = [unit_kafka, unit_es, input_mock] checkin_handler = ConnectorCheckinHandler( - client_mock, config_wrapper_mock, service_manager_mock + client_mock, + config_wrapper_mock, + service_manager_mock, ) + checkin_handler.connector_record_manager = connector_record_manager_mock await checkin_handler.apply_from_client() # Only ES output from the policy should be used by connectors component assert config_wrapper_mock.try_update.called_once() - called_unit = config_wrapper_mock.try_update.call_args[0][0] - assert called_unit.config.id == "config-es" + _, called_kwargs = config_wrapper_mock.try_update.call_args + called_output_unit = called_kwargs.get("output_unit") + assert called_output_unit.config.id == "config-es" assert service_manager_mock.restart.called @pytest.mark.asyncio async def test_apply_from_client_when_units_with_multiple_mixed_outputs_and_updating_config( - self, + self, connector_record_manager_mock, input_mock ): client_mock = Mock() config_wrapper_mock = Mock() @@ -171,24 +221,28 @@ async def test_apply_from_client_when_units_with_multiple_mixed_outputs_and_upda unit_kafka.config.type = "kafka" unit_kafka.config.id = "config-kafka" - client_mock.units = [unit_kafka, unit_es_2, unit_es_1] + client_mock.units = [unit_kafka, unit_es_2, unit_es_1, input_mock] checkin_handler = ConnectorCheckinHandler( - client_mock, config_wrapper_mock, service_manager_mock + client_mock, + config_wrapper_mock, + service_manager_mock, ) + checkin_handler.connector_record_manager = connector_record_manager_mock await checkin_handler.apply_from_client() # First ES output from the policy should be used by connectors component assert config_wrapper_mock.try_update.called_once() - called_unit = config_wrapper_mock.try_update.call_args[0][0] - assert called_unit.config.id == "config-es-2" + _, called_kwargs = config_wrapper_mock.try_update.call_args + called_output_unit = called_kwargs.get("output_unit") + assert called_output_unit.config.id == "config-es-2" assert service_manager_mock.restart.called @pytest.mark.asyncio async def test_apply_from_client_when_units_with_output_and_updating_log_level( - self, + self, connector_record_manager_mock, input_mock ): client_mock = Mock() config_wrapper = ConnectorsAgentConfigurationWrapper() @@ -213,11 +267,14 @@ async def test_apply_from_client_when_units_with_output_and_updating_log_level( log_level=proto.UnitLogLevel.DEBUG, ) - client_mock.units = [unit] + client_mock.units = [unit, input_mock] checkin_handler = ConnectorCheckinHandler( - client_mock, config_wrapper, service_manager_mock + client_mock, + config_wrapper, + service_manager_mock, ) + checkin_handler.connector_record_manager = connector_record_manager_mock await checkin_handler.apply_from_client() assert service_manager_mock.restart.called