Skip to content

Commit

Permalink
get_column_schema_from_query macro (#6986)
Browse files Browse the repository at this point in the history
Add adapter.get_column_schema_from_query
  • Loading branch information
MichelleArk authored Mar 3, 2023
1 parent 72076b3 commit b681908
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 41 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230222-130632.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: get_column_schema_from_query_macro
time: 2023-02-22T13:06:32.583743-05:00
custom:
Author: jtcohen6 michelleark
Issue: "6751"
20 changes: 16 additions & 4 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Iterator,
Set,
)

import agate
import pytz

Expand All @@ -37,10 +38,7 @@
UnexpectedNonTimestampError,
)

from dbt.adapters.protocol import (
AdapterConfig,
ConnectionManagerProtocol,
)
from dbt.adapters.protocol import AdapterConfig, ConnectionManagerProtocol
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.manifest import Manifest, MacroManifest
Expand Down Expand Up @@ -176,6 +174,7 @@ class BaseAdapter(metaclass=AdapterMeta):
- truncate_relation
- rename_relation
- get_columns_in_relation
- get_column_schema_from_query
- expand_column_types
- list_relations_without_caching
- is_cancelable
Expand Down Expand Up @@ -268,6 +267,19 @@ def execute(
"""
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch)

@available.parse(lambda *a, **k: [])
def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:
"""Get a list of the Columns with names and data types from the given sql."""
_, cursor = self.connections.add_select_query(sql)
columns = [
self.Column.create(
column_name, self.connections.data_type_code_to_name(column_type_code)
)
# https://peps.python.org/pep-0249/#description
for column_name, column_type_code, *_ in cursor.description
]
return columns

@available.parse(lambda *a, **k: ("", empty_table()))
def get_partitions_metadata(self, table: str) -> Tuple[agate.Table]:
"""Obtain partitions metadata for a BigQuery partitioned table.
Expand Down
15 changes: 14 additions & 1 deletion core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import time
from typing import List, Optional, Tuple, Any, Iterable, Dict
from typing import List, Optional, Tuple, Any, Iterable, Dict, Union

import agate

Expand Down Expand Up @@ -52,6 +52,7 @@ def add_query(
bindings: Optional[Any] = None,
abridge_sql_log: bool = False,
) -> Tuple[Connection, Any]:

connection = self.get_thread_connection()
if auto_begin and connection.transaction_open is False:
self.begin()
Expand Down Expand Up @@ -128,6 +129,14 @@ def get_result_from_cursor(cls, cursor: Any) -> agate.Table:

return dbt.clients.agate_helper.table_from_data_flat(data, column_names)

@classmethod
def data_type_code_to_name(cls, type_code: Union[int, str]) -> str:
"""Get the string representation of the data type from the type_code."""
# https://peps.python.org/pep-0249/#type-objects
raise dbt.exceptions.NotImplementedError(
"`data_type_code_to_name` is not implemented for this adapter!"
)

def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[AdapterResponse, agate.Table]:
Expand All @@ -146,6 +155,10 @@ def add_begin_query(self):
def add_commit_query(self):
return self.add_query("COMMIT", auto_begin=False)

def add_select_query(self, sql: str) -> Tuple[Connection, Any]:
sql = self._add_query_comment(sql)
return self.add_query(sql, auto_begin=False)

def begin(self):
connection = self.get_thread_connection()
if connection.transaction_open is True:
Expand Down
Binary file modified core/dbt/docs/build/doctrees/environment.pickle
Binary file not shown.
46 changes: 39 additions & 7 deletions core/dbt/include/global_project/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,55 @@
{% endmacro %}


{% macro get_empty_subquery_sql(select_sql) -%}
{{ return(adapter.dispatch('get_empty_subquery_sql', 'dbt')(select_sql)) }}
{% endmacro %}

{#
Builds a query that results in the same schema as the given select_sql statement, without necessitating a data scan.
Useful for running a query in a 'pre-flight' context, such as model contract enforcement (assert_columns_equivalent macro).
#}
{% macro default__get_empty_subquery_sql(select_sql) %}
select * from (
{{ select_sql }}
) as __dbt_sbq
where false
limit 0
{% endmacro %}


{% macro get_empty_schema_sql(columns) -%}
{{ return(adapter.dispatch('get_empty_schema_sql', 'dbt')(columns)) }}
{% endmacro %}

{% macro default__get_empty_schema_sql(columns) %}
select
{% for i in columns %}
{%- set col = columns[i] -%}
cast(null as {{ col['data_type'] }}) as {{ col['name'] }}{{ ", " if not loop.last }}
{%- endfor -%}
{% endmacro %}

{% macro get_column_schema_from_query(select_sql) -%}
{% set columns = [] %}
{# -- Using an 'empty subquery' here to get the same schema as the given select_sql statement, without necessitating a data scan.#}
{% set sql = get_empty_subquery_sql(select_sql) %}
{% set column_schema = adapter.get_column_schema_from_query(sql) %}
{{ return(column_schema) }}
{% endmacro %}

-- here for back compat
{% macro get_columns_in_query(select_sql) -%}
{{ return(adapter.dispatch('get_columns_in_query', 'dbt')(select_sql)) }}
{% endmacro %}

{% macro default__get_columns_in_query(select_sql) %}
{% call statement('get_columns_in_query', fetch_result=True, auto_begin=False) -%}
select * from (
{{ select_sql }}
) as __dbt_sbq
where false
limit 0
{{ get_empty_subquery_sql(select_sql) }}
{% endcall %}

{{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }}
{% endmacro %}


{% macro alter_column_type(relation, column_name, new_column_type) -%}
{{ return(adapter.dispatch('alter_column_type', 'dbt')(relation, column_name, new_column_type)) }}
{% endmacro %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,34 @@
{{ return(assert_columns_equivalent(sql)) }}
{%- endmacro %}

{#
Compares the column schema provided by a model's sql file to the column schema provided by a model's schema file.
If any differences in name, data_type or order of columns exist between the two schemas, raises a compiler error
#}
{% macro assert_columns_equivalent(sql) %}
{#- loop through user_provided_columns to get column names -#}
{%- set user_provided_columns = model['columns'] -%}
{%- set column_names_config_only = [] -%}
{%- for i in user_provided_columns -%}
{%- set col = user_provided_columns[i] -%}
{%- set col_name = col['name'] -%}
{%- set column_names_config_only = column_names_config_only.append(col_name) -%}
{%- endfor -%}
{%- set sql_file_provided_columns = get_columns_in_query(sql) -%}

{#- uppercase both schema and sql file columns -#}
{%- set column_names_config_upper= column_names_config_only|map('upper')|join(',') -%}
{%- set column_names_config_formatted = column_names_config_upper.split(',') -%}
{%- set sql_file_provided_columns_upper = sql_file_provided_columns|map('upper')|join(',') -%}
{%- set sql_file_provided_columns_formatted = sql_file_provided_columns_upper.split(',') -%}

{%- if column_names_config_formatted != sql_file_provided_columns_formatted -%}
{%- do exceptions.raise_compiler_error('Please ensure the name, order, and number of columns in your `yml` file match the columns in your SQL file.\nSchema File Columns: ' ~ column_names_config_formatted ~ '\nSQL File Columns: ' ~ sql_file_provided_columns_formatted ~ ' ' ) %}
{%- endif -%}
{#-- Obtain the column schema provided by sql file. #}
{%- set sql_file_provided_columns = get_column_schema_from_query(sql) -%}
{#--Obtain the column schema provided by the schema file by generating an 'empty schema' query from the model's columns. #}
{%- set schema_file_provided_columns = get_column_schema_from_query(get_empty_schema_sql(model['columns'])) -%}

{%- set sql_file_provided_columns_formatted = format_columns(sql_file_provided_columns) -%}
{%- set schema_file_provided_columns_formatted = format_columns(schema_file_provided_columns) -%}

{%- if sql_file_provided_columns_formatted != schema_file_provided_columns_formatted -%}
{%- do exceptions.raise_compiler_error('Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file.\nSchema File Columns: ' ~ (schema_file_provided_columns_formatted|trim) ~ '\n\nSQL File Columns: ' ~ (sql_file_provided_columns_formatted|trim) ~ ' ' ) %}
{%- endif -%}

{% endmacro %}

{% macro format_columns(columns) %}
{% set formatted_columns = [] %}
{% for column in columns %}
{%- set formatted_column = adapter.dispatch('format_column', 'dbt')(column) -%}
{%- do formatted_columns.append(formatted_column) -%}
{% endfor %}
{{ return(formatted_columns|join(', ')) }}
{%- endmacro -%}

{% macro default__format_column(column) -%}
{{ return(column.column.lower() ~ " " ~ column.dtype) }}
{%- endmacro -%}
5 changes: 5 additions & 0 deletions plugins/postgres/dbt/adapters/postgres/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import contextmanager

import psycopg2
from psycopg2.extensions import string_types

import dbt.exceptions
from dbt.adapters.base import Credentials
Expand Down Expand Up @@ -190,3 +191,7 @@ def get_response(cls, cursor) -> AdapterResponse:
status_messsage_strings = [part for part in status_message_parts if not part.isdigit()]
code = " ".join(status_messsage_strings)
return AdapterResponse(_message=message, code=code, rows_affected=rows)

@classmethod
def data_type_code_to_name(cls, type_code: int) -> str:
return string_types[type_code].name
26 changes: 24 additions & 2 deletions tests/adapter/dbt/tests/adapter/constraints/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
}}
select
1 as color,
'blue' as id,
'blue' as color,
1 as id,
cast('2019-01-01' as date) as date_day
"""

Expand All @@ -37,6 +37,17 @@
cast('2019-01-01' as date) as date_day
"""

my_model_data_type_sql = """
{{{{
config(
materialized = "table"
)
}}}}
select
{sql_value} as wrong_data_type_column_name
"""

my_model_with_nulls_sql = """
{{
config(
Expand Down Expand Up @@ -117,3 +128,14 @@
- name: date_day
data_type: date
"""

model_data_type_schema_yml = """
version: 2
models:
- name: my_model_data_type
config:
contract: true
columns:
- name: wrong_data_type_column_name
data_type: {data_type}
"""
Loading

0 comments on commit b681908

Please sign in to comment.