From c7366b9a38d8e73a782ff68423d1209237633265 Mon Sep 17 00:00:00 2001 From: Rex Ledesma Date: Thu, 8 Feb 2024 20:39:08 -0500 Subject: [PATCH] feat(dbt): collect column metadata using the underlying dbt adapter --- .../dagster_dbt/core/resources_v2.py | 70 +++++++++++++++++++ .../dagster-dbt/dagster_dbt/utils.py | 8 +-- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py b/python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py index c02ae257145c6..6a1e6b6a712a8 100644 --- a/python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py +++ b/python_modules/libraries/dagster-dbt/dagster_dbt/core/resources_v2.py @@ -30,11 +30,20 @@ ConfigurableResource, OpExecutionContext, Output, + _check as check, get_dagster_logger, ) from dagster._annotations import public from dagster._config.pythonic_config.pydantic_compat_layer import compat_model_validator +from dagster._core.definitions.metadata import MetadataValue +from dagster._core.definitions.metadata.table import TableColumn, TableSchema from dagster._core.errors import DagsterExecutionInterruptedError, DagsterInvalidPropertyError +from dbt.adapters.base import BaseRelation +from dbt.adapters.base.impl import BaseAdapter +from dbt.adapters.factory import FACTORY +from dbt.cli.context import make_context +from dbt.cli.requires import preflight, profile, project +from dbt.config import RuntimeConfig from dbt.contracts.results import NodeStatus, TestStatus from dbt.node_types import NodeType from dbt.version import __version__ as dbt_version @@ -108,6 +117,7 @@ def to_default_asset_events( manifest: DbtManifestParam, dagster_dbt_translator: DagsterDbtTranslator = DagsterDbtTranslator(), context: Optional[OpExecutionContext] = None, + adapter: Optional[BaseAdapter] = None, ) -> Iterator[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]]: """Convert a dbt CLI event to a set of corresponding Dagster events. @@ -167,6 +177,27 @@ def to_default_asset_events( finished_at = dateutil.parser.isoparse(event_node_info["node_finished_at"]) duration_seconds = (finished_at - started_at).total_seconds() + # Actually get the column metadata! Programmatically! Horray! + column_metadata = {} + if adapter: + node_relation = self.raw_event["data"]["node_info"]["node_relation"] + relation = BaseRelation.create( + database=node_relation["database"], + schema=node_relation["schema"], + identifier=node_relation["alias"], + ) + adapter.acquire_connection() + column_metadata = { + "columns": MetadataValue.table_schema( + TableSchema( + columns=[ + TableColumn(name=column.column, type=column.data_type) + for column in adapter.get_columns_in_relation(relation) + ] + ) + ) + } + if has_asset_def: yield Output( value=None, @@ -175,6 +206,7 @@ def to_default_asset_events( "unique_id": unique_id, "invocation_id": invocation_id, "Execution Duration": duration_seconds, + **column_metadata, **adapter_response_metadata, }, ) @@ -188,6 +220,7 @@ def to_default_asset_events( "unique_id": unique_id, "invocation_id": invocation_id, "Execution Duration": duration_seconds, + **column_metadata, **adapter_response_metadata, }, ) @@ -301,6 +334,7 @@ class DbtCliInvocation: target_path: Path raise_on_error: bool log_level: Literal["info", "debug"] + adapter: BaseAdapter context: Optional[OpExecutionContext] = field(default=None, repr=False) termination_timeout_seconds: float = field( init=False, default=DAGSTER_DBT_TERMINATION_TIMEOUT_SECONDS @@ -319,6 +353,7 @@ def run( target_path: Path, raise_on_error: bool, log_level: Literal["info", "debug"], + adapter: BaseAdapter, context: Optional[OpExecutionContext], ) -> "DbtCliInvocation": # Attempt to take advantage of partial parsing. If there is a `partial_parse.msgpack` in @@ -361,6 +396,7 @@ def run( target_path=target_path, raise_on_error=raise_on_error, log_level=log_level, + adapter=adapter, context=context, ) logger.info(f"Running dbt command: `{dbt_cli_invocation.dbt_command}`.") @@ -451,6 +487,7 @@ def my_dbt_assets(context, dbt: DbtCliResource): manifest=self.manifest, dagster_dbt_translator=self.dagster_dbt_translator, context=self.context, + adapter=self.adapter, ) @public @@ -564,6 +601,8 @@ def _raise_on_error(self) -> None: logger.info(f"Finished dbt command: `{self.dbt_command}`.") + self.adapter.release_connection() + if not is_successful and self.raise_on_error: log_path = self.target_path.joinpath("dbt.log") extra_description = "" @@ -1031,6 +1070,36 @@ def my_dbt_op(dbt: DbtCliResource): if not target_path.is_absolute(): target_path = project_dir.joinpath(target_path) + # Begin cursed code: invoke the internal dbt functions that the CLI uses under the hood + # to setup the dbt context and initialize the dbt adapter. + + # Change the current working directory to the dbt project directory, for Click CLI to + # properly parse the dbt project. + current_path = os.getcwd() + os.chdir(path=self.project_dir) + ctx = check.not_none(make_context(args=args[1:])) + + # When invoking Dagster from the CLI, we use Click. dbt also uses Click. We need + # sys.argv to reference the dbt CLI command that's running, not the Dagster CLI command. + sys.argv = list(args) + + # Do the dbt setup. + preflight(lambda _: None)(ctx) + profile(lambda _: None)(ctx) + project(lambda _: None)(ctx) + + # Retrieve the dbt credentials as specified from the profiles.yml file. + config = RuntimeConfig.from_parts( + ctx.obj["project"], + ctx.obj["profile"], + ctx.obj["flags"], + ) + + # Initialize the dbt adapter. + FACTORY.register_adapter(config) + adapter = cast(BaseAdapter, FACTORY.lookup_adapter(config.credentials.type)) + os.chdir(path=current_path) + return DbtCliInvocation.run( args=args, env=env, @@ -1040,6 +1109,7 @@ def my_dbt_op(dbt: DbtCliResource): target_path=target_path, raise_on_error=raise_on_error, log_level=log_level, + adapter=adapter, context=context, ) diff --git a/python_modules/libraries/dagster-dbt/dagster_dbt/utils.py b/python_modules/libraries/dagster-dbt/dagster_dbt/utils.py index 53a6eccfe769c..29f0467de7e54 100644 --- a/python_modules/libraries/dagster-dbt/dagster_dbt/utils.py +++ b/python_modules/libraries/dagster-dbt/dagster_dbt/utils.py @@ -1,4 +1,5 @@ from typing import ( + TYPE_CHECKING, AbstractSet, Any, Callable, @@ -24,6 +25,9 @@ from .types import DbtOutput +if TYPE_CHECKING: + from dbt.graph.selector_spec import SelectionSpec + # dbt resource types that may be considered assets ASSET_RESOURCE_TYPES = ["model", "seed", "snapshot"] @@ -242,8 +246,6 @@ def select_unique_ids_from_manifest( import dbt.graph.cli as graph_cli import dbt.graph.selector as graph_selector from dbt.contracts.graph.manifest import Manifest - from dbt.flags import GLOBAL_FLAGS - from dbt.graph.selector_spec import IndirectSelection, SelectionSpec from networkx import DiGraph manifest = Manifest.from_dict(manifest_json) @@ -252,8 +254,6 @@ def select_unique_ids_from_manifest( graph = graph_selector.Graph(DiGraph(incoming_graph_data=child_map)) # create a parsed selection from the select string - setattr(GLOBAL_FLAGS, "INDIRECT_SELECTION", IndirectSelection.Eager) - setattr(GLOBAL_FLAGS, "WARN_ERROR", True) parsed_spec: SelectionSpec = graph_cli.parse_union([select], True) if exclude: