diff --git a/src/integrations/prefect-dbt/prefect_dbt/__init__.py b/src/integrations/prefect-dbt/prefect_dbt/__init__.py index 5e39db25c398..ec50612a3bef 100644 --- a/src/integrations/prefect-dbt/prefect_dbt/__init__.py +++ b/src/integrations/prefect-dbt/prefect_dbt/__init__.py @@ -1,5 +1,6 @@ from . import _version +from .core import PrefectDbtSettings from .cloud import DbtCloudCredentials, DbtCloudJob # noqa from .cli import ( # noqa DbtCliProfile, diff --git a/src/integrations/prefect-dbt/prefect_dbt/cli/configs/base.py b/src/integrations/prefect-dbt/prefect_dbt/cli/configs/base.py index 1104fc0b7a6c..f55aabd1e4ae 100644 --- a/src/integrations/prefect-dbt/prefect_dbt/cli/configs/base.py +++ b/src/integrations/prefect-dbt/prefect_dbt/cli/configs/base.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Type -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from typing_extensions import Self from prefect.blocks.core import Block @@ -40,15 +40,15 @@ class DbtConfigs(Block, abc.ABC): def _populate_configs_json( self, - configs_json: Dict[str, Any], - fields: Dict[str, Any], - model: BaseModel = None, - ) -> Dict[str, Any]: + configs_json: dict[str, Any], + fields: dict[str, Any], + model: Optional[BaseModel] = None, + ) -> dict[str, Any]: """ Recursively populate configs_json. """ # if allow_field_overrides is True keys from TargetConfigs take precedence - override_configs_json = {} + override_configs_json: dict[str, Any] = {} for field_name, field in fields.items(): if model is not None: @@ -93,7 +93,7 @@ def _populate_configs_json( configs_json.update(override_configs_json) return configs_json - def get_configs(self) -> Dict[str, Any]: + def get_configs(self) -> dict[str, Any]: """ Returns the dbt configs, likely used eventually for writing to profiles.yml. @@ -120,6 +120,19 @@ class BaseTargetConfigs(DbtConfigs, abc.ABC): ), ) + @model_validator(mode="before") + @classmethod + def handle_target_configs(cls, v: Any) -> Any: + """Handle target configs field aliasing during validation""" + if isinstance(v, dict): + if "schema_" in v: + v["schema"] = v.pop("schema_") + # Handle nested blocks + for value in v.values(): + if isinstance(value, dict) and "schema_" in value: + value["schema"] = value.pop("schema_") + return v + class TargetConfigs(BaseTargetConfigs): """ @@ -289,7 +302,7 @@ class GlobalConfigs(DbtConfigs): write_json: Optional[bool] = Field( default=None, description=( - "Determines whether dbt writes JSON artifacts to " "the target/ directory." + "Determines whether dbt writes JSON artifacts to the target/ directory." ), ) warn_error: Optional[bool] = Field( @@ -321,9 +334,7 @@ class GlobalConfigs(DbtConfigs): ) use_experimental_parser: Optional[bool] = Field( default=None, - description=( - "Opt into the latest experimental version " "of the static parser." - ), + description=("Opt into the latest experimental version of the static parser."), ) static_parser: Optional[bool] = Field( default=None, diff --git a/src/integrations/prefect-dbt/prefect_dbt/cli/credentials.py b/src/integrations/prefect-dbt/prefect_dbt/cli/credentials.py index 5fa4e9cd9163..715f24cfbf58 100644 --- a/src/integrations/prefect-dbt/prefect_dbt/cli/credentials.py +++ b/src/integrations/prefect-dbt/prefect_dbt/cli/credentials.py @@ -1,8 +1,8 @@ """Module containing credentials for interacting with dbt CLI""" -from typing import Any, Dict, Optional, Union +from typing import Annotated, Any, Dict, Optional, Union -from pydantic import Field +from pydantic import Discriminator, Field, Tag from prefect.blocks.core import Block from prefect_dbt.cli.configs import GlobalConfigs, TargetConfigs @@ -23,6 +23,18 @@ PostgresTargetConfigs = None +def target_configs_discriminator(v: Any) -> str: + """ + Discriminator function for target configs. Returns the block type slug. + """ + if isinstance(v, dict): + return v.get("block_type_slug", "dbt-cli-target-configs") + if isinstance(v, Block): + # When creating a new instance, we get a concrete Block type + return v.get_block_type_slug() + return "dbt-cli-target-configs" # Default to base type + + class DbtCliProfile(Block): """ Profile for use across dbt CLI tasks and flows. @@ -116,11 +128,14 @@ class DbtCliProfile(Block): target: str = Field( default=..., description="The default target your dbt project will use." ) - target_configs: Union[ - SnowflakeTargetConfigs, - BigQueryTargetConfigs, - PostgresTargetConfigs, - TargetConfigs, + target_configs: Annotated[ + Union[ + Annotated[SnowflakeTargetConfigs, Tag("dbt-cli-snowflake-target-configs")], + Annotated[BigQueryTargetConfigs, Tag("dbt-cli-bigquery-target-configs")], + Annotated[PostgresTargetConfigs, Tag("dbt-cli-postgres-target-configs")], + Annotated[TargetConfigs, Tag("dbt-cli-target-configs")], + ], + Discriminator(target_configs_discriminator), ] = Field( default=..., description=( diff --git a/src/integrations/prefect-dbt/prefect_dbt/core/__init__.py b/src/integrations/prefect-dbt/prefect_dbt/core/__init__.py new file mode 100644 index 000000000000..89fc638846f3 --- /dev/null +++ b/src/integrations/prefect-dbt/prefect_dbt/core/__init__.py @@ -0,0 +1,3 @@ +from prefect_dbt.core.settings import PrefectDbtSettings + +__all__ = ["PrefectDbtSettings"] diff --git a/src/integrations/prefect-dbt/prefect_dbt/core/settings.py b/src/integrations/prefect-dbt/prefect_dbt/core/settings.py new file mode 100644 index 000000000000..2ca9a346ccec --- /dev/null +++ b/src/integrations/prefect-dbt/prefect_dbt/core/settings.py @@ -0,0 +1,17 @@ +""" +A class for configuring or automatically discovering settings to be used with PrefectDbtRunner. +""" + +from pathlib import Path + +from dbt_common.events.base_types import EventLevel +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class PrefectDbtSettings(BaseSettings): + model_config = SettingsConfigDict(env_prefix="DBT_") + + profiles_dir: Path = Field(default=Path.home() / ".dbt") + project_dir: Path = Field(default_factory=Path.cwd) + log_level: EventLevel = Field(default=EventLevel.INFO) diff --git a/src/integrations/prefect-dbt/tests/cli/test_commands.py b/src/integrations/prefect-dbt/tests/cli/test_commands.py index 96add6efa380..60381e0f7ede 100644 --- a/src/integrations/prefect-dbt/tests/cli/test_commands.py +++ b/src/integrations/prefect-dbt/tests/cli/test_commands.py @@ -1,5 +1,4 @@ import datetime -import os from pathlib import Path from unittest.mock import MagicMock @@ -409,12 +408,14 @@ def test_flow(): @pytest.mark.usefixtures("dbt_runner_ls_result") -def test_trigger_dbt_cli_command_find_env(profiles_dir, dbt_cli_profile_bare): +def test_trigger_dbt_cli_command_find_env( + profiles_dir, dbt_cli_profile_bare, monkeypatch +): @flow def test_flow(): return trigger_dbt_cli_command("ls", dbt_cli_profile=dbt_cli_profile_bare) - os.environ["DBT_PROFILES_DIR"] = str(profiles_dir) + monkeypatch.setenv("DBT_PROFILES_DIR", str(profiles_dir)) result = test_flow() assert isinstance(result, dbtRunnerResult) @@ -474,9 +475,9 @@ def dbt_cli_profile(self): ) def test_find_valid_profiles_dir_default_env( - self, tmp_path, mock_open_process, mock_shell_process + self, tmp_path, mock_open_process, mock_shell_process, monkeypatch ): - os.environ["DBT_PROFILES_DIR"] = str(tmp_path) + monkeypatch.setenv("DBT_PROFILES_DIR", str(tmp_path)) (tmp_path / "profiles.yml").write_text("test") DbtCoreOperation(commands=["dbt debug"]).run() actual = str(mock_open_process.call_args_list[0][1]["env"]["DBT_PROFILES_DIR"]) diff --git a/src/integrations/prefect-dbt/tests/cli/test_credentials.py b/src/integrations/prefect-dbt/tests/cli/test_credentials.py index c9872e5bcdbf..6db07018aa93 100644 --- a/src/integrations/prefect-dbt/tests/cli/test_credentials.py +++ b/src/integrations/prefect-dbt/tests/cli/test_credentials.py @@ -1,10 +1,11 @@ import pytest from prefect_dbt.cli.credentials import DbtCliProfile, GlobalConfigs, TargetConfigs from pydantic import ValidationError +from typing_extensions import Literal @pytest.mark.parametrize("configs_type", ["dict", "model"]) -def test_dbt_cli_profile_init(configs_type): +def test_dbt_cli_profile_init(configs_type: Literal["dict", "model"]): target_configs = dict(type="snowflake", schema="schema") global_configs = dict(use_colors=False) if configs_type == "model": @@ -60,7 +61,9 @@ def test_dbt_cli_profile_get_profile(): "class_target_configs", ], ) -async def test_dbt_cli_profile_save_load_roundtrip(target_configs_request, request): +async def test_dbt_cli_profile_save_load_roundtrip( + target_configs_request: str, request: pytest.FixtureRequest +): target_configs = request.getfixturevalue(target_configs_request) dbt_cli_profile = DbtCliProfile( name="my_name", diff --git a/src/integrations/prefect-dbt/tests/core/test_settings.py b/src/integrations/prefect-dbt/tests/core/test_settings.py new file mode 100644 index 000000000000..3285ac35c675 --- /dev/null +++ b/src/integrations/prefect-dbt/tests/core/test_settings.py @@ -0,0 +1,37 @@ +from pathlib import Path + +from dbt_common.events.base_types import EventLevel +from prefect_dbt.core.settings import PrefectDbtSettings +from pytest import MonkeyPatch + + +def test_default_settings(): + settings = PrefectDbtSettings() + assert settings.profiles_dir == Path.home() / ".dbt" + assert settings.project_dir == Path.cwd() + assert settings.log_level == EventLevel.INFO + + +def test_custom_settings(): + custom_profiles_dir = Path("/custom/profiles/dir") + custom_project_dir = Path("/custom/project/dir") + + settings = PrefectDbtSettings( + profiles_dir=custom_profiles_dir, project_dir=custom_project_dir + ) + + assert settings.profiles_dir == custom_profiles_dir + assert settings.project_dir == custom_project_dir + + +def test_env_var_override(monkeypatch: MonkeyPatch): + env_profiles_dir = "/env/profiles/dir" + env_project_dir = "/env/project/dir" + + monkeypatch.setenv("DBT_PROFILES_DIR", env_profiles_dir) + monkeypatch.setenv("DBT_PROJECT_DIR", env_project_dir) + + settings = PrefectDbtSettings() + + assert settings.profiles_dir == Path(env_profiles_dir) + assert settings.project_dir == Path(env_project_dir) diff --git a/src/prefect/blocks/core.py b/src/prefect/blocks/core.py index 173b21aa92e1..748c625d81ea 100644 --- a/src/prefect/blocks/core.py +++ b/src/prefect/blocks/core.py @@ -669,7 +669,7 @@ def _generate_code_example(cls) -> str: module_str = ".".join(qualified_name.split(".")[:-1]) origin = cls.__pydantic_generic_metadata__.get("origin") or cls class_name = origin.__name__ - block_variable_name = f'{cls.get_block_type_slug().replace("-", "_")}_block' + block_variable_name = f"{cls.get_block_type_slug().replace('-', '_')}_block" return dedent( f"""\