Skip to content

Commit

Permalink
Allow models to execute on different warehouses (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
rcypher-databricks authored Nov 10, 2023
2 parents 6e42d38 + e1f89c0 commit 7c9fc66
Show file tree
Hide file tree
Showing 5 changed files with 457 additions and 1 deletion.
110 changes: 109 additions & 1 deletion dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,15 @@
Connection,
ConnectionState,
DEFAULT_QUERY_COMMENT,
Identifier,
LazyHandle,
)
from dbt.events.types import (
NewConnection,
ConnectionReused,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.events import AdapterLogger
from dbt.events.contextvars import get_node_info
from dbt.events.functions import fire_event
Expand Down Expand Up @@ -111,6 +118,10 @@ class DatabricksCredentials(Credentials):
connection_parameters: Optional[Dict[str, Any]] = None
auth_type: Optional[str] = None

# Named compute resources specified in the profile. Used for
# creating a connection when a model specifies a compute resource.
compute: Optional[Dict[str, Any]] = None

connect_retries: int = 1
connect_timeout: Optional[int] = None
retry_all: bool = False
Expand Down Expand Up @@ -741,6 +752,50 @@ def exception_handler(self, sql: str) -> Iterator[None]:
else:
raise dbt.exceptions.DbtRuntimeError(str(exc)) from exc

# override/overload
def set_connection_name(
self, name: Optional[str] = None, node: Optional[ResultNode] = None
) -> Connection:
"""Called by 'acquire_connection' in DatabricksAdapter, which is called by
'connection_named', called by 'connection_for(node)'.
Creates a connection for this thread if one doesn't already
exist, and will rename an existing connection."""

conn_name: str = "master" if name is None else name

# Get a connection for this thread
conn = self.get_if_exists()

if conn and conn.name == conn_name and conn.state == "open":
# Found a connection and nothing to do, so just return it
return conn

if conn is None:
# Create a new connection
conn = Connection(
type=Identifier(self.TYPE),
name=conn_name,
state=ConnectionState.INIT,
transaction_open=False,
handle=None,
credentials=self.profile.credentials,
)
conn.handle = LazyHandle(self.get_open_for_model(node))
# Add the connection to thread_connections for this thread
self.set_thread_connection(conn)
fire_event(
NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info())
)
else: # existing connection either wasn't open or didn't have the right name
if conn.state != "open":
conn.handle = LazyHandle(self.get_open_for_model(node))
if conn.name != conn_name:
orig_conn_name: str = conn.name or ""
conn.name = conn_name
fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name))

return conn

def add_query(
self,
sql: str,
Expand Down Expand Up @@ -861,8 +916,29 @@ def list_tables(self, database: str, schema: str, identifier: Optional[str] = No
),
)

@classmethod
def get_open_for_model(
cls, node: Optional[ResultNode] = None
) -> Callable[[Connection], Connection]:
# If there is no node we can simply return the exsting class method open.
# If there is a node create a closure that will call cls._open with the node.
if not node:
return cls.open

def open_for_model(connection: Connection) -> Connection:
return cls._open(connection, node)

return open_for_model

@classmethod
def open(cls, connection: Connection) -> Connection:
# Simply call _open with no ResultNode argument.
# Because this is an overridden method we can't just add
# a ResultNode parameter to open.
return cls._open(connection)

@classmethod
def _open(cls, connection: Connection, node: Optional[ResultNode] = None) -> Connection:
if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection
Expand All @@ -885,12 +961,16 @@ def open(cls, connection: Connection) -> Connection:
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

# If a model specifies a compute resource the http path
# may be different than the http_path property of creds.
http_path = _get_http_path(node, creds)

def connect() -> DatabricksSQLConnectionWrapper:
try:
# TODO: what is the error when a user specifies a catalog they don't have access to
conn: DatabricksSQLConnection = dbsql.connect(
server_hostname=creds.host,
http_path=creds.http_path,
http_path=http_path,
credentials_provider=cls.credentials_provider,
http_headers=http_headers if http_headers else None,
session_configuration=creds.session_properties,
Expand Down Expand Up @@ -1028,3 +1108,31 @@ def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id:
msg = error_events[0].get("message", "")

return msg


def _get_compute_name(node: Optional[ResultNode]) -> Optional[str]:
# Get the name of the specified compute resource from the node's
# config.
compute_name = None
if node and node.config:
compute_name = node.config.get("databricks_compute", None)
return compute_name


def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]:
# Get the http path of the compute resource specified in the node's config.
# If none is specified return the default path from creds.
compute_name = _get_compute_name(node)
if not node or not compute_name:
return creds.http_path

http_path = None
if creds.compute:
http_path = creds.compute.get(compute_name, {}).get("http_path", None)

if not http_path:
raise dbt.exceptions.DbtRuntimeError(
f"Compute resource {compute_name} does not exist, relation: {node.relation_name}"
)

return http_path
20 changes: 20 additions & 0 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER, empty_table
from dbt.contracts.connection import AdapterResponse, Connection
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.relation import RelationType
import dbt.exceptions
from dbt.events import AdapterLogger
Expand Down Expand Up @@ -118,6 +119,25 @@ class DatabricksAdapter(SparkAdapter):
}
)

# override/overload
def acquire_connection(
self, name: Optional[str] = None, node: Optional[ResultNode] = None
) -> Connection:
return self.connections.set_connection_name(name, node)

# override
@contextmanager
def connection_named(self, name: str, node: Optional[ResultNode] = None) -> Iterator[None]:
try:
if self.connections.query_header is not None:
self.connections.query_header.set(name, node)
self.acquire_connection(name, node)
yield
finally:
self.release_connection()
if self.connections.query_header is not None:
self.connections.query_header.reset()

@available.parse(lambda *a, **k: 0)
def compare_dbr_version(self, major: int, minor: int) -> int:
"""
Expand Down
50 changes: 50 additions & 0 deletions tests/functional/adapter/warehouse_per_model/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
source = """id,name,date
1,Alice,2022-01-01
2,Bob,2022-01-02
"""

target = """
{{config(materialized='table', databricks_compute='alternate_warehouse')}}
select * from {{ ref('source') }}
"""

target2 = """
{{config(materialized='table')}}
select * from {{ ref('source') }}
"""

target3 = """
{{config(materialized='table')}}
select * from {{ ref('source') }}
"""

model_schema = """
version: 2
models:
- name: target
columns:
- name: id
- name: name
- name: date
- name: target2
config:
databricks_compute: alternate_warehouse
columns:
- name: id
- name: name
- name: date
- name: target3
columns:
- name: id
- name: name
- name: date
"""

expected_target = """id,name,date
1,Alice,2022-01-01
2,Bob,2022-01-02
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
from dbt.tests import util
from tests.functional.adapter.warehouse_per_model import fixtures


class BaseWarehousePerModel:
args_formatter = ""

@pytest.fixture(scope="class")
def seeds(self):
return {
"source.csv": fixtures.source,
}

@pytest.fixture(scope="class")
def models(self):
d = dict()
d["target4.sql"] = fixtures.target3
return {
"target.sql": fixtures.target,
"target2.sql": fixtures.target2,
"target3.sql": fixtures.target3,
"schema.yml": fixtures.model_schema,
"special": d,
}


class BaseSpecifyingCompute(BaseWarehousePerModel):
"""Base class for testing various ways to specify a warehouse."""

def test_wpm(self, project):
util.run_dbt(["seed"])
models = project.test_config.get("model_names")
for model_name in models:
# Since the profile doesn't define a compute resource named 'alternate_warehouse'
# we should fail with an error if the warehouse specified for the model is
# correctly handled.
res = util.run_dbt(["run", "--select", model_name], expect_pass=False)
msg = res.results[0].message
assert "Compute resource alternate_warehouse does not exist" in msg
assert model_name in msg


class TestSpecifyingInConfigBlock(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target"]}


class TestSpecifyingInSchemaYml(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target2"]}


class TestSpecifyingForProjectModels(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+databricks_compute": "alternate_warehouse",
}
}

@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target3"]}


class TestSpecifyingForProjectModelsInFolder(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"test": {
"special": {
"+databricks_compute": "alternate_warehouse",
},
},
}
}

@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target4"]}


class TestWarehousePerModel(BaseWarehousePerModel):
@pytest.fixture(scope="class")
def profiles_config_update(self, dbt_profile_target):
outputs = {"default": dbt_profile_target}
outputs["default"]["compute"] = {
"alternate_warehouse": {"http_path": dbt_profile_target["http_path"]}
}
return {"test": {"outputs": outputs, "target": "default"}}

def test_wpm(self, project):
util.run_dbt(["seed"])
util.run_dbt(["run", "--select", "target"])
util.check_relations_equal(project.adapter, ["target", "source"])
Loading

0 comments on commit 7c9fc66

Please sign in to comment.