diff --git a/dbt_dry_run/adapter/service.py b/dbt_dry_run/adapter/service.py index 916b9e8..26c33ec 100644 --- a/dbt_dry_run/adapter/service.py +++ b/dbt_dry_run/adapter/service.py @@ -1,9 +1,9 @@ -import json import os from argparse import Namespace from dataclasses import asdict, dataclass, field from typing import Any, Dict, Optional +from dbt.adapters.base import BaseAdapter from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters from dbt.config import RuntimeConfig from dbt.contracts.connection import Connection @@ -62,3 +62,7 @@ def get_dbt_manifest(self) -> Manifest: @property def threads(self) -> int: return self._profile.threads + + @property + def adapter(self) -> BaseAdapter: + return self._adapter diff --git a/dbt_dry_run/columns_metadata.py b/dbt_dry_run/columns_metadata.py index 4d3ca89..5d15860 100644 --- a/dbt_dry_run/columns_metadata.py +++ b/dbt_dry_run/columns_metadata.py @@ -10,10 +10,12 @@ STRUCT_SEPERATOR_LENGTH = len(STRUCT_SEPERATOR) -def _extract_fields(table_fields: List[TableField], prefix: str = "") -> List[str]: +def _extract_fields( + table_fields: List[TableField], prefix: str = "" +) -> List[Tuple[str, BigQueryFieldType]]: field_names = [] for field in table_fields: - field_names.append(f"{prefix}{field.name}") + field_names.append((f"{prefix}{field.name}", field.type_)) if field.fields: new_prefix = f"{prefix}{field.name}." field_names.extend(_extract_fields(field.fields, prefix=new_prefix)) @@ -27,7 +29,19 @@ def expand_table_fields(table: Table) -> Set[str]: Eg: TableField(name="a", fields=[TableField(name="a1")]) Returns: ["a", "a.a1"] """ - return set(_extract_fields(table.fields)) + name_type_pairs = _extract_fields(table.fields) + return set(name for name, _ in name_type_pairs) + + +def expand_table_fields_with_types(table: Table) -> Dict[str, BigQueryFieldType]: + """ + Expand table fields to dot notation (like in dbt metadata) + + Eg: TableField(name="a", fields=[TableField(name="a1")]) + Returns: ["a", "a.a1"] + """ + name_type_pairs = _extract_fields(table.fields) + return {name: type_ for name, type_ in name_type_pairs} def _column_is_repeated(data_type: str) -> bool: diff --git a/dbt_dry_run/models/manifest.py b/dbt_dry_run/models/manifest.py index f749d10..bb4c83a 100644 --- a/dbt_dry_run/models/manifest.py +++ b/dbt_dry_run/models/manifest.py @@ -69,6 +69,7 @@ class NodeConfig(BaseModel): partition_by: Optional[PartitionBy] meta: Optional[NodeMeta] full_refresh: Optional[bool] + column_types: Dict[str, str] = Field(default_factory=dict) class ManifestColumn(BaseModel): diff --git a/dbt_dry_run/node_runner/seed_runner.py b/dbt_dry_run/node_runner/seed_runner.py index a9bce91..d1fd2ec 100644 --- a/dbt_dry_run/node_runner/seed_runner.py +++ b/dbt_dry_run/node_runner/seed_runner.py @@ -2,8 +2,8 @@ from typing import List, Optional import agate as ag -from agate import data_types +from dbt_dry_run.exception import UnknownSchemaException from dbt_dry_run.models import BigQueryFieldType, Table, TableField from dbt_dry_run.models.manifest import Node from dbt_dry_run.node_runner import NodeRunner @@ -11,15 +11,6 @@ class SeedRunner(NodeRunner): - DEFAULT_TYPE = BigQueryFieldType.STRING - TYPE_MAP = { - data_types.Text: BigQueryFieldType.STRING, - data_types.Number: BigQueryFieldType.FLOAT64, - data_types.Boolean: BigQueryFieldType.BOOLEAN, - data_types.Date: BigQueryFieldType.DATE, - data_types.DateTime: BigQueryFieldType.DATETIME, - } - def run(self, node: Node) -> DryRunResult: if not node.root_path: raise ValueError(f"Node {node.unique_id} does not have `root_path`") @@ -28,9 +19,23 @@ def run(self, node: Node) -> DryRunResult: csv_table = ag.Table.from_csv(f) fields: List[TableField] = [] - for column in csv_table.columns: - type_ = self.TYPE_MAP.get(column.data_type.__class__, self.DEFAULT_TYPE) - new_field = TableField(name=column.name, type=type_) + for idx, column in enumerate(csv_table.columns): + override_type = node.config.column_types.get(column.name) + new_type = override_type or self._sql_runner.convert_agate_type( + csv_table, idx + ) + if new_type is None: + msg = f"Unknown Big Query schema for seed '{node.unique_id}' Column '{column.name}'" + exception = UnknownSchemaException(msg) + return DryRunResult( + node=node, + table=None, + status=DryRunStatus.FAILURE, + exception=exception, + ) + new_field = TableField( + name=column.name, type=BigQueryFieldType[new_type.upper()] + ) fields.append(new_field) schema = Table(fields=fields) diff --git a/dbt_dry_run/sql_runner/__init__.py b/dbt_dry_run/sql_runner/__init__.py index 36fdef2..cd875c5 100644 --- a/dbt_dry_run/sql_runner/__init__.py +++ b/dbt_dry_run/sql_runner/__init__.py @@ -1,6 +1,9 @@ from abc import ABCMeta, abstractmethod from typing import Optional, Tuple +import agate + +from dbt_dry_run.adapter.service import ProjectService from dbt_dry_run.models import Table from dbt_dry_run.models.manifest import Node from dbt_dry_run.results import DryRunStatus @@ -11,6 +14,9 @@ class SQLRunner(metaclass=ABCMeta): Used to adapt to multiple warehouse backends """ + def __init__(self, project: ProjectService): + self._project = project + @abstractmethod def node_exists(self, node: Node) -> bool: ... @@ -24,3 +30,8 @@ def query( self, sql: str ) -> Tuple[DryRunStatus, Optional[Table], Optional[Exception]]: ... + + def convert_agate_type( + self, agate_table: agate.Table, col_idx: int + ) -> Optional[str]: + return self._project.adapter.convert_agate_type(agate_table, col_idx) diff --git a/dbt_dry_run/sql_runner/big_query_sql_runner.py b/dbt_dry_run/sql_runner/big_query_sql_runner.py index d6c7012..7f98aaf 100644 --- a/dbt_dry_run/sql_runner/big_query_sql_runner.py +++ b/dbt_dry_run/sql_runner/big_query_sql_runner.py @@ -1,10 +1,8 @@ -from contextlib import contextmanager from typing import List, Optional, Tuple from google.cloud.bigquery import ( Client, DatasetReference, - QueryJob, QueryJobConfig, SchemaField, TableReference, diff --git a/dbt_dry_run/test/node_runner/test_seed_runner.py b/dbt_dry_run/test/node_runner/test_seed_runner.py index cb0b3e9..2b4cc92 100644 --- a/dbt_dry_run/test/node_runner/test_seed_runner.py +++ b/dbt_dry_run/test/node_runner/test_seed_runner.py @@ -1,9 +1,9 @@ from pathlib import Path -from typing import Set +from typing import Optional, Set from unittest.mock import MagicMock from dbt_dry_run import flags -from dbt_dry_run.exception import NotCompiledException +from dbt_dry_run.exception import NotCompiledException, UnknownSchemaException from dbt_dry_run.flags import Flags from dbt_dry_run.models import BigQueryFieldType from dbt_dry_run.models.manifest import Node @@ -13,17 +13,22 @@ from dbt_dry_run.test.utils import SimpleNode +def get_result(node: Node, column_type: Optional[str] = "string") -> DryRunResult: + mock_sql_runner = MagicMock() + mock_sql_runner.convert_agate_type.return_value = column_type + seed_runner = SeedRunner(mock_sql_runner, MagicMock()) + return seed_runner.run(node) + + def assert_success_and_columns_equal( - node: Node, expected_columns: Set[str] + node: Node, expected_columns: Set[str], column_type: Optional[str] = "string" ) -> DryRunResult: - seed_runner = SeedRunner(MagicMock(), MagicMock()) - result: DryRunResult = seed_runner.run(node) + result = get_result(node, column_type) assert result.status == DryRunStatus.SUCCESS assert result.table fields = result.table.fields field_names: Set[str] = set(f.name for f in fields) assert field_names == expected_columns - return result @@ -45,11 +50,11 @@ def test_seed_runner_loads_file(tmp_path: Path) -> None: assert_success_and_columns_equal(node, expected_columns) -def test_seed_runner_infers_dates(tmp_path: Path) -> None: +def test_seed_runner_fails_if_type_returns_none(tmp_path: Path) -> None: p = tmp_path / "seed1.csv" csv_content = """a,b,c - foo,bar,2021-01-01 - foo2,bar2,2021-01-01 + foo,bar,baz + foo2,bar2,baz2 """ p.write_text(csv_content) @@ -59,11 +64,32 @@ def test_seed_runner_infers_dates(tmp_path: Path) -> None: resource_type=ManifestScheduler.SEED, original_file_path=p.as_posix(), ).to_node() - expected_columns = set(csv_content.splitlines()[0].split(",")) - result = assert_success_and_columns_equal(node, expected_columns) + result = get_result(node, None) + assert result.status == DryRunStatus.FAILURE + assert type(result.exception) == UnknownSchemaException - assert result.table - assert result.table.fields[2].type_ == BigQueryFieldType.DATE + +def test_seed_runner_uses_column_overrides(tmp_path: Path) -> None: + p = tmp_path / "seed1.csv" + csv_content = """a,b,c + foo,bar,baz + foo2,bar2,baz2 + """ + p.write_text(csv_content) + + node = SimpleNode( + unique_id="node1", + depends_on=[], + resource_type=ManifestScheduler.SEED, + original_file_path=p.as_posix(), + ).to_node() + node.config.column_types = {"a": "NUMERIC"} + result = get_result(node, "STRING") + + expected_fields = {"a": "NUMERIC", "b": "STRING", "c": "STRING"} + assert result.table, "Expected result to have a table" + mapped_fields = {field.name: field.type_ for field in result.table.fields} + assert mapped_fields == expected_fields def test_validate_node_returns_none_if_node_is_not_compiled() -> None: diff --git a/integration/projects/test_models_are_executed/seeds/badly_configured_seed.csv b/integration/projects/test_models_are_executed/seeds/badly_configured_seed.csv new file mode 100644 index 0000000..107cdb3 --- /dev/null +++ b/integration/projects/test_models_are_executed/seeds/badly_configured_seed.csv @@ -0,0 +1,3 @@ +seed_number,seed_string +1,hello +2,world \ No newline at end of file diff --git a/integration/projects/test_models_are_executed/seeds/meta.yaml b/integration/projects/test_models_are_executed/seeds/meta.yaml new file mode 100644 index 0000000..d94fa79 --- /dev/null +++ b/integration/projects/test_models_are_executed/seeds/meta.yaml @@ -0,0 +1,13 @@ +seeds: + - name: my_seed + config: + column_types: + seed_a: STRING + seed_b: FLOAT64 + seed_c: BIGNUMERIC + + - name: badly_configured_seed + config: + column_types: + seed_number: NUMERIC + seed_string: NUMERIC \ No newline at end of file diff --git a/integration/projects/test_models_are_executed/seeds/my_seed.csv b/integration/projects/test_models_are_executed/seeds/my_seed.csv index 17dc561..73c0d9c 100644 --- a/integration/projects/test_models_are_executed/seeds/my_seed.csv +++ b/integration/projects/test_models_are_executed/seeds/my_seed.csv @@ -1,3 +1,3 @@ -a,seed_b -hello,1 -world,2 \ No newline at end of file +a,seed_b,seed_c +hello,1,1.2 +world,2,2.4 \ No newline at end of file diff --git a/integration/projects/test_models_are_executed/test_models_are_executed.py b/integration/projects/test_models_are_executed/test_models_are_executed.py index d06f546..9f9c9ee 100644 --- a/integration/projects/test_models_are_executed/test_models_are_executed.py +++ b/integration/projects/test_models_are_executed/test_models_are_executed.py @@ -1,24 +1,30 @@ +import pytest + +from dbt_dry_run.columns_metadata import expand_table_fields_with_types +from dbt_dry_run.models import BigQueryFieldType +from dbt_dry_run.results import DryRunStatus from integration.conftest import DryRunResult from integration.utils import ( - assert_report_success, get_report_node_by_id, assert_report_node_has_columns, + assert_report_produced, ) -def test_success(dry_run_result: DryRunResult): - assert_report_success(dry_run_result) - - def test_ran_correct_number_of_nodes(dry_run_result: DryRunResult): - report = assert_report_success(dry_run_result) - assert report.node_count == 4 + report = assert_report_produced(dry_run_result) + assert report.node_count == 5 def test_table_of_nodes_is_returned(dry_run_result: DryRunResult): - report = assert_report_success(dry_run_result) + report = assert_report_produced(dry_run_result) seed_node = get_report_node_by_id(report, "seed.test_models_are_executed.my_seed") - assert_report_node_has_columns(seed_node, {"a", "seed_b"}) + columns = expand_table_fields_with_types(seed_node.table) + assert columns == { + "a": BigQueryFieldType.STRING, + "seed_b": BigQueryFieldType.FLOAT64, + "seed_c": BigQueryFieldType.BIGNUMERIC, + } first_layer = get_report_node_by_id( report, "model.test_models_are_executed.first_layer" @@ -28,16 +34,27 @@ def test_table_of_nodes_is_returned(dry_run_result: DryRunResult): second_layer = get_report_node_by_id( report, "model.test_models_are_executed.second_layer" ) - assert_report_node_has_columns(second_layer, {"a", "b", "c", "seed_b"}) + assert_report_node_has_columns(second_layer, {"a", "b", "c", "seed_b", "seed_c"}) def test_disabled_model_not_run(dry_run_result: DryRunResult): - report = assert_report_success(dry_run_result) + report = assert_report_produced(dry_run_result) assert "model.test_models_are_executed.disabled_model" not in set( n.unique_id for n in report.nodes ), "Found disabled model in dry run output" +@pytest.mark.xfail( + reason="Seed type compatibility not checked. (Trying to convert string to number)" +) +def test_badly_configured_seed_fails(dry_run_result: DryRunResult): + report = assert_report_produced(dry_run_result) + seed_node = get_report_node_by_id( + report, "seed.test_models_are_executed.badly_configured_seed" + ) + assert seed_node.status == DryRunStatus.FAILURE + + def test_model_with_all_column_types_succeeds(dry_run_result: DryRunResult): node = get_report_node_by_id( dry_run_result.report, diff --git a/integration/utils.py b/integration/utils.py index 18c1c9b..1233da0 100644 --- a/integration/utils.py +++ b/integration/utils.py @@ -48,7 +48,7 @@ def assert_node_failed_with_error(report: Report, unique_id: str, error: str) -> node = get_report_node_by_id(report, unique_id) assert ( not node.success - ), f"Expected node {node.unique_id} to fail but it was successful" + ), f"Expected node {node.unique_id} to fail but it was successful. Schema {node.table}" assert ( node.error_message == error ), f"Node failed but error message '{node.error_message}' did not match expected: '{error}'"