Skip to content

Commit

Permalink
Feature: dbt profile args (#18)
Browse files Browse the repository at this point in the history
* profile overrides

* make profile optional  everywhere

---------

Co-authored-by: Nico Gelders <nicolas.gelders@telenetgroup.be>
nicogelders and Nico Gelders authored Oct 23, 2023

Verified

This commit was signed with the committer’s verified signature.
farfromrefug farfromrefuge
1 parent ec1dc3b commit 4072b71
Showing 8 changed files with 231 additions and 49 deletions.
5 changes: 4 additions & 1 deletion examples/sample_project/profiles.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
sample_project:
target: prod
target: test
outputs:
prod:
type: duckdb
@@ -10,3 +10,6 @@ sample_project:
test:
type: duckdb
path: "{{ env_var('DUCKDB_DB_FILE') }}"
override_in_test:
type: duckdb
path: /does/not/exist/so/override.duckdb
37 changes: 14 additions & 23 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions prefect_dbt_flow/dbt/__init__.py
Original file line number Diff line number Diff line change
@@ -38,9 +38,11 @@ class DbtProfile:
Args:
target: dbt target, usualy "dev" or "prod"
overrides: dbt profile overrides
"""

target: str
overrides: Optional[dict[str, str]] = None


@dataclass
50 changes: 38 additions & 12 deletions prefect_dbt_flow/dbt/cli.py
Original file line number Diff line number Diff line change
@@ -38,85 +38,111 @@ def dbt_ls(
return cmd.run(" ".join(dbt_ls_cmd))


def dbt_run(project: DbtProject, profile: DbtProfile, model: str) -> str:
def dbt_run(
project: DbtProject,
model: str,
profile: Optional[DbtProfile],
) -> str:
"""
Function that executes `dbt run` command
Args:
project: A class that represents a dbt project configuration.
profile: A class that represents a dbt profile configuration.
model: Name of the model to run.
profile: A class that represents a dbt profile configuration.
Returns:
A string representing the output of the `dbt run` command.
"""
dbt_run_cmd = [DBT_EXE, "run"]
dbt_run_cmd.extend(["--project-dir", str(project.project_dir)])
dbt_run_cmd.extend(["--profiles-dir", str(project.profiles_dir)])
dbt_run_cmd.extend(["-t", profile.target])
dbt_run_cmd.extend(["-m", model])

if profile:
dbt_run_cmd.extend(["-t", profile.target])

return cmd.run(" ".join(dbt_run_cmd))


def dbt_test(project: DbtProject, profile: DbtProfile, model: str) -> str:
def dbt_test(
project: DbtProject,
model: str,
profile: Optional[DbtProfile],
) -> str:
"""
Function that executes `dbt test` command
Args:
project: A class that represents a dbt project configuration.
profile: A class that represents a dbt profile configuration.
model: Name of the model to run.
profile: A class that represents a dbt profile configuration.
Returns:
A string representing the output of the `dbt test` command.
"""
dbt_test_cmd = [DBT_EXE, "test"]
dbt_test_cmd.extend(["--project-dir", str(project.project_dir)])
dbt_test_cmd.extend(["--profiles-dir", str(project.profiles_dir)])
dbt_test_cmd.extend(["-t", profile.target])
dbt_test_cmd.extend(["-m", model])

if profile:
dbt_test_cmd.extend(["-t", profile.target])

return cmd.run(" ".join(dbt_test_cmd))


def dbt_seed(project: DbtProject, profile: DbtProfile, seed: str) -> str:
def dbt_seed(
project: DbtProject,
seed: str,
profile: Optional[DbtProfile],
) -> str:
"""
Function that executes `dbt seed` command
Args:
project: A class that represents a dbt project configuration.
profile: A class that represents a dbt profile configuration.
seed: Name of the seed to run.
profile: A class that represents a dbt profile configuration.
Returns:
A string representing the output of the `dbt seed` command
"""
dbt_seed_cmd = [DBT_EXE, "seed"]
dbt_seed_cmd.extend(["--project-dir", str(project.project_dir)])
dbt_seed_cmd.extend(["--profiles-dir", str(project.profiles_dir)])
dbt_seed_cmd.extend(["-t", profile.target])
dbt_seed_cmd.extend(["--select", seed])

if profile:
dbt_seed_cmd.extend(["-t", profile.target])

return cmd.run(" ".join(dbt_seed_cmd))


def dbt_snapshot(project: DbtProject, profile: DbtProfile, snapshot: str) -> str:
def dbt_snapshot(
project: DbtProject,
snapshot: str,
profile: Optional[DbtProfile],
) -> str:
"""
Function that executes `dbt snapshot` command
Args:
project: A class that represents a dbt project configuration.
profile: A class that represents a dbt profile configuration.
snapshot: Name of the snapshot to run.
profile: A class that represents a dbt profile configuration.
Returns:
A string representing the output of the `dbt snapshot` command
"""
dbt_seed_cmd = [DBT_EXE, "snapshot"]
dbt_seed_cmd.extend(["--project-dir", str(project.project_dir)])
dbt_seed_cmd.extend(["--profiles-dir", str(project.profiles_dir)])
dbt_seed_cmd.extend(["-t", profile.target])
dbt_seed_cmd.extend(["--select", snapshot])

if profile:
dbt_seed_cmd.extend(["-t", profile.target])

return cmd.run(" ".join(dbt_seed_cmd))
75 changes: 75 additions & 0 deletions prefect_dbt_flow/dbt/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Logic to override dbt profiles.yml"""
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator, Optional

import yaml # type: ignore

from prefect_dbt_flow.dbt import DbtProfile, DbtProject


@contextmanager
def override_profile(
project: DbtProject, profile: Optional[DbtProfile]
) -> Generator[DbtProject, None, None]:
"""
Override dbt profiles.yml with the given profile configuration.
Args:
project: A class that represents a dbt project configuration.
profile: A class that represents a dbt profile configuration.
Returns:
dbt_project: DbtProject.
"""
if not profile or not profile.overrides:
yield project
return

dbt_project_name = _get_dbt_project_name(Path(project.project_dir))
dbt_profile_path = Path(project.profiles_dir) / "profiles.yml"

existing_profile_content = {}

if dbt_profile_path.exists():
existing_profile_content = (
yaml.safe_load(dbt_profile_path.read_text())
.get(dbt_project_name, {})
.get("outputs", {})
.get(profile.target, {})
)

with TemporaryDirectory() as tmp_profiles_dir:
tmp_profiles_path = Path(tmp_profiles_dir) / "profiles.yml"
with open(tmp_profiles_path, "w") as tmp_profiles_file:
yaml.safe_dump(
{
dbt_project_name: {
"target": profile.target,
"outputs": {
profile.target: {
**existing_profile_content,
**profile.overrides,
}
},
},
},
tmp_profiles_file,
)

yield DbtProject(
name=project.name,
project_dir=project.project_dir,
profiles_dir=tmp_profiles_dir,
)


def _get_dbt_project_name(project_dir: Path) -> str:
dbt_project_path = project_dir / "dbt_project.yml"

if not dbt_project_path.exists():
raise ValueError(f"dbt_project.yml not found in {project_dir}")

with open(dbt_project_path) as f:
return yaml.safe_load(f)["name"]
31 changes: 18 additions & 13 deletions prefect_dbt_flow/dbt/tasks.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from prefect.futures import PrefectFuture

from prefect_dbt_flow.dbt import DbtNode, DbtProfile, DbtProject, DbtResourceType, cli
from prefect_dbt_flow.dbt.profile import override_profile

DBT_RUN_EMOJI = "🏃"
DBT_TEST_EMOJI = "🧪"
@@ -14,7 +15,7 @@

def _task_dbt_snapshot(
project: DbtProject,
profile: DbtProfile,
profile: Optional[DbtProfile],
dbt_node: DbtNode,
task_kwargs: Optional[Dict] = None,
) -> Task:
@@ -43,15 +44,16 @@ def dbt_snapshot():
Returns:
None
"""
dbt_snapshot_output = cli.dbt_snapshot(project, profile, dbt_node.name)
get_run_logger().info(dbt_snapshot_output)
with override_profile(project, profile) as _project:
dbt_snapshot_output = cli.dbt_snapshot(_project, dbt_node.name, profile)
get_run_logger().info(dbt_snapshot_output)

return dbt_snapshot


def _task_dbt_seed(
project: DbtProject,
profile: DbtProfile,
profile: Optional[DbtProfile],
dbt_node: DbtNode,
task_kwargs: Optional[Dict] = None,
) -> Task:
@@ -80,15 +82,16 @@ def dbt_seed():
Returns:
None
"""
dbt_seed_output = cli.dbt_seed(project, profile, dbt_node.name)
get_run_logger().info(dbt_seed_output)
with override_profile(project, profile) as _project:
dbt_seed_output = cli.dbt_seed(_project, dbt_node.name, profile)
get_run_logger().info(dbt_seed_output)

return dbt_seed


def _task_dbt_run(
project: DbtProject,
profile: DbtProfile,
profile: Optional[DbtProfile],
dbt_node: DbtNode,
task_kwargs: Optional[Dict] = None,
) -> Task:
@@ -117,15 +120,16 @@ def dbt_run():
Returns:
None
"""
dbt_run_output = cli.dbt_run(project, profile, dbt_node.name)
get_run_logger().info(dbt_run_output)
with override_profile(project, profile) as _project:
dbt_run_output = cli.dbt_run(_project, dbt_node.name, profile)
get_run_logger().info(dbt_run_output)

return dbt_run


def _task_dbt_test(
project: DbtProject,
profile: DbtProfile,
profile: Optional[DbtProfile],
dbt_node: DbtNode,
task_kwargs: Optional[Dict] = None,
) -> Task:
@@ -154,8 +158,9 @@ def dbt_test():
Returns:
None
"""
dbt_test_output = cli.dbt_test(project, profile, dbt_node.name)
get_run_logger().info(dbt_test_output)
with override_profile(project, profile) as _project:
dbt_test_output = cli.dbt_test(_project, dbt_node.name, profile)
get_run_logger().info(dbt_test_output)

return dbt_test

@@ -169,7 +174,7 @@ def dbt_test():

def generate_tasks_dag(
project: DbtProject,
profile: DbtProfile,
profile: Optional[DbtProfile],
dbt_graph: List[DbtNode],
run_test_after_model: bool = False,
) -> None:
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ mkdocstrings = { extras = ["python"], version = "^0.23.0" }
coverage = "^7.3.2"
pytest-xdist = "^3.3.1"
pytest-cov = "^4.1.0"
types-pyyaml = "^6.0.12.12"

[build-system]
requires = ["poetry-core"]
79 changes: 79 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
@@ -163,3 +163,82 @@ def test_flow_jaffle_shop(duckdb_db_file: Path):

with duckdb.connect(str(duckdb_db_file)) as ddb:
assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 9


def test_flow_sample_project_overrides_new_profile(duckdb_db_file: Path):
dbt_project_path = SAMPLE_PROJECT_PATH

my_dbt_flow = dbt_flow(
project=DbtProject(
name="sample_project",
project_dir=dbt_project_path,
profiles_dir=dbt_project_path,
),
profile=DbtProfile(
target="something_else",
overrides={
"type": "duckdb",
"path": str(duckdb_db_file.absolute()),
},
),
flow_kwargs={
# Ensure only one process has access to the duckdb db
# file at the same time
"task_runner": SequentialTaskRunner(),
},
)

my_dbt_flow()

with duckdb.connect(str(duckdb_db_file)) as ddb:
assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4


def test_flow_sample_project_overrides_existing_profile(duckdb_db_file: Path):
dbt_project_path = SAMPLE_PROJECT_PATH

my_dbt_flow = dbt_flow(
project=DbtProject(
name="sample_project",
project_dir=dbt_project_path,
profiles_dir=dbt_project_path,
),
profile=DbtProfile(
target="override_in_test",
overrides={
"path": str(duckdb_db_file.absolute()),
},
),
flow_kwargs={
# Ensure only one process has access to the duckdb db
# file at the same time
"task_runner": SequentialTaskRunner(),
},
)

my_dbt_flow()

with duckdb.connect(str(duckdb_db_file)) as ddb:
assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4


def test_flow_sample_project_dont_specify_target(duckdb_db_file: Path):
dbt_project_path = SAMPLE_PROJECT_PATH

my_dbt_flow = dbt_flow(
project=DbtProject(
name="sample_project",
project_dir=dbt_project_path,
profiles_dir=dbt_project_path,
),
flow_kwargs={
# Ensure only one process has access to the duckdb db
# file at the same time
"task_runner": SequentialTaskRunner(),
},
)

my_dbt_flow()

with duckdb.connect(str(duckdb_db_file)) as ddb:
assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4

0 comments on commit 4072b71

Please sign in to comment.