diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index c6a8c2053..98219a177 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -37,8 +37,15 @@ Connection, ConnectionState, DEFAULT_QUERY_COMMENT, + Identifier, + LazyHandle, +) +from dbt.events.types import ( + NewConnection, + ConnectionReused, ) from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.nodes import ResultNode from dbt.events import AdapterLogger from dbt.events.contextvars import get_node_info from dbt.events.functions import fire_event @@ -111,6 +118,10 @@ class DatabricksCredentials(Credentials): connection_parameters: Optional[Dict[str, Any]] = None auth_type: Optional[str] = None + # Named compute resources specified in the profile. Used for + # creating a connection when a model specifies a compute resource. + compute: Optional[Dict[str, Any]] = None + connect_retries: int = 1 connect_timeout: Optional[int] = None retry_all: bool = False @@ -739,6 +750,50 @@ def exception_handler(self, sql: str) -> Iterator[None]: else: raise dbt.exceptions.DbtRuntimeError(str(exc)) from exc + # override/overload + def set_connection_name( + self, name: Optional[str] = None, node: Optional[ResultNode] = None + ) -> Connection: + """Called by 'acquire_connection' in DatabricksAdapter, which is called by + 'connection_named', called by 'connection_for(node)'. + Creates a connection for this thread if one doesn't already + exist, and will rename an existing connection.""" + + conn_name: str = "master" if name is None else name + + # Get a connection for this thread + conn = self.get_if_exists() + + if conn and conn.name == conn_name and conn.state == "open": + # Found a connection and nothing to do, so just return it + return conn + + if conn is None: + # Create a new connection + conn = Connection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, + ) + conn.handle = LazyHandle(self.get_open_for_model(node)) + # Add the connection to thread_connections for this thread + self.set_thread_connection(conn) + fire_event( + NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) + ) + else: # existing connection either wasn't open or didn't have the right name + if conn.state != "open": + conn.handle = LazyHandle(self.get_open_for_model(node)) + if conn.name != conn_name: + orig_conn_name: str = conn.name or "" + conn.name = conn_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) + + return conn + def add_query( self, sql: str, @@ -847,8 +902,29 @@ def list_tables(self, database: str, schema: str, identifier: Optional[str] = No ), ) + @classmethod + def get_open_for_model( + cls, node: Optional[ResultNode] = None + ) -> Callable[[Connection], Connection]: + # If there is no node we can simply return the exsting class method open. + # If there is a node create a closure that will call cls._open with the node. + if not node: + return cls.open + + def _open(connection: Connection) -> Connection: + return cls._open(connection, node) + + return _open + @classmethod def open(cls, connection: Connection) -> Connection: + # Simply call _open with no ResultNode argument. + # Because this is an overridden method we can't just add + # a ResultNode parameter to open. + return cls._open(connection) + + @classmethod + def _open(cls, connection: Connection, node: Optional[ResultNode] = None) -> Connection: if connection.state == ConnectionState.OPEN: logger.debug("Connection is already open, skipping open.") return connection @@ -871,12 +947,16 @@ def open(cls, connection: Connection) -> Connection: creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() ) + # If a model specifies a compute resource to use the http path + # may be different than the http_path property of creds. + http_path = get_http_path(node, creds) + def connect() -> DatabricksSQLConnectionWrapper: try: # TODO: what is the error when a user specifies a catalog they don't have access to conn: DatabricksSQLConnection = dbsql.connect( server_hostname=creds.host, - http_path=creds.http_path, + http_path=http_path, credentials_provider=cls.credentials_provider, http_headers=http_headers if http_headers else None, session_configuration=creds.session_properties, @@ -1014,3 +1094,23 @@ def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id: msg = error_events[0].get("message", "") return msg + + +def get_compute_name(node: Optional[ResultNode]) -> Optional[str]: + # Get the name of the specified compute resource from the node's + # config. + compute_name = None + if node and node.config and node.config.extra: + compute_name = node.config.extra.get("databricks_compute", None) + return compute_name + + +def get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]: + # Get the http path of the compute resource specified in the node's config. + # If none is specified return the default path from creds. + compute_name = get_compute_name(node) + http_path = creds.http_path + if compute_name and creds.compute: + http_path = creds.compute.get(compute_name, {}).get("http_path", creds.http_path) + + return http_path diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index e1d6e1973..b33a54924 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -107,6 +107,25 @@ class DatabricksAdapter(SparkAdapter): AdapterSpecificConfigs = DatabricksConfig + # override/overload + def acquire_connection( + self, name: Optional[str] = None, node: Optional[ResultNode] = None + ) -> Connection: + return self.connections.set_connection_name(name, node) + + # override + @contextmanager + def connection_named(self, name: str, node: Optional[ResultNode] = None) -> Iterator[None]: + try: + if self.connections.query_header is not None: + self.connections.query_header.set(name, node) + self.acquire_connection(name, node) + yield + finally: + self.release_connection() + if self.connections.query_header is not None: + self.connections.query_header.reset() + @available.parse(lambda *a, **k: 0) def compare_dbr_version(self, major: int, minor: int) -> int: """