diff --git a/catalog/dags/common/constants.py b/catalog/dags/common/constants.py index 13a7dab1857..c2de110341d 100644 --- a/catalog/dags/common/constants.py +++ b/catalog/dags/common/constants.py @@ -15,6 +15,9 @@ STAGING = "staging" PRODUCTION = "production" +Environment = Literal["staging", "production"] +ENVIRONMENTS = [STAGING, PRODUCTION] + CONTACT_EMAIL = os.getenv("CONTACT_EMAIL") DAG_DEFAULT_ARGS = { diff --git a/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index.py b/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index.py index 9fbc58192c5..1cb2b970eda 100644 --- a/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index.py +++ b/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index.py @@ -2,7 +2,6 @@ from datetime import timedelta from airflow.decorators import task, task_group -from airflow.models.connection import Connection from airflow.providers.elasticsearch.hooks.elasticsearch import ElasticsearchPythonHook from airflow.sensors.python import PythonSensor @@ -21,12 +20,6 @@ GET_CURRENT_INDEX_CONFIG_TASK_NAME = "get_current_index_configuration" -@task -def get_es_host(environment: str): - conn = Connection.get_connection_from_secrets(f"elasticsearch_http_{environment}") - return conn.host - - @task def get_index_name(media_type: str, index_suffix: str): return f"{media_type}-{index_suffix}".lower() diff --git a/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index_dag.py b/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index_dag.py index ad4c3229c3a..5e2e517a042 100644 --- a/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index_dag.py +++ b/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index_dag.py @@ -109,6 +109,7 @@ CREATE_NEW_INDEX_CONFIGS, CreateNewIndex, ) +from elasticsearch_cluster.shared import get_es_host logger = logging.getLogger(__name__) @@ -188,7 +189,7 @@ def create_new_es_index_dag(config: CreateNewIndex): with dag: prevent_concurrency = prevent_concurrency_with_dags(config.blocking_dags) - es_host = es.get_es_host(environment=config.environment) + es_host = get_es_host(environment=config.environment) index_name = es.get_index_name( media_type="{{ params.media_type }}", diff --git a/catalog/dags/elasticsearch_cluster/shared.py b/catalog/dags/elasticsearch_cluster/shared.py new file mode 100644 index 00000000000..ef53d0fade8 --- /dev/null +++ b/catalog/dags/elasticsearch_cluster/shared.py @@ -0,0 +1,11 @@ +from airflow.decorators import task +from airflow.models.connection import Connection +from airflow.models.xcom_arg import XComArg + +from common.constants import Environment + + +@task +def get_es_host(environment: Environment) -> XComArg: + conn = Connection.get_connection_from_secrets(f"elasticsearch_http_{environment}") + return conn.host