diff --git a/examples/sample_project/profiles.yml b/examples/sample_project/profiles.yml index bc94a7a..50956a0 100644 --- a/examples/sample_project/profiles.yml +++ b/examples/sample_project/profiles.yml @@ -1,5 +1,5 @@ sample_project: - target: prod + target: test outputs: prod: type: duckdb @@ -10,3 +10,6 @@ sample_project: test: type: duckdb path: "{{ env_var('DUCKDB_DB_FILE') }}" + override_in_test: + type: duckdb + path: /does/not/exist/so/override.duckdb diff --git a/poetry.lock b/poetry.lock index 10f56b9..93eaef8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "agate" @@ -1453,16 +1453,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2429,7 +2419,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2437,15 +2426,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2462,7 +2444,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2470,7 +2451,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2981,7 +2961,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} +greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""} typing-extensions = ">=4.2.0" [package.extras] @@ -3096,6 +3076,17 @@ dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2 doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + [[package]] name = "typing-extensions" version = "4.8.0" @@ -3374,4 +3365,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "eb1e115cb8e7707012928994e9d0badecb3474049ac72cb5e7caff96936f247b" +content-hash = "f3ec990c2ed0a50a50883371fb26874b4bfbafd19f99d736ba5c6ff1a100f65c" diff --git a/prefect_dbt_flow/dbt/__init__.py b/prefect_dbt_flow/dbt/__init__.py index e79cf81..bbf1732 100644 --- a/prefect_dbt_flow/dbt/__init__.py +++ b/prefect_dbt_flow/dbt/__init__.py @@ -38,9 +38,11 @@ class DbtProfile: Args: target: dbt target, usualy "dev" or "prod" + overrides: dbt profile overrides """ target: str + overrides: Optional[dict[str, str]] = None @dataclass diff --git a/prefect_dbt_flow/dbt/cli.py b/prefect_dbt_flow/dbt/cli.py index fa6cb73..b822f89 100644 --- a/prefect_dbt_flow/dbt/cli.py +++ b/prefect_dbt_flow/dbt/cli.py @@ -38,14 +38,18 @@ def dbt_ls( return cmd.run(" ".join(dbt_ls_cmd)) -def dbt_run(project: DbtProject, profile: DbtProfile, model: str) -> str: +def dbt_run( + project: DbtProject, + model: str, + profile: Optional[DbtProfile], +) -> str: """ Function that executes `dbt run` command Args: project: A class that represents a dbt project configuration. - profile: A class that represents a dbt profile configuration. model: Name of the model to run. + profile: A class that represents a dbt profile configuration. Returns: A string representing the output of the `dbt run` command. @@ -53,20 +57,26 @@ def dbt_run(project: DbtProject, profile: DbtProfile, model: str) -> str: dbt_run_cmd = [DBT_EXE, "run"] dbt_run_cmd.extend(["--project-dir", str(project.project_dir)]) dbt_run_cmd.extend(["--profiles-dir", str(project.profiles_dir)]) - dbt_run_cmd.extend(["-t", profile.target]) dbt_run_cmd.extend(["-m", model]) + if profile: + dbt_run_cmd.extend(["-t", profile.target]) + return cmd.run(" ".join(dbt_run_cmd)) -def dbt_test(project: DbtProject, profile: DbtProfile, model: str) -> str: +def dbt_test( + project: DbtProject, + model: str, + profile: Optional[DbtProfile], +) -> str: """ Function that executes `dbt test` command Args: project: A class that represents a dbt project configuration. - profile: A class that represents a dbt profile configuration. model: Name of the model to run. + profile: A class that represents a dbt profile configuration. Returns: A string representing the output of the `dbt test` command. @@ -74,20 +84,27 @@ def dbt_test(project: DbtProject, profile: DbtProfile, model: str) -> str: dbt_test_cmd = [DBT_EXE, "test"] dbt_test_cmd.extend(["--project-dir", str(project.project_dir)]) dbt_test_cmd.extend(["--profiles-dir", str(project.profiles_dir)]) - dbt_test_cmd.extend(["-t", profile.target]) dbt_test_cmd.extend(["-m", model]) + if profile: + dbt_test_cmd.extend(["-t", profile.target]) + return cmd.run(" ".join(dbt_test_cmd)) -def dbt_seed(project: DbtProject, profile: DbtProfile, seed: str) -> str: +def dbt_seed( + project: DbtProject, + seed: str, + profile: Optional[DbtProfile], +) -> str: """ Function that executes `dbt seed` command Args: project: A class that represents a dbt project configuration. - profile: A class that represents a dbt profile configuration. seed: Name of the seed to run. + profile: A class that represents a dbt profile configuration. + Returns: A string representing the output of the `dbt seed` command @@ -95,20 +112,27 @@ def dbt_seed(project: DbtProject, profile: DbtProfile, seed: str) -> str: dbt_seed_cmd = [DBT_EXE, "seed"] dbt_seed_cmd.extend(["--project-dir", str(project.project_dir)]) dbt_seed_cmd.extend(["--profiles-dir", str(project.profiles_dir)]) - dbt_seed_cmd.extend(["-t", profile.target]) dbt_seed_cmd.extend(["--select", seed]) + if profile: + dbt_seed_cmd.extend(["-t", profile.target]) + return cmd.run(" ".join(dbt_seed_cmd)) -def dbt_snapshot(project: DbtProject, profile: DbtProfile, snapshot: str) -> str: +def dbt_snapshot( + project: DbtProject, + snapshot: str, + profile: Optional[DbtProfile], +) -> str: """ Function that executes `dbt snapshot` command Args: project: A class that represents a dbt project configuration. - profile: A class that represents a dbt profile configuration. snapshot: Name of the snapshot to run. + profile: A class that represents a dbt profile configuration. + Returns: A string representing the output of the `dbt snapshot` command @@ -116,7 +140,9 @@ def dbt_snapshot(project: DbtProject, profile: DbtProfile, snapshot: str) -> str dbt_seed_cmd = [DBT_EXE, "snapshot"] dbt_seed_cmd.extend(["--project-dir", str(project.project_dir)]) dbt_seed_cmd.extend(["--profiles-dir", str(project.profiles_dir)]) - dbt_seed_cmd.extend(["-t", profile.target]) dbt_seed_cmd.extend(["--select", snapshot]) + if profile: + dbt_seed_cmd.extend(["-t", profile.target]) + return cmd.run(" ".join(dbt_seed_cmd)) diff --git a/prefect_dbt_flow/dbt/profile.py b/prefect_dbt_flow/dbt/profile.py new file mode 100644 index 0000000..be6ec96 --- /dev/null +++ b/prefect_dbt_flow/dbt/profile.py @@ -0,0 +1,75 @@ +"""Logic to override dbt profiles.yml""" +from contextlib import contextmanager +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Generator, Optional + +import yaml # type: ignore + +from prefect_dbt_flow.dbt import DbtProfile, DbtProject + + +@contextmanager +def override_profile( + project: DbtProject, profile: Optional[DbtProfile] +) -> Generator[DbtProject, None, None]: + """ + Override dbt profiles.yml with the given profile configuration. + + Args: + project: A class that represents a dbt project configuration. + profile: A class that represents a dbt profile configuration. + + Returns: + dbt_project: DbtProject. + """ + if not profile or not profile.overrides: + yield project + return + + dbt_project_name = _get_dbt_project_name(Path(project.project_dir)) + dbt_profile_path = Path(project.profiles_dir) / "profiles.yml" + + existing_profile_content = {} + + if dbt_profile_path.exists(): + existing_profile_content = ( + yaml.safe_load(dbt_profile_path.read_text()) + .get(dbt_project_name, {}) + .get("outputs", {}) + .get(profile.target, {}) + ) + + with TemporaryDirectory() as tmp_profiles_dir: + tmp_profiles_path = Path(tmp_profiles_dir) / "profiles.yml" + with open(tmp_profiles_path, "w") as tmp_profiles_file: + yaml.safe_dump( + { + dbt_project_name: { + "target": profile.target, + "outputs": { + profile.target: { + **existing_profile_content, + **profile.overrides, + } + }, + }, + }, + tmp_profiles_file, + ) + + yield DbtProject( + name=project.name, + project_dir=project.project_dir, + profiles_dir=tmp_profiles_dir, + ) + + +def _get_dbt_project_name(project_dir: Path) -> str: + dbt_project_path = project_dir / "dbt_project.yml" + + if not dbt_project_path.exists(): + raise ValueError(f"dbt_project.yml not found in {project_dir}") + + with open(dbt_project_path) as f: + return yaml.safe_load(f)["name"] diff --git a/prefect_dbt_flow/dbt/tasks.py b/prefect_dbt_flow/dbt/tasks.py index 666d0bb..8918901 100644 --- a/prefect_dbt_flow/dbt/tasks.py +++ b/prefect_dbt_flow/dbt/tasks.py @@ -5,6 +5,7 @@ from prefect.futures import PrefectFuture from prefect_dbt_flow.dbt import DbtNode, DbtProfile, DbtProject, DbtResourceType, cli +from prefect_dbt_flow.dbt.profile import override_profile DBT_RUN_EMOJI = "๐Ÿƒ" DBT_TEST_EMOJI = "๐Ÿงช" @@ -14,7 +15,7 @@ def _task_dbt_snapshot( project: DbtProject, - profile: DbtProfile, + profile: Optional[DbtProfile], dbt_node: DbtNode, task_kwargs: Optional[Dict] = None, ) -> Task: @@ -43,15 +44,16 @@ def dbt_snapshot(): Returns: None """ - dbt_snapshot_output = cli.dbt_snapshot(project, profile, dbt_node.name) - get_run_logger().info(dbt_snapshot_output) + with override_profile(project, profile) as _project: + dbt_snapshot_output = cli.dbt_snapshot(_project, dbt_node.name, profile) + get_run_logger().info(dbt_snapshot_output) return dbt_snapshot def _task_dbt_seed( project: DbtProject, - profile: DbtProfile, + profile: Optional[DbtProfile], dbt_node: DbtNode, task_kwargs: Optional[Dict] = None, ) -> Task: @@ -80,15 +82,16 @@ def dbt_seed(): Returns: None """ - dbt_seed_output = cli.dbt_seed(project, profile, dbt_node.name) - get_run_logger().info(dbt_seed_output) + with override_profile(project, profile) as _project: + dbt_seed_output = cli.dbt_seed(_project, dbt_node.name, profile) + get_run_logger().info(dbt_seed_output) return dbt_seed def _task_dbt_run( project: DbtProject, - profile: DbtProfile, + profile: Optional[DbtProfile], dbt_node: DbtNode, task_kwargs: Optional[Dict] = None, ) -> Task: @@ -117,15 +120,16 @@ def dbt_run(): Returns: None """ - dbt_run_output = cli.dbt_run(project, profile, dbt_node.name) - get_run_logger().info(dbt_run_output) + with override_profile(project, profile) as _project: + dbt_run_output = cli.dbt_run(_project, dbt_node.name, profile) + get_run_logger().info(dbt_run_output) return dbt_run def _task_dbt_test( project: DbtProject, - profile: DbtProfile, + profile: Optional[DbtProfile], dbt_node: DbtNode, task_kwargs: Optional[Dict] = None, ) -> Task: @@ -154,8 +158,9 @@ def dbt_test(): Returns: None """ - dbt_test_output = cli.dbt_test(project, profile, dbt_node.name) - get_run_logger().info(dbt_test_output) + with override_profile(project, profile) as _project: + dbt_test_output = cli.dbt_test(_project, dbt_node.name, profile) + get_run_logger().info(dbt_test_output) return dbt_test @@ -169,7 +174,7 @@ def dbt_test(): def generate_tasks_dag( project: DbtProject, - profile: DbtProfile, + profile: Optional[DbtProfile], dbt_graph: List[DbtNode], run_test_after_model: bool = False, ) -> None: diff --git a/pyproject.toml b/pyproject.toml index 43a8651..254d105 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ mkdocstrings = { extras = ["python"], version = "^0.23.0" } coverage = "^7.3.2" pytest-xdist = "^3.3.1" pytest-cov = "^4.1.0" +types-pyyaml = "^6.0.12.12" [build-system] requires = ["poetry-core"] diff --git a/tests/test_flow.py b/tests/test_flow.py index 53c8360..a3edeec 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -163,3 +163,82 @@ def test_flow_jaffle_shop(duckdb_db_file: Path): with duckdb.connect(str(duckdb_db_file)) as ddb: assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 9 + + +def test_flow_sample_project_overrides_new_profile(duckdb_db_file: Path): + dbt_project_path = SAMPLE_PROJECT_PATH + + my_dbt_flow = dbt_flow( + project=DbtProject( + name="sample_project", + project_dir=dbt_project_path, + profiles_dir=dbt_project_path, + ), + profile=DbtProfile( + target="something_else", + overrides={ + "type": "duckdb", + "path": str(duckdb_db_file.absolute()), + }, + ), + flow_kwargs={ + # Ensure only one process has access to the duckdb db + # file at the same time + "task_runner": SequentialTaskRunner(), + }, + ) + + my_dbt_flow() + + with duckdb.connect(str(duckdb_db_file)) as ddb: + assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4 + + +def test_flow_sample_project_overrides_existing_profile(duckdb_db_file: Path): + dbt_project_path = SAMPLE_PROJECT_PATH + + my_dbt_flow = dbt_flow( + project=DbtProject( + name="sample_project", + project_dir=dbt_project_path, + profiles_dir=dbt_project_path, + ), + profile=DbtProfile( + target="override_in_test", + overrides={ + "path": str(duckdb_db_file.absolute()), + }, + ), + flow_kwargs={ + # Ensure only one process has access to the duckdb db + # file at the same time + "task_runner": SequentialTaskRunner(), + }, + ) + + my_dbt_flow() + + with duckdb.connect(str(duckdb_db_file)) as ddb: + assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4 + + +def test_flow_sample_project_dont_specify_target(duckdb_db_file: Path): + dbt_project_path = SAMPLE_PROJECT_PATH + + my_dbt_flow = dbt_flow( + project=DbtProject( + name="sample_project", + project_dir=dbt_project_path, + profiles_dir=dbt_project_path, + ), + flow_kwargs={ + # Ensure only one process has access to the duckdb db + # file at the same time + "task_runner": SequentialTaskRunner(), + }, + ) + + my_dbt_flow() + + with duckdb.connect(str(duckdb_db_file)) as ddb: + assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4