diff --git a/python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py b/python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py index 8864395cdc29e..7b685a656d7d5 100644 --- a/python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py +++ b/python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py @@ -1,6 +1,5 @@ import contextlib import copy -import dataclasses import os import shutil import signal @@ -8,6 +7,7 @@ import sys import uuid from argparse import Namespace +from collections import abc from concurrent.futures import Future, ThreadPoolExecutor from contextlib import suppress from dataclasses import InitVar, dataclass, field @@ -73,7 +73,7 @@ from sqlglot.expressions import normalize_table_name from sqlglot.lineage import lineage from sqlglot.optimizer import optimize -from typing_extensions import Final, Literal +from typing_extensions import Final, Literal, TypeVar from ..asset_utils import ( DAGSTER_DBT_EXCLUDE_METADATA_KEY, @@ -539,7 +539,6 @@ class DbtCliInvocation: init=False, default=DAGSTER_DBT_TERMINATION_TIMEOUT_SECONDS ) adapter: Optional[BaseAdapter] = field(default=None) - should_fetch_row_count: bool = field(default=False) _stdout: List[str] = field(init=False, default_factory=list) _error_messages: List[str] = field(init=False, default_factory=list) @@ -597,7 +596,6 @@ def run( raise_on_error=raise_on_error, context=context, adapter=adapter, - should_fetch_row_count=False, ) logger.info(f"Running dbt command: `{dbt_cli_invocation.dbt_command}`.") @@ -658,108 +656,12 @@ def _stream_asset_events( target_path=self.target_path, ) - def _get_dbt_resource_props_from_event(self, event: DbtDagsterEventType) -> Dict[str, Any]: - unique_id = cast(TextMetadataValue, event.metadata["unique_id"]).text - return check.not_none(self.manifest["nodes"].get(unique_id)) - - def _attach_post_materialization_metadata( - self, - event: DbtDagsterEventType, - ) -> DbtDagsterEventType: - """Threaded task which runs any postprocessing steps on the given event before it's - emitted to user code. - - This is used to, for example, query the row count of a table after it has been - materialized by dbt. - """ - adapter = check.not_none(self.adapter) - - dbt_resource_props = self._get_dbt_resource_props_from_event(event) - is_view = dbt_resource_props["config"]["materialized"] == "view" - - # Avoid counting rows for views, since they may include complex SQL queries - # that are costly to execute. We can revisit this in the future if there is - # a demand for it. - if is_view: - return event - - # If the adapter is DuckDB, we need to wait for the dbt CLI process to complete - # so that the DuckDB lock is released. This is because DuckDB does not allow for - # opening multiple connections to the same database when a write connection, such - # as the one dbt uses, is open. - try: - from dbt.adapters.duckdb import DuckDBAdapter - - if isinstance(adapter, DuckDBAdapter): - self._dbt_run_thread.result() - except ImportError: - pass - - unique_id = dbt_resource_props["unique_id"] - logger.debug("Fetching row count for %s", unique_id) - table_str = f"{dbt_resource_props['database']}.{dbt_resource_props['schema']}.{dbt_resource_props['name']}" - - with adapter.connection_named(f"row_count_{unique_id}"): - query_result = adapter.execute( - f""" - SELECT - count(*) as row_count - FROM - {table_str} - """, - fetch=True, - ) - row_count = query_result[1][0]["row_count"] - additional_metadata = {**TableMetadataSet(row_count=row_count)} - - if isinstance(event, Output): - return event.with_metadata(metadata={**event.metadata, **additional_metadata}) - else: - return event._replace(metadata={**event.metadata, **additional_metadata}) - - def _stream_dbt_events_and_enqueue_postprocessing( - self, - output_events_and_futures: List[Union[Future, DbtDagsterEventType]], - executor: ThreadPoolExecutor, - ) -> None: - """Task which streams dbt events and either directly places them in - the output_events list to be emitted to user code, or enqueues post-processing tasks - where needed. - """ - for event in self._stream_asset_events(): - # For any materialization or output event, we run postprocessing steps - # to attach additional metadata to the event. - if self.should_fetch_row_count and isinstance(event, (AssetMaterialization, Output)): - output_events_and_futures.append( - executor.submit( - self._attach_post_materialization_metadata, - event, - ) - ) - else: - output_events_and_futures.append(event) - - @experimental - def enable_fetch_row_count( - self, - ) -> "DbtCliInvocation": - """Experimental functionality which will fetch row counts for materialized dbt - models in a dbt run once they are built. Note that row counts will not be fetched - for views, since this requires running the view's SQL query which may be costly. - """ - return dataclasses.replace(self, should_fetch_row_count=True) - @public def stream( self, - ) -> Iterator[ - Union[ - Output, - AssetMaterialization, - AssetObservation, - AssetCheckResult, - ] - ]: + ) -> ( + "DbtEventIterator[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]]" + ): """Stream the events from the dbt CLI process and convert them to Dagster events. Returns: @@ -785,59 +687,7 @@ def stream( def my_dbt_assets(context, dbt: DbtCliResource): yield from dbt.cli(["run"], context=context).stream() """ - has_any_parallel_tasks = self.should_fetch_row_count - - if not has_any_parallel_tasks: - # If we're not enqueuing any parallel tasks, we can just stream the events in - # the main thread. - yield from self._stream_asset_events() - return - - if self.should_fetch_row_count: - logger.info( - "Row counts will be fetched for non-view models once they are materialized." - ) - - # We keep a list of emitted Dagster events and pending futures which augment - # emitted events with additional metadata. This ensures we can yield events in the order - # they are emitted by dbt. - output_events_and_futures: List[Union[Future, DbtDagsterEventType]] = [] - - # Point at project directory to ensure dbt adapters run correctly - with pushd(str(self.project_dir)), ThreadPoolExecutor( - max_workers=STREAM_EVENTS_THREADPOOL_SIZE - ) as executor: - self._dbt_run_thread = executor.submit( - self._stream_dbt_events_and_enqueue_postprocessing, - output_events_and_futures, - executor, - ) - - # Step through the list of output events and futures, yielding them in order - # once they are ready to be emitted - event_to_emit_idx = 0 - while True: - all_work_complete = get_future_completion_state_or_err( - [self._dbt_run_thread, *output_events_and_futures] - ) - if all_work_complete and event_to_emit_idx >= len(output_events_and_futures): - break - - if event_to_emit_idx < len(output_events_and_futures): - event_to_emit = output_events_and_futures[event_to_emit_idx] - - if isinstance(event_to_emit, Future): - # If the next event to emit is a Future (waiting on postprocessing), - # we need to wait for it to complete before yielding the event. - try: - event = event_to_emit.result(timeout=0.1) - yield event - event_to_emit_idx += 1 - except: - pass - else: - yield event_to_emit - event_to_emit_idx += 1 + return DbtEventIterator(self._stream_asset_events(), self) @public def stream_raw_events(self) -> Iterator[DbtCliEventMessage]: @@ -993,6 +843,191 @@ def _raise_on_error(self) -> None: ) +# We define DbtEventIterator as a generic type for the sake of type hinting. +# This is so that users who inspect the type of the return value of `DbtCliInvocation.stream()` +# will be able to see the inner type of the iterator, rather than just `DbtEventIterator`. +T = TypeVar("T", bound=DbtDagsterEventType) + + +class DbtEventIterator(abc.Iterator[T]): + """A wrapper around an iterator of dbt events which contains additional methods for + post-processing the events, such as fetching row counts for materialized tables. + """ + + def __init__(self, events: Iterator[T], dbt_cli_invocation: DbtCliInvocation) -> None: + self._inner_iterator = events + self._dbt_cli_invocation = dbt_cli_invocation + + def __next__(self) -> T: + return next(self._inner_iterator) + + def __iter__(self) -> "DbtEventIterator[T]": + return self + + def _get_dbt_resource_props_from_event(self, event: DbtDagsterEventType) -> Dict[str, Any]: + unique_id = cast(TextMetadataValue, event.metadata["unique_id"]).text + return check.not_none(self._dbt_cli_invocation.manifest["nodes"].get(unique_id)) + + def _fetch_and_attach_row_count_metadata( + self, + event: DbtDagsterEventType, + ) -> DbtDagsterEventType: + """Threaded task which fetches row counts for materialized dbt models in a dbt run + once they are built, and attaches the row count as metadata to the event. + """ + adapter = check.not_none(self._dbt_cli_invocation.adapter) + + dbt_resource_props = self._get_dbt_resource_props_from_event(event) + is_view = dbt_resource_props["config"]["materialized"] == "view" + + # Avoid counting rows for views, since they may include complex SQL queries + # that are costly to execute. We can revisit this in the future if there is + # a demand for it. + if is_view: + return event + + # If the adapter is DuckDB, we need to wait for the dbt CLI process to complete + # so that the DuckDB lock is released. This is because DuckDB does not allow for + # opening multiple connections to the same database when a write connection, such + # as the one dbt uses, is open. + try: + from dbt.adapters.duckdb import DuckDBAdapter + + if isinstance(adapter, DuckDBAdapter): + self._dbt_run_thread.result() + except ImportError: + pass + + unique_id = dbt_resource_props["unique_id"] + logger.debug("Fetching row count for %s", unique_id) + table_str = f"{dbt_resource_props['database']}.{dbt_resource_props['schema']}.{dbt_resource_props['name']}" + + try: + with adapter.connection_named(f"row_count_{unique_id}"): + query_result = adapter.execute( + f""" + SELECT + count(*) as row_count + FROM + {table_str} + """, + fetch=True, + ) + row_count = query_result[1][0]["row_count"] + additional_metadata = {**TableMetadataSet(row_count=row_count)} + + if isinstance(event, Output): + return event.with_metadata(metadata={**event.metadata, **additional_metadata}) + else: + return event._replace(metadata={**event.metadata, **additional_metadata}) + except Exception as e: + logger.exception( + f"An error occurred while fetching row count for {unique_id}. Row count metadata" + " will not be included in the event.\n\n" + f"Exception: {e}" + ) + return event + + def _stream_dbt_events_and_enqueue_postprocessing( + self, + output_events_and_futures: List[Union[Future, DbtDagsterEventType]], + executor: ThreadPoolExecutor, + ) -> None: + """Task which streams dbt events and either directly places them in + the output_events list to be emitted to user code, or enqueues post-processing tasks + where needed. + """ + for event in self: + # For any materialization or output event, we run postprocessing steps + # to attach additional metadata to the event. + if isinstance(event, (AssetMaterialization, Output)): + output_events_and_futures.append( + executor.submit( + self._fetch_and_attach_row_count_metadata, + event, + ) + ) + else: + output_events_and_futures.append(event) + + @public + @experimental + def fetch_row_counts( + self, + ) -> ( + "DbtEventIterator[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]]" + ): + """Experimental functionality which will fetch row counts for materialized dbt + models in a dbt run once they are built. Note that row counts will not be fetched + for views, since this requires running the view's SQL query which may be costly. + + Returns: + Iterator[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]]: + A set of corresponding Dagster events for dbt models, with row counts attached, + yielded in the order they are emitted by dbt. + """ + return DbtEventIterator( + self._fetch_row_counts_inner(), + dbt_cli_invocation=self._dbt_cli_invocation, + ) + + def _fetch_row_counts_inner( + self, + ) -> Iterator[ + Union[ + Output, + AssetMaterialization, + AssetObservation, + AssetCheckResult, + ] + ]: + logger.info("Row counts will be fetched for non-view models once they are materialized.") + + # We keep a list of emitted Dagster events and pending futures which augment + # emitted events with additional metadata. This ensures we can yield events in the order + # they are emitted by dbt. + output_events_and_futures: List[Union[Future, DbtDagsterEventType]] = [] + + # Point at project directory to ensure dbt adapters run correctly + with pushd(str(self._dbt_cli_invocation.project_dir)), ThreadPoolExecutor( + max_workers=STREAM_EVENTS_THREADPOOL_SIZE + ) as executor: + self._dbt_run_thread = executor.submit( + self._stream_dbt_events_and_enqueue_postprocessing, + output_events_and_futures, + executor, + ) + + # Step through the list of output events and futures, yielding them in order + # once they are ready to be emitted + event_to_emit_idx = 0 + while True: + all_work_complete = get_future_completion_state_or_err( + [self._dbt_run_thread, *output_events_and_futures] + ) + if all_work_complete and event_to_emit_idx >= len(output_events_and_futures): + break + + if event_to_emit_idx < len(output_events_and_futures): + event_to_emit = output_events_and_futures[event_to_emit_idx] + + try: + # If the next event to emit is a Future (waiting on postprocessing), + # we need to wait for it to complete before yielding the event. + event = ( + event_to_emit.result(timeout=0.1) + if isinstance(event_to_emit, Future) + else event_to_emit + ) + yield event + event_to_emit_idx += 1 + except: + # If the Future has not completed, it will raise a TimeoutError. + # Any other exception will be reraised in the main thread as part + # of get_future_completion_state_or_err. + pass + + class DbtCliResource(ConfigurableResource): """A resource used to execute dbt CLI commands. diff --git a/python_modules/libraries/dagster-dbt/dagster_dbt_tests/core/test_row_count_postprocessing.py b/python_modules/libraries/dagster-dbt/dagster_dbt_tests/core/test_row_count_postprocessing.py index 941cc3063b2dd..f865ad0ad6599 100644 --- a/python_modules/libraries/dagster-dbt/dagster_dbt_tests/core/test_row_count_postprocessing.py +++ b/python_modules/libraries/dagster-dbt/dagster_dbt_tests/core/test_row_count_postprocessing.py @@ -1,6 +1,7 @@ import os from typing import Any, Dict, cast +import mock import pytest from dagster import ( AssetExecutionContext, @@ -64,7 +65,7 @@ def test_row_count( ) -> None: @dbt_assets(manifest=test_jaffle_shop_manifest_standalone_duckdb_dbfile) def my_dbt_assets(context: AssetExecutionContext, dbt: DbtCliResource): - yield from dbt.cli(["build"], context=context).enable_fetch_row_count().stream() + yield from dbt.cli(["build"], context=context).stream().fetch_row_counts() result = materialize( [my_dbt_assets], @@ -92,3 +93,33 @@ def my_dbt_assets(context: AssetExecutionContext, dbt: DbtCliResource): if "stg" not in check.not_none(event.asset_key).path[-1] ] assert all(row_count and row_count > 0 for row_count in row_counts), row_counts + + +def test_row_count_err( + test_jaffle_shop_manifest_standalone_duckdb_dbfile: Dict[str, Any], + caplog: pytest.LogCaptureFixture, +) -> None: + # test that we can handle exceptions in row count fetching + # and still complete the dbt assets materialization + with mock.patch("dbt.adapters.duckdb.DuckDBAdapter.execute") as mock_execute: + mock_execute.side_effect = Exception("mock_execute exception") + + @dbt_assets(manifest=test_jaffle_shop_manifest_standalone_duckdb_dbfile) + def my_dbt_assets(context: AssetExecutionContext, dbt: DbtCliResource): + yield from dbt.cli(["build"], context=context).stream().fetch_row_counts() + + result = materialize( + [my_dbt_assets], + resources={"dbt": DbtCliResource(project_dir=os.fspath(test_jaffle_shop_path))}, + ) + + assert result.success + + # Validate that no row counts were fetched due to the exception + assert not any( + "dagster/row_count" in event.materialization.metadata + for event in result.get_asset_materialization_events() + ) + + # assert we have warning message in logs + assert "An error occurred while fetching row count for " in caplog.text