Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PrefectDbtSettings to prefect-dbt #16834

Merged
merged 8 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/integrations/prefect-dbt/prefect_dbt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import _version

from .core import PrefectDbtSettings
from .cloud import DbtCloudCredentials, DbtCloudJob # noqa
from .cli import ( # noqa
DbtCliProfile,
Expand Down
33 changes: 22 additions & 11 deletions src/integrations/prefect-dbt/prefect_dbt/cli/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import Any, Dict, Optional, Type

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self

from prefect.blocks.core import Block
Expand Down Expand Up @@ -40,15 +40,15 @@ class DbtConfigs(Block, abc.ABC):

def _populate_configs_json(
self,
configs_json: Dict[str, Any],
fields: Dict[str, Any],
model: BaseModel = None,
) -> Dict[str, Any]:
configs_json: dict[str, Any],
fields: dict[str, Any],
model: Optional[BaseModel] = None,
) -> dict[str, Any]:
"""
Recursively populate configs_json.
"""
# if allow_field_overrides is True keys from TargetConfigs take precedence
override_configs_json = {}
override_configs_json: dict[str, Any] = {}

for field_name, field in fields.items():
if model is not None:
Expand Down Expand Up @@ -93,7 +93,7 @@ def _populate_configs_json(
configs_json.update(override_configs_json)
return configs_json

def get_configs(self) -> Dict[str, Any]:
def get_configs(self) -> dict[str, Any]:
"""
Returns the dbt configs, likely used eventually for writing to profiles.yml.

Expand All @@ -120,6 +120,19 @@ class BaseTargetConfigs(DbtConfigs, abc.ABC):
),
)

@model_validator(mode="before")
@classmethod
def handle_target_configs(cls, v: Any) -> Any:
"""Handle target configs field aliasing during validation"""
if isinstance(v, dict):
if "schema_" in v:
v["schema"] = v.pop("schema_")
# Handle nested blocks
for value in v.values():
if isinstance(value, dict) and "schema_" in value:
value["schema"] = value.pop("schema_")
return v


class TargetConfigs(BaseTargetConfigs):
"""
Expand Down Expand Up @@ -289,7 +302,7 @@ class GlobalConfigs(DbtConfigs):
write_json: Optional[bool] = Field(
default=None,
description=(
"Determines whether dbt writes JSON artifacts to " "the target/ directory."
"Determines whether dbt writes JSON artifacts to the target/ directory."
),
)
warn_error: Optional[bool] = Field(
Expand Down Expand Up @@ -321,9 +334,7 @@ class GlobalConfigs(DbtConfigs):
)
use_experimental_parser: Optional[bool] = Field(
default=None,
description=(
"Opt into the latest experimental version " "of the static parser."
),
description=("Opt into the latest experimental version of the static parser."),
)
static_parser: Optional[bool] = Field(
default=None,
Expand Down
29 changes: 22 additions & 7 deletions src/integrations/prefect-dbt/prefect_dbt/cli/credentials.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Module containing credentials for interacting with dbt CLI"""

from typing import Any, Dict, Optional, Union
from typing import Annotated, Any, Dict, Optional, Union

from pydantic import Field
from pydantic import Discriminator, Field, Tag

from prefect.blocks.core import Block
from prefect_dbt.cli.configs import GlobalConfigs, TargetConfigs
Expand All @@ -23,6 +23,18 @@
PostgresTargetConfigs = None


def target_configs_discriminator(v: Any) -> str:
"""
Discriminator function for target configs. Returns the block type slug.
"""
if isinstance(v, dict):
return v.get("block_type_slug", "dbt-cli-target-configs")
if isinstance(v, Block):
# When creating a new instance, we get a concrete Block type
return v.get_block_type_slug()
return "dbt-cli-target-configs" # Default to base type


class DbtCliProfile(Block):
"""
Profile for use across dbt CLI tasks and flows.
Expand Down Expand Up @@ -116,11 +128,14 @@ class DbtCliProfile(Block):
target: str = Field(
default=..., description="The default target your dbt project will use."
)
target_configs: Union[
SnowflakeTargetConfigs,
BigQueryTargetConfigs,
PostgresTargetConfigs,
TargetConfigs,
target_configs: Annotated[
Union[
Annotated[SnowflakeTargetConfigs, Tag("dbt-cli-snowflake-target-configs")],
Annotated[BigQueryTargetConfigs, Tag("dbt-cli-bigquery-target-configs")],
Annotated[PostgresTargetConfigs, Tag("dbt-cli-postgres-target-configs")],
Annotated[TargetConfigs, Tag("dbt-cli-target-configs")],
],
Discriminator(target_configs_discriminator),
] = Field(
default=...,
description=(
Expand Down
3 changes: 3 additions & 0 deletions src/integrations/prefect-dbt/prefect_dbt/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from prefect_dbt.core.settings import PrefectDbtSettings

__all__ = ["PrefectDbtSettings"]
17 changes: 17 additions & 0 deletions src/integrations/prefect-dbt/prefect_dbt/core/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
A class for configuring or automatically discovering settings to be used with PrefectDbtRunner.
"""

from pathlib import Path

from dbt_common.events.base_types import EventLevel
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict


class PrefectDbtSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="DBT_")

profiles_dir: Path = Field(default=Path.home() / ".dbt")
project_dir: Path = Field(default_factory=Path.cwd)
log_level: EventLevel = Field(default=EventLevel.INFO)
11 changes: 6 additions & 5 deletions src/integrations/prefect-dbt/tests/cli/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import os
from pathlib import Path
from unittest.mock import MagicMock

Expand Down Expand Up @@ -409,12 +408,14 @@ def test_flow():


@pytest.mark.usefixtures("dbt_runner_ls_result")
def test_trigger_dbt_cli_command_find_env(profiles_dir, dbt_cli_profile_bare):
def test_trigger_dbt_cli_command_find_env(
profiles_dir, dbt_cli_profile_bare, monkeypatch
):
@flow
def test_flow():
return trigger_dbt_cli_command("ls", dbt_cli_profile=dbt_cli_profile_bare)

os.environ["DBT_PROFILES_DIR"] = str(profiles_dir)
monkeypatch.setenv("DBT_PROFILES_DIR", str(profiles_dir))
result = test_flow()
assert isinstance(result, dbtRunnerResult)

Expand Down Expand Up @@ -474,9 +475,9 @@ def dbt_cli_profile(self):
)

def test_find_valid_profiles_dir_default_env(
self, tmp_path, mock_open_process, mock_shell_process
self, tmp_path, mock_open_process, mock_shell_process, monkeypatch
):
os.environ["DBT_PROFILES_DIR"] = str(tmp_path)
monkeypatch.setenv("DBT_PROFILES_DIR", str(tmp_path))
(tmp_path / "profiles.yml").write_text("test")
DbtCoreOperation(commands=["dbt debug"]).run()
actual = str(mock_open_process.call_args_list[0][1]["env"]["DBT_PROFILES_DIR"])
Expand Down
7 changes: 5 additions & 2 deletions src/integrations/prefect-dbt/tests/cli/test_credentials.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from prefect_dbt.cli.credentials import DbtCliProfile, GlobalConfigs, TargetConfigs
from pydantic import ValidationError
from typing_extensions import Literal


@pytest.mark.parametrize("configs_type", ["dict", "model"])
def test_dbt_cli_profile_init(configs_type):
def test_dbt_cli_profile_init(configs_type: Literal["dict", "model"]):
target_configs = dict(type="snowflake", schema="schema")
global_configs = dict(use_colors=False)
if configs_type == "model":
Expand Down Expand Up @@ -60,7 +61,9 @@ def test_dbt_cli_profile_get_profile():
"class_target_configs",
],
)
async def test_dbt_cli_profile_save_load_roundtrip(target_configs_request, request):
async def test_dbt_cli_profile_save_load_roundtrip(
target_configs_request: str, request: pytest.FixtureRequest
):
target_configs = request.getfixturevalue(target_configs_request)
dbt_cli_profile = DbtCliProfile(
name="my_name",
Expand Down
37 changes: 37 additions & 0 deletions src/integrations/prefect-dbt/tests/core/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pathlib import Path

from dbt_common.events.base_types import EventLevel
from prefect_dbt.core.settings import PrefectDbtSettings
from pytest import MonkeyPatch


def test_default_settings():
settings = PrefectDbtSettings()
assert settings.profiles_dir == Path.home() / ".dbt"
assert settings.project_dir == Path.cwd()
assert settings.log_level == EventLevel.INFO


def test_custom_settings():
custom_profiles_dir = Path("/custom/profiles/dir")
custom_project_dir = Path("/custom/project/dir")

settings = PrefectDbtSettings(
profiles_dir=custom_profiles_dir, project_dir=custom_project_dir
)

assert settings.profiles_dir == custom_profiles_dir
assert settings.project_dir == custom_project_dir


def test_env_var_override(monkeypatch: MonkeyPatch):
env_profiles_dir = "/env/profiles/dir"
env_project_dir = "/env/project/dir"

monkeypatch.setenv("DBT_PROFILES_DIR", env_profiles_dir)
monkeypatch.setenv("DBT_PROJECT_DIR", env_project_dir)

settings = PrefectDbtSettings()

assert settings.profiles_dir == Path(env_profiles_dir)
assert settings.project_dir == Path(env_project_dir)
2 changes: 1 addition & 1 deletion src/prefect/blocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def _generate_code_example(cls) -> str:
module_str = ".".join(qualified_name.split(".")[:-1])
origin = cls.__pydantic_generic_metadata__.get("origin") or cls
class_name = origin.__name__
block_variable_name = f'{cls.get_block_type_slug().replace("-", "_")}_block'
block_variable_name = f"{cls.get_block_type_slug().replace('-', '_')}_block"

return dedent(
f"""\
Expand Down
Loading