Skip to content

Commit

Permalink
Use column_types config for seeds (#46)
Browse files Browse the repository at this point in the history
* Use adapter to convert agate types for seeds

* Print schema if node success when failure expected

* Load `column_types` when dry running seeds
  • Loading branch information
ccharlesgb authored Nov 15, 2023
1 parent 0b771b9 commit f6eca2b
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 47 deletions.
6 changes: 5 additions & 1 deletion dbt_dry_run/adapter/service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import os
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional

from dbt.adapters.base import BaseAdapter
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters
from dbt.config import RuntimeConfig
from dbt.contracts.connection import Connection
Expand Down Expand Up @@ -62,3 +62,7 @@ def get_dbt_manifest(self) -> Manifest:
@property
def threads(self) -> int:
return self._profile.threads

@property
def adapter(self) -> BaseAdapter:
return self._adapter
20 changes: 17 additions & 3 deletions dbt_dry_run/columns_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
STRUCT_SEPERATOR_LENGTH = len(STRUCT_SEPERATOR)


def _extract_fields(table_fields: List[TableField], prefix: str = "") -> List[str]:
def _extract_fields(
table_fields: List[TableField], prefix: str = ""
) -> List[Tuple[str, BigQueryFieldType]]:
field_names = []
for field in table_fields:
field_names.append(f"{prefix}{field.name}")
field_names.append((f"{prefix}{field.name}", field.type_))
if field.fields:
new_prefix = f"{prefix}{field.name}."
field_names.extend(_extract_fields(field.fields, prefix=new_prefix))
Expand All @@ -27,7 +29,19 @@ def expand_table_fields(table: Table) -> Set[str]:
Eg: TableField(name="a", fields=[TableField(name="a1")])
Returns: ["a", "a.a1"]
"""
return set(_extract_fields(table.fields))
name_type_pairs = _extract_fields(table.fields)
return set(name for name, _ in name_type_pairs)


def expand_table_fields_with_types(table: Table) -> Dict[str, BigQueryFieldType]:
"""
Expand table fields to dot notation (like in dbt metadata)
Eg: TableField(name="a", fields=[TableField(name="a1")])
Returns: ["a", "a.a1"]
"""
name_type_pairs = _extract_fields(table.fields)
return {name: type_ for name, type_ in name_type_pairs}


def _column_is_repeated(data_type: str) -> bool:
Expand Down
1 change: 1 addition & 0 deletions dbt_dry_run/models/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class NodeConfig(BaseModel):
partition_by: Optional[PartitionBy]
meta: Optional[NodeMeta]
full_refresh: Optional[bool]
column_types: Dict[str, str] = Field(default_factory=dict)


class ManifestColumn(BaseModel):
Expand Down
31 changes: 18 additions & 13 deletions dbt_dry_run/node_runner/seed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,15 @@
from typing import List, Optional

import agate as ag
from agate import data_types

from dbt_dry_run.exception import UnknownSchemaException
from dbt_dry_run.models import BigQueryFieldType, Table, TableField
from dbt_dry_run.models.manifest import Node
from dbt_dry_run.node_runner import NodeRunner
from dbt_dry_run.results import DryRunResult, DryRunStatus


class SeedRunner(NodeRunner):
DEFAULT_TYPE = BigQueryFieldType.STRING
TYPE_MAP = {
data_types.Text: BigQueryFieldType.STRING,
data_types.Number: BigQueryFieldType.FLOAT64,
data_types.Boolean: BigQueryFieldType.BOOLEAN,
data_types.Date: BigQueryFieldType.DATE,
data_types.DateTime: BigQueryFieldType.DATETIME,
}

def run(self, node: Node) -> DryRunResult:
if not node.root_path:
raise ValueError(f"Node {node.unique_id} does not have `root_path`")
Expand All @@ -28,9 +19,23 @@ def run(self, node: Node) -> DryRunResult:
csv_table = ag.Table.from_csv(f)

fields: List[TableField] = []
for column in csv_table.columns:
type_ = self.TYPE_MAP.get(column.data_type.__class__, self.DEFAULT_TYPE)
new_field = TableField(name=column.name, type=type_)
for idx, column in enumerate(csv_table.columns):
override_type = node.config.column_types.get(column.name)
new_type = override_type or self._sql_runner.convert_agate_type(
csv_table, idx
)
if new_type is None:
msg = f"Unknown Big Query schema for seed '{node.unique_id}' Column '{column.name}'"
exception = UnknownSchemaException(msg)
return DryRunResult(
node=node,
table=None,
status=DryRunStatus.FAILURE,
exception=exception,
)
new_field = TableField(
name=column.name, type=BigQueryFieldType[new_type.upper()]
)
fields.append(new_field)

schema = Table(fields=fields)
Expand Down
11 changes: 11 additions & 0 deletions dbt_dry_run/sql_runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from abc import ABCMeta, abstractmethod
from typing import Optional, Tuple

import agate

from dbt_dry_run.adapter.service import ProjectService
from dbt_dry_run.models import Table
from dbt_dry_run.models.manifest import Node
from dbt_dry_run.results import DryRunStatus
Expand All @@ -11,6 +14,9 @@ class SQLRunner(metaclass=ABCMeta):
Used to adapt to multiple warehouse backends
"""

def __init__(self, project: ProjectService):
self._project = project

@abstractmethod
def node_exists(self, node: Node) -> bool:
...
Expand All @@ -24,3 +30,8 @@ def query(
self, sql: str
) -> Tuple[DryRunStatus, Optional[Table], Optional[Exception]]:
...

def convert_agate_type(
self, agate_table: agate.Table, col_idx: int
) -> Optional[str]:
return self._project.adapter.convert_agate_type(agate_table, col_idx)
2 changes: 0 additions & 2 deletions dbt_dry_run/sql_runner/big_query_sql_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from contextlib import contextmanager
from typing import List, Optional, Tuple

from google.cloud.bigquery import (
Client,
DatasetReference,
QueryJob,
QueryJobConfig,
SchemaField,
TableReference,
Expand Down
52 changes: 39 additions & 13 deletions dbt_dry_run/test/node_runner/test_seed_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pathlib import Path
from typing import Set
from typing import Optional, Set
from unittest.mock import MagicMock

from dbt_dry_run import flags
from dbt_dry_run.exception import NotCompiledException
from dbt_dry_run.exception import NotCompiledException, UnknownSchemaException
from dbt_dry_run.flags import Flags
from dbt_dry_run.models import BigQueryFieldType
from dbt_dry_run.models.manifest import Node
Expand All @@ -13,17 +13,22 @@
from dbt_dry_run.test.utils import SimpleNode


def get_result(node: Node, column_type: Optional[str] = "string") -> DryRunResult:
mock_sql_runner = MagicMock()
mock_sql_runner.convert_agate_type.return_value = column_type
seed_runner = SeedRunner(mock_sql_runner, MagicMock())
return seed_runner.run(node)


def assert_success_and_columns_equal(
node: Node, expected_columns: Set[str]
node: Node, expected_columns: Set[str], column_type: Optional[str] = "string"
) -> DryRunResult:
seed_runner = SeedRunner(MagicMock(), MagicMock())
result: DryRunResult = seed_runner.run(node)
result = get_result(node, column_type)
assert result.status == DryRunStatus.SUCCESS
assert result.table
fields = result.table.fields
field_names: Set[str] = set(f.name for f in fields)
assert field_names == expected_columns

return result


Expand All @@ -45,11 +50,11 @@ def test_seed_runner_loads_file(tmp_path: Path) -> None:
assert_success_and_columns_equal(node, expected_columns)


def test_seed_runner_infers_dates(tmp_path: Path) -> None:
def test_seed_runner_fails_if_type_returns_none(tmp_path: Path) -> None:
p = tmp_path / "seed1.csv"
csv_content = """a,b,c
foo,bar,2021-01-01
foo2,bar2,2021-01-01
foo,bar,baz
foo2,bar2,baz2
"""
p.write_text(csv_content)

Expand All @@ -59,11 +64,32 @@ def test_seed_runner_infers_dates(tmp_path: Path) -> None:
resource_type=ManifestScheduler.SEED,
original_file_path=p.as_posix(),
).to_node()
expected_columns = set(csv_content.splitlines()[0].split(","))
result = assert_success_and_columns_equal(node, expected_columns)
result = get_result(node, None)
assert result.status == DryRunStatus.FAILURE
assert type(result.exception) == UnknownSchemaException

assert result.table
assert result.table.fields[2].type_ == BigQueryFieldType.DATE

def test_seed_runner_uses_column_overrides(tmp_path: Path) -> None:
p = tmp_path / "seed1.csv"
csv_content = """a,b,c
foo,bar,baz
foo2,bar2,baz2
"""
p.write_text(csv_content)

node = SimpleNode(
unique_id="node1",
depends_on=[],
resource_type=ManifestScheduler.SEED,
original_file_path=p.as_posix(),
).to_node()
node.config.column_types = {"a": "NUMERIC"}
result = get_result(node, "STRING")

expected_fields = {"a": "NUMERIC", "b": "STRING", "c": "STRING"}
assert result.table, "Expected result to have a table"
mapped_fields = {field.name: field.type_ for field in result.table.fields}
assert mapped_fields == expected_fields


def test_validate_node_returns_none_if_node_is_not_compiled() -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
seed_number,seed_string
1,hello
2,world
13 changes: 13 additions & 0 deletions integration/projects/test_models_are_executed/seeds/meta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
seeds:
- name: my_seed
config:
column_types:
seed_a: STRING
seed_b: FLOAT64
seed_c: BIGNUMERIC

- name: badly_configured_seed
config:
column_types:
seed_number: NUMERIC
seed_string: NUMERIC
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
a,seed_b
hello,1
world,2
a,seed_b,seed_c
hello,1,1.2
world,2,2.4
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import pytest

from dbt_dry_run.columns_metadata import expand_table_fields_with_types
from dbt_dry_run.models import BigQueryFieldType
from dbt_dry_run.results import DryRunStatus
from integration.conftest import DryRunResult
from integration.utils import (
assert_report_success,
get_report_node_by_id,
assert_report_node_has_columns,
assert_report_produced,
)


def test_success(dry_run_result: DryRunResult):
assert_report_success(dry_run_result)


def test_ran_correct_number_of_nodes(dry_run_result: DryRunResult):
report = assert_report_success(dry_run_result)
assert report.node_count == 4
report = assert_report_produced(dry_run_result)
assert report.node_count == 5


def test_table_of_nodes_is_returned(dry_run_result: DryRunResult):
report = assert_report_success(dry_run_result)
report = assert_report_produced(dry_run_result)
seed_node = get_report_node_by_id(report, "seed.test_models_are_executed.my_seed")
assert_report_node_has_columns(seed_node, {"a", "seed_b"})
columns = expand_table_fields_with_types(seed_node.table)
assert columns == {
"a": BigQueryFieldType.STRING,
"seed_b": BigQueryFieldType.FLOAT64,
"seed_c": BigQueryFieldType.BIGNUMERIC,
}

first_layer = get_report_node_by_id(
report, "model.test_models_are_executed.first_layer"
Expand All @@ -28,16 +34,27 @@ def test_table_of_nodes_is_returned(dry_run_result: DryRunResult):
second_layer = get_report_node_by_id(
report, "model.test_models_are_executed.second_layer"
)
assert_report_node_has_columns(second_layer, {"a", "b", "c", "seed_b"})
assert_report_node_has_columns(second_layer, {"a", "b", "c", "seed_b", "seed_c"})


def test_disabled_model_not_run(dry_run_result: DryRunResult):
report = assert_report_success(dry_run_result)
report = assert_report_produced(dry_run_result)
assert "model.test_models_are_executed.disabled_model" not in set(
n.unique_id for n in report.nodes
), "Found disabled model in dry run output"


@pytest.mark.xfail(
reason="Seed type compatibility not checked. (Trying to convert string to number)"
)
def test_badly_configured_seed_fails(dry_run_result: DryRunResult):
report = assert_report_produced(dry_run_result)
seed_node = get_report_node_by_id(
report, "seed.test_models_are_executed.badly_configured_seed"
)
assert seed_node.status == DryRunStatus.FAILURE


def test_model_with_all_column_types_succeeds(dry_run_result: DryRunResult):
node = get_report_node_by_id(
dry_run_result.report,
Expand Down
2 changes: 1 addition & 1 deletion integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def assert_node_failed_with_error(report: Report, unique_id: str, error: str) ->
node = get_report_node_by_id(report, unique_id)
assert (
not node.success
), f"Expected node {node.unique_id} to fail but it was successful"
), f"Expected node {node.unique_id} to fail but it was successful. Schema {node.table}"
assert (
node.error_message == error
), f"Node failed but error message '{node.error_message}' did not match expected: '{error}'"
Expand Down

0 comments on commit f6eca2b

Please sign in to comment.