Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow models to execute on different warehouses #488

Merged
merged 3 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,10 +925,10 @@ def get_open_for_model(
if not node:
return cls.open

def _open(connection: Connection) -> Connection:
def open_for_model(connection: Connection) -> Connection:
return cls._open(connection, node)

return _open
return open_for_model

@classmethod
def open(cls, connection: Connection) -> Connection:
Expand Down Expand Up @@ -961,9 +961,9 @@ def _open(cls, connection: Connection, node: Optional[ResultNode] = None) -> Con
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

# If a model specifies a compute resource to use the http path
# If a model specifies a compute resource the http path
# may be different than the http_path property of creds.
http_path = get_http_path(node, creds)
http_path = _get_http_path(node, creds)

def connect() -> DatabricksSQLConnectionWrapper:
try:
Expand Down Expand Up @@ -1110,19 +1110,19 @@ def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id:
return msg


def get_compute_name(node: Optional[ResultNode]) -> Optional[str]:
def _get_compute_name(node: Optional[ResultNode]) -> Optional[str]:
# Get the name of the specified compute resource from the node's
# config.
compute_name = None
if node and node.config and node.config.extra:
compute_name = node.config.extra.get("databricks_compute", None)
if node and node.config:
compute_name = node.config.get("databricks_compute", None)
return compute_name


def get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]:
def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]:
# Get the http path of the compute resource specified in the node's config.
# If none is specified return the default path from creds.
compute_name = get_compute_name(node)
compute_name = _get_compute_name(node)
if not node or not compute_name:
return creds.http_path

Expand Down
50 changes: 50 additions & 0 deletions tests/functional/adapter/warehouse_per_model/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
source = """id,name,date
1,Alice,2022-01-01
2,Bob,2022-01-02
"""

target = """
{{config(materialized='table', databricks_compute='alternate_warehouse')}}

select * from {{ ref('source') }}
"""

target2 = """
{{config(materialized='table')}}

select * from {{ ref('source') }}
"""

target3 = """
{{config(materialized='table')}}

select * from {{ ref('source') }}
"""

model_schema = """
version: 2

models:
- name: target
columns:
- name: id
- name: name
- name: date
- name: target2
config:
databricks_compute: alternate_warehouse
columns:
- name: id
- name: name
- name: date
- name: target3
columns:
- name: id
- name: name
- name: date
"""

expected_target = """id,name,date
1,Alice,2022-01-01
2,Bob,2022-01-02
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
from dbt.tests import util
from tests.functional.adapter.warehouse_per_model import fixtures


class BaseWarehousePerModel:
args_formatter = ""

@pytest.fixture(scope="class")
def seeds(self):
return {
"source.csv": fixtures.source,
}

@pytest.fixture(scope="class")
def models(self):
d = dict()
d["target4.sql"] = fixtures.target3
return {
"target.sql": fixtures.target,
"target2.sql": fixtures.target2,
"target3.sql": fixtures.target3,
"schema.yml": fixtures.model_schema,
"special": d,
}


class BaseSpecifyingCompute(BaseWarehousePerModel):
"""Base class for testing various ways to specify a warehouse."""

def test_wpm(self, project):
util.run_dbt(["seed"])
models = project.test_config.get("model_names")
for model_name in models:
# Since the profile doesn't define a compute resource named 'alternate_warehouse'
# we should fail with an error if the warehouse specified for the model is
# correctly handled.
res = util.run_dbt(["run", "--select", model_name], expect_pass=False)
msg = res.results[0].message
assert "Compute resource alternate_warehouse does not exist" in msg
assert model_name in msg


class TestSpecifyingInConfigBlock(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target"]}


class TestSpecifyingInSchemaYml(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target2"]}


class TestSpecifyingForProjectModels(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+databricks_compute": "alternate_warehouse",
}
}

@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target3"]}


class TestSpecifyingForProjectModelsInFolder(BaseSpecifyingCompute):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if we can specify a compute to use with models of a particular tag? This came up in a customer call where they would want to tag certain models as heavy_compute for example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the dbt docs on tags, I don't think that would work, which is probably fine. I think having the named compute approach gets us 95% of the way to what it would be if they could target compute to tags.

@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"test": {
"special": {
"+databricks_compute": "alternate_warehouse",
},
},
}
}

@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target4"]}


class TestWarehousePerModel(BaseWarehousePerModel):
@pytest.fixture(scope="class")
def profiles_config_update(self, dbt_profile_target):
outputs = {"default": dbt_profile_target}
outputs["default"]["compute"] = {
"alternate_warehouse": {"http_path": dbt_profile_target["http_path"]}
}
return {"test": {"outputs": outputs, "target": "default"}}

def test_wpm(self, project):
util.run_dbt(["seed"])
util.run_dbt(["run", "--select", "target"])
util.check_relations_equal(project.adapter, ["target", "source"])
178 changes: 178 additions & 0 deletions tests/unit/test_compute_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import unittest
import dbt.exceptions
from dbt.contracts.graph import nodes, model_config
from dbt.adapters.databricks import connections


class TestDatabricksConnectionHTTPPath(unittest.TestCase):
"""Test the various cases for determining a specified warehouse."""

def test_get_http_path_model(self):
default_path = "my_http_path"
creds = connections.DatabricksCredentials(http_path=default_path)

path = connections._get_http_path(None, creds)
self.assertEqual(default_path, path)

node = nodes.ModelNode(
relation_name="a_relation",
database="database",
schema="schema",
name="node_name",
resource_type="model",
package_name="package",
path="path",
original_file_path="orig_path",
unique_id="uniqueID",
fqn=[],
alias="alias",
checksum=None,
)
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config = model_config.ModelConfig()
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra = {}
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra["databricks_compute"] = "foo"
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {}}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {"http_path": "alternate_path"}}
path = connections._get_http_path(node, creds)
self.assertEqual("alternate_path", path)

def test_get_http_path_seed(self):
default_path = "my_http_path"
creds = connections.DatabricksCredentials(http_path=default_path)

path = connections._get_http_path(None, creds)
self.assertEqual(default_path, path)

node = nodes.SeedNode(
relation_name="a_relation",
database="database",
schema="schema",
name="node_name",
resource_type="model",
package_name="package",
path="path",
original_file_path="orig_path",
unique_id="uniqueID",
fqn=[],
alias="alias",
checksum=None,
)
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config = model_config.SeedConfig()
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra = {}
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra["databricks_compute"] = "foo"
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {}}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {"http_path": "alternate_path"}}
path = connections._get_http_path(node, creds)
self.assertEqual("alternate_path", path)

def test_get_http_path_snapshot(self):
default_path = "my_http_path"
creds = connections.DatabricksCredentials(http_path=default_path)

path = connections._get_http_path(None, creds)
self.assertEqual(default_path, path)

node = nodes.SnapshotNode(
config=None,
relation_name="a_relation",
database="database",
schema="schema",
name="node_name",
resource_type="model",
package_name="package",
path="path",
original_file_path="orig_path",
unique_id="uniqueID",
fqn=[],
alias="alias",
checksum=None,
)

node.config = model_config.SnapshotConfig()
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra = {}
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra["databricks_compute"] = "foo"
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {}}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {"http_path": "alternate_path"}}
path = connections._get_http_path(node, creds)
self.assertEqual("alternate_path", path)
Loading