diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c6354a37..8cdee2f4 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko @thenaturalist +* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0e2fef2a..ab75579a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.8', '3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index 36b127c9..6a7cf974 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,6 @@ cython_debug/ # Project specific test.py + +# OS +.DS_Store diff --git a/README.md b/README.md index fc1ef08a..6a72ad96 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,19 @@
+ + + - +
## Features -* Supports dbt version `1.5.*` +* Supports dbt version `1.6.*` +* Supports from Python * Supports [seeds][seeds] * Correctly detects views and their columns * Supports [table materialization][table] @@ -79,6 +83,7 @@ A dbt profile can be configured to run against AWS Athena using the following co | schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` | | database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` | | poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` | +| debug_query_state | Flag if debug message with Athena query state is needed | Optional | `false` | | aws_access_key_id | Access key ID of the user performing requests. | Optional | `AKIAIOSFODNN7EXAMPLE` | | aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` | | aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` | @@ -86,6 +91,8 @@ A dbt profile can be configured to run against AWS Athena using the following co | num_retries | Number of times to retry a failing query | Optional | `3` | | spark_work_group | Identifier of Athena Spark workgroup | Optional | `my-spark-workgroup` | | spark_threads | Number of spark sessions to create. Recommended to be same as threads. | Optional | `4` | +| seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` | +| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` | **Example profiles.yml entry:** ```yaml @@ -105,6 +112,8 @@ athena: work_group: my-workgroup spark_work_group: my-spark-workgroup spark_threads: 4 + seed_s3_upload_args: + ACL: bucket-owner-full-control ``` _Additional information_ diff --git a/dbt/adapters/athena/__version__.py b/dbt/adapters/athena/__version__.py index e3a0f015..cead7e89 100644 --- a/dbt/adapters/athena/__version__.py +++ b/dbt/adapters/athena/__version__.py @@ -1 +1 @@ -version = "1.5.0" +version = "1.6.1" diff --git a/dbt/adapters/athena/config.py b/dbt/adapters/athena/config.py index 4837acbd..62ed6392 100644 --- a/dbt/adapters/athena/config.py +++ b/dbt/adapters/athena/config.py @@ -1,7 +1,7 @@ +import importlib.metadata from functools import lru_cache from typing import Any, Dict -import pkg_resources from botocore import config from dbt.adapters.athena.constants import ( @@ -14,9 +14,7 @@ @lru_cache() def get_boto3_config() -> config.Config: - return config.Config( - user_agent_extra="dbt-athena-community/" + pkg_resources.get_distribution("dbt-athena-community").version - ) + return config.Config(user_agent_extra="dbt-athena-community/" + importlib.metadata.version("dbt-athena-community")) class AthenaSparkSessionConfig: diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index f8ce2462..c639abf7 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -1,4 +1,7 @@ import hashlib +import json +import re +import time from concurrent.futures.thread import ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy @@ -33,6 +36,13 @@ from dbt.contracts.connection import AdapterResponse, Connection, ConnectionState from dbt.exceptions import ConnectionError, DbtRuntimeError +logger = AdapterLogger("Athena") + + +@dataclass +class AthenaAdapterResponse(AdapterResponse): + data_scanned_in_bytes: Optional[int] = None + @dataclass class AthenaCredentials(Credentials): @@ -44,13 +54,17 @@ class AthenaCredentials(Credentials): aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None poll_interval: float = 1.0 + debug_query_state: bool = False _ALIASES = {"catalog": "database"} num_retries: Optional[int] = 5 s3_data_dir: Optional[str] = None s3_data_naming: Optional[str] = "schema_table_unique" spark_work_group: Optional[str] = None spark_threads: Optional[int] = DEFAULT_THREAD_COUNT - lf_tags: Optional[Dict[str, str]] = None + seed_s3_upload_args: Optional[Dict[str, Any]] = None + # Unfortunately we can not just use dict, must by Dict because we'll get the following error: + # Credentials in profile "athena", target "athena" invalid: Unable to create schema for 'dict' + lf_tags_database: Optional[Dict[str, str]] = None @property def type(self) -> str: @@ -74,7 +88,9 @@ def _connection_keys(self) -> Tuple[str, ...]: "endpoint_url", "s3_data_dir", "s3_data_naming", - "lf_tags", + "debug_query_state", + "seed_s3_upload_args", + "lf_tags_database", "spark_work_group", "spark_threads", ) @@ -95,6 +111,32 @@ def _collect_result_set(self, query_id: str) -> AthenaResultSet: retry_config=self._retry_config, ) + def _poll(self, query_id: str) -> AthenaQueryExecution: + try: + query_execution = self.__poll(query_id) + except KeyboardInterrupt as e: + if self._kill_on_interrupt: + logger.warning("Query canceled by user.") + self._cancel(query_id) + query_execution = self.__poll(query_id) + else: + raise e + return query_execution + + def __poll(self, query_id: str) -> AthenaQueryExecution: + while True: + query_execution = self._get_query_execution(query_id) + if query_execution.state in [ + AthenaQueryExecution.STATE_SUCCEEDED, + AthenaQueryExecution.STATE_FAILED, + AthenaQueryExecution.STATE_CANCELLED, + ]: + return query_execution + else: + if self.connection.cursor_kwargs.get("debug_query_state", False): + logger.debug(f"Query state is: {query_execution.state}. Sleeping for {self._poll_interval}...") + time.sleep(self._poll_interval) + def execute( # type: ignore self, operation: str, @@ -104,6 +146,7 @@ def execute( # type: ignore endpoint_url: Optional[str] = None, cache_size: int = 0, cache_expiration_time: int = 0, + catch_partitions_limit: bool = False, **kwargs, ): def inner() -> AthenaCursor: @@ -130,7 +173,12 @@ def inner() -> AthenaCursor: return self retry = tenacity.Retrying( - retry=retry_if_exception(lambda _: True), + # No need to retry if TOO_MANY_OPEN_PARTITIONS occurs. + # Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry, + # because not all files are removed immediately after first try to create table + retry=retry_if_exception( + lambda e: False if catch_partitions_limit and "TOO_MANY_OPEN_PARTITIONS" in str(e) else True + ), stop=stop_after_attempt(self._retry_config.attempt), wait=wait_exponential( multiplier=self._retry_config.attempt, @@ -175,9 +223,11 @@ def open(cls, connection: Connection) -> Connection: handle = AthenaConnection( s3_staging_dir=creds.s3_staging_dir, endpoint_url=creds.endpoint_url, + catalog_name=creds.database, schema_name=creds.schema, work_group=creds.work_group, cursor_class=AthenaCursor, + cursor_kwargs={"debug_query_state": creds.debug_query_state}, formatter=AthenaParameterFormatter(), poll_interval=creds.poll_interval, session=get_boto3_session(connection), @@ -200,12 +250,39 @@ def open(cls, connection: Connection) -> Connection: return connection @classmethod - def get_response(cls, cursor: AthenaCursor) -> AdapterResponse: + def get_response(cls, cursor: AthenaCursor) -> AthenaAdapterResponse: code = "OK" if cursor.state == AthenaQueryExecution.STATE_SUCCEEDED else "ERROR" - return AdapterResponse(_message=f"{code} {cursor.rowcount}", rows_affected=cursor.rowcount, code=code) + rowcount, data_scanned_in_bytes = cls.process_query_stats(cursor) + return AthenaAdapterResponse( + _message=f"{code} {rowcount}", + rows_affected=rowcount, + code=code, + data_scanned_in_bytes=data_scanned_in_bytes, + ) + + @staticmethod + def process_query_stats(cursor: AthenaCursor) -> Tuple[int, int]: + """ + Helper function to parse query statistics from SELECT statements. + The function looks for all statements that contains rowcount or data_scanned_in_bytes, + then strip the SELECT statements, and pick the value between curly brackets. + """ + if all(map(cursor.query.__contains__, ["rowcount", "data_scanned_in_bytes"])): + try: + query_split = cursor.query.lower().split("select")[-1] + # query statistics are in the format {"rowcount":1, "data_scanned_in_bytes": 3} + # the following statement extract the content between { and } + query_stats = re.search("{(.*)}", query_split) + if query_stats: + stats = json.loads("{" + query_stats.group(1) + "}") + return stats.get("rowcount", -1), stats.get("data_scanned_in_bytes", 0) + except Exception as err: + logger.debug(f"There was an error parsing query stats {err}") + return -1, 0 + return cursor.rowcount, cursor.data_scanned_in_bytes def cancel(self, connection: Connection) -> None: - connection.handle.cancel() + pass def add_begin_query(self) -> None: pass diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index d1ef4978..faf7ac5c 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -1,7 +1,9 @@ import csv import os import posixpath as path +import re import tempfile +from dataclasses import dataclass from itertools import chain from textwrap import dedent from threading import Lock @@ -15,9 +17,11 @@ from mypy_boto3_glue.type_defs import ( ColumnTypeDef, GetTableResponseTypeDef, + TableInputTypeDef, TableTypeDef, TableVersionTypeDef, ) +from pyathena.error import OperationalError from dbt.adapters.athena import AthenaConnectionManager from dbt.adapters.athena.column import AthenaColumn @@ -42,8 +46,15 @@ get_table_type, ) from dbt.adapters.athena.s3 import S3DataNaming -from dbt.adapters.athena.utils import clean_sql_comment, get_catalog_id, get_chunks +from dbt.adapters.athena.utils import ( + AthenaCatalogType, + clean_sql_comment, + get_catalog_id, + get_catalog_type, + get_chunks, +) from dbt.adapters.base import ConstraintSupport, PythonJobHelper, available +from dbt.adapters.base.impl import AdapterConfig from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.sql import SQLAdapter from dbt.contracts.connection import AdapterResponse @@ -54,12 +65,58 @@ boto3_client_lock = Lock() +@dataclass +class AthenaConfig(AdapterConfig): + """Database and relation-level configs. + + Args: + work_group (str) : Identifier of Athena workgroup. + s3_staging_dir (str) : S3 location to store Athena query results and metadata. + external_location (str) : If set, the full S3 path in which the table will be saved. + partitioned_by (str) : An array list of columns by which the table will be partitioned. + bucketed_by (str) : An array list of columns to bucket data, ignored if using Iceberg. + bucket_count (str) : The number of buckets for bucketing your data, ignored if using Iceberg. + table_type (str) : The type of table, supports hive or iceberg. + ha (bool) : If the table should be built using the high-availability method. + format (str) : The data format for the table. Supports ORC, PARQUET, AVRO, JSON, TEXTFILE. + write_compression (str) : The compression type to use for any storage format + that allows compression to be specified. + field_delimiter (str) : Custom field delimiter, for when format is set to TEXTFILE. + table_properties (str) : Table properties to add to the table, valid for Iceberg only. + native_drop (str) : Relation drop operations will be performed with SQL, not direct Glue API calls. + seed_by_insert (bool) : default behaviour uploads seed data to S3. + lf_tags_config (Dict[str, Any]) : AWS lakeformation tags to associate with the table and columns. + seed_s3_upload_args (Dict[str, Any]) : Dictionary containing boto3 ExtraArgs when uploading to S3. + partitions_limit (int) : Maximum numbers of partitions when batching. + + """ + + work_group: Optional[str] = None + s3_staging_dir: Optional[str] = None + external_location: Optional[str] = None + partitioned_by: Optional[str] = None + bucketed_by: Optional[str] = None + bucket_count: Optional[str] = None + table_type: Optional[str] = "hive" + ha: Optional[bool] = False + format: Optional[str] = "parquet" + write_compression: Optional[str] = None + field_delimiter: Optional[str] = None + table_properties: Optional[str] = None + native_drop: Optional[str] = None + seed_by_insert: Optional[bool] = False + lf_tags_config: Optional[Dict[str, Any]] = None + seed_s3_upload_args: Optional[Dict[str, Any]] = None + partitions_limit: Optional[int] = None + + class AthenaAdapter(SQLAdapter): BATCH_CREATE_PARTITION_API_LIMIT = 100 BATCH_DELETE_PARTITION_API_LIMIT = 25 ConnectionManager = AthenaConnectionManager Relation = AthenaRelation + AdapterSpecificConfigs = AthenaConfig # There is no such concept as constraints in Athena CONSTRAINT_SUPPORT = { @@ -87,6 +144,19 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" + @available + def add_lf_tags_to_database(self, relation: AthenaRelation) -> None: + conn = self.connections.get_thread_connection() + client = conn.handle + if lf_tags := conn.credentials.lf_tags_database: + config = LfTagsConfig(enabled=True, tags=lf_tags) + with boto3_client_lock: + lf_client = client.session.client("lakeformation", client.region_name, config=get_boto3_config()) + manager = LfTagsManager(lf_client, relation, config) + manager.process_lf_tags_database() + else: + LOGGER.debug(f"Lakeformation is disabled for {relation}") + @available def add_lf_tags(self, relation: AthenaRelation, lf_tags_config: Dict[str, Any]) -> None: config = LfTagsConfig(**lf_tags_config) @@ -203,11 +273,15 @@ def get_glue_table(self, relation: AthenaRelation) -> Optional[GetTableResponseT """ conn = self.connections.get_thread_connection() client = conn.handle + + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + with boto3_client_lock: glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) try: - table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier) + table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier) except ClientError as e: if e.response["Error"]["Code"] == "EntityNotFoundException": LOGGER.debug(f"Table {relation.render()} does not exists - Ignoring") @@ -290,6 +364,7 @@ def upload_seed_to_s3( s3_data_dir: Optional[str] = None, s3_data_naming: Optional[str] = None, external_location: Optional[str] = None, + seed_s3_upload_args: Optional[Dict[str, Any]] = None, ) -> str: conn = self.connections.get_thread_connection() client = conn.handle @@ -308,7 +383,7 @@ def upload_seed_to_s3( # This ensures cross-platform support, tempfile.NamedTemporaryFile does not tmpfile = os.path.join(tempfile.gettempdir(), os.urandom(24).hex()) table.to_csv(tmpfile, quoting=csv.QUOTE_NONNUMERIC) - s3_client.upload_file(tmpfile, bucket, object_name) + s3_client.upload_file(tmpfile, bucket, object_name, ExtraArgs=seed_s3_upload_args) os.remove(tmpfile) return str(s3_location) @@ -410,6 +485,29 @@ def _get_one_table_for_catalog(self, table: TableTypeDef, database: str) -> List for idx, col in enumerate(table["StorageDescriptor"]["Columns"] + table.get("PartitionKeys", [])) ] + def _get_one_table_for_non_glue_catalog( + self, table: TableTypeDef, schema: str, database: str + ) -> List[Dict[str, Any]]: + table_catalog = { + "table_database": database, + "table_schema": schema, + "table_name": table["Name"], + "table_type": RELATION_TYPE_MAP[table.get("TableType", "EXTERNAL_TABLE")].value, + "table_comment": table.get("Parameters", {}).get("comment", ""), + } + return [ + { + **table_catalog, + **{ + "column_name": col["Name"], + "column_index": idx, + "column_type": col["Type"], + "column_comment": col.get("Comment", ""), + }, + } + for idx, col in enumerate(table["Columns"] + table.get("PartitionKeys", [])) + ] + def _get_one_catalog( self, information_schema: InformationSchema, @@ -417,29 +515,55 @@ def _get_one_catalog( manifest: Manifest, ) -> agate.Table: data_catalog = self._get_data_catalog(information_schema.path.database) - catalog_id = get_catalog_id(data_catalog) + data_catalog_type = get_catalog_type(data_catalog) + conn = self.connections.get_thread_connection() client = conn.handle - with boto3_client_lock: - glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) - - catalog = [] - paginator = glue_client.get_paginator("get_tables") - for schema, relations in schemas.items(): - kwargs = { - "DatabaseName": schema, - "MaxResults": 100, - } - # If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 infers it from the account Id. - if catalog_id: - kwargs["CatalogId"] = catalog_id + if data_catalog_type == AthenaCatalogType.GLUE: + with boto3_client_lock: + glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) + + catalog = [] + paginator = glue_client.get_paginator("get_tables") + for schema, relations in schemas.items(): + kwargs = { + "DatabaseName": schema, + "MaxResults": 100, + } + # If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 + # infers it from the account Id. + catalog_id = get_catalog_id(data_catalog) + if catalog_id: + kwargs["CatalogId"] = catalog_id + + for page in paginator.paginate(**kwargs): + for table in page["TableList"]: + if relations and table["Name"] in relations: + catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database)) + table = agate.Table.from_object(catalog) + else: + with boto3_client_lock: + athena_client = client.session.client( + "athena", region_name=client.region_name, config=get_boto3_config() + ) - for page in paginator.paginate(**kwargs): - for table in page["TableList"]: - if relations and table["Name"] in relations: - catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database)) + catalog = [] + paginator = athena_client.get_paginator("list_table_metadata") + for schema, relations in schemas.items(): + for page in paginator.paginate( + CatalogName=information_schema.path.database, + DatabaseName=schema, + MaxResults=50, # Limit supported by this operation + ): + for table in page["TableMetadataList"]: + if relations and table["Name"] in relations: + catalog.extend( + self._get_one_table_for_non_glue_catalog( + table, schema, information_schema.path.database + ) + ) + table = agate.Table.from_object(catalog) - table = agate.Table.from_object(catalog) filtered_table = self._catalog_filter_table(table, manifest) return self._join_catalog_table_owners(filtered_table, manifest) @@ -468,6 +592,7 @@ def _get_data_catalog(self, database: str) -> Optional[DataCatalogTypeDef]: return athena.get_data_catalog(Name=database)["DataCatalog"] return None + @available def list_relations_without_caching(self, schema_relation: AthenaRelation) -> List[BaseRelation]: data_catalog = self._get_data_catalog(schema_relation.database) catalog_id = get_catalog_id(data_catalog) @@ -526,16 +651,25 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati conn = self.connections.get_thread_connection() client = conn.handle + data_catalog = self._get_data_catalog(src_relation.database) + src_catalog_id = get_catalog_id(data_catalog) + with boto3_client_lock: glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) - src_table = glue_client.get_table(DatabaseName=src_relation.schema, Name=src_relation.identifier).get("Table") + src_table = glue_client.get_table( + CatalogId=src_catalog_id, DatabaseName=src_relation.schema, Name=src_relation.identifier + ).get("Table") + src_table_partitions = glue_client.get_partitions( - DatabaseName=src_relation.schema, TableName=src_relation.identifier + CatalogId=src_catalog_id, DatabaseName=src_relation.schema, TableName=src_relation.identifier ).get("Partitions") + data_catalog = self._get_data_catalog(src_relation.database) + target_catalog_id = get_catalog_id(data_catalog) + target_table_partitions = glue_client.get_partitions( - DatabaseName=target_relation.schema, TableName=target_relation.identifier + CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableName=target_relation.identifier ).get("Partitions") target_table_version = { @@ -548,7 +682,9 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati } # perform a table swap - glue_client.update_table(DatabaseName=target_relation.schema, TableInput=target_table_version) + glue_client.update_table( + CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableInput=target_table_version + ) LOGGER.debug(f"Table {target_relation.render()} swapped with the content of {src_relation.render()}") # we delete the target table partitions in any case @@ -557,6 +693,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati if target_table_partitions: for partition_batch in get_chunks(target_table_partitions, AthenaAdapter.BATCH_DELETE_PARTITION_API_LIMIT): glue_client.batch_delete_partition( + CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableName=target_relation.identifier, PartitionsToDelete=[{"Values": partition["Values"]} for partition in partition_batch], @@ -565,6 +702,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati if src_table_partitions: for partition_batch in get_chunks(src_table_partitions, AthenaAdapter.BATCH_CREATE_PARTITION_API_LIMIT): glue_client.batch_create_partition( + CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableName=target_relation.identifier, PartitionInputList=[ @@ -606,6 +744,9 @@ def expire_glue_table_versions( conn = self.connections.get_thread_connection() client = conn.handle + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + with boto3_client_lock: glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) @@ -618,7 +759,10 @@ def expire_glue_table_versions( location = v["Table"]["StorageDescriptor"]["Location"] try: glue_client.delete_table_version( - DatabaseName=relation.schema, TableName=relation.identifier, VersionId=str(version) + CatalogId=catalog_id, + DatabaseName=relation.schema, + TableName=relation.identifier, + VersionId=str(version), ) deleted_versions.append(version) LOGGER.debug(f"Deleted version {version} of table {relation.render()} ") @@ -638,35 +782,72 @@ def persist_docs_to_glue( model: Dict[str, Any], persist_relation_docs: bool = False, persist_column_docs: bool = False, + skip_archive_table_version: bool = False, ) -> None: + """Save model/columns description to Glue Table metadata. + + :param skip_archive_table_version: if True, current table version will not be archived before creating new one. + The purpose is to avoid creating redundant table version if it already was created during the same dbt run + after CREATE OR REPLACE VIEW or ALTER TABLE statements. + Every dbt run should create not more than one table version. + """ conn = self.connections.get_thread_connection() client = conn.handle + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + with boto3_client_lock: glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) - table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.name).get("Table") - updated_table = { - "Name": table["Name"], - "StorageDescriptor": table["StorageDescriptor"], - "PartitionKeys": table.get("PartitionKeys", []), - "TableType": table["TableType"], - "Parameters": table.get("Parameters", {}), - "Description": table.get("Description", ""), - } + # By default, there is no need to update Glue Table + need_udpate_table = False + # Get Table from Glue + table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.name)["Table"] + # Prepare new version of Glue Table picking up significant fields + updated_table = self._get_table_input(table) + # Update table description if persist_relation_docs: - table_comment = clean_sql_comment(model["description"]) - updated_table["Description"] = table_comment - updated_table["Parameters"]["comment"] = table_comment - + # Prepare dbt description + clean_table_description = clean_sql_comment(model["description"]) + # Get current description from Glue + glue_table_description = table.get("Description", "") + # Get current description parameter from Glue + glue_table_comment = table["Parameters"].get("comment", "") + # Update description if it's different + if clean_table_description != glue_table_description or clean_table_description != glue_table_comment: + updated_table["Description"] = clean_table_description + updated_table_parameters: Dict[str, str] = dict(updated_table["Parameters"]) + updated_table_parameters["comment"] = clean_table_description + updated_table["Parameters"] = updated_table_parameters + need_udpate_table = True + + # Update column comments if persist_column_docs: + # Process every column for col_obj in updated_table["StorageDescriptor"]["Columns"]: + # Get column description from dbt col_name = col_obj["Name"] - col_comment = model["columns"].get(col_name, {}).get("description") - if col_comment: - col_obj["Comment"] = clean_sql_comment(col_comment) - - glue_client.update_table(DatabaseName=relation.schema, TableInput=updated_table) + if col_name in model["columns"]: + col_comment = model["columns"][col_name]["description"] + # Prepare column description from dbt + clean_col_comment = clean_sql_comment(col_comment) + # Get current column comment from Glue + glue_col_comment = col_obj.get("Comment", "") + # Update column description if it's different + if glue_col_comment != clean_col_comment: + col_obj["Comment"] = clean_col_comment + need_udpate_table = True + + # Update Glue Table only if table/column description is modified. + # It prevents redundant schema version creating after incremental runs. + if need_udpate_table: + glue_client.update_table( + CatalogId=catalog_id, + DatabaseName=relation.schema, + TableInput=updated_table, + SkipArchive=skip_archive_table_version, + ) def generate_python_submission_response(self, submission_result: Any) -> AdapterResponse: if not submission_result: @@ -709,11 +890,16 @@ def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn conn = self.connections.get_thread_connection() client = conn.handle + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + with boto3_client_lock: glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) try: - table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)["Table"] + table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)[ + "Table" + ] except ClientError as e: if e.response["Error"]["Code"] == "EntityNotFoundException": LOGGER.debug("table not exist, catching the error") @@ -741,11 +927,14 @@ def delete_from_glue_catalog(self, relation: AthenaRelation) -> None: conn = self.connections.get_thread_connection() client = conn.handle + data_catalog = self._get_data_catalog(relation.database) + catalog_id = get_catalog_id(data_catalog) + with boto3_client_lock: glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) try: - glue_client.delete_table(DatabaseName=schema_name, Name=table_name) + glue_client.delete_table(CatalogId=catalog_id, DatabaseName=schema_name, Name=table_name) LOGGER.debug(f"Deleted table from glue catalog: {relation.render()}") except ClientError as e: if e.response["Error"]["Code"] == "EntityNotFoundException": @@ -848,3 +1037,38 @@ def is_list(self, value: Any) -> bool: a list since this is complicated with purely Jinja syntax. """ return isinstance(value, list) + + @staticmethod + def _get_table_input(table: TableTypeDef) -> TableInputTypeDef: + """ + Prepare Glue Table dictionary to be a table_input argument of update_table() method. + + This is needed because update_table() does not accept some read-only fields of table dictionary + returned by get_table() method. + """ + return {k: v for k, v in table.items() if k in TableInputTypeDef.__annotations__} + + @available + def run_query_with_partitions_limit_catching(self, sql: str) -> str: + query = self.connections._add_query_comment(sql) + conn = self.connections.get_thread_connection() + cursor = conn.handle.cursor() + LOGGER.debug(f"Running Athena query:\n{query}") + try: + cursor.execute(query, catch_partitions_limit=True) + except OperationalError as e: + LOGGER.debug(f"CAUGHT EXCEPTION: {e}") + if "TOO_MANY_OPEN_PARTITIONS" in str(e): + return "TOO_MANY_OPEN_PARTITIONS" + raise e + return f'{{"rowcount":{cursor.rowcount},"data_scanned_in_bytes":{cursor.data_scanned_in_bytes}}}' + + @available + def format_partition_keys(self, partition_keys: List[str]) -> str: + return ", ".join([self.format_one_partition_key(k) for k in partition_keys]) + + @available + def format_one_partition_key(self, partition_key: str) -> str: + """Check if partition key uses Iceberg hidden partitioning""" + hidden = re.search(r"^(hour|day|month|year)\((.+)\)", partition_key.lower()) + return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})" if hidden else partition_key.lower() diff --git a/dbt/adapters/athena/lakeformation.py b/dbt/adapters/athena/lakeformation.py index cd29e7e6..9fc8047b 100644 --- a/dbt/adapters/athena/lakeformation.py +++ b/dbt/adapters/athena/lakeformation.py @@ -34,6 +34,14 @@ def __init__(self, lf_client: LakeFormationClient, relation: AthenaRelation, lf_ self.lf_tags = lf_tags_config.tags self.lf_tags_columns = lf_tags_config.tags_columns + def process_lf_tags_database(self) -> None: + if self.lf_tags: + database_resource = {"Database": {"Name": self.database}} + response = self.lf_client.add_lf_tags_to_resource( + Resource=database_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()] + ) + self._parse_and_log_lf_response(response, None, self.lf_tags) + def process_lf_tags(self) -> None: table_resource = {"Table": {"DatabaseName": self.database, "Name": self.table}} existing_lf_tags = self.lf_client.get_resource_lf_tags(Resource=table_resource) @@ -65,7 +73,7 @@ def _remove_lf_tags_columns(self, existing_lf_tags: GetResourceLFTagsResponseTyp response = self.lf_client.remove_lf_tags_from_resource( Resource=resource, LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}] ) - logger.debug(self._parse_lf_response(response, columns, {tag_key: tag_value}, "remove")) + self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}, "remove") def _apply_lf_tags_table( self, table_resource: ResourceTypeDef, existing_lf_tags: GetResourceLFTagsResponseTypeDef @@ -84,13 +92,13 @@ def _apply_lf_tags_table( response = self.lf_client.remove_lf_tags_from_resource( Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": v} for k, v in to_remove.items()] ) - logger.debug(self._parse_lf_response(response, None, self.lf_tags, "remove")) + self._parse_and_log_lf_response(response, None, self.lf_tags, "remove") if self.lf_tags: response = self.lf_client.add_lf_tags_to_resource( Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()] ) - logger.debug(self._parse_lf_response(response, None, self.lf_tags)) + self._parse_and_log_lf_response(response, None, self.lf_tags) def _apply_lf_tags_columns(self) -> None: if self.lf_tags_columns: @@ -103,25 +111,26 @@ def _apply_lf_tags_columns(self) -> None: Resource=resource, LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}], ) - logger.debug(self._parse_lf_response(response, columns, {tag_key: tag_value})) + self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}) - def _parse_lf_response( + def _parse_and_log_lf_response( self, response: Union[AddLFTagsToResourceResponseTypeDef, RemoveLFTagsFromResourceResponseTypeDef], columns: Optional[List[str]] = None, lf_tags: Optional[Dict[str, str]] = None, verb: str = "add", - ) -> str: - failures = response.get("Failures", []) + ) -> None: + table_appendix = f".{self.table}" if self.table else "" columns_appendix = f" for columns {columns}" if columns else "" - if failures: - base_msg = f"Failed to {verb} LF tags: {lf_tags} to {self.database}.{self.table}" + columns_appendix + resource_msg = self.database + table_appendix + columns_appendix + if failures := response.get("Failures", []): + base_msg = f"Failed to {verb} LF tags: {lf_tags} to " + resource_msg for failure in failures: tag = failure.get("LFTag", {}).get("TagKey") error = failure.get("Error", {}).get("ErrorMessage") - logger.error(f"Failed to {verb} {tag} for {self.database}.{self.table}" + f" - {error}") + logger.error(f"Failed to {verb} {tag} for " + resource_msg + f" - {error}") raise DbtRuntimeError(base_msg) - return f"Success: {verb} LF tags: {lf_tags} to {self.database}.{self.table}" + columns_appendix + logger.debug(f"Success: {verb} LF tags {lf_tags} to " + resource_msg) class FilterConfig(BaseModel): diff --git a/dbt/adapters/athena/relation.py b/dbt/adapters/athena/relation.py index bb07393f..680f9e09 100644 --- a/dbt/adapters/athena/relation.py +++ b/dbt/adapters/athena/relation.py @@ -79,6 +79,7 @@ def add(self, relation: AthenaRelation) -> None: RELATION_TYPE_MAP = { "EXTERNAL_TABLE": TableType.TABLE, + "EXTERNAL": TableType.TABLE, # type returned by federated query tables "MANAGED_TABLE": TableType.TABLE, "VIRTUAL_VIEW": TableType.VIEW, "table": TableType.TABLE, diff --git a/dbt/adapters/athena/utils.py b/dbt/adapters/athena/utils.py index 778fb4c2..dcd74916 100644 --- a/dbt/adapters/athena/utils.py +++ b/dbt/adapters/athena/utils.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Generator, List, Optional, TypeVar from mypy_boto3_athena.type_defs import DataCatalogTypeDef @@ -9,7 +10,17 @@ def clean_sql_comment(comment: str) -> str: def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]: - return catalog["Parameters"]["catalog-id"] if catalog else None + return catalog["Parameters"]["catalog-id"] if catalog and catalog["Type"] == AthenaCatalogType.GLUE.value else None + + +class AthenaCatalogType(Enum): + GLUE = "GLUE" + LAMBDA = "LAMBDA" + HIVE = "HIVE" + + +def get_catalog_type(catalog: Optional[DataCatalogTypeDef]) -> Optional[AthenaCatalogType]: + return AthenaCatalogType(catalog["Type"]) if catalog else None T = TypeVar("T") diff --git a/dbt/include/athena/macros/adapters/columns.sql b/dbt/include/athena/macros/adapters/columns.sql index f3abce88..bb106d3b 100644 --- a/dbt/include/athena/macros/adapters/columns.sql +++ b/dbt/include/athena/macros/adapters/columns.sql @@ -9,8 +9,10 @@ {%- set col = columns[i] -%} {%- if col['data_type'] is not defined -%} {{ col_err.append(col['name']) }} + {%- else -%} + {% set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] %} + cast(null as {{ dml_data_type(col['data_type']) }}) as {{ col_name }}{{ ", " if not loop.last }} {%- endif -%} - cast(null as {{ dml_data_type(col['data_type']) }}) as {{ col['name'] }}{{ ", " if not loop.last }} {%- endfor -%} {%- if (col_err | length) > 0 -%} {{ exceptions.column_type_missing(column_names=col_err) }} diff --git a/dbt/include/athena/macros/adapters/persist_docs.sql b/dbt/include/athena/macros/adapters/persist_docs.sql index 2a7b8748..78503ba9 100644 --- a/dbt/include/athena/macros/adapters/persist_docs.sql +++ b/dbt/include/athena/macros/adapters/persist_docs.sql @@ -1,10 +1,12 @@ {% macro athena__persist_docs(relation, model, for_relation, for_columns) -%} {% set persist_relation_docs = for_relation and config.persist_relation_docs() and model.description %} {% set persist_column_docs = for_columns and config.persist_column_docs() and model.columns %} - {% if (persist_relation_docs or persist_column_docs) and relation.type != 'view' %} + {% set skip_archive_table_version = not is_incremental() %} + {% if persist_relation_docs or persist_column_docs %} {% do adapter.persist_docs_to_glue(relation, model, persist_relation_docs, - persist_column_docs) %}} + persist_column_docs, + skip_archive_table_version=skip_archive_table_version) %}} {% endif %} {% endmacro %} diff --git a/dbt/include/athena/macros/adapters/schema.sql b/dbt/include/athena/macros/adapters/schema.sql index 777a3690..2750a7f9 100644 --- a/dbt/include/athena/macros/adapters/schema.sql +++ b/dbt/include/athena/macros/adapters/schema.sql @@ -2,6 +2,9 @@ {%- call statement('create_schema') -%} create schema if not exists {{ relation.without_identifier().render_hive() }} {% endcall %} + + {{ adapter.add_lf_tags_to_database(relation) }} + {% endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql b/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql new file mode 100644 index 00000000..3f64cc59 --- /dev/null +++ b/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql @@ -0,0 +1,61 @@ +{% macro get_partition_batches(sql, as_subquery=True) -%} + {%- set partitioned_by = config.get('partitioned_by') -%} + {%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%} + {%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%} + {% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %} + + {% call statement('get_partitions', fetch_result=True) %} + {%- if as_subquery -%} + select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }}; + {%- else -%} + select distinct {{ partitioned_keys }} from {{ sql }} order by {{ partitioned_keys }}; + {%- endif -%} + {% endcall %} + + {%- set table = load_result('get_partitions').table -%} + {%- set rows = table.rows -%} + {%- set partitions = {} -%} + {% do log('TOTAL PARTITIONS TO PROCESS: ' ~ rows | length) %} + {%- set partitions_batches = [] -%} + + {%- for row in rows -%} + {%- set single_partition = [] -%} + {%- for col in row -%} + + + {%- set column_type = adapter.convert_type(table, loop.index0) -%} + {%- set comp_func = '=' -%} + {%- if col is none -%} + {%- set value = 'null' -%} + {%- set comp_func = ' is ' -%} + {%- elif column_type == 'integer' or column_type is none -%} + {%- set value = col | string -%} + {%- elif column_type == 'string' -%} + {%- set value = "'" + col + "'" -%} + {%- elif column_type == 'date' -%} + {%- set value = "DATE'" + col | string + "'" -%} + {%- elif column_type == 'timestamp' -%} + {%- set value = "TIMESTAMP'" + col | string + "'" -%} + {%- else -%} + {%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%} + {%- endif -%} + {%- set partition_key = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%} + {%- do single_partition.append(partition_key + comp_func + value) -%} + {%- endfor -%} + + {%- set single_partition_expression = single_partition | join(' and ') -%} + + {%- set batch_number = (loop.index0 / athena_partitions_limit) | int -%} + {% if not batch_number in partitions %} + {% do partitions.update({batch_number: []}) %} + {% endif %} + + {%- do partitions[batch_number].append('(' + single_partition_expression + ')') -%} + {%- if partitions[batch_number] | length == athena_partitions_limit or loop.last -%} + {%- do partitions_batches.append(partitions[batch_number] | join(' or ')) -%} + {%- endif -%} + {%- endfor -%} + + {{ return(partitions_batches) }} + +{%- endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql b/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql index 53b9bfc0..e8263295 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/column_helpers.sql @@ -48,3 +48,11 @@ {{ return(run_query(sql)) }} {% endif %} {% endmacro %} + +{% macro alter_relation_rename_column(relation, source_column, target_column, target_column_type) -%} + {% set sql -%} + alter {{ relation.type }} {{ relation.render_pure() }} + change column {{ source_column }} {{ target_column }} {{ ddl_data_type(target_column_type) }} + {%- endset -%} + {{ return(run_query(sql)) }} +{% endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql index a965e85c..76c2ed73 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql @@ -22,19 +22,42 @@ {% endmacro %} {% macro incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation, statement_name="main") %} - {% set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} - {% if not dest_columns %} + {%- set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) -%} + {%- if not dest_columns -%} {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} - {% endif %} + {%- endif -%} {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} - insert into {{ target_relation }} ({{ dest_cols_csv }}) - ( - select {{ dest_cols_csv }} - from {{ tmp_relation }} - ); + {%- set insert_full -%} + insert into {{ target_relation }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ tmp_relation }} + ); + {%- endset -%} + + {%- set query_result = adapter.run_query_with_partitions_limit_catching(insert_full) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {% set partitions_batches = get_partition_batches(tmp_relation) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches|length) -%} + {%- set insert_batch_partitions -%} + insert into {{ target_relation }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + ); + {%- endset -%} + {%- do run_query(insert_batch_partitions) -%} + {%- endfor -%} + {%- endif -%} + SELECT '{{query_result}}' {%- endmacro %} + {% macro delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} {%- set partitioned_keys = partitioned_by | tojson | replace('\"', '') | replace('[', '') | replace(']', '') -%} {% call statement('get_partitions', fetch_result=True) %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql index 96a565d4..c1730dd6 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -8,7 +8,7 @@ {% set lf_tags_config = config.get('lf_tags_config') %} {% set lf_grants = config.get('lf_grants') %} - {% set partitioned_by = config.get('partitioned_by', default=none) %} + {% set partitioned_by = config.get('partitioned_by') %} {% set target_relation = this.incorporate(type='table') %} {% set existing_relation = load_relation(this) %} {% set tmp_relation = make_temp_relation(this) %} @@ -25,22 +25,18 @@ {% set to_drop = [] %} {% if existing_relation is none %} - {% set build_sql = create_table_as(False, target_relation, compiled_code, model_language) -%} + {% set query_result = safe_create_table_as(False, target_relation, compiled_code, language=model_language) -%} + {% set build_sql = "select '" ~ query_result ~ "'" -%} {% elif existing_relation.is_view or should_full_refresh() %} {% do drop_relation(existing_relation) %} - {% set build_sql = create_table_as(False, target_relation, compiled_code, model_language) -%} + {% set query_result = safe_create_table_as(False, target_relation, compiled_code, language=model_language) -%} + {% set build_sql = "select '" ~ query_result ~ "'" -%} {% elif partitioned_by is not none and strategy == 'insert_overwrite' %} {% set tmp_relation = make_temp_relation(target_relation) %} {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% if model_language == "sql" %} - {% do run_query(create_table_as(True, tmp_relation, compiled_code, model_language)) %} - {% else %} - {% call statement('py_save_table', language=model_language) -%} - {{ create_table_as(False, target_relation, compiled_code, model_language) }} - {%- endcall %} - {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, language=model_language) -%} {% do delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} @@ -49,13 +45,7 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% if model_language == "sql" %} - {% do run_query(create_table_as(True, tmp_relation, compiled_code, model_language)) %} - {% else %} - {% call statement('py_save_table', language=model_language) -%} - {{ create_table_as(False, target_relation, compiled_code, model_language) }} - {%- endcall %} - {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, language=model_language) -%} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} {% elif strategy == 'merge' and table_type == 'iceberg' %} @@ -80,13 +70,7 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% if model_language == "sql" %} - {% do run_query(create_table_as(True, tmp_relation, compiled_code, model_language)) %} - {% else %} - {% call statement('py_save_table', language=model_language) -%} - {{ create_table_as(True, target_relation, compiled_code, model_language) }} - {%- endcall %} - {% endif %} + {% set query_result = safe_create_table_as(True, tmp_relation, compiled_code, language=model_language) -%} {% set build_sql = iceberg_merge(on_schema_change, tmp_relation, target_relation, unique_key, incremental_predicates, existing_relation, delete_condition) %} {% do to_drop.append(tmp_relation) %} {% endif %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/merge.sql b/dbt/include/athena/macros/materializations/models/incremental/merge.sql index cd06cb03..adae970c 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/merge.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/merge.sql @@ -73,33 +73,68 @@ {%- endfor -%} {%- set update_columns = get_merge_update_columns(merge_update_columns, merge_exclude_columns, dest_columns_wo_keys) -%} {%- set src_cols_csv = src_columns_quoted | join(', ') -%} - merge into {{ target_relation }} as target using {{ tmp_relation }} as src - on ( - {%- for key in unique_key_cols %} - target.{{ key }} = src.{{ key }} {{ "and " if not loop.last }} - {%- endfor %} - ) - {% if incremental_predicates is not none -%} - and ( - {%- for inc_predicate in incremental_predicates %} - {{ inc_predicate }} {{ "and " if not loop.last }} - {%- endfor %} - ) - {%- endif %} - {% if delete_condition is not none -%} - when matched and ({{ delete_condition }}) - then delete - {%- endif %} - when matched - then update set - {%- for col in update_columns %} - {%- if merge_update_columns_rules and col.name in merge_update_columns_rules %} - {{ get_update_statement(col, merge_update_columns_rules[col.name], loop.last) }} - {%- else -%} - {{ get_update_statement(col, merge_update_columns_default_rule, loop.last) }} - {%- endif -%} - {%- endfor %} - when not matched - then insert ({{ dest_cols_csv }}) - values ({{ src_cols_csv }}); + + {%- set src_part -%} + merge into {{ target_relation }} as target using {{ tmp_relation }} as src + {%- endset -%} + + {%- set merge_part -%} + on ( + {%- for key in unique_key_cols -%} + target.{{ key }} = src.{{ key }} + {{ " and " if not loop.last }} + {%- endfor -%} + {% if incremental_predicates is not none -%} + and ( + {%- for inc_predicate in incremental_predicates %} + {{ inc_predicate }} {{ "and " if not loop.last }} + {%- endfor %} + ) + {%- endif %} + ) + {% if delete_condition is not none -%} + when matched and ({{ delete_condition }}) + then delete + {%- endif %} + {% if update_columns -%} + when matched + then update set + {%- for col in update_columns %} + {%- if merge_update_columns_rules and col.name in merge_update_columns_rules %} + {{ get_update_statement(col, merge_update_columns_rules[col.name], loop.last) }} + {%- else -%} + {{ get_update_statement(col, merge_update_columns_default_rule, loop.last) }} + {%- endif -%} + {%- endfor %} + {%- endif %} + when not matched + then insert ({{ dest_cols_csv }}) + values ({{ src_cols_csv }}) + {%- endset -%} + + {%- set merge_full -%} + {{ src_part }} + {{ merge_part }} + {%- endset -%} + + {%- set query_result = adapter.run_query_with_partitions_limit_catching(merge_full) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {% set partitions_batches = get_partition_batches(tmp_relation) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%} + {%- set src_batch_part -%} + merge into {{ target_relation }} as target + using (select * from {{ tmp_relation }} where {{ batch }}) as src + {%- endset -%} + {%- set merge_batch -%} + {{ src_batch_part }} + {{ merge_part }} + {%- endset -%} + {%- do run_query(merge_batch) -%} + {%- endfor -%} + {%- endif -%} + + SELECT '{{query_result}}' {%- endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql b/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql index 3d9f8137..ef641b70 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/on_schema_change.sql @@ -13,12 +13,32 @@ {%- set remove_from_target_arr = schema_changes_dict['target_not_in_source'] -%} {%- set new_target_types = schema_changes_dict['new_target_types'] -%} {% if table_type == 'iceberg' %} + {# + If last run of alter_column_type was failed on rename tmp column to origin. + Do rename to protect origin column from deletion and losing data. + #} + {% for remove_col in remove_from_target_arr if remove_col.column.endswith('__dbt_alter') %} + {%- set origin_col_name = remove_col.column | replace('__dbt_alter', '') -%} + {% for add_col in add_to_target_arr if add_col.column == origin_col_name %} + {%- do alter_relation_rename_column(target_relation, remove_col.name, add_col.name, add_col.data_type) -%} + {%- do remove_from_target_arr.remove(remove_col) -%} + {%- do add_to_target_arr.remove(add_col) -%} + {% endfor %} + {% endfor %} + {% if add_to_target_arr | length > 0 %} {%- do alter_relation_add_columns(target_relation, add_to_target_arr) -%} {% endif %} {% if remove_from_target_arr | length > 0 %} {%- do alter_relation_drop_columns(target_relation, remove_from_target_arr) -%} {% endif %} + {% if new_target_types != [] %} + {% for ntt in new_target_types %} + {% set column_name = ntt['column_name'] %} + {% set new_type = ntt['new_type'] %} + {% do alter_column_type(target_relation, column_name, new_type) %} + {% endfor %} + {% endif %} {% else %} {%- set replace_with_target_arr = remove_partitions_from_columns(schema_changes_dict['source_columns'], partitioned_by) -%} {% if add_to_target_arr | length > 0 or remove_from_target_arr | length > 0 or new_target_types | length > 0 %} @@ -35,3 +55,34 @@ {% endset %} {% do log(schema_change_message) %} {% endmacro %} + +{% macro athena__alter_column_type(relation, column_name, new_column_type) -%} + {# + 1. Create a new column (w/ temp name and correct type) + 2. Copy data over to it + 3. Drop the existing column + 4. Rename the new column to existing column + #} + {%- set tmp_column = column_name + '__dbt_alter' -%} + {%- set new_ddl_data_type = ddl_data_type(new_column_type) -%} + + {#- do alter_relation_add_columns(relation, [ tmp_column ]) -#} + {%- set add_column_query -%} + alter {{ relation.type }} {{ relation.render_pure() }} add columns({{ tmp_column }} {{ new_ddl_data_type }}); + {%- endset -%} + {%- do run_query(add_column_query) -%} + + {%- set update_query -%} + update {{ relation.render_pure() }} set {{ tmp_column }} = cast({{ column_name }} as {{ new_column_type }}); + {%- endset -%} + {%- do run_query(update_query) -%} + + {#- do alter_relation_drop_columns(relation, [ column_name ]) -#} + {%- set drop_column_query -%} + alter {{ relation.type }} {{ relation.render_pure() }} drop column {{ column_name }}; + {%- endset -%} + {%- do run_query(drop_column_query) -%} + + {%- do alter_relation_rename_column(relation, tmp_column, column_name, new_column_type) -%} + +{% endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql index 3cd8238b..d1b305e4 100644 --- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -1,7 +1,7 @@ -{% macro athena__create_table_as(temporary, relation, compiled_code, language='sql') -%} +{% macro athena__create_table_as(temporary, relation, compiled_code, skip_partitioning=False, language='sql') -%} {%- set materialized = config.get('materialized', default='table') -%} {%- set external_location = config.get('external_location', default=none) -%} - {%- set partitioned_by = config.get('partitioned_by', default=none) -%} + {%- set partitioned_by = config.get('partitioned_by', default=none) if not skip_partitioning else none -%} {%- set bucketed_by = config.get('bucketed_by', default=none) -%} {%- set bucket_count = config.get('bucket_count', default=none) -%} {%- set field_delimiter = config.get('field_delimiter', default=none) -%} @@ -24,7 +24,7 @@ {%- set contract_config = config.get('contract') -%} {%- if contract_config.enforced -%} - {{ get_assert_columns_equivalent(sql) }} + {{ get_assert_columns_equivalent(compiled_code) }} {%- endif -%} {%- if table_type == 'iceberg' -%} @@ -106,3 +106,79 @@ {% do exceptions.raise_compiler_error("athena__create_table_as macro doesn't support the provided language, it got %s" % language) %} {%- endif -%} {%- endmacro -%} + +{% macro create_table_as_with_partitions(temporary, relation, compiled_code, language='sql') -%} + + {%- set tmp_relation = api.Relation.create( + identifier=relation.identifier ~ '__tmp_not_partitioned', + schema=relation.schema, + database=relation.database, + s3_path_table_part=relation.identifier ~ '__tmp_not_partitioned' , + type='table' + ) + -%} + + {%- if tmp_relation is not none -%} + {%- do drop_relation(tmp_relation) -%} + {%- endif -%} + + {%- do log('CREATE NON-PARTIONED STAGING TABLE: ' ~ tmp_relation) -%} + {%- do run_query(create_table_as(temporary, tmp_relation, compiled_code, True, language=language)) -%} + + {% set partitions_batches = get_partition_batches(sql=tmp_relation, as_subquery=False) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + + {%- set dest_columns = adapter.get_columns_in_relation(tmp_relation) -%} + {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} + + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%} + + {%- if loop.index == 1 -%} + {%- set create_target_relation_sql -%} + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + {%- endset -%} + {%- do run_query(create_table_as(temporary, relation, create_target_relation_sql, language=language)) -%} + {%- else -%} + {%- set insert_batch_partitions_sql -%} + insert into {{ relation }} ({{ dest_cols_csv }}) + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + {%- endset -%} + + {%- do run_query(insert_batch_partitions_sql) -%} + {%- endif -%} + + + {%- endfor -%} + + {%- do drop_relation(tmp_relation) -%} + + select 'SUCCESSFULLY CREATED TABLE {{ relation }}' + +{%- endmacro %} + +{% macro safe_create_table_as(temporary, relation, compiled_code, language='sql') -%} + {%- if language != 'sql' -%} + {% call statement('py_save_table', language=language) -%} + {{ create_table_as(temporary, relation, compiled_code, language=language) }} + {%- endcall %} + {%- set compiled_code_result = relation ~ ' created with spark' -%} + {%- endif -%} + {%- if temporary -%} + {%- do run_query(create_table_as(temporary, relation, compiled_code, True, language=language)) -%} + {%- set compiled_code_result = relation ~ ' as temporary relation without partitioning created' -%} + {%- else -%} + {%- set compiled_code_result = adapter.run_query_with_partitions_limit_catching(create_table_as(temporary, relation, compiled_code, language=language)) -%} + {%- do log('COMPILED CODE RESULT: ' ~ compiled_code_result) -%} + {%- if compiled_code_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {%- do create_table_as_with_partitions(temporary, relation, compiled_code, language=language) -%} + {%- set compiled_code_result = relation ~ ' with many partitions created' -%} + {%- endif -%} + {%- endif -%} + + {{ return(compiled_code_result) }} +{%- endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt/include/athena/macros/materializations/models/table/table.sql index ff32f06e..8471a33e 100644 --- a/dbt/include/athena/macros/materializations/models/table/table.sql +++ b/dbt/include/athena/macros/materializations/models/table/table.sql @@ -50,28 +50,22 @@ {%- endif -%} -- create tmp table - {%- call statement('main', language=language) -%} - {{ create_table_as(False, tmp_relation, compiled_code, language) }} - {%- endcall %} + {%- set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language=language) -%} -- swap table - {%- set swap_table = adapter.swap_table(tmp_relation, - target_relation) -%} + {%- set swap_table = adapter.swap_table(tmp_relation, target_relation) -%} -- delete glue tmp table, do not use drop_relation, as it will remove data of the target table {%- do adapter.delete_from_glue_catalog(tmp_relation) -%} - {% do adapter.expire_glue_table_versions(target_relation, - versions_to_keep, - True) %} + {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, True) %} + {%- else -%} -- Here we are in the case of non-ha tables or ha tables but in case of full refresh. {%- if old_relation is not none -%} {{ drop_relation(old_relation) }} {%- endif -%} - {%- call statement('main', language=language) -%} - {{ create_table_as(False, target_relation, compiled_code, language) }} - {%- endcall %} + {%- set query_result = safe_create_table_as(False, target_relation, compiled_code, language=language) -%} {%- endif -%} {{ set_table_classification(target_relation) }} @@ -79,35 +73,42 @@ {%- else -%} {%- if old_relation is none -%} - {%- call statement('main', language=language) -%} - {{ create_table_as(False, target_relation, compiled_code, language) }} - {%- endcall %} + {%- set query_result = safe_create_table_as(False, target_relation, compiled_code, language=language) -%} {%- else -%} - {%- if tmp_relation is not none -%} - {%- do drop_relation(tmp_relation) -%} - {%- endif -%} - - {%- set old_relation_bkp = make_temp_relation(old_relation, '__bkp') -%} - -- If we have this, it means that at least the first renaming occurred but there was an issue - -- afterwards, therefore we are in weird state. The easiest and cleanest should be to remove - -- the backup relation. It won't have an impact because since we are in the else condition, - -- that means that old relation exists therefore no downtime yet. - {%- if old_relation_bkp is not none -%} - {%- do drop_relation(old_relation_bkp) -%} + {%- if old_relation.is_view -%} + {%- set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language=language) -%} + {%- do drop_relation(old_relation) -%} + {%- do rename_relation(tmp_relation, target_relation) -%} + {%- else -%} + + {%- if tmp_relation is not none -%} + {%- do drop_relation(tmp_relation) -%} + {%- endif -%} + + {%- set old_relation_bkp = make_temp_relation(old_relation, '__bkp') -%} + -- If we have this, it means that at least the first renaming occurred but there was an issue + -- afterwards, therefore we are in weird state. The easiest and cleanest should be to remove + -- the backup relation. It won't have an impact because since we are in the else condition, + -- that means that old relation exists therefore no downtime yet. + {%- if old_relation_bkp is not none -%} + {%- do drop_relation(old_relation_bkp) -%} + {%- endif -%} + + {% set query_result = safe_create_table_as(False, tmp_relation, compiled_code, language=language) %} + + {{ rename_relation(old_relation, old_relation_bkp) }} + {{ rename_relation(tmp_relation, target_relation) }} + + {{ drop_relation(old_relation_bkp) }} {%- endif -%} - - {%- call statement('main', language=language) -%} - {{ create_table_as(False, tmp_relation, compiled_code, language) }} - {%- endcall -%} - - {{ rename_relation(old_relation, old_relation_bkp) }} - {{ rename_relation(tmp_relation, target_relation) }} - - {{ drop_relation(old_relation_bkp) }} {%- endif -%} {%- endif -%} + {% call statement("main") %} + SELECT '{{ query_result }}'; + {% endcall %} + {{ run_hooks(post_hooks) }} {% if lf_tags_config is not none %} diff --git a/dbt/include/athena/macros/materializations/models/view/view.sql b/dbt/include/athena/macros/materializations/models/view/view.sql index 1cf3b337..3b1a4a89 100644 --- a/dbt/include/athena/macros/materializations/models/view/view.sql +++ b/dbt/include/athena/macros/materializations/models/view/view.sql @@ -2,6 +2,7 @@ {% set to_return = create_or_replace_view(run_outside_transaction_hooks=False) %} {% set target_relation = this.incorporate(type='view') %} + {% do persist_docs(target_relation, model) %} {% do return(to_return) %} {%- endmaterialization %} diff --git a/dbt/include/athena/macros/materializations/seeds/helpers.sql b/dbt/include/athena/macros/materializations/seeds/helpers.sql index 48582f6d..93e374bc 100644 --- a/dbt/include/athena/macros/materializations/seeds/helpers.sql +++ b/dbt/include/athena/macros/materializations/seeds/helpers.sql @@ -96,6 +96,7 @@ {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} {%- set s3_data_naming = config.get('s3_data_naming', target.s3_data_naming) -%} {%- set external_location = config.get('external_location', default=none) -%} + {%- set seed_s3_upload_args = config.get('seed_s3_upload_args', default=target.seed_s3_upload_args) -%} {%- set tmp_relation = api.Relation.create( identifier=identifier + "__dbt_tmp", @@ -110,6 +111,7 @@ s3_data_dir, s3_data_naming, external_location, + seed_s3_upload_args=seed_s3_upload_args ) -%} -- create target relation @@ -197,17 +199,6 @@ {%- set sql_table = create_csv_table_upload(model, agate_table) -%} {%- endif -%} - {%- set lf_tags_config = config.get('lf_tags_config') -%} - {%- set lf_grants = config.get('lf_grants') -%} - - {% if lf_tags_config is not none %} - {{ adapter.add_lf_tags(relation, lf_tags_config) }} - {% endif %} - - {% if lf_grants is not none %} - {{ adapter.apply_lf_grants(relation, lf_grants) }} - {% endif %} - {{ return(sql_table) }} {% endmacro %} diff --git a/dbt/include/athena/macros/materializations/snapshots/snapshot.sql b/dbt/include/athena/macros/materializations/snapshots/snapshot.sql index 208f3e9d..08d41c57 100644 --- a/dbt/include/athena/macros/materializations/snapshots/snapshot.sql +++ b/dbt/include/athena/macros/materializations/snapshots/snapshot.sql @@ -137,11 +137,6 @@ identifier=target_table, type='table') -%} - - {% if not adapter.check_schema_exists(model.database, model.schema) %} - {% do create_schema(model.database, model.schema) %} - {% endif %} - {%- if not target_relation.is_table -%} {% do exceptions.relation_wrong_type(target_relation, 'table') %} {%- endif -%} diff --git a/dbt/include/athena/macros/utils/ddl_dml_data_type.sql b/dbt/include/athena/macros/utils/ddl_dml_data_type.sql index a97c489c..ddd0fcb0 100644 --- a/dbt/include/athena/macros/utils/ddl_dml_data_type.sql +++ b/dbt/include/athena/macros/utils/ddl_dml_data_type.sql @@ -1,7 +1,9 @@ {# Athena has different types between DML and DDL #} {# ref: https://docs.aws.amazon.com/athena/latest/ug/data-types.html #} {% macro ddl_data_type(col_type) -%} - -- transform varchar + {%- set table_type = config.get('table_type', 'hive') -%} + + -- transform varchar {% set re = modules.re %} {% set data_type = re.sub('(?:varchar|character varying)(?:\(\d+\))?', 'string', col_type) %} @@ -15,6 +17,17 @@ {% set data_type = data_type.replace('integer', 'int') -%} {%- endif -%} + -- transform timestamp + {%- if table_type == 'iceberg' -%} + {%- if 'timestamp' in data_type -%} + {% set data_type = 'timestamp' -%} + {%- endif -%} + + {%- if 'binary' in data_type -%} + {% set data_type = 'binary' -%} + {%- endif -%} + {%- endif -%} + {{ return(data_type) }} {% endmacro %} diff --git a/dbt/include/athena/profile_template.yml b/dbt/include/athena/profile_template.yml index 87be52ed..72d94710 100644 --- a/dbt/include/athena/profile_template.yml +++ b/dbt/include/athena/profile_template.yml @@ -17,6 +17,9 @@ prompts: hint: Specify the database (Data catalog) to build models into (lowercase only) default: awsdatacatalog + seed_s3_upload_args: + hint: Specify any extra arguments to use in the S3 Upload, e.g. ACL, SSEKMSKeyId + threads: hint: '1 or more' type: 'int' diff --git a/dev-requirements.txt b/dev-requirements.txt index 54ef92f8..59d22a3c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,14 +1,15 @@ autoflake~=1.7 -black~=23.3 -dbt-tests-adapter~=1.5.1 -flake8~=5.0 +black~=23.9 +boto3-stubs[s3]~=1.28 +dbt-tests-adapter~=1.6.3 +flake8~=6.1 Flake8-pyproject~=1.2 isort~=5.11 -moto~=4.1.11 -pre-commit~=2.21 -pyparsing~=3.0.9 -pytest~=7.3 +moto~=4.2.4 +pre-commit~=3.4 +pyparsing~=3.1.1 +pytest~=7.4 pytest-cov~=4.1 pytest-dotenv~=0.5 pytest-xdist~=3.3 -pyupgrade~=3.3 +pyupgrade~=3.11 diff --git a/setup.py b/setup.py index dae9a695..bfaacbfa 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def _get_package_version() -> str: return f'{parts["major"]}.{parts["minor"]}.{parts["patch"]}' -dbt_version = "1.5" +dbt_version = "1.6" package_version = _get_package_version() description = "The athena adapter plugin for dbt (data build tool)" @@ -55,9 +55,9 @@ def _get_package_version() -> str: # In order to control dbt-core version and package version "boto3~=1.26", "boto3-stubs[athena,glue,lakeformation,sts]~=1.26", - "dbt-core~=1.5.0", + "dbt-core~=1.6.0", "pyathena>=2.25,<4.0", - "pydantic~=1.10", + "pydantic>=1.10,<3.0", "tenacity~=8.2", ], classifiers=[ @@ -66,10 +66,10 @@ def _get_package_version() -> str: "Operating System :: Microsoft :: Windows", "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ], + python_requires=">=3.8", ) diff --git a/tests/conftest.py b/tests/conftest.py index 3ca4d14a..dfb7eed4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ SPARK_WORKGROUP, ) -# Import the fuctional fixtures as a plugin +# Import the functional fixtures as a plugin # Note: fixtures with session scope need to be local pytest_plugins = ["dbt.tests.fixtures.project"] @@ -58,14 +58,17 @@ def dbt_debug_caplog() -> StringIO: def _setup_custom_caplog(name: str, level: EventLevel): capture_config = _get_stdout_config( - line_format=LineFormat.PlainText, level=level, use_colors=False, debug=True, log_cache_events=True, quiet=False + line_format=LineFormat.PlainText, + level=level, + use_colors=False, + log_cache_events=True, ) capture_config.name = name capture_config.filter = NoFilter - stringbuf = StringIO() - capture_config.output_stream = stringbuf + string_buf = StringIO() + capture_config.output_stream = string_buf EVENT_MANAGER.add_logger(capture_config) - return stringbuf + return string_buf @pytest.fixture(scope="class") diff --git a/tests/functional/adapter/fixture_split_parts.py b/tests/functional/adapter/fixture_split_parts.py new file mode 100644 index 00000000..c9ff3d7c --- /dev/null +++ b/tests/functional/adapter/fixture_split_parts.py @@ -0,0 +1,39 @@ +models__test_split_part_sql = """ +with data as ( + + select * from {{ ref('data_split_part') }} + +) + +select + {{ split_part('parts', 'split_on', 1) }} as actual, + result_1 as expected + +from data + +union all + +select + {{ split_part('parts', 'split_on', 2) }} as actual, + result_2 as expected + +from data + +union all + +select + {{ split_part('parts', 'split_on', 3) }} as actual, + result_3 as expected + +from data +""" + +models__test_split_part_yml = """ +version: 2 +models: + - name: test_split_part + tests: + - assert_equal: + actual: actual + expected: expected +""" diff --git a/tests/functional/adapter/test_change_relation_types.py b/tests/functional/adapter/test_change_relation_types.py new file mode 100644 index 00000000..047cac75 --- /dev/null +++ b/tests/functional/adapter/test_change_relation_types.py @@ -0,0 +1,26 @@ +import pytest + +from dbt.tests.adapter.relations.test_changing_relation_type import ( + BaseChangeRelationTypeValidator, +) + + +class TestChangeRelationTypesHive(BaseChangeRelationTypeValidator): + pass + + +class TestChangeRelationTypesIceberg(BaseChangeRelationTypeValidator): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + } + } + + def test_changing_materialization_changes_relation_type(self, project): + self._run_and_check_materialization("view") + self._run_and_check_materialization("table") + self._run_and_check_materialization("view") + # skip incremntal that doesn't work with Iceberg + self._run_and_check_materialization("table", extra_args=["--full-refresh"]) diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py new file mode 100644 index 00000000..9295f681 --- /dev/null +++ b/tests/functional/adapter/test_constraints.py @@ -0,0 +1,26 @@ +import pytest + +from dbt.tests.adapter.constraints.fixtures import ( + model_quoted_column_schema_yml, + my_model_with_quoted_column_name_sql, +) +from dbt.tests.adapter.constraints.test_constraints import BaseConstraintQuotedColumn + + +class TestAthenaConstraintQuotedColumn(BaseConstraintQuotedColumn): + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_with_quoted_column_name_sql, + "constraints_schema.yml": model_quoted_column_schema_yml.replace("text", "string"), + } + + @pytest.fixture(scope="class") + def expected_sql(self): + # FIXME: dbt-athena outputs a query about stats into `target/run/` directory. + # dbt-core expects the query to be a ddl statement to create a table. + # This is a workaround to pass the test for now. + + # NOTE: by the above reason, this test just checks the query can be executed without errors. + # The query itself is not checked. + return 'SELECT \'{"rowcount":1,"data_scanned_in_bytes":0}\';' diff --git a/tests/functional/adapter/test_incremental_iceberg.py b/tests/functional/adapter/test_incremental_iceberg.py index d2494599..ad4d21d1 100644 --- a/tests/functional/adapter/test_incremental_iceberg.py +++ b/tests/functional/adapter/test_incremental_iceberg.py @@ -6,6 +6,9 @@ import pytest +from dbt.tests.adapter.incremental.test_incremental_merge_exclude_columns import ( + BaseMergeExcludeColumns, +) from dbt.tests.adapter.incremental.test_incremental_predicates import ( BaseIncrementalPredicates, ) @@ -46,6 +49,42 @@ 3,anyway,purple """ +models__merge_exclude_all_columns_sql = """ +{{ config( + materialized = 'incremental', + unique_key = 'id', + incremental_strategy='merge', + merge_exclude_columns=['msg', 'color'] +) }} + +{% if not is_incremental() %} + +-- data for first invocation of model + +select 1 as id, 'hello' as msg, 'blue' as color +union all +select 2 as id, 'goodbye' as msg, 'red' as color + +{% else %} + +-- data for subsequent incremental update + +select 1 as id, 'hey' as msg, 'blue' as color +union all +select 2 as id, 'yo' as msg, 'green' as color +union all +select 3 as id, 'anyway' as msg, 'purple' as color + +{% endif %} +""" + + +seeds__expected_merge_exclude_all_columns_csv = """id,msg,color +1,hello,blue +2,goodbye,red +3,anyway,purple +""" + class TestIcebergIncrementalUniqueKey(BaseIncrementalUniqueKey): @pytest.fixture(scope="class") @@ -185,6 +224,36 @@ def test__incremental_predicates(self, project): self.check_scenario_correctness(expected_fields, test_case_fields, project) +class TestIcebergMergeExcludeColumns(BaseMergeExcludeColumns): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + } + } + + +class TestIcebergMergeExcludeAllColumns(BaseMergeExcludeColumns): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+incremental_strategy": "merge", + "+table_type": "iceberg", + } + } + + @pytest.fixture(scope="class") + def models(self): + return {"merge_exclude_columns.sql": models__merge_exclude_all_columns_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_exclude_columns.csv": seeds__expected_merge_exclude_all_columns_csv} + + def replace_cast_date(model: str) -> str: """Wrap all date strings with a cast date function""" diff --git a/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py b/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py new file mode 100644 index 00000000..039da2d8 --- /dev/null +++ b/tests/functional/adapter/test_incremental_iceberg_merge_no_updates.py @@ -0,0 +1,115 @@ +from collections import namedtuple + +import pytest + +from dbt.tests.util import check_relations_equal, run_dbt + +models__merge_no_updates_sql = """ +{{ config( + materialized = 'incremental', + unique_key = 'id', + incremental_strategy = 'merge', + merge_update_columns = ['id'], + table_type = 'iceberg', +) }} + +{% if not is_incremental() %} + +-- data for first invocation of model + +select 1 as id, 'hello' as msg, 'blue' as color +union all +select 2 as id, 'goodbye' as msg, 'red' as color + +{% else %} + +-- data for subsequent incremental update + +select 1 as id, 'hey' as msg, 'blue' as color +union all +select 2 as id, 'yo' as msg, 'green' as color +union all +select 3 as id, 'anyway' as msg, 'purple' as color + +{% endif %} +""" + +seeds__expected_merge_no_updates_csv = """id,msg,color +1,hello,blue +2,goodbye,red +3,anyway,purple +""" + +ResultHolder = namedtuple( + "ResultHolder", + [ + "seed_count", + "model_count", + "seed_rows", + "inc_test_model_count", + "relation", + ], +) + + +class TestIncrementalIcebergMergeNoUpdates: + @pytest.fixture(scope="class") + def models(self): + return {"merge_no_updates.sql": models__merge_no_updates_sql} + + @pytest.fixture(scope="class") + def seeds(self): + return {"expected_merge_no_updates.csv": seeds__expected_merge_no_updates_csv} + + def update_incremental_model(self, incremental_model): + """update incremental model after the seed table has been updated""" + model_result_set = run_dbt(["run", "--select", incremental_model]) + return len(model_result_set) + + def get_test_fields(self, project, seed, incremental_model, update_sql_file): + seed_count = len(run_dbt(["seed", "--select", seed, "--full-refresh"])) + + model_count = len(run_dbt(["run", "--select", incremental_model, "--full-refresh"])) + + relation = incremental_model + # update seed in anticipation of incremental model update + row_count_query = f"select * from {project.test_schema}.{seed}" + + seed_rows = len(project.run_sql(row_count_query, fetch="all")) + + # propagate seed state to incremental model according to unique keys + inc_test_model_count = self.update_incremental_model(incremental_model=incremental_model) + + return ResultHolder(seed_count, model_count, seed_rows, inc_test_model_count, relation) + + def check_scenario_correctness(self, expected_fields, test_case_fields, project): + """Invoke assertions to verify correct build functionality""" + # 1. test seed(s) should build afresh + assert expected_fields.seed_count == test_case_fields.seed_count + # 2. test model(s) should build afresh + assert expected_fields.model_count == test_case_fields.model_count + # 3. seeds should have intended row counts post update + assert expected_fields.seed_rows == test_case_fields.seed_rows + # 4. incremental test model(s) should be updated + assert expected_fields.inc_test_model_count == test_case_fields.inc_test_model_count + # 5. result table should match intended result set (itself a relation) + check_relations_equal(project.adapter, [expected_fields.relation, test_case_fields.relation]) + + def test__merge_no_updates(self, project): + """seed should match model after incremental run""" + + expected_fields = ResultHolder( + seed_count=1, + model_count=1, + inc_test_model_count=1, + seed_rows=3, + relation="expected_merge_no_updates", + ) + + test_case_fields = self.get_test_fields( + project, + seed="expected_merge_no_updates", + incremental_model="merge_no_updates", + update_sql_file=None, + ) + self.check_scenario_correctness(expected_fields, test_case_fields, project) diff --git a/tests/functional/adapter/test_partitions.py b/tests/functional/adapter/test_partitions.py new file mode 100644 index 00000000..f5f1e6d3 --- /dev/null +++ b/tests/functional/adapter/test_partitions.py @@ -0,0 +1,266 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +# this query generates 212 records +test_partitions_model_sql = """ +select + random() as rnd, + cast(date_column as date) as date_column, + doy(date_column) as doy +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +""" + +test_single_nullable_partition_model_sql = """ +with data as ( + select + random() as col_1, + row_number() over() as id + from + unnest(sequence(1, 200)) +) + +select + col_1, id +from data +union all +select random() as col_1, NULL as id +union all +select random() as col_1, NULL as id +""" + +test_nullable_partitions_model_sql = """ +{{ config( + materialized='table', + format='parquet', + s3_data_naming='table', + partitioned_by=['id', 'date_column'] +) }} + +with data as ( + select + random() as rnd, + row_number() over() as id, + cast(date_column as date) as date_column +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +) + +select + rnd, + case when id <= 50 then null else id end as id, + date_column +from data +union all +select + random() as rnd, + NULL as id, + NULL as date_column +union all +select + random() as rnd, + NULL as id, + cast('2023-09-02' as date) as date_column +union all +select + random() as rnd, + 40 as id, + NULL as date_column +""" + + +class TestHiveTablePartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"models": {"+table_type": "hive", "+materialized": "table", "+partitioned_by": ["date_column", "doy"]}} + + @pytest.fixture(scope="class") + def models(self): + return { + "test_hive_partitions.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + relation_name = "test_hive_partitions" + model_run_result_row_count_query = "select count(*) as records from {}.{}".format( + project.test_schema, relation_name + ) + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + +class TestIcebergTablePartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["DAY(date_column)", "doy"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_iceberg_partitions.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + relation_name = "test_iceberg_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + +class TestIcebergIncrementalPartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "incremental", + "+incremental_strategy": "merge", + "+unique_key": "doy", + "+partitioned_by": ["DAY(date_column)", "doy"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_iceberg_partitions_incremental.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + """ + Check that the incremental run works with iceberg and partitioned datasets + """ + + relation_name = "test_iceberg_partitions_incremental" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name, "--full-refresh"]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + incremental_model_run = run_dbt(["run", "--select", relation_name]) + + incremental_model_run_result = incremental_model_run.results[0] + + # check that the model run successfully after incremental run + assert incremental_model_run_result.status == RunStatus.Success + + incremental_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert incremental_records_count == 212 + + +class TestHiveNullValuedPartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "hive", + "+materialized": "table", + "+partitioned_by": ["id", "date_column"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_nullable_partitions_model.sql": test_nullable_partitions_model_sql, + } + + def test__check_run_with_partitions(self, project): + relation_name = "test_nullable_partitions_model" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + model_run_result_null_id_count_query = ( + f"select count(*) as records from {project.test_schema}.{relation_name} where id is null" + ) + model_run_result_null_date_count_query = ( + f"select count(*) as records from {project.test_schema}.{relation_name} where date_column is null" + ) + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 215 + + null_id_count_first_run = project.run_sql(model_run_result_null_id_count_query, fetch="all")[0][0] + + assert null_id_count_first_run == 52 + + null_date_count_first_run = project.run_sql(model_run_result_null_date_count_query, fetch="all")[0][0] + + assert null_date_count_first_run == 2 + + +class TestHiveSingleNullValuedPartition: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "hive", + "+materialized": "table", + "+partitioned_by": ["id"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_single_nullable_partition_model.sql": test_single_nullable_partition_model_sql, + } + + def test__check_run_with_partitions(self, project): + relation_name = "test_single_nullable_partition_model" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 202 diff --git a/tests/functional/adapter/utils/test_utils.py b/tests/functional/adapter/utils/test_utils.py index b998b6c7..83f743e9 100644 --- a/tests/functional/adapter/utils/test_utils.py +++ b/tests/functional/adapter/utils/test_utils.py @@ -3,6 +3,10 @@ models__test_datediff_sql, seeds__data_datediff_csv, ) +from tests.functional.adapter.fixture_split_parts import ( + models__test_split_part_sql, + models__test_split_part_yml, +) from dbt.tests.adapter.utils.fixture_datediff import models__test_datediff_yml from dbt.tests.adapter.utils.test_any_value import BaseAnyValue @@ -100,7 +104,12 @@ class TestRight(BaseRight): class TestSplitPart(BaseSplitPart): - pass + @pytest.fixture(scope="class") + def models(self): + return { + "test_split_part.yml": models__test_split_part_yml, + "test_split_part.sql": self.interpolate_macro_namespace(models__test_split_part_sql, "split_part"), + } class TestStringLiteral(BaseStringLiteral): diff --git a/tests/unit/constants.py b/tests/unit/constants.py index 51deab6f..4aff969a 100644 --- a/tests/unit/constants.py +++ b/tests/unit/constants.py @@ -1,6 +1,7 @@ CATALOG_ID = "12345678910" DATA_CATALOG_NAME = "awsdatacatalog" SHARED_DATA_CATALOG_NAME = "9876543210" +FEDERATED_QUERY_CATALOG_NAME = "federated_query_data_source" DATABASE_NAME = "test_dbt_athena" BUCKET = "test-dbt-athena" AWS_REGION = "eu-west-1" diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 30d39e4b..aaef3d94 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -4,7 +4,10 @@ import agate import boto3 +import botocore import pytest + +# from botocore.client.BaseClient import _make_api_call from moto import mock_athena, mock_glue, mock_s3, mock_sts from moto.core import DEFAULT_ACCOUNT_ID @@ -14,6 +17,7 @@ from dbt.adapters.athena.connections import AthenaCursor, AthenaParameterFormatter from dbt.adapters.athena.exceptions import S3LocationException from dbt.adapters.athena.relation import AthenaRelation, TableType +from dbt.adapters.athena.utils import AthenaCatalogType from dbt.clients import agate_helper from dbt.contracts.connection import ConnectionState from dbt.contracts.files import FileHash @@ -28,6 +32,7 @@ BUCKET, DATA_CATALOG_NAME, DATABASE_NAME, + FEDERATED_QUERY_CATALOG_NAME, S3_STAGING_DIR, SHARED_DATA_CATALOG_NAME, ) @@ -66,6 +71,7 @@ def setup_method(self, _): ("awsdatacatalog", "quux"), ("awsdatacatalog", "baz"), (SHARED_DATA_CATALOG_NAME, "foo"), + (FEDERATED_QUERY_CATALOG_NAME, "foo"), } self.mock_manifest.nodes = { "model.root.model1": CompiledNode( @@ -212,6 +218,42 @@ def setup_method(self, _): raw_code="select * from source_table", language="", ), + "model.root.model5": CompiledNode( + name="model5", + database=FEDERATED_QUERY_CATALOG_NAME, + schema="foo", + resource_type=NodeType.Model, + unique_id="model.root.model5", + alias="bar", + fqn=["root", "model5"], + package_name="root", + refs=[], + sources=[], + depends_on=DependsOn(), + config=NodeConfig.from_dict( + { + "enabled": True, + "materialized": "table", + "persist_docs": {}, + "post-hook": [], + "pre-hook": [], + "vars": {}, + "meta": {"owner": "data-engineers"}, + "quoting": {}, + "column_types": {}, + "tags": [], + } + ), + tags=[], + path="model5.sql", + original_file_path="model5.sql", + compiled=True, + extra_ctes_injected=False, + extra_ctes=[], + checksum=FileHash.from_contents(""), + raw_code="select * from source_table", + language="", + ), } @property @@ -359,6 +401,7 @@ def test_generate_s3_location( @mock_glue @mock_s3 @mock_athena + @mock_sts def test_get_table_location(self, dbt_debug_caplog, mock_aws_service): table_name = "test_table" self.adapter.acquire_connection("dummy") @@ -375,6 +418,7 @@ def test_get_table_location(self, dbt_debug_caplog, mock_aws_service): @mock_glue @mock_s3 @mock_athena + @mock_sts def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog, mock_aws_service): table_name = "test_table" self.adapter.acquire_connection("dummy") @@ -396,6 +440,7 @@ def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog, @mock_glue @mock_s3 @mock_athena + @mock_sts def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service): view_name = "view" self.adapter.acquire_connection("dummy") @@ -410,6 +455,7 @@ def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service): @mock_glue @mock_s3 @mock_athena + @mock_sts def test_get_table_location_with_failure(self, dbt_debug_caplog, mock_aws_service): table_name = "test_table" self.adapter.acquire_connection("dummy") @@ -458,6 +504,7 @@ def test_clean_up_partitions_will_work(self, dbt_debug_caplog, mock_aws_service) @mock_glue @mock_athena + @mock_sts def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -475,6 +522,7 @@ def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_se @mock_glue @mock_athena + @mock_sts def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -492,6 +540,7 @@ def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service): @mock_glue @mock_s3 @mock_athena + @mock_sts def test_clean_up_table_delete_table(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -612,9 +661,84 @@ def test__get_one_catalog_shared_catalog(self, mock_aws_service): for row in actual.rows.values(): assert row.values() in expected_rows + @mock_athena + def test__get_one_catalog_federated_query_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog( + catalog_name=FEDERATED_QUERY_CATALOG_NAME, catalog_type=AthenaCatalogType.LAMBDA + ) + mock_information_schema = mock.MagicMock() + mock_information_schema.path.database = FEDERATED_QUERY_CATALOG_NAME + + # Original botocore _make_api_call function + orig = botocore.client.BaseClient._make_api_call + + # Mocking this as list_table_metadata and creating non glue tables is not supported by moto. + # Followed this guide: http://docs.getmoto.org/en/latest/docs/services/patching_other_services.html + def mock_athena_list_table_metadata(self, operation_name, kwarg): + if operation_name == "ListTableMetadata": + return { + "TableMetadataList": [ + { + "Name": "bar", + "TableType": "EXTERNAL_TABLE", + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + ], + "PartitionKeys": [ + { + "Name": "dt", + "Type": "date", + }, + ], + } + ], + } + # If we don't want to patch the API call + return orig(self, operation_name, kwarg) + + self.adapter.acquire_connection("dummy") + with patch("botocore.client.BaseClient._make_api_call", new=mock_athena_list_table_metadata): + actual = self.adapter._get_one_catalog( + mock_information_schema, + { + "foo": {"bar"}, + }, + self.mock_manifest, + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + "table_owner", + ) + expected_rows = [ + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"), + ] + + assert actual.column_names == expected_column_names + assert len(actual.rows) == len(expected_rows) + for row in actual.rows.values(): + assert row.values() in expected_rows + def test__get_catalog_schemas(self): res = self.adapter._get_catalog_schemas(self.mock_manifest) - assert len(res.keys()) == 2 + assert len(res.keys()) == 3 information_schema_0 = list(res.keys())[0] assert information_schema_0.name == "INFORMATION_SCHEMA" @@ -632,6 +756,14 @@ def test__get_catalog_schemas(self): assert set(relations.keys()) == {"foo"} assert list(relations.values()) == [{"bar"}] + information_schema_1 = list(res.keys())[2] + assert information_schema_1.name == "INFORMATION_SCHEMA" + assert information_schema_1.schema is None + assert information_schema_1.database == FEDERATED_QUERY_CATALOG_NAME + relations = list(res.values())[1] + assert set(relations.keys()) == {"foo"} + assert list(relations.values()) == [{"bar"}] + @mock_athena @mock_sts def test__get_data_catalog(self, mock_aws_service): @@ -696,7 +828,7 @@ def test_list_relations_without_caching_with_non_glue_data_catalog( self, parent_list_relations_without_caching, mock_aws_service ): data_catalog_name = "other_data_catalog" - mock_aws_service.create_data_catalog(data_catalog_name, "HIVE") + mock_aws_service.create_data_catalog(data_catalog_name, AthenaCatalogType.HIVE) schema_relation = self.adapter.Relation.create( database=data_catalog_name, schema=DATABASE_NAME, @@ -719,6 +851,7 @@ def test_parse_s3_path(self, s3_path, expected): @mock_athena @mock_glue @mock_s3 + @mock_sts def test_swap_table_with_partitions(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -745,6 +878,7 @@ def test_swap_table_with_partitions(self, mock_aws_service): @mock_athena @mock_glue @mock_s3 + @mock_sts def test_swap_table_without_partitions(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -769,6 +903,7 @@ def test_swap_table_without_partitions(self, mock_aws_service): @mock_athena @mock_glue @mock_s3 + @mock_sts def test_swap_table_with_partitions_to_one_without(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -806,6 +941,7 @@ def test_swap_table_with_partitions_to_one_without(self, mock_aws_service): @mock_athena @mock_glue @mock_s3 + @mock_sts def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -865,6 +1001,7 @@ def test__get_glue_table_versions_to_expire(self, mock_aws_service, dbt_debug_ca @mock_athena @mock_glue @mock_s3 + @mock_sts def test_expire_glue_table_versions(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -976,6 +1113,7 @@ def test_get_work_group_output_location_not_enforced(self, mock_aws_service): @mock_athena @mock_glue @mock_s3 + @mock_sts def test_persist_docs_to_glue_no_comment(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1017,6 +1155,7 @@ def test_persist_docs_to_glue_no_comment(self, mock_aws_service): @mock_athena @mock_glue @mock_s3 + @mock_sts def test_persist_docs_to_glue_comment(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1069,6 +1208,7 @@ def test_list_schemas(self, mock_aws_service): @mock_athena @mock_glue + @mock_sts def test_get_columns_in_relation(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1089,6 +1229,7 @@ def test_get_columns_in_relation(self, mock_aws_service): @mock_athena @mock_glue + @mock_sts def test_get_columns_in_relation_not_found_table(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1104,6 +1245,7 @@ def test_get_columns_in_relation_not_found_table(self, mock_aws_service): @mock_athena @mock_glue + @mock_sts def test_delete_from_glue_catalog(self, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1117,6 +1259,7 @@ def test_delete_from_glue_catalog(self, mock_aws_service): @mock_athena @mock_glue + @mock_sts def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1133,6 +1276,7 @@ def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_a @mock_glue @mock_s3 @mock_athena + @mock_sts def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1147,6 +1291,7 @@ def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service): @mock_glue @mock_s3 @mock_athena + @mock_sts def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1161,6 +1306,7 @@ def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_servic @mock_glue @mock_s3 @mock_athena + @mock_sts def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() @@ -1175,6 +1321,7 @@ def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service): @mock_glue @mock_s3 @mock_athena + @mock_sts def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_data_catalog() mock_aws_service.create_database() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index aebcc8fe..54775d10 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,6 +1,5 @@ +import importlib.metadata from unittest.mock import Mock - -import pkg_resources import pytest from dbt.adapters.athena.config import AthenaSparkSessionConfig, get_boto3_config @@ -8,7 +7,7 @@ class TestConfig: def test_get_boto3_config(self): - pkg_resources.get_distribution = Mock(return_value=pkg_resources.Distribution(version="2.4.6")) + importlib.metadata.version = Mock(return_value="2.4.6") get_boto3_config.cache_clear() config = get_boto3_config() assert config._user_provided_options["user_agent_extra"] == "dbt-athena-community/2.4.6" diff --git a/tests/unit/test_connection_manager.py b/tests/unit/test_connection_manager.py index a0ede751..c37a4792 100644 --- a/tests/unit/test_connection_manager.py +++ b/tests/unit/test_connection_manager.py @@ -4,7 +4,7 @@ from pyathena.model import AthenaQueryExecution from dbt.adapters.athena import AthenaConnectionManager -from dbt.contracts.connection import AdapterResponse +from dbt.adapters.athena.connections import AthenaAdapterResponse class TestAthenaConnectionManager: @@ -19,11 +19,13 @@ def test_get_response(self, state, result): cursor = mock.MagicMock() cursor.rowcount = 1 cursor.state = state + cursor.data_scanned_in_bytes = 123 cm = AthenaConnectionManager(mock.MagicMock()) response = cm.get_response(cursor) - assert isinstance(response, AdapterResponse) + assert isinstance(response, AthenaAdapterResponse) assert response.code == result assert response.rows_affected == 1 + assert response.data_scanned_in_bytes == 123 def test_data_type_code_to_name(self): cm = AthenaConnectionManager(mock.MagicMock()) diff --git a/tests/unit/test_lakeformation.py b/tests/unit/test_lakeformation.py index d5025e2b..ab061c09 100644 --- a/tests/unit/test_lakeformation.py +++ b/tests/unit/test_lakeformation.py @@ -11,7 +11,7 @@ # get_resource_lf_tags class TestLfTagsManager: @pytest.mark.parametrize( - "response,columns,lf_tags,verb,expected", + "response,identifier,columns,lf_tags,verb,expected", [ pytest.param( { @@ -22,6 +22,7 @@ class TestLfTagsManager: } ] }, + "tbl_name", ["column1", "column2"], {"tag_key": "tag_value"}, "add", @@ -31,32 +32,45 @@ class TestLfTagsManager: ), pytest.param( {"Failures": []}, + "tbl_name", None, {"tag_key": "tag_value"}, "add", - "Success: add LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", id="add lf_tag", ), pytest.param( {"Failures": []}, None, + None, + {"tag_key": "tag_value"}, + "add", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena", + id="add lf_tag_to_database", + ), + pytest.param( + {"Failures": []}, + "tbl_name", + None, {"tag_key": "tag_value"}, "remove", - "Success: remove LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", + "Success: remove LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", id="remove lf_tag", ), pytest.param( {"Failures": []}, + "tbl_name", ["c1", "c2"], {"tag_key": "tag_value"}, "add", - "Success: add LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name for columns ['c1', 'c2']", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name for columns ['c1', 'c2']", id="lf_tag database table and columns", ), ], ) - def test__parse_lf_response(self, response, columns, lf_tags, verb, expected): - relation = AthenaRelation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_name") + def test__parse_lf_response(self, dbt_debug_caplog, response, identifier, columns, lf_tags, verb, expected): + relation = AthenaRelation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier=identifier) lf_client = boto3.client("lakeformation", region_name=AWS_REGION) manager = LfTagsManager(lf_client, relation, LfTagsConfig()) - assert manager._parse_lf_response(response, columns, lf_tags, verb) == expected + manager._parse_and_log_lf_response(response, columns, lf_tags, verb) + assert expected in dbt_debug_caplog.getvalue() diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 60ec02ed..097f6ce9 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -5,6 +5,7 @@ import agate import boto3 +from dbt.adapters.athena.utils import AthenaCatalogType from dbt.config.project import PartialProject from .constants import AWS_REGION, BUCKET, CATALOG_ID, DATA_CATALOG_NAME, DATABASE_NAME @@ -146,10 +147,18 @@ def _make_table_of(self, rows, column_types): class MockAWSService: def create_data_catalog( - self, catalog_name: str = DATA_CATALOG_NAME, catalog_type: str = "GLUE", catalog_id: str = CATALOG_ID + self, + catalog_name: str = DATA_CATALOG_NAME, + catalog_type: AthenaCatalogType = AthenaCatalogType.GLUE, + catalog_id: str = CATALOG_ID, ): athena = boto3.client("athena", region_name=AWS_REGION) - athena.create_data_catalog(Name=catalog_name, Type=catalog_type, Parameters={"catalog-id": catalog_id}) + parameters = {} + if catalog_type == AthenaCatalogType.GLUE: + parameters = {"catalog-id": catalog_id} + else: + parameters = {"catalog": catalog_name} + athena.create_data_catalog(Name=catalog_name, Type=catalog_type.value, Parameters=parameters) def create_database(self, name: str = DATABASE_NAME, catalog_id: str = CATALOG_ID): glue = boto3.client("glue", region_name=AWS_REGION)