Skip to content

Commit

Permalink
feat(dbt): collect column metadata using the underlying dbt adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
rexledesma committed Feb 12, 2024
1 parent addbfc4 commit c7366b9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,20 @@
ConfigurableResource,
OpExecutionContext,
Output,
_check as check,
get_dagster_logger,
)
from dagster._annotations import public
from dagster._config.pythonic_config.pydantic_compat_layer import compat_model_validator
from dagster._core.definitions.metadata import MetadataValue
from dagster._core.definitions.metadata.table import TableColumn, TableSchema
from dagster._core.errors import DagsterExecutionInterruptedError, DagsterInvalidPropertyError
from dbt.adapters.base import BaseRelation
from dbt.adapters.base.impl import BaseAdapter
from dbt.adapters.factory import FACTORY
from dbt.cli.context import make_context
from dbt.cli.requires import preflight, profile, project
from dbt.config import RuntimeConfig
from dbt.contracts.results import NodeStatus, TestStatus
from dbt.node_types import NodeType
from dbt.version import __version__ as dbt_version
Expand Down Expand Up @@ -108,6 +117,7 @@ def to_default_asset_events(
manifest: DbtManifestParam,
dagster_dbt_translator: DagsterDbtTranslator = DagsterDbtTranslator(),
context: Optional[OpExecutionContext] = None,
adapter: Optional[BaseAdapter] = None,
) -> Iterator[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]]:
"""Convert a dbt CLI event to a set of corresponding Dagster events.
Expand Down Expand Up @@ -167,6 +177,27 @@ def to_default_asset_events(
finished_at = dateutil.parser.isoparse(event_node_info["node_finished_at"])
duration_seconds = (finished_at - started_at).total_seconds()

# Actually get the column metadata! Programmatically! Horray!
column_metadata = {}
if adapter:
node_relation = self.raw_event["data"]["node_info"]["node_relation"]
relation = BaseRelation.create(
database=node_relation["database"],
schema=node_relation["schema"],
identifier=node_relation["alias"],
)
adapter.acquire_connection()
column_metadata = {
"columns": MetadataValue.table_schema(
TableSchema(
columns=[
TableColumn(name=column.column, type=column.data_type)
for column in adapter.get_columns_in_relation(relation)
]
)
)
}

if has_asset_def:
yield Output(
value=None,
Expand All @@ -175,6 +206,7 @@ def to_default_asset_events(
"unique_id": unique_id,
"invocation_id": invocation_id,
"Execution Duration": duration_seconds,
**column_metadata,
**adapter_response_metadata,
},
)
Expand All @@ -188,6 +220,7 @@ def to_default_asset_events(
"unique_id": unique_id,
"invocation_id": invocation_id,
"Execution Duration": duration_seconds,
**column_metadata,
**adapter_response_metadata,
},
)
Expand Down Expand Up @@ -301,6 +334,7 @@ class DbtCliInvocation:
target_path: Path
raise_on_error: bool
log_level: Literal["info", "debug"]
adapter: BaseAdapter
context: Optional[OpExecutionContext] = field(default=None, repr=False)
termination_timeout_seconds: float = field(
init=False, default=DAGSTER_DBT_TERMINATION_TIMEOUT_SECONDS
Expand All @@ -319,6 +353,7 @@ def run(
target_path: Path,
raise_on_error: bool,
log_level: Literal["info", "debug"],
adapter: BaseAdapter,
context: Optional[OpExecutionContext],
) -> "DbtCliInvocation":
# Attempt to take advantage of partial parsing. If there is a `partial_parse.msgpack` in
Expand Down Expand Up @@ -361,6 +396,7 @@ def run(
target_path=target_path,
raise_on_error=raise_on_error,
log_level=log_level,
adapter=adapter,
context=context,
)
logger.info(f"Running dbt command: `{dbt_cli_invocation.dbt_command}`.")
Expand Down Expand Up @@ -451,6 +487,7 @@ def my_dbt_assets(context, dbt: DbtCliResource):
manifest=self.manifest,
dagster_dbt_translator=self.dagster_dbt_translator,
context=self.context,
adapter=self.adapter,
)

@public
Expand Down Expand Up @@ -564,6 +601,8 @@ def _raise_on_error(self) -> None:

logger.info(f"Finished dbt command: `{self.dbt_command}`.")

self.adapter.release_connection()

if not is_successful and self.raise_on_error:
log_path = self.target_path.joinpath("dbt.log")
extra_description = ""
Expand Down Expand Up @@ -1031,6 +1070,36 @@ def my_dbt_op(dbt: DbtCliResource):
if not target_path.is_absolute():
target_path = project_dir.joinpath(target_path)

# Begin cursed code: invoke the internal dbt functions that the CLI uses under the hood
# to setup the dbt context and initialize the dbt adapter.

# Change the current working directory to the dbt project directory, for Click CLI to
# properly parse the dbt project.
current_path = os.getcwd()
os.chdir(path=self.project_dir)
ctx = check.not_none(make_context(args=args[1:]))

# When invoking Dagster from the CLI, we use Click. dbt also uses Click. We need
# sys.argv to reference the dbt CLI command that's running, not the Dagster CLI command.
sys.argv = list(args)

# Do the dbt setup.
preflight(lambda _: None)(ctx)
profile(lambda _: None)(ctx)
project(lambda _: None)(ctx)

# Retrieve the dbt credentials as specified from the profiles.yml file.
config = RuntimeConfig.from_parts(
ctx.obj["project"],
ctx.obj["profile"],
ctx.obj["flags"],
)

# Initialize the dbt adapter.
FACTORY.register_adapter(config)
adapter = cast(BaseAdapter, FACTORY.lookup_adapter(config.credentials.type))
os.chdir(path=current_path)

return DbtCliInvocation.run(
args=args,
env=env,
Expand All @@ -1040,6 +1109,7 @@ def my_dbt_op(dbt: DbtCliResource):
target_path=target_path,
raise_on_error=raise_on_error,
log_level=log_level,
adapter=adapter,
context=context,
)

Expand Down
8 changes: 4 additions & 4 deletions python_modules/libraries/dagster-dbt/dagster_dbt/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Expand All @@ -24,6 +25,9 @@

from .types import DbtOutput

if TYPE_CHECKING:
from dbt.graph.selector_spec import SelectionSpec

# dbt resource types that may be considered assets
ASSET_RESOURCE_TYPES = ["model", "seed", "snapshot"]

Expand Down Expand Up @@ -242,8 +246,6 @@ def select_unique_ids_from_manifest(
import dbt.graph.cli as graph_cli
import dbt.graph.selector as graph_selector
from dbt.contracts.graph.manifest import Manifest
from dbt.flags import GLOBAL_FLAGS
from dbt.graph.selector_spec import IndirectSelection, SelectionSpec
from networkx import DiGraph

manifest = Manifest.from_dict(manifest_json)
Expand All @@ -252,8 +254,6 @@ def select_unique_ids_from_manifest(
graph = graph_selector.Graph(DiGraph(incoming_graph_data=child_map))

# create a parsed selection from the select string
setattr(GLOBAL_FLAGS, "INDIRECT_SELECTION", IndirectSelection.Eager)
setattr(GLOBAL_FLAGS, "WARN_ERROR", True)
parsed_spec: SelectionSpec = graph_cli.parse_union([select], True)

if exclude:
Expand Down

0 comments on commit c7366b9

Please sign in to comment.