Skip to content

Commit

Permalink
Merge branch 'main' into support-python-submissions
Browse files Browse the repository at this point in the history
  • Loading branch information
Personal authored and Personal committed Sep 26, 2023
2 parents d044c35 + 23cfe4b commit cad13db
Show file tree
Hide file tree
Showing 43 changed files with 1,533 additions and 226 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko @thenaturalist
* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ cython_debug/

# Project specific
test.py

# OS
.DS_Store
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
<p align="center">
<img src="https://raw.githubusercontent.com/dbt-athena/dbt-athena/main/static/images/dbt-athena-long.png" />
<a href="https://pypi.org/project/dbt-athena-community/"><img src="https://badge.fury.io/py/dbt-athena-community.svg" /></a>
<a target="_blank" href="https://pypi.org/project/dlt/" style="background:none">
<img src="https://img.shields.io/pypi/pyversions/dbt-athena-community">
</a>
<a href="https://pycqa.github.io/isort/"><img src="https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336" /></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a>
<a href="https://github.com/python/mypy"><img src="https://www.mypy-lang.org/static/mypy_badge.svg" /></a>
<a href="https://pepy.tech/project/dbt-athena-community"><img src="https://pepy.tech/badge/dbt-athena-community/month" /></a>
<a href="https://pepy.tech/project/dbt-athena-community"><img src="https://static.pepy.tech/badge/dbt-athena-community/month" /></a>
</p>

## Features

* Supports dbt version `1.5.*`
* Supports dbt version `1.6.*`
* Supports from Python
* Supports [seeds][seeds]
* Correctly detects views and their columns
* Supports [table materialization][table]
Expand Down Expand Up @@ -79,13 +83,16 @@ A dbt profile can be configured to run against AWS Athena using the following co
| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` |
| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` |
| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` |
| debug_query_state | Flag if debug message with Athena query state is needed | Optional | `false` |
| aws_access_key_id | Access key ID of the user performing requests. | Optional | `AKIAIOSFODNN7EXAMPLE` |
| aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` |
| aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` |
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| num_retries | Number of times to retry a failing query | Optional | `3` |
| spark_work_group | Identifier of Athena Spark workgroup | Optional | `my-spark-workgroup` |
| spark_threads | Number of spark sessions to create. Recommended to be same as threads. | Optional | `4` |
| seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` |
| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` |

**Example profiles.yml entry:**
```yaml
Expand All @@ -105,6 +112,8 @@ athena:
work_group: my-workgroup
spark_work_group: my-spark-workgroup
spark_threads: 4
seed_s3_upload_args:
ACL: bucket-owner-full-control
```
_Additional information_
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/athena/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.5.0"
version = "1.6.1"
6 changes: 2 additions & 4 deletions dbt/adapters/athena/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib.metadata
from functools import lru_cache
from typing import Any, Dict

import pkg_resources
from botocore import config

from dbt.adapters.athena.constants import (
Expand All @@ -14,9 +14,7 @@

@lru_cache()
def get_boto3_config() -> config.Config:
return config.Config(
user_agent_extra="dbt-athena-community/" + pkg_resources.get_distribution("dbt-athena-community").version
)
return config.Config(user_agent_extra="dbt-athena-community/" + importlib.metadata.version("dbt-athena-community"))


class AthenaSparkSessionConfig:
Expand Down
89 changes: 83 additions & 6 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import hashlib
import json
import re
import time
from concurrent.futures.thread import ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -33,6 +36,13 @@
from dbt.contracts.connection import AdapterResponse, Connection, ConnectionState
from dbt.exceptions import ConnectionError, DbtRuntimeError

logger = AdapterLogger("Athena")


@dataclass
class AthenaAdapterResponse(AdapterResponse):
data_scanned_in_bytes: Optional[int] = None


@dataclass
class AthenaCredentials(Credentials):
Expand All @@ -44,13 +54,17 @@ class AthenaCredentials(Credentials):
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
poll_interval: float = 1.0
debug_query_state: bool = False
_ALIASES = {"catalog": "database"}
num_retries: Optional[int] = 5
s3_data_dir: Optional[str] = None
s3_data_naming: Optional[str] = "schema_table_unique"
spark_work_group: Optional[str] = None
spark_threads: Optional[int] = DEFAULT_THREAD_COUNT
lf_tags: Optional[Dict[str, str]] = None
seed_s3_upload_args: Optional[Dict[str, Any]] = None
# Unfortunately we can not just use dict, must by Dict because we'll get the following error:
# Credentials in profile "athena", target "athena" invalid: Unable to create schema for 'dict'
lf_tags_database: Optional[Dict[str, str]] = None

@property
def type(self) -> str:
Expand All @@ -74,7 +88,9 @@ def _connection_keys(self) -> Tuple[str, ...]:
"endpoint_url",
"s3_data_dir",
"s3_data_naming",
"lf_tags",
"debug_query_state",
"seed_s3_upload_args",
"lf_tags_database",
"spark_work_group",
"spark_threads",
)
Expand All @@ -95,6 +111,32 @@ def _collect_result_set(self, query_id: str) -> AthenaResultSet:
retry_config=self._retry_config,
)

def _poll(self, query_id: str) -> AthenaQueryExecution:
try:
query_execution = self.__poll(query_id)
except KeyboardInterrupt as e:
if self._kill_on_interrupt:
logger.warning("Query canceled by user.")
self._cancel(query_id)
query_execution = self.__poll(query_id)
else:
raise e
return query_execution

def __poll(self, query_id: str) -> AthenaQueryExecution:
while True:
query_execution = self._get_query_execution(query_id)
if query_execution.state in [
AthenaQueryExecution.STATE_SUCCEEDED,
AthenaQueryExecution.STATE_FAILED,
AthenaQueryExecution.STATE_CANCELLED,
]:
return query_execution
else:
if self.connection.cursor_kwargs.get("debug_query_state", False):
logger.debug(f"Query state is: {query_execution.state}. Sleeping for {self._poll_interval}...")
time.sleep(self._poll_interval)

def execute( # type: ignore
self,
operation: str,
Expand All @@ -104,6 +146,7 @@ def execute( # type: ignore
endpoint_url: Optional[str] = None,
cache_size: int = 0,
cache_expiration_time: int = 0,
catch_partitions_limit: bool = False,
**kwargs,
):
def inner() -> AthenaCursor:
Expand All @@ -130,7 +173,12 @@ def inner() -> AthenaCursor:
return self

retry = tenacity.Retrying(
retry=retry_if_exception(lambda _: True),
# No need to retry if TOO_MANY_OPEN_PARTITIONS occurs.
# Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry,
# because not all files are removed immediately after first try to create table
retry=retry_if_exception(
lambda e: False if catch_partitions_limit and "TOO_MANY_OPEN_PARTITIONS" in str(e) else True
),
stop=stop_after_attempt(self._retry_config.attempt),
wait=wait_exponential(
multiplier=self._retry_config.attempt,
Expand Down Expand Up @@ -175,9 +223,11 @@ def open(cls, connection: Connection) -> Connection:
handle = AthenaConnection(
s3_staging_dir=creds.s3_staging_dir,
endpoint_url=creds.endpoint_url,
catalog_name=creds.database,
schema_name=creds.schema,
work_group=creds.work_group,
cursor_class=AthenaCursor,
cursor_kwargs={"debug_query_state": creds.debug_query_state},
formatter=AthenaParameterFormatter(),
poll_interval=creds.poll_interval,
session=get_boto3_session(connection),
Expand All @@ -200,12 +250,39 @@ def open(cls, connection: Connection) -> Connection:
return connection

@classmethod
def get_response(cls, cursor: AthenaCursor) -> AdapterResponse:
def get_response(cls, cursor: AthenaCursor) -> AthenaAdapterResponse:
code = "OK" if cursor.state == AthenaQueryExecution.STATE_SUCCEEDED else "ERROR"
return AdapterResponse(_message=f"{code} {cursor.rowcount}", rows_affected=cursor.rowcount, code=code)
rowcount, data_scanned_in_bytes = cls.process_query_stats(cursor)
return AthenaAdapterResponse(
_message=f"{code} {rowcount}",
rows_affected=rowcount,
code=code,
data_scanned_in_bytes=data_scanned_in_bytes,
)

@staticmethod
def process_query_stats(cursor: AthenaCursor) -> Tuple[int, int]:
"""
Helper function to parse query statistics from SELECT statements.
The function looks for all statements that contains rowcount or data_scanned_in_bytes,
then strip the SELECT statements, and pick the value between curly brackets.
"""
if all(map(cursor.query.__contains__, ["rowcount", "data_scanned_in_bytes"])):
try:
query_split = cursor.query.lower().split("select")[-1]
# query statistics are in the format {"rowcount":1, "data_scanned_in_bytes": 3}
# the following statement extract the content between { and }
query_stats = re.search("{(.*)}", query_split)
if query_stats:
stats = json.loads("{" + query_stats.group(1) + "}")
return stats.get("rowcount", -1), stats.get("data_scanned_in_bytes", 0)
except Exception as err:
logger.debug(f"There was an error parsing query stats {err}")
return -1, 0
return cursor.rowcount, cursor.data_scanned_in_bytes

def cancel(self, connection: Connection) -> None:
connection.handle.cancel()
pass

def add_begin_query(self) -> None:
pass
Expand Down
Loading

0 comments on commit cad13db

Please sign in to comment.