Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement support bucket function for more than 100 partitions #549

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import os
import posixpath as path
import re
import struct
import tempfile
from dataclasses import dataclass
from datetime import date, datetime
from itertools import chain
from textwrap import dedent
from threading import Lock
Expand All @@ -12,6 +14,7 @@
from uuid import uuid4

import agate
import mmh3
from botocore.exceptions import ClientError
from mypy_boto3_athena.type_defs import DataCatalogTypeDef
from mypy_boto3_glue.type_defs import (
Expand Down Expand Up @@ -118,6 +121,7 @@ class AthenaConfig(AdapterConfig):
class AthenaAdapter(SQLAdapter):
BATCH_CREATE_PARTITION_API_LIMIT = 100
BATCH_DELETE_PARTITION_API_LIMIT = 25
INTEGER_MAX_VALUE_32_BIT_SIGNED = 0x7FFFFFFF

ConnectionManager = AthenaConnectionManager
Relation = AthenaRelation
Expand Down Expand Up @@ -1262,6 +1266,51 @@ def format_partition_keys(self, partition_keys: List[str]) -> str:

@available
def format_one_partition_key(self, partition_key: str) -> str:
"""Check if partition key uses Iceberg hidden partitioning"""
"""Check if partition key uses Iceberg hidden partitioning or bucket partitioning"""
hidden = re.search(r"^(hour|day|month|year)\((.+)\)", partition_key.lower())
return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})" if hidden else partition_key.lower()
bucket = re.search(r"bucket\((.+),", partition_key.lower())
if hidden:
return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})"
elif bucket:
return bucket.group(1)
else:
return partition_key.lower()

@available
def murmur3_hash(self, value: Any, num_buckets: int) -> int:
"""
Computes a hash for the given value using the MurmurHash3 algorithm and returns a bucket number.

This method was adopted from https://github.com/apache/iceberg-python/blob/main/pyiceberg/transforms.py#L240
"""
if isinstance(value, int): # int, long
hash_value = mmh3.hash(struct.pack("<q", value))
elif isinstance(value, (datetime, date)): # date, time, timestamp, timestampz
timestamp = int(value.timestamp()) if isinstance(value, datetime) else int(value.strftime("%s"))
hash_value = mmh3.hash(struct.pack("<q", timestamp))
elif isinstance(value, (str, bytes)): # string
hash_value = mmh3.hash(value)
else:
raise TypeError(f"Need to add support data type for hashing: {type(value)}")

return int((hash_value & self.INTEGER_MAX_VALUE_32_BIT_SIGNED) % num_buckets)

@available
def format_value_for_partition(self, value: Any, column_type: str) -> Tuple[str, str]:
"""Formats a value based on its column type for inclusion in a SQL query."""
comp_func = "=" # Default comparison function
if value is None:
return "null", " is "
elif column_type == "integer":
return str(value), comp_func
elif column_type == "string":
# Properly escape single quotes in the string value
escaped_value = str(value).replace("'", "''")
return f"'{escaped_value}'", comp_func
elif column_type == "date":
return f"DATE'{value}'", comp_func
elif column_type == "timestamp":
return f"TIMESTAMP'{value}'", comp_func
else:
# Raise an error for unsupported column types
raise ValueError(f"Unsupported column type: {column_type}")
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
{% macro get_partition_batches(sql, as_subquery=True) -%}
{# Retrieve partition configuration and set default partition limit #}
{%- set partitioned_by = config.get('partitioned_by') -%}
{%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%}
{%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%}
{% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %}

{# Retrieve distinct partitions from the given SQL #}
{% call statement('get_partitions', fetch_result=True) %}
{%- if as_subquery -%}
select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }};
Expand All @@ -12,48 +14,80 @@
{%- endif -%}
{% endcall %}

{# Initialize variables to store partition info #}
{%- set table = load_result('get_partitions').table -%}
{%- set rows = table.rows -%}
{%- set partitions = {} -%}
{% do log('TOTAL PARTITIONS TO PROCESS: ' ~ rows | length) %}
{%- set partitions_batches = [] -%}
{%- set ns = namespace(partitions = [], bucket_conditions = {}, bucket_numbers = [], bucket_column = None, is_bucketed = false) -%}

{# Process each partition row #}
{%- for row in rows -%}
{%- set single_partition = [] -%}
{%- for col in row -%}
{%- for col, partition_key in zip(row, partitioned_by) -%}
{# Determine the column type and check if it's a bucketed column #}
{%- set column_type = adapter.convert_type(table, loop.index0) -%}
{%- set bucket_match = modules.re.search('bucket\((.+),.+([0-9]+)\)', partition_key) -%}

{%- if bucket_match -%}
{# For bucketed columns, compute bucket numbers and conditions #}
{%- set ns.is_bucketed = true -%}
{%- set ns.bucket_column = bucket_match[1] -%}
{%- set bucket_num = adapter.murmur3_hash(col, bucket_match[2] | int) -%}
{%- set formatted_value, comp_func = adapter.format_value_for_partition(col, column_type) -%}

{%- if bucket_num not in ns.bucket_numbers %}
{%- do ns.bucket_numbers.append(bucket_num) %}
{%- do ns.bucket_conditions.update({bucket_num: [formatted_value]}) -%}
{%- elif formatted_value not in ns.bucket_conditions[bucket_num] %}
{%- do ns.bucket_conditions[bucket_num].append(formatted_value) -%}
{%- endif -%}
nicor88 marked this conversation as resolved.
Show resolved Hide resolved

{%- set column_type = adapter.convert_type(table, loop.index0) -%}
{%- set comp_func = '=' -%}
{%- if col is none -%}
{%- set value = 'null' -%}
{%- set comp_func = ' is ' -%}
{%- elif column_type == 'integer' or column_type is none -%}
{%- set value = col | string -%}
{%- elif column_type == 'string' -%}
{%- set value = "'" + col + "'" -%}
{%- elif column_type == 'date' -%}
{%- set value = "DATE'" + col | string + "'" -%}
{%- elif column_type == 'timestamp' -%}
{%- set value = "TIMESTAMP'" + col | string + "'" -%}
{%- else -%}
{%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%}
{# For non-bucketed columns, format partition key and value #}
{%- set value, comp_func = adapter.format_value_for_partition(col, column_type) -%}
{%- set partition_key_formatted = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%}
{%- do single_partition.append(partition_key_formatted + comp_func + value) -%}
{%- endif -%}
{%- set partition_key = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%}
{%- do single_partition.append(partition_key + comp_func + value) -%}
{%- endfor -%}

{# Concatenate conditions for a single partition #}
{%- set single_partition_expression = single_partition | join(' and ') -%}
{%- if single_partition_expression not in ns.partitions %}
{%- do ns.partitions.append(single_partition_expression) -%}
{%- endif -%}
{%- endfor -%}

{# Calculate total batches based on bucketing and partitioning #}
{%- if ns.is_bucketed -%}
{%- set total_batches = ns.partitions | length * ns.bucket_numbers | length -%}
{%- else -%}
{%- set total_batches = ns.partitions | length -%}
{%- endif -%}

{%- set batch_number = (loop.index0 / athena_partitions_limit) | int -%}
{% if not batch_number in partitions %}
{% do partitions.update({batch_number: []}) %}
{% endif %}
{# Determine the number of batches per partition limit #}
{%- set batches_per_partition_limit = (total_batches // athena_partitions_limit) + (total_batches % athena_partitions_limit > 0) -%}
{% do log('TOTAL PARTITIONS TO PROCESS: ' ~ total_batches) %}

{%- do partitions[batch_number].append('(' + single_partition_expression + ')') -%}
{%- if partitions[batch_number] | length == athena_partitions_limit or loop.last -%}
{%- do partitions_batches.append(partitions[batch_number] | join(' or ')) -%}
{# Create conditions for each batch #}
{%- set partitions_batches = [] -%}
{%- for i in range(batches_per_partition_limit) -%}
{%- set batch_conditions = [] -%}
{%- if ns.is_bucketed -%}
{# Combine partition and bucket conditions for each batch #}
{%- for partition_expression in ns.partitions -%}
{%- for bucket_num in ns.bucket_numbers -%}
{%- set bucket_condition = ns.bucket_column + " IN (" + ns.bucket_conditions[bucket_num] | join(", ") + ")" -%}
{%- set combined_condition = "(" + partition_expression + ' and ' + bucket_condition + ")" -%}
{%- do batch_conditions.append(combined_condition) -%}
{%- endfor -%}
{%- endfor -%}
{%- else -%}
{# Extend batch conditions with partitions for non-bucketed columns #}
{%- do batch_conditions.extend(ns.partitions) -%}
{%- endif -%}
{# Calculate batch start and end index and append batch conditions #}
{%- set start_index = i * athena_partitions_limit -%}
{%- set end_index = start_index + athena_partitions_limit -%}
{%- do partitions_batches.append(batch_conditions[start_index:end_index] | join(' or ')) -%}
{%- endfor -%}

{{ return(partitions_batches) }}
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _get_package_version() -> str:
"boto3~=1.26",
"boto3-stubs[athena,glue,lakeformation,sts]~=1.26",
"dbt-core~=1.7.0",
"mmh3~=4.0.1",
"pyathena>=2.25,<4.0",
"pydantic>=1.10,<3.0",
"tenacity~=8.2",
Expand Down
54 changes: 54 additions & 0 deletions tests/functional/adapter/test_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@
NULL as date_column
"""

test_bucket_partitions_sql = """
with non_random_strings as (
select
chr(cast(65 + (row_number() over () % 26) as bigint)) ||
chr(cast(65 + ((row_number() over () + 1) % 26) as bigint)) ||
chr(cast(65 + ((row_number() over () + 4) % 26) as bigint)) as non_random_str
from
(select 1 union all select 2 union all select 3) as temp_table
)
select
cast(date_column as date) as date_column,
doy(date_column) as doy,
nrnd.non_random_str
from (
values (
sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-24'), interval '1' day)
)
) as t1(date_array)
cross join unnest(date_array) as t2(date_column)
join non_random_strings nrnd on true
"""


class TestHiveTablePartitions:
@pytest.fixture(scope="class")
Expand Down Expand Up @@ -264,3 +286,35 @@ def test__check_run_with_partitions(self, project):
records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 202


class TestIcebergTablePartitionsBuckets:
nicor88 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+table_type": "iceberg",
"+materialized": "table",
"+partitioned_by": ["DAY(date_column)", "doy", "bucket(non_random_str, 5)"],
}
}

@pytest.fixture(scope="class")
def models(self):
return {
"test_bucket_partitions.sql": test_bucket_partitions_sql,
}

def test__check_incremental_run_with_bucket_in_partitions(self, project):
relation_name = "test_bucket_partitions"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"

first_model_run = run_dbt(["run", "--select", relation_name])
first_model_run_result = first_model_run.results[0]

# check that the model run successfully
assert first_model_run_result.status == RunStatus.Success

records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 615
70 changes: 70 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import decimal
from unittest import mock
from unittest.mock import patch
Expand Down Expand Up @@ -1442,6 +1443,75 @@ def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service):
def test__is_current_column(self, column, expected):
assert self.adapter._is_current_column(column) == expected

@pytest.mark.parametrize(
"partition_keys, expected_result",
[
(
["year(date_col)", "bucket(col_name, 10)", "default_partition_key"],
"date_trunc('year', date_col), col_name, default_partition_key",
),
],
)
def test_format_partition_keys(self, partition_keys, expected_result):
assert self.adapter.format_partition_keys(partition_keys) == expected_result

@pytest.mark.parametrize(
"partition_key, expected_result",
[
("month(hidden)", "date_trunc('month', hidden)"),
("bucket(bucket_col, 10)", "bucket_col"),
("regular_col", "regular_col"),
],
)
def test_format_one_partition_key(self, partition_key, expected_result):
assert self.adapter.format_one_partition_key(partition_key) == expected_result

def test_murmur3_hash_with_int(self):
bucket_number = self.adapter.murmur3_hash(123, 100)
assert isinstance(bucket_number, int)
assert 0 <= bucket_number < 100
assert bucket_number == 54

def test_murmur3_hash_with_datetime(self):
dt = datetime.datetime.now()
bucket_number = self.adapter.murmur3_hash(dt, 100)
assert isinstance(bucket_number, int)
assert 0 <= bucket_number < 100
nicor88 marked this conversation as resolved.
Show resolved Hide resolved

def test_murmur3_hash_with_str(self):
bucket_number = self.adapter.murmur3_hash("test_string", 100)
assert isinstance(bucket_number, int)
assert 0 <= bucket_number < 100
assert bucket_number == 88

def test_murmur3_hash_uniqueness(self):
# Ensuring different inputs produce different hashes
hash1 = self.adapter.murmur3_hash("string1", 100)
hash2 = self.adapter.murmur3_hash("string2", 100)
assert hash1 != hash2

def test_murmur3_hash_with_unsupported_type(self):
with pytest.raises(TypeError):
self.adapter.murmur3_hash([1, 2, 3], 100)

@pytest.mark.parametrize(
"value, column_type, expected_result",
[
(None, "integer", ("null", " is ")),
(42, "integer", ("42", "=")),
("O'Reilly", "string", ("'O''Reilly'", "=")),
("test", "string", ("'test'", "=")),
("2021-01-01", "date", ("DATE'2021-01-01'", "=")),
("2021-01-01 12:00:00", "timestamp", ("TIMESTAMP'2021-01-01 12:00:00'", "=")),
],
)
def test_format_value_for_partition(self, value, column_type, expected_result):
assert self.adapter.format_value_for_partition(value, column_type) == expected_result

def test_format_unsupported_type(self):
with pytest.raises(ValueError):
self.adapter.format_value_for_partition("test", "unsupported_type")


class TestAthenaFilterCatalog:
def test__catalog_filter_table(self):
Expand Down