Skip to content

Commit

Permalink
Tests for warehouse-per-model
Browse files Browse the repository at this point in the history
Signed-off-by: Raymond Cypher <[email protected]>
  • Loading branch information
rcypher-databricks committed Nov 8, 2023
1 parent 81fa7ba commit e1f89c0
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 9 deletions.
18 changes: 9 additions & 9 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,10 +925,10 @@ def get_open_for_model(
if not node:
return cls.open

def _open(connection: Connection) -> Connection:
def open_for_model(connection: Connection) -> Connection:
return cls._open(connection, node)

return _open
return open_for_model

@classmethod
def open(cls, connection: Connection) -> Connection:
Expand Down Expand Up @@ -961,9 +961,9 @@ def _open(cls, connection: Connection, node: Optional[ResultNode] = None) -> Con
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

# If a model specifies a compute resource to use the http path
# If a model specifies a compute resource the http path
# may be different than the http_path property of creds.
http_path = get_http_path(node, creds)
http_path = _get_http_path(node, creds)

def connect() -> DatabricksSQLConnectionWrapper:
try:
Expand Down Expand Up @@ -1110,19 +1110,19 @@ def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id:
return msg


def get_compute_name(node: Optional[ResultNode]) -> Optional[str]:
def _get_compute_name(node: Optional[ResultNode]) -> Optional[str]:
# Get the name of the specified compute resource from the node's
# config.
compute_name = None
if node and node.config and node.config.extra:
compute_name = node.config.extra.get("databricks_compute", None)
if node and node.config:
compute_name = node.config.get("databricks_compute", None)
return compute_name


def get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]:
def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]:
# Get the http path of the compute resource specified in the node's config.
# If none is specified return the default path from creds.
compute_name = get_compute_name(node)
compute_name = _get_compute_name(node)
if not node or not compute_name:
return creds.http_path

Expand Down
50 changes: 50 additions & 0 deletions tests/functional/adapter/warehouse_per_model/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
source = """id,name,date
1,Alice,2022-01-01
2,Bob,2022-01-02
"""

target = """
{{config(materialized='table', databricks_compute='alternate_warehouse')}}
select * from {{ ref('source') }}
"""

target2 = """
{{config(materialized='table')}}
select * from {{ ref('source') }}
"""

target3 = """
{{config(materialized='table')}}
select * from {{ ref('source') }}
"""

model_schema = """
version: 2
models:
- name: target
columns:
- name: id
- name: name
- name: date
- name: target2
config:
databricks_compute: alternate_warehouse
columns:
- name: id
- name: name
- name: date
- name: target3
columns:
- name: id
- name: name
- name: date
"""

expected_target = """id,name,date
1,Alice,2022-01-01
2,Bob,2022-01-02
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
from dbt.tests import util
from tests.functional.adapter.warehouse_per_model import fixtures


class BaseWarehousePerModel:
args_formatter = ""

@pytest.fixture(scope="class")
def seeds(self):
return {
"source.csv": fixtures.source,
}

@pytest.fixture(scope="class")
def models(self):
d = dict()
d["target4.sql"] = fixtures.target3
return {
"target.sql": fixtures.target,
"target2.sql": fixtures.target2,
"target3.sql": fixtures.target3,
"schema.yml": fixtures.model_schema,
"special": d,
}


class BaseSpecifyingCompute(BaseWarehousePerModel):
"""Base class for testing various ways to specify a warehouse."""

def test_wpm(self, project):
util.run_dbt(["seed"])
models = project.test_config.get("model_names")
for model_name in models:
# Since the profile doesn't define a compute resource named 'alternate_warehouse'
# we should fail with an error if the warehouse specified for the model is
# correctly handled.
res = util.run_dbt(["run", "--select", model_name], expect_pass=False)
msg = res.results[0].message
assert "Compute resource alternate_warehouse does not exist" in msg
assert model_name in msg


class TestSpecifyingInConfigBlock(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target"]}


class TestSpecifyingInSchemaYml(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target2"]}


class TestSpecifyingForProjectModels(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+databricks_compute": "alternate_warehouse",
}
}

@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target3"]}


class TestSpecifyingForProjectModelsInFolder(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"test": {
"special": {
"+databricks_compute": "alternate_warehouse",
},
},
}
}

@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target4"]}


class TestWarehousePerModel(BaseWarehousePerModel):
@pytest.fixture(scope="class")
def profiles_config_update(self, dbt_profile_target):
outputs = {"default": dbt_profile_target}
outputs["default"]["compute"] = {
"alternate_warehouse": {"http_path": dbt_profile_target["http_path"]}
}
return {"test": {"outputs": outputs, "target": "default"}}

def test_wpm(self, project):
util.run_dbt(["seed"])
util.run_dbt(["run", "--select", "target"])
util.check_relations_equal(project.adapter, ["target", "source"])
178 changes: 178 additions & 0 deletions tests/unit/test_compute_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import unittest
import dbt.exceptions
from dbt.contracts.graph import nodes, model_config
from dbt.adapters.databricks import connections


class TestDatabricksConnectionHTTPPath(unittest.TestCase):
"""Test the various cases for determining a specified warehouse."""

def test_get_http_path_model(self):
default_path = "my_http_path"
creds = connections.DatabricksCredentials(http_path=default_path)

path = connections._get_http_path(None, creds)
self.assertEqual(default_path, path)

node = nodes.ModelNode(
relation_name="a_relation",
database="database",
schema="schema",
name="node_name",
resource_type="model",
package_name="package",
path="path",
original_file_path="orig_path",
unique_id="uniqueID",
fqn=[],
alias="alias",
checksum=None,
)
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config = model_config.ModelConfig()
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra = {}
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra["databricks_compute"] = "foo"
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {}}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {"http_path": "alternate_path"}}
path = connections._get_http_path(node, creds)
self.assertEqual("alternate_path", path)

def test_get_http_path_seed(self):
default_path = "my_http_path"
creds = connections.DatabricksCredentials(http_path=default_path)

path = connections._get_http_path(None, creds)
self.assertEqual(default_path, path)

node = nodes.SeedNode(
relation_name="a_relation",
database="database",
schema="schema",
name="node_name",
resource_type="model",
package_name="package",
path="path",
original_file_path="orig_path",
unique_id="uniqueID",
fqn=[],
alias="alias",
checksum=None,
)
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config = model_config.SeedConfig()
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra = {}
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra["databricks_compute"] = "foo"
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {}}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {"http_path": "alternate_path"}}
path = connections._get_http_path(node, creds)
self.assertEqual("alternate_path", path)

def test_get_http_path_snapshot(self):
default_path = "my_http_path"
creds = connections.DatabricksCredentials(http_path=default_path)

path = connections._get_http_path(None, creds)
self.assertEqual(default_path, path)

node = nodes.SnapshotNode(
config=None,
relation_name="a_relation",
database="database",
schema="schema",
name="node_name",
resource_type="model",
package_name="package",
path="path",
original_file_path="orig_path",
unique_id="uniqueID",
fqn=[],
alias="alias",
checksum=None,
)

node.config = model_config.SnapshotConfig()
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra = {}
path = connections._get_http_path(node, creds)
self.assertEqual(default_path, path)

node.config._extra["databricks_compute"] = "foo"
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {}}
with self.assertRaisesRegex(
dbt.exceptions.DbtRuntimeError,
"Compute resource foo does not exist, relation: a_relation",
):
connections._get_http_path(node, creds)

creds.compute = {"foo": {"http_path": "alternate_path"}}
path = connections._get_http_path(node, creds)
self.assertEqual("alternate_path", path)

0 comments on commit e1f89c0

Please sign in to comment.