diff --git a/dbt-athena/test.env.example b/dbt-athena/.env.example similarity index 88% rename from dbt-athena/test.env.example rename to dbt-athena/.env.example index 7182e9c0..ca571cfb 100644 --- a/dbt-athena/test.env.example +++ b/dbt-athena/.env.example @@ -1,10 +1,10 @@ DBT_TEST_ATHENA_S3_STAGING_DIR= DBT_TEST_ATHENA_S3_TMP_TABLE_DIR= DBT_TEST_ATHENA_REGION_NAME= -DBT_TEST_ATHENA_THREADS= -DBT_TEST_ATHENA_POLL_INTERVAL= DBT_TEST_ATHENA_DATABASE= DBT_TEST_ATHENA_SCHEMA= DBT_TEST_ATHENA_WORK_GROUP= +DBT_TEST_ATHENA_THREADS= +DBT_TEST_ATHENA_POLL_INTERVAL= +DBT_TEST_ATHENA_NUM_RETRIES= DBT_TEST_ATHENA_AWS_PROFILE_NAME= -DBT_TEST_ATHENA_SPARK_WORK_GROUP= diff --git a/dbt-athena/pyproject.toml b/dbt-athena/pyproject.toml index a56d057c..27b84498 100644 --- a/dbt-athena/pyproject.toml +++ b/dbt-athena/pyproject.toml @@ -98,7 +98,6 @@ check-sdist = [ ] [tool.pytest] -env_files = ["test.env"] testpaths = [ "tests/unit", "tests/functional", diff --git a/dbt-athena/tests/conftest.py b/dbt-athena/tests/conftest.py deleted file mode 100644 index e94bd8cf..00000000 --- a/dbt-athena/tests/conftest.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -from io import StringIO -from unittest.mock import MagicMock, patch - -import boto3 -import pytest -from dbt_common.events import get_event_manager -from dbt_common.events.base_types import EventLevel -from dbt_common.events.logger import LineFormat, LoggerConfig, NoFilter - -from dbt.adapters.athena import connections -from dbt.adapters.athena.connections import AthenaCredentials - -from .unit.constants import ( - ATHENA_WORKGROUP, - AWS_REGION, - DATA_CATALOG_NAME, - DATABASE_NAME, - S3_STAGING_DIR, - SPARK_WORKGROUP, -) - -# Import the functional fixtures as a plugin -# Note: fixtures with session scope need to be local - -pytest_plugins = ["dbt.tests.fixtures.project"] - - -# The profile dictionary, used to write out profiles.yml -@pytest.fixture(scope="class") -def dbt_profile_target(): - return { - "type": "athena", - "s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"), - "s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"), - "schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"), - "database": os.getenv("DBT_TEST_ATHENA_DATABASE"), - "region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"), - "threads": int(os.getenv("DBT_TEST_ATHENA_THREADS", "1")), - "poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")), - "num_retries": int(os.getenv("DBT_TEST_ATHENA_NUM_RETRIES", "2")), - "work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"), - "aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None, - "spark_work_group": os.getenv("DBT_TEST_ATHENA_SPARK_WORK_GROUP"), - } - - -@pytest.fixture(scope="function") -def dbt_error_caplog() -> StringIO: - return _setup_custom_caplog("dbt_error", EventLevel.ERROR) - - -@pytest.fixture(scope="function") -def dbt_debug_caplog() -> StringIO: - return _setup_custom_caplog("dbt_debug", EventLevel.DEBUG) - - -def _setup_custom_caplog(name: str, level: EventLevel): - string_buf = StringIO() - capture_config = LoggerConfig( - name=name, - level=level, - use_colors=False, - line_format=LineFormat.PlainText, - filter=NoFilter, - output_stream=string_buf, - ) - event_manager = get_event_manager() - event_manager.add_logger(capture_config) - return string_buf - - -@pytest.fixture(scope="class") -def athena_client(): - with patch.object(boto3.session.Session, "client", return_value=MagicMock()) as mock_athena_client: - return mock_athena_client - - -@patch.object(connections, "AthenaCredentials") -@pytest.fixture(scope="class") -def athena_credentials(): - return AthenaCredentials( - database=DATA_CATALOG_NAME, - schema=DATABASE_NAME, - s3_staging_dir=S3_STAGING_DIR, - region_name=AWS_REGION, - work_group=ATHENA_WORKGROUP, - spark_work_group=SPARK_WORKGROUP, - ) diff --git a/dbt-athena/tests/functional/conftest.py b/dbt-athena/tests/functional/conftest.py new file mode 100644 index 00000000..1591459b --- /dev/null +++ b/dbt-athena/tests/functional/conftest.py @@ -0,0 +1,26 @@ +import os + +import pytest + +# Import the functional fixtures as a plugin +# Note: fixtures with session scope need to be local +pytest_plugins = ["dbt.tests.fixtures.project"] + + +# The profile dictionary, used to write out profiles.yml +@pytest.fixture(scope="class") +def dbt_profile_target(): + return { + "type": "athena", + "s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"), + "s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"), + "region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"), + "database": os.getenv("DBT_TEST_ATHENA_DATABASE"), + "schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"), + "work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"), + "threads": int(os.getenv("DBT_TEST_ATHENA_THREADS", "1")), + "poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")), + "num_retries": int(os.getenv("DBT_TEST_ATHENA_NUM_RETRIES", "2")), + "aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None, + "spark_work_group": os.getenv("DBT_TEST_ATHENA_SPARK_WORK_GROUP"), + } diff --git a/dbt-athena/tests/unit/conftest.py b/dbt-athena/tests/unit/conftest.py index 90cf0765..21d051c7 100644 --- a/dbt-athena/tests/unit/conftest.py +++ b/dbt-athena/tests/unit/conftest.py @@ -1,9 +1,25 @@ +from io import StringIO import os +from unittest.mock import MagicMock, patch +import boto3 import pytest -from .constants import AWS_REGION -from .utils import MockAWSService +from dbt_common.events import get_event_manager +from dbt_common.events.base_types import EventLevel +from dbt_common.events.logger import LineFormat, LoggerConfig, NoFilter + +from dbt.adapters.athena import connections +from dbt.adapters.athena.connections import AthenaCredentials + +from tests.unit.utils import MockAWSService +from tests.unit import constants + + +@pytest.fixture(scope="class") +def athena_client(): + with patch.object(boto3.session.Session, "client", return_value=MagicMock()) as mock_athena_client: + return mock_athena_client @pytest.fixture(scope="function") @@ -13,9 +29,47 @@ def aws_credentials(): os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" os.environ["AWS_SECURITY_TOKEN"] = "testing" os.environ["AWS_SESSION_TOKEN"] = "testing" - os.environ["AWS_DEFAULT_REGION"] = AWS_REGION + os.environ["AWS_DEFAULT_REGION"] = constants.AWS_REGION + + +@patch.object(connections, "AthenaCredentials") +@pytest.fixture(scope="class") +def athena_credentials(): + return AthenaCredentials( + database=constants.DATA_CATALOG_NAME, + schema=constants.DATABASE_NAME, + s3_staging_dir=constants.S3_STAGING_DIR, + region_name=constants.AWS_REGION, + work_group=constants.ATHENA_WORKGROUP, + spark_work_group=constants.SPARK_WORKGROUP, + ) @pytest.fixture() def mock_aws_service(aws_credentials) -> MockAWSService: return MockAWSService() + + +@pytest.fixture(scope="function") +def dbt_error_caplog() -> StringIO: + return _setup_custom_caplog("dbt_error", EventLevel.ERROR) + + +@pytest.fixture(scope="function") +def dbt_debug_caplog() -> StringIO: + return _setup_custom_caplog("dbt_debug", EventLevel.DEBUG) + + +def _setup_custom_caplog(name: str, level: EventLevel): + string_buf = StringIO() + capture_config = LoggerConfig( + name=name, + level=level, + use_colors=False, + line_format=LineFormat.PlainText, + filter=NoFilter, + output_stream=string_buf, + ) + event_manager = get_event_manager() + event_manager.add_logger(capture_config) + return string_buf