From 5af234c0ac16f9f29c14af3e4e71e8c9a4a86406 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Mon, 20 Nov 2023 16:53:31 -0700 Subject: [PATCH 1/8] Extend Connection and add last used time and acquire/release counter Added USE_LONG_SESSIONS flag. Extended Connection with DatabricksDBTConnection. Added properties acquire_release_count and last_used_time. Updated DatabricksConnectionManager.set_connection_name to create instance of DatabricksDBTConnection and update new properties. Override ConnectionManager.release to not close the session and to update the new connection properties. Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 51 ++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 5ce424434..4e0d7e864 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -105,6 +105,8 @@ def emit(self, record: logging.LogRecord) -> None: CLIENT_ID = "dbt-databricks" SCOPES = ["all-apis", "offline_access"] +USE_LONG_SESSIONS = os.environ.get("DBT_DATABRICKS_LONG_SESSIONS", "true").upper() == "TRUE" + @dataclass class DatabricksCredentials(Credentials): @@ -714,6 +716,12 @@ class DatabricksAdapterResponse(AdapterResponse): query_id: str = "" +@dataclass(init=False) +class DatabricksDBTConnection(Connection): + last_used_time: Optional[str] = None + acquire_release_count: int = 0 + + class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_provider: CredentialsProvider = None @@ -773,14 +781,25 @@ def set_connection_name( if conn is None: # Create a new connection - conn = Connection( - type=Identifier(self.TYPE), - name=conn_name, - state=ConnectionState.INIT, - transaction_open=False, - handle=None, - credentials=self.profile.credentials, - ) + if USE_LONG_SESSIONS: + conn = DatabricksDBTConnection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, + ) + else: + conn = Connection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, + ) + conn.handle = LazyHandle(self.get_open_for_model(node)) # Add the connection to thread_connections for this thread self.set_thread_connection(conn) @@ -795,8 +814,24 @@ def set_connection_name( conn.name = conn_name fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) + if USE_LONG_SESSIONS: + conn.last_used_time = None + conn.acquire_release_count += 1 + return conn + def release(self) -> None: + if USE_LONG_SESSIONS: + with self.lock: + conn = self.get_if_exists() + if conn is None: + return + + conn.acquire_release_count -= 1 + conn.last_used_time = time.time() + else: + super().release() + def add_query( self, sql: str, From 573c3aceda01694ccee928a76859af29cf27eb9a Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Tue, 21 Nov 2023 14:00:46 -0700 Subject: [PATCH 2/8] Added new connection pool. Added new connection pool (threads_compute_connections) to DatabricksConnectionManager. It is a map of thread ID to map of compute name to DatabricksDBTConnection. Updated DatabricksConnectionManager.set_connection_name() to look for existing connections in the new pool and then update the existing thread_connections connection pool. Overrode cleanup_all in DatabricksConnectionManager to fire events based on the connection acquire/release count, rather than the connection state. Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 112 ++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 4e0d7e864..64d827ddb 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -24,6 +24,7 @@ Tuple, cast, Union, + Hashable, ) from agate import Table @@ -35,6 +36,7 @@ from dbt.clients import agate_helper from dbt.contracts.connection import ( AdapterResponse, + AdapterRequiredConfig, Connection, ConnectionState, DEFAULT_QUERY_COMMENT, @@ -44,7 +46,10 @@ from dbt.events.types import ( NewConnection, ConnectionReused, + ConnectionLeftOpenInCleanup, + ConnectionClosedInCleanup, ) + from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import ResultNode from dbt.events import AdapterLogger @@ -718,14 +723,23 @@ class DatabricksAdapterResponse(AdapterResponse): @dataclass(init=False) class DatabricksDBTConnection(Connection): - last_used_time: Optional[str] = None + last_used_time: Optional[float] = None acquire_release_count: int = 0 + compute_name: str = "" + http_path: str = "" class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_provider: CredentialsProvider = None + def __init__(self, profile: AdapterRequiredConfig) -> None: + super().__init__(profile) + if USE_LONG_SESSIONS: + self.threads_compute_connections: Dict[ + Hashable, Dict[Hashable, DatabricksDBTConnection] + ] = {} + def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) @@ -773,7 +787,10 @@ def set_connection_name( conn_name: str = "master" if name is None else name # Get a connection for this thread - conn = self.get_if_exists() + if USE_LONG_SESSIONS: + conn = self.get_if_exists_compute(_get_compute_name(node) or "") + else: + conn = self.get_if_exists() if conn and conn.name == conn_name and conn.state == "open": # Found a connection and nothing to do, so just return it @@ -782,6 +799,10 @@ def set_connection_name( if conn is None: # Create a new connection if USE_LONG_SESSIONS: + compute_name = _get_compute_name(node=node) or "" + logger.debug( + f"Creating DatabricksDBTConnection. thread: {self.get_thread_identifier}, compute: `{compute_name}`" + ) conn = DatabricksDBTConnection( type=Identifier(self.TYPE), name=conn_name, @@ -790,6 +811,10 @@ def set_connection_name( handle=None, credentials=self.profile.credentials, ) + conn.compute_name = compute_name + conn.http_path = _get_http_path(node=node, creds=self.profile.credentials) + self.set_thread_compute_connection(conn) + self.clear_thread_connection() else: conn = Connection( type=Identifier(self.TYPE), @@ -813,10 +838,30 @@ def set_connection_name( orig_conn_name: str = conn.name or "" conn.name = conn_name fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) + if USE_LONG_SESSIONS: + current_thread_conn = self.get_if_exists() + if current_thread_conn.compute_name != conn.compute_name: + self.clear_thread_connection() + self.set_thread_connection(conn) + logger.debug( + f"Reusing DatabricksDBTConnection. thread: {self.get_thread_identifier}, compute: `{conn.compute_name}`" + ) + if node: + if not conn.compute_name: + logger.debug( + f"On thread {self.get_thread_identifier}: {node.relation_name} using default compute resource." + ) + else: + logger.debug( + f"On thread {self.get_thread_identifier}: {node.relation_name} using compute resource '{conn.compute_name}'." + ) if USE_LONG_SESSIONS: conn.last_used_time = None conn.acquire_release_count += 1 + logger.debug( + f"DatabricksDBTConnection: thread: {self.get_thread_identifier}, compute: `{conn.compute_name}`, acquire_release_count: {conn.acquire_release_count}, last_used: {conn.last_used_time}" + ) return conn @@ -829,9 +874,72 @@ def release(self) -> None: conn.acquire_release_count -= 1 conn.last_used_time = time.time() + logger.debug( + f"release DatabricksDBTConnection: thread: {self.get_thread_identifier}, compute: `{conn.compute_name}`, acquire_release_count: {conn.acquire_release_count}, last_used: {conn.last_used_time}" + ) else: super().release() + def cleanup_all(self) -> None: + if USE_LONG_SESSIONS: + with self.lock: + for thread_connections in self.threads_compute_connections.values(): + for connection in thread_connections.values(): + if connection.acquire_release_count > 0 and connection.state not in { + "init" + }: + fire_event( + ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) + ) + else: + fire_event( + ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) + ) + self.close(connection) + + # garbage collect these connections + self.thread_connections.clear() + self.threads_compute_connections.clear() + else: + super().cleanup_all() + + def set_thread_compute_connection(self, conn: Connection) -> None: + if not USE_LONG_SESSIONS: + raise dbt.exceptions.DbtInternalError( + "set_thread_compute_connection() should not be called when USE_LONG_SESSIONS is False" + ) + + thread_map = self.get_thread_compute_connections() + if conn.compute_name in thread_map: + raise dbt.exceptions.DbtInternalError( + f"In set_thread_compute_connection, connection exists for `{conn.compute_name}`" + ) + thread_map[conn.compute_name] = conn + + def get_thread_compute_connections(self) -> Dict[Hashable, DatabricksDBTConnection]: + if not USE_LONG_SESSIONS: + raise dbt.exceptions.DbtInternalError( + "get_thread_compute_connections() should not be called when USE_LONG_SESSIONS is False" + ) + + thread_id = self.get_thread_identifier() + with self.lock: + thread_map = self.threads_compute_connections.get(thread_id) + if not thread_map: + thread_map = {} + self.threads_compute_connections[thread_id] = thread_map + return thread_map + + def get_if_exists_compute(self, compute_name: str) -> Optional[Connection]: + if not USE_LONG_SESSIONS: + raise dbt.exceptions.DbtInternalError( + "get_if_exists_compute() should not be called when USE_LONG_SESSIONS is False" + ) + + with self.lock: + threads_map = self.get_thread_compute_connections() + return threads_map.get(compute_name) + def add_query( self, sql: str, From b1e1532d72b2b0a700a70d2230b91c2db3f71653 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Wed, 22 Nov 2023 13:53:59 -0700 Subject: [PATCH 3/8] Refactoring of long session code. Added _start_using and _stop_using to DatabricksDBTConnection. These handle logging, last_used_time, and acquire_release_count. Moved long session specific code from DatabricksConnectionManager.set_connection_name to its own method. Refactoring and renaming of the long session code. Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 311 ++++++++++++++++--------- 1 file changed, 203 insertions(+), 108 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 64d827ddb..efdd0eb2e 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -110,7 +110,7 @@ def emit(self, record: logging.LogRecord) -> None: CLIENT_ID = "dbt-databricks" SCOPES = ["all-apis", "offline_access"] -USE_LONG_SESSIONS = os.environ.get("DBT_DATABRICKS_LONG_SESSIONS", "true").upper() == "TRUE" +USE_LONG_SESSIONS = os.environ.get("DBT_DATABRICKS_LONG_SESSIONS", "FALSE").upper() == "TRUE" @dataclass @@ -727,6 +727,41 @@ class DatabricksDBTConnection(Connection): acquire_release_count: int = 0 compute_name: str = "" http_path: str = "" + thread_identifier: Tuple[int, int] = (0, 0) + + def _acquire(self, node: Optional[ResultNode]): + """Indicate that this connection is in use.""" + logger.debug(f"DatabricksDBTConnection._acquire: {self._get_conn_info_str()}") + self._log_usage(node) + # Make sure this won't be cleaned up as being idle for too long. + self.last_used_time = None + self.acquire_release_count += 1 + + def _release(self): + """Indicate that this connection is not in use.""" + logger.debug(f"DatabricksDBTConnection._release: {self._get_conn_info_str()}") + self.last_used_time = time.time() + # Need to check for > 0 because in some situations the dbt code will make an extra + # release call on a connection. + if self.acquire_release_count > 0: + self.acquire_release_count -= 1 + + def _get_conn_info_str(self) -> str: + """Generate a string describing this connection.""" + return f"name: {self.name}, thread: {self.thread_identifier}, compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count}, last_used: {self.last_used_time}" + + def _log_usage(self, node: Optional[ResultNode]) -> None: + if node: + if not self.compute_name: + logger.debug( + f"On thread {self.thread_identifier}: {node.relation_name} using default compute resource." + ) + else: + logger.debug( + f"On thread {self.thread_identifier}: {node.relation_name} using compute resource '{self.compute_name}'." + ) + else: + logger.debug(f"Thread {self.thread_identifier} using default compute resource.") class DatabricksConnectionManager(SparkConnectionManager): @@ -784,13 +819,13 @@ def set_connection_name( Creates a connection for this thread if one doesn't already exist, and will rename an existing connection.""" + if USE_LONG_SESSIONS: + return self._get_compute_connection(name, node) + conn_name: str = "master" if name is None else name # Get a connection for this thread - if USE_LONG_SESSIONS: - conn = self.get_if_exists_compute(_get_compute_name(node) or "") - else: - conn = self.get_if_exists() + conn = self.get_if_exists() if conn and conn.name == conn_name and conn.state == "open": # Found a connection and nothing to do, so just return it @@ -798,34 +833,17 @@ def set_connection_name( if conn is None: # Create a new connection - if USE_LONG_SESSIONS: - compute_name = _get_compute_name(node=node) or "" - logger.debug( - f"Creating DatabricksDBTConnection. thread: {self.get_thread_identifier}, compute: `{compute_name}`" - ) - conn = DatabricksDBTConnection( - type=Identifier(self.TYPE), - name=conn_name, - state=ConnectionState.INIT, - transaction_open=False, - handle=None, - credentials=self.profile.credentials, - ) - conn.compute_name = compute_name - conn.http_path = _get_http_path(node=node, creds=self.profile.credentials) - self.set_thread_compute_connection(conn) - self.clear_thread_connection() - else: - conn = Connection( - type=Identifier(self.TYPE), - name=conn_name, - state=ConnectionState.INIT, - transaction_open=False, - handle=None, - credentials=self.profile.credentials, - ) + conn = Connection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, + ) conn.handle = LazyHandle(self.get_open_for_model(node)) + # Add the connection to thread_connections for this thread self.set_thread_connection(conn) fire_event( @@ -838,89 +856,150 @@ def set_connection_name( orig_conn_name: str = conn.name or "" conn.name = conn_name fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) - if USE_LONG_SESSIONS: - current_thread_conn = self.get_if_exists() - if current_thread_conn.compute_name != conn.compute_name: - self.clear_thread_connection() - self.set_thread_connection(conn) - logger.debug( - f"Reusing DatabricksDBTConnection. thread: {self.get_thread_identifier}, compute: `{conn.compute_name}`" - ) - if node: - if not conn.compute_name: - logger.debug( - f"On thread {self.get_thread_identifier}: {node.relation_name} using default compute resource." + + return conn + + # override + def release(self) -> None: + if not USE_LONG_SESSIONS: + return super().release() + + with self.lock: + conn = self.get_if_exists() + if conn is None: + return + + conn._release() + + # override + def cleanup_all(self) -> None: + if not USE_LONG_SESSIONS: + return super().cleanup_all() + + with self.lock: + for thread_connections in self.threads_compute_connections.values(): + for connection in thread_connections.values(): + if connection.acquire_release_count > 0 and connection.state not in {"init"}: + fire_event( + ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) ) else: - logger.debug( - f"On thread {self.get_thread_identifier}: {node.relation_name} using compute resource '{conn.compute_name}'." + fire_event( + ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) ) + self.close(connection) - if USE_LONG_SESSIONS: - conn.last_used_time = None - conn.acquire_release_count += 1 - logger.debug( - f"DatabricksDBTConnection: thread: {self.get_thread_identifier}, compute: `{conn.compute_name}`, acquire_release_count: {conn.acquire_release_count}, last_used: {conn.last_used_time}" - ) + # garbage collect these connections + self.thread_connections.clear() + self.threads_compute_connections.clear() + + def _get_compute_connection( + self, name: Optional[str] = None, node: Optional[ResultNode] = None + ) -> Connection: + """Called by 'set_connection_name' in DatabricksConnectionManager. + Creates a connection for this thread/node if one doesn't already + exist, and will rename an existing connection.""" + + _long_sessions_only("_set_connection_name_long_sessions") + + conn_name: str = "master" if name is None else name + + # Get a connection for this thread + conn = self._get_if_exists_compute_connection(_get_compute_name(node) or "") + + if conn is None: + conn = self._create_compute_connection(conn_name, node) + else: # existing connection either wasn't open or didn't have the right name + conn = self._update_compute_connection(conn, conn_name, node) + + conn._acquire(node) return conn - def release(self) -> None: - if USE_LONG_SESSIONS: - with self.lock: - conn = self.get_if_exists() - if conn is None: - return + def _update_compute_connection( + self, + conn: DatabricksDBTConnection, + new_name: str, + node: Optional[ResultNode] = None, + ) -> DatabricksDBTConnection: + """Update a connection that is being re-used with a new name, handle, etc.""" + _long_sessions_only("_update_connection_for_compute") + + compute_name = _get_compute_name(node=node) or "" + if conn.name == new_name and conn.state == "open" and conn.compute_name == compute_name: + # Found a connection and nothing to do, so just return it + return conn - conn.acquire_release_count -= 1 - conn.last_used_time = time.time() - logger.debug( - f"release DatabricksDBTConnection: thread: {self.get_thread_identifier}, compute: `{conn.compute_name}`, acquire_release_count: {conn.acquire_release_count}, last_used: {conn.last_used_time}" - ) - else: - super().release() + if conn.state != "open": + conn.handle = LazyHandle(self.get_open_for_model(node)) + if conn.name != new_name: + orig_conn_name: str = conn.name or "" + conn.name = new_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) + + current_thread_conn = self.get_if_exists() + if current_thread_conn.compute_name != conn.compute_name: + self.clear_thread_connection() + self.set_thread_connection(conn) - def cleanup_all(self) -> None: - if USE_LONG_SESSIONS: - with self.lock: - for thread_connections in self.threads_compute_connections.values(): - for connection in thread_connections.values(): - if connection.acquire_release_count > 0 and connection.state not in { - "init" - }: - fire_event( - ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) - ) - else: - fire_event( - ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) - ) - self.close(connection) - - # garbage collect these connections - self.thread_connections.clear() - self.threads_compute_connections.clear() - else: - super().cleanup_all() + logger.debug(f"Reusing DatabricksDBTConnection. {conn._get_conn_info_str()}") - def set_thread_compute_connection(self, conn: Connection) -> None: - if not USE_LONG_SESSIONS: - raise dbt.exceptions.DbtInternalError( - "set_thread_compute_connection() should not be called when USE_LONG_SESSIONS is False" - ) + return conn - thread_map = self.get_thread_compute_connections() + def _create_compute_connection( + self, conn_name: str, node: Optional[ResultNode] = None + ) -> DatabricksDBTConnection: + """Create anew connection for the combination of current thread and compute associated + with the given node.""" + _long_sessions_only("_create_connection_for_compute") + + # Create a new connection + compute_name = _get_compute_name(node=node) or "" + logger.debug( + f"Creating DatabricksDBTConnection. name: {conn_name}, thread: {self.get_thread_identifier()}, compute: `{compute_name}`" + ) + conn = DatabricksDBTConnection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, + ) + conn.compute_name = compute_name + conn.http_path = _get_http_path(node=node, creds=self.profile.credentials) + conn.thread_identifier = self.get_thread_identifier() + + conn.handle = LazyHandle(self.get_open_for_model(node)) + # Add this connection to the thread/compute connection pool. + self._add_compute_connection(conn) + # Remove the connection currently in use by this thread from the thread connection pool. + self.clear_thread_connection() + # Add the connection to thread connection pool. + self.set_thread_connection(conn) + + fire_event( + NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) + ) + + return conn + + def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: + """Add a new connection to the map of connection per thread per compute.""" + _long_sessions_only("_set_thread_compute_connection") + + thread_map = self._get_compute_connections() if conn.compute_name in thread_map: raise dbt.exceptions.DbtInternalError( f"In set_thread_compute_connection, connection exists for `{conn.compute_name}`" ) thread_map[conn.compute_name] = conn - def get_thread_compute_connections(self) -> Dict[Hashable, DatabricksDBTConnection]: - if not USE_LONG_SESSIONS: - raise dbt.exceptions.DbtInternalError( - "get_thread_compute_connections() should not be called when USE_LONG_SESSIONS is False" - ) + def _get_compute_connections( + self, + ) -> Dict[Hashable, DatabricksDBTConnection]: + """Retrieve a map of compute name to connection for the current thread.""" + _long_sessions_only("_get_thread_compute_connections") thread_id = self.get_thread_identifier() with self.lock: @@ -930,14 +1009,14 @@ def get_thread_compute_connections(self) -> Dict[Hashable, DatabricksDBTConnecti self.threads_compute_connections[thread_id] = thread_map return thread_map - def get_if_exists_compute(self, compute_name: str) -> Optional[Connection]: - if not USE_LONG_SESSIONS: - raise dbt.exceptions.DbtInternalError( - "get_if_exists_compute() should not be called when USE_LONG_SESSIONS is False" - ) + def _get_if_exists_compute_connection( + self, compute_name: str + ) -> Optional[DatabricksDBTConnection]: + """Get the connection for the current thread and named compute, if it exists.""" + _long_sessions_only("_get_if_exists_compute") with self.lock: - threads_map = self.get_thread_compute_connections() + threads_map = self._get_compute_connections() return threads_map.get(compute_name) def add_query( @@ -1264,18 +1343,25 @@ def _get_compute_name(node: Optional[ResultNode]) -> Optional[str]: def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]: + """Get the http_path for the compute specified for the node. + If none is specified default will be used.""" + thread_id = (os.getpid(), get_ident()) # If there is no node we return the http_path for the default compute. if not node: - logger.debug(f"Thread {thread_id}: using default compute resource.") + if not USE_LONG_SESSIONS: + logger.debug(f"Thread {thread_id}: using default compute resource.") return creds.http_path # Get the name of the compute resource specified in the node's config. # If none is specified return the http_path for the default compute. compute_name = _get_compute_name(node) if not compute_name: - logger.debug(f"On thread {thread_id}: {node.relation_name} using default compute resource.") + if not USE_LONG_SESSIONS: + logger.debug( + f"On thread {thread_id}: {node.relation_name} using default compute resource." + ) return creds.http_path # Get the http_path for the named compute. @@ -1290,8 +1376,17 @@ def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> f"does not specify http_path, relation: {node.relation_name}" ) - logger.debug( - f"On thread {thread_id}: {node.relation_name} using compute resource '{compute_name}'." - ) + if not USE_LONG_SESSIONS: + logger.debug( + f"On thread {thread_id}: {node.relation_name} using compute resource '{compute_name}'." + ) return http_path + + +def _long_sessions_only(name: str) -> None: + """Helper function to raise exception is USE_LONG_SESSIONS is false.""" + if not USE_LONG_SESSIONS: + raise dbt.exceptions.DbtInternalError( + f"{name}() should not be called when USE_LONG_SESSIONS is False" + ) From 51e09ac43bde6b3d2cf8a6a50d3a25c6b70df000 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Wed, 22 Nov 2023 15:04:21 -0700 Subject: [PATCH 4/8] Added a functional test for long sessions Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 10 ++-- .../adapter/long_sessions/fixtures.py | 22 ++++++++ .../long_sessions/test_long_sessions.py | 50 +++++++++++++++++++ 3 files changed, 77 insertions(+), 5 deletions(-) create mode 100644 tests/functional/adapter/long_sessions/fixtures.py create mode 100644 tests/functional/adapter/long_sessions/test_long_sessions.py diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index efdd0eb2e..8d509493f 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -110,7 +110,7 @@ def emit(self, record: logging.LogRecord) -> None: CLIENT_ID = "dbt-databricks" SCOPES = ["all-apis", "offline_access"] -USE_LONG_SESSIONS = os.environ.get("DBT_DATABRICKS_LONG_SESSIONS", "FALSE").upper() == "TRUE" +USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "FALSE").upper() == "TRUE" @dataclass @@ -770,10 +770,10 @@ class DatabricksConnectionManager(SparkConnectionManager): def __init__(self, profile: AdapterRequiredConfig) -> None: super().__init__(profile) - if USE_LONG_SESSIONS: - self.threads_compute_connections: Dict[ - Hashable, Dict[Hashable, DatabricksDBTConnection] - ] = {} + # if USE_LONG_SESSIONS: + self.threads_compute_connections: Dict[ + Hashable, Dict[Hashable, DatabricksDBTConnection] + ] = {} def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) diff --git a/tests/functional/adapter/long_sessions/fixtures.py b/tests/functional/adapter/long_sessions/fixtures.py new file mode 100644 index 000000000..e97d0d3d0 --- /dev/null +++ b/tests/functional/adapter/long_sessions/fixtures.py @@ -0,0 +1,22 @@ +source = """id,name,date +1,Alice,2022-01-01 +2,Bob,2022-01-02 +""" + +target = """ +{{config(materialized='table')}} + +select * from {{ ref('source') }} +""" + +target2 = """ +{{config(materialized='table')}} + +select * from {{ ref('source') }} +""" + +target3 = """ +{{config(materialized='table')}} + +select * from {{ ref('source') }} +""" diff --git a/tests/functional/adapter/long_sessions/test_long_sessions.py b/tests/functional/adapter/long_sessions/test_long_sessions.py new file mode 100644 index 000000000..445013c41 --- /dev/null +++ b/tests/functional/adapter/long_sessions/test_long_sessions.py @@ -0,0 +1,50 @@ +import pytest +import os +from unittest import mock +from dbt.tests import util +from tests.functional.adapter.long_sessions import fixtures +from dbt.adapters.databricks import connections + + +class TestLongSessionsBase: + args_formatter = "" + + @pytest.fixture(scope="class") + def seeds(self): + return { + "source.csv": fixtures.source, + } + + @pytest.fixture(scope="class") + def models(self): + return { + "target.sql": fixtures.target, + "target2.sql": fixtures.target2, + "target3.sql": fixtures.target3, + } + + def test_long_sessions(self, project): + connections.USE_LONG_SESSIONS = True + _, log = util.run_dbt_and_capture(["--debug", "seed"]) + open_count = log.count("Sending request: OpenSession") + assert open_count == 4 + + _, log = util.run_dbt_and_capture(["--debug", "run"]) + open_count = log.count("Sending request: OpenSession") + assert open_count == 4 + + +class TestLongSessionsMultipleThreads(TestLongSessionsBase): + def test_long_sessions(self, project): + connections.USE_LONG_SESSIONS = True + _, log = util.run_dbt_and_capture(["--debug", "seed"]) + open_count = log.count("Sending request: OpenSession") + assert open_count == 4 + + _, log = util.run_dbt_and_capture(["--debug", "run", "--threads", "2"]) + open_count = log.count("Sending request: OpenSession") + assert open_count == 6 + + _, log = util.run_dbt_and_capture(["--debug", "run", "--threads", "3"]) + open_count = log.count("Sending request: OpenSession") + assert open_count == 8 From bd2d5dd426369e4777121eecaab451ddc1d05578 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Mon, 27 Nov 2023 16:12:46 -0700 Subject: [PATCH 5/8] Simplified functional tests Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 22 +++++----- .../long_sessions/test_long_sessions.py | 42 +++++++++---------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 8d509493f..016cc0d05 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -770,10 +770,10 @@ class DatabricksConnectionManager(SparkConnectionManager): def __init__(self, profile: AdapterRequiredConfig) -> None: super().__init__(profile) - # if USE_LONG_SESSIONS: - self.threads_compute_connections: Dict[ - Hashable, Dict[Hashable, DatabricksDBTConnection] - ] = {} + if USE_LONG_SESSIONS: + self.threads_compute_connections: Dict[ + Hashable, Dict[Hashable, DatabricksDBTConnection] + ] = {} def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) @@ -900,7 +900,7 @@ def _get_compute_connection( Creates a connection for this thread/node if one doesn't already exist, and will rename an existing connection.""" - _long_sessions_only("_set_connection_name_long_sessions") + _long_sessions_only("_get_compute_connection") conn_name: str = "master" if name is None else name @@ -923,7 +923,7 @@ def _update_compute_connection( node: Optional[ResultNode] = None, ) -> DatabricksDBTConnection: """Update a connection that is being re-used with a new name, handle, etc.""" - _long_sessions_only("_update_connection_for_compute") + _long_sessions_only("_update_compute_connection") compute_name = _get_compute_name(node=node) or "" if conn.name == new_name and conn.state == "open" and conn.compute_name == compute_name: @@ -951,7 +951,7 @@ def _create_compute_connection( ) -> DatabricksDBTConnection: """Create anew connection for the combination of current thread and compute associated with the given node.""" - _long_sessions_only("_create_connection_for_compute") + _long_sessions_only("_create_compute_connection") # Create a new connection compute_name = _get_compute_name(node=node) or "" @@ -986,7 +986,7 @@ def _create_compute_connection( def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: """Add a new connection to the map of connection per thread per compute.""" - _long_sessions_only("_set_thread_compute_connection") + _long_sessions_only("_add_compute_connection") thread_map = self._get_compute_connections() if conn.compute_name in thread_map: @@ -999,7 +999,7 @@ def _get_compute_connections( self, ) -> Dict[Hashable, DatabricksDBTConnection]: """Retrieve a map of compute name to connection for the current thread.""" - _long_sessions_only("_get_thread_compute_connections") + _long_sessions_only("_get_compute_connections") thread_id = self.get_thread_identifier() with self.lock: @@ -1013,7 +1013,7 @@ def _get_if_exists_compute_connection( self, compute_name: str ) -> Optional[DatabricksDBTConnection]: """Get the connection for the current thread and named compute, if it exists.""" - _long_sessions_only("_get_if_exists_compute") + _long_sessions_only("_get_if_exists_compute_connection") with self.lock: threads_map = self._get_compute_connections() @@ -1385,7 +1385,7 @@ def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> def _long_sessions_only(name: str) -> None: - """Helper function to raise exception is USE_LONG_SESSIONS is false.""" + """Helper function to raise exception when USE_LONG_SESSIONS is false.""" if not USE_LONG_SESSIONS: raise dbt.exceptions.DbtInternalError( f"{name}() should not be called when USE_LONG_SESSIONS is False" diff --git a/tests/functional/adapter/long_sessions/test_long_sessions.py b/tests/functional/adapter/long_sessions/test_long_sessions.py index 445013c41..bdb1211ce 100644 --- a/tests/functional/adapter/long_sessions/test_long_sessions.py +++ b/tests/functional/adapter/long_sessions/test_long_sessions.py @@ -3,6 +3,12 @@ from unittest import mock from dbt.tests import util from tests.functional.adapter.long_sessions import fixtures +from timeit import default_timer as timer +from datetime import timedelta + +with mock.patch.dict(os.environ, {"DBT_DATABRICKS_LONG_SESSIONS": "true"}): + import dbt.adapters.databricks.connections + from dbt.adapters.databricks import connections @@ -17,34 +23,28 @@ def seeds(self): @pytest.fixture(scope="class") def models(self): - return { - "target.sql": fixtures.target, - "target2.sql": fixtures.target2, - "target3.sql": fixtures.target3, - } + m = {} + for i in range(10): + m[f"target{i}.sql"] = fixtures.target + + return m def test_long_sessions(self, project): - connections.USE_LONG_SESSIONS = True + # connections.USE_LONG_SESSIONS = True _, log = util.run_dbt_and_capture(["--debug", "seed"]) - open_count = log.count("Sending request: OpenSession") - assert open_count == 4 + open_count = log.count("Sending request: OpenSession") / 2 + assert open_count == 2 _, log = util.run_dbt_and_capture(["--debug", "run"]) - open_count = log.count("Sending request: OpenSession") - assert open_count == 4 + open_count = log.count("Sending request: OpenSession") / 2 + assert open_count == 2 class TestLongSessionsMultipleThreads(TestLongSessionsBase): def test_long_sessions(self, project): - connections.USE_LONG_SESSIONS = True - _, log = util.run_dbt_and_capture(["--debug", "seed"]) - open_count = log.count("Sending request: OpenSession") - assert open_count == 4 - - _, log = util.run_dbt_and_capture(["--debug", "run", "--threads", "2"]) - open_count = log.count("Sending request: OpenSession") - assert open_count == 6 + util.run_dbt_and_capture(["seed"]) - _, log = util.run_dbt_and_capture(["--debug", "run", "--threads", "3"]) - open_count = log.count("Sending request: OpenSession") - assert open_count == 8 + for n_threads in [1, 2, 3]: + _, log = util.run_dbt_and_capture(["--debug", "run", "--threads", f"{n_threads}"]) + open_count = log.count("Sending request: OpenSession") / 2 + assert open_count == (n_threads + 1) From 526544dccec0d4d33915d3420a2575e9663c1c08 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Mon, 27 Nov 2023 16:47:32 -0700 Subject: [PATCH 6/8] Fixed linting errors Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 30 ++++++++++++------- .../long_sessions/test_long_sessions.py | 8 ++--- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 016cc0d05..511fe085c 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -729,7 +729,7 @@ class DatabricksDBTConnection(Connection): http_path: str = "" thread_identifier: Tuple[int, int] = (0, 0) - def _acquire(self, node: Optional[ResultNode]): + def _acquire(self, node: Optional[ResultNode]) -> None: """Indicate that this connection is in use.""" logger.debug(f"DatabricksDBTConnection._acquire: {self._get_conn_info_str()}") self._log_usage(node) @@ -737,7 +737,7 @@ def _acquire(self, node: Optional[ResultNode]): self.last_used_time = None self.acquire_release_count += 1 - def _release(self): + def _release(self) -> None: """Indicate that this connection is not in use.""" logger.debug(f"DatabricksDBTConnection._release: {self._get_conn_info_str()}") self.last_used_time = time.time() @@ -748,17 +748,23 @@ def _release(self): def _get_conn_info_str(self) -> str: """Generate a string describing this connection.""" - return f"name: {self.name}, thread: {self.thread_identifier}, compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count}, last_used: {self.last_used_time}" + return ( + f"name: {self.name}, thread: {self.thread_identifier}, " + f"compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count}," + f" last_used: {self.last_used_time}" + ) def _log_usage(self, node: Optional[ResultNode]) -> None: if node: if not self.compute_name: logger.debug( - f"On thread {self.thread_identifier}: {node.relation_name} using default compute resource." + f"On thread {self.thread_identifier}: {node.relation_name} " + "using default compute resource." ) else: logger.debug( - f"On thread {self.thread_identifier}: {node.relation_name} using compute resource '{self.compute_name}'." + f"On thread {self.thread_identifier}: {node.relation_name} " + "using compute resource '{self.compute_name}'." ) else: logger.debug(f"Thread {self.thread_identifier} using default compute resource.") @@ -865,7 +871,7 @@ def release(self) -> None: return super().release() with self.lock: - conn = self.get_if_exists() + conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) if conn is None: return @@ -937,8 +943,8 @@ def _update_compute_connection( conn.name = new_name fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - current_thread_conn = self.get_if_exists() - if current_thread_conn.compute_name != conn.compute_name: + current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: self.clear_thread_connection() self.set_thread_connection(conn) @@ -956,7 +962,8 @@ def _create_compute_connection( # Create a new connection compute_name = _get_compute_name(node=node) or "" logger.debug( - f"Creating DatabricksDBTConnection. name: {conn_name}, thread: {self.get_thread_identifier()}, compute: `{compute_name}`" + f"Creating DatabricksDBTConnection. name: {conn_name}, " + "thread: {self.get_thread_identifier()}, compute: `{compute_name}`" ) conn = DatabricksDBTConnection( type=Identifier(self.TYPE), @@ -967,8 +974,9 @@ def _create_compute_connection( credentials=self.profile.credentials, ) conn.compute_name = compute_name - conn.http_path = _get_http_path(node=node, creds=self.profile.credentials) - conn.thread_identifier = self.get_thread_identifier() + creds = cast(DatabricksCredentials, self.profile.credentials) + conn.http_path = _get_http_path(node=node, creds=creds) or "" + conn.thread_identifier = cast(Tuple[int, int], self.get_thread_identifier()) conn.handle = LazyHandle(self.get_open_for_model(node)) # Add this connection to the thread/compute connection pool. diff --git a/tests/functional/adapter/long_sessions/test_long_sessions.py b/tests/functional/adapter/long_sessions/test_long_sessions.py index bdb1211ce..551ca0f4b 100644 --- a/tests/functional/adapter/long_sessions/test_long_sessions.py +++ b/tests/functional/adapter/long_sessions/test_long_sessions.py @@ -3,13 +3,9 @@ from unittest import mock from dbt.tests import util from tests.functional.adapter.long_sessions import fixtures -from timeit import default_timer as timer -from datetime import timedelta with mock.patch.dict(os.environ, {"DBT_DATABRICKS_LONG_SESSIONS": "true"}): - import dbt.adapters.databricks.connections - -from dbt.adapters.databricks import connections + import dbt.adapters.databricks.connections # noqa class TestLongSessionsBase: @@ -24,7 +20,7 @@ def seeds(self): @pytest.fixture(scope="class") def models(self): m = {} - for i in range(10): + for i in range(5): m[f"target{i}.sql"] = fixtures.target return m From 3e5f353bfe76aa614f0a1d6feba7810c7d4a95ff Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Tue, 28 Nov 2023 13:12:25 -0700 Subject: [PATCH 7/8] Assert and connection state constants Use ConnectionState constants instead of string literals. Use assert instead of throwing an exception if in unexpected code path when USE_LONG_SESSIONS=False Added long session test with warehouse per model Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 64 +++++++++++-------- .../adapter/long_sessions/fixtures.py | 8 +-- .../long_sessions/test_long_sessions.py | 36 ++++++++++- 3 files changed, 73 insertions(+), 35 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 511fe085c..dc3ba84db 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -833,7 +833,7 @@ def set_connection_name( # Get a connection for this thread conn = self.get_if_exists() - if conn and conn.name == conn_name and conn.state == "open": + if conn and conn.name == conn_name and conn.state == ConnectionState.OPEN: # Found a connection and nothing to do, so just return it return conn @@ -847,16 +847,14 @@ def set_connection_name( handle=None, credentials=self.profile.credentials, ) - conn.handle = LazyHandle(self.get_open_for_model(node)) - # Add the connection to thread_connections for this thread self.set_thread_connection(conn) fire_event( NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) ) else: # existing connection either wasn't open or didn't have the right name - if conn.state != "open": + if conn.state != ConnectionState.OPEN: conn.handle = LazyHandle(self.get_open_for_model(node)) if conn.name != conn_name: orig_conn_name: str = conn.name or "" @@ -885,7 +883,7 @@ def cleanup_all(self) -> None: with self.lock: for thread_connections in self.threads_compute_connections.values(): for connection in thread_connections.values(): - if connection.acquire_release_count > 0 and connection.state not in {"init"}: + if connection.acquire_release_count > 0: fire_event( ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) ) @@ -906,7 +904,9 @@ def _get_compute_connection( Creates a connection for this thread/node if one doesn't already exist, and will rename an existing connection.""" - _long_sessions_only("_get_compute_connection") + assert ( + USE_LONG_SESSIONS + ), "This path, '_get_compute_connection', should only be reachable with USE_LONG_SESSIONS" conn_name: str = "master" if name is None else name @@ -929,14 +929,21 @@ def _update_compute_connection( node: Optional[ResultNode] = None, ) -> DatabricksDBTConnection: """Update a connection that is being re-used with a new name, handle, etc.""" - _long_sessions_only("_update_compute_connection") + assert USE_LONG_SESSIONS, ( + "This path, '_update_compute_connection', should only be " + "reachable with USE_LONG_SESSIONS" + ) compute_name = _get_compute_name(node=node) or "" - if conn.name == new_name and conn.state == "open" and conn.compute_name == compute_name: + if ( + conn.name == new_name + and conn.state == ConnectionState.OPEN + and conn.compute_name == compute_name + ): # Found a connection and nothing to do, so just return it return conn - if conn.state != "open": + if conn.state != ConnectionState.OPEN: conn.handle = LazyHandle(self.get_open_for_model(node)) if conn.name != new_name: orig_conn_name: str = conn.name or "" @@ -957,7 +964,10 @@ def _create_compute_connection( ) -> DatabricksDBTConnection: """Create anew connection for the combination of current thread and compute associated with the given node.""" - _long_sessions_only("_create_compute_connection") + assert USE_LONG_SESSIONS, ( + "This path, '_create_compute_connection', should only be reachable " + "with USE_LONG_SESSIONS" + ) # Create a new connection compute_name = _get_compute_name(node=node) or "" @@ -994,20 +1004,25 @@ def _create_compute_connection( def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: """Add a new connection to the map of connection per thread per compute.""" - _long_sessions_only("_add_compute_connection") + assert ( + USE_LONG_SESSIONS + ), "This path, '_add_compute_connection', should only be reachable with USE_LONG_SESSIONS" - thread_map = self._get_compute_connections() - if conn.compute_name in thread_map: - raise dbt.exceptions.DbtInternalError( - f"In set_thread_compute_connection, connection exists for `{conn.compute_name}`" - ) - thread_map[conn.compute_name] = conn + with self.lock: + thread_map = self._get_compute_connections() + if conn.compute_name in thread_map: + raise dbt.exceptions.DbtInternalError( + f"In set_thread_compute_connection, connection exists for `{conn.compute_name}`" + ) + thread_map[conn.compute_name] = conn def _get_compute_connections( self, ) -> Dict[Hashable, DatabricksDBTConnection]: """Retrieve a map of compute name to connection for the current thread.""" - _long_sessions_only("_get_compute_connections") + assert ( + USE_LONG_SESSIONS + ), "This path, '_get_compute_connections', should only be reachable with USE_LONG_SESSIONS" thread_id = self.get_thread_identifier() with self.lock: @@ -1021,7 +1036,10 @@ def _get_if_exists_compute_connection( self, compute_name: str ) -> Optional[DatabricksDBTConnection]: """Get the connection for the current thread and named compute, if it exists.""" - _long_sessions_only("_get_if_exists_compute_connection") + assert USE_LONG_SESSIONS, ( + "This path, '_get_if_exists_compute_connection', should only be reachable " + "with USE_LONG_SESSIONS" + ) with self.lock: threads_map = self._get_compute_connections() @@ -1390,11 +1408,3 @@ def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> ) return http_path - - -def _long_sessions_only(name: str) -> None: - """Helper function to raise exception when USE_LONG_SESSIONS is false.""" - if not USE_LONG_SESSIONS: - raise dbt.exceptions.DbtInternalError( - f"{name}() should not be called when USE_LONG_SESSIONS is False" - ) diff --git a/tests/functional/adapter/long_sessions/fixtures.py b/tests/functional/adapter/long_sessions/fixtures.py index e97d0d3d0..3fb9fc16d 100644 --- a/tests/functional/adapter/long_sessions/fixtures.py +++ b/tests/functional/adapter/long_sessions/fixtures.py @@ -10,13 +10,7 @@ """ target2 = """ -{{config(materialized='table')}} - -select * from {{ ref('source') }} -""" - -target3 = """ -{{config(materialized='table')}} +{{config(materialized='table', databricks_compute='alternate_warehouse')}} select * from {{ ref('source') }} """ diff --git a/tests/functional/adapter/long_sessions/test_long_sessions.py b/tests/functional/adapter/long_sessions/test_long_sessions.py index 551ca0f4b..9ec2aae6a 100644 --- a/tests/functional/adapter/long_sessions/test_long_sessions.py +++ b/tests/functional/adapter/long_sessions/test_long_sessions.py @@ -26,7 +26,6 @@ def models(self): return m def test_long_sessions(self, project): - # connections.USE_LONG_SESSIONS = True _, log = util.run_dbt_and_capture(["--debug", "seed"]) open_count = log.count("Sending request: OpenSession") / 2 assert open_count == 2 @@ -44,3 +43,38 @@ def test_long_sessions(self, project): _, log = util.run_dbt_and_capture(["--debug", "run", "--threads", f"{n_threads}"]) open_count = log.count("Sending request: OpenSession") / 2 assert open_count == (n_threads + 1) + + +class TestLongSessionsMultipleCompute: + args_formatter = "" + + @pytest.fixture(scope="class") + def seeds(self): + return { + "source.csv": fixtures.source, + } + + @pytest.fixture(scope="class") + def models(self): + m = {} + for i in range(2): + m[f"target{i}.sql"] = fixtures.target + + m["target_alt.sql"] = fixtures.target2 + + return m + + @pytest.fixture(scope="class") + def profiles_config_update(self, dbt_profile_target): + outputs = {"default": dbt_profile_target} + outputs["default"]["compute"] = { + "alternate_warehouse": {"http_path": dbt_profile_target["http_path"]}, + } + return {"test": {"outputs": outputs, "target": "default"}} + + def test_long_sessions(self, project): + util.run_dbt_and_capture(["--debug", "seed"]) + + _, log = util.run_dbt_and_capture(["--debug", "run"]) + open_count = log.count("Sending request: OpenSession") / 2 + assert open_count == 3 From 7c1c511f67852ff2117c68cf5fb0070f74122f0f Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Tue, 5 Dec 2023 17:19:22 -0700 Subject: [PATCH 8/8] Cleanup of idle connections Added default max idle time of 600 seconds. Added code to read user specified max idle time from profile or from alternate compute definitions. Updated DatabricksDBTConnection with connect_max_idle property and helper functions to determine if the connection should be cleaned up. In DatabricksConnectionManager added _cleanup_idle_connections(). This is called whenever a connection is acquired for use. Added a new class method _open2() to DatabricksConnectionManager. This is used with USE_LONG_SESSIONS is true and uses the http_path property of DatabricksDBTConnection. Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 157 ++++++++++-- .../adapter/long_sessions/fixtures.py | 30 +++ .../long_sessions/test_long_sessions.py | 36 +++ tests/unit/test_idle_config.py | 238 ++++++++++++++++++ 4 files changed, 447 insertions(+), 14 deletions(-) create mode 100644 tests/unit/test_idle_config.py diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index dc3ba84db..af4c33942 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -26,6 +26,7 @@ Union, Hashable, ) +from numbers import Number from agate import Table @@ -110,7 +111,12 @@ def emit(self, record: logging.LogRecord) -> None: CLIENT_ID = "dbt-databricks" SCOPES = ["all-apis", "offline_access"] -USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "FALSE").upper() == "TRUE" +# toggle for session managements that minimizes the number of sessions opened/closed +USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" + +# Number of idle seconds before a connection is automatically closed. Only applicable if +# USE_LONG_SESSIONS is true. +DEFAULT_MAX_IDLE_TIME = 600 @dataclass @@ -133,6 +139,7 @@ class DatabricksCredentials(Credentials): connect_retries: int = 1 connect_timeout: Optional[int] = None retry_all: bool = False + connect_max_idle: Optional[int] = None _credentials_provider: Optional[Dict[str, Any]] = None _lock = threading.Lock() # to avoid concurrent auth @@ -728,30 +735,37 @@ class DatabricksDBTConnection(Connection): compute_name: str = "" http_path: str = "" thread_identifier: Tuple[int, int] = (0, 0) + max_idle_time: int = DEFAULT_MAX_IDLE_TIME def _acquire(self, node: Optional[ResultNode]) -> None: """Indicate that this connection is in use.""" logger.debug(f"DatabricksDBTConnection._acquire: {self._get_conn_info_str()}") self._log_usage(node) - # Make sure this won't be cleaned up as being idle for too long. - self.last_used_time = None self.acquire_release_count += 1 def _release(self) -> None: """Indicate that this connection is not in use.""" logger.debug(f"DatabricksDBTConnection._release: {self._get_conn_info_str()}") - self.last_used_time = time.time() # Need to check for > 0 because in some situations the dbt code will make an extra # release call on a connection. if self.acquire_release_count > 0: self.acquire_release_count -= 1 + if self.acquire_release_count == 0: + self.last_used_time = time.time() + + def _get_idle_time(self) -> float: + return 0 if self.last_used_time is None else time.time() - self.last_used_time + + def _idle_too_long(self) -> bool: + return self.max_idle_time > 0 and self._get_idle_time() > self.max_idle_time + def _get_conn_info_str(self) -> str: """Generate a string describing this connection.""" return ( f"name: {self.name}, thread: {self.thread_identifier}, " f"compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count}," - f" last_used: {self.last_used_time}" + f" idle time: {self._get_idle_time()}s" ) def _log_usage(self, node: Optional[ResultNode]) -> None: @@ -908,6 +922,8 @@ def _get_compute_connection( USE_LONG_SESSIONS ), "This path, '_get_compute_connection', should only be reachable with USE_LONG_SESSIONS" + self._cleanup_idle_connections() + conn_name: str = "master" if name is None else name # Get a connection for this thread @@ -934,17 +950,12 @@ def _update_compute_connection( "reachable with USE_LONG_SESSIONS" ) - compute_name = _get_compute_name(node=node) or "" - if ( - conn.name == new_name - and conn.state == ConnectionState.OPEN - and conn.compute_name == compute_name - ): + if conn.name == new_name and conn.state == ConnectionState.OPEN: # Found a connection and nothing to do, so just return it return conn if conn.state != ConnectionState.OPEN: - conn.handle = LazyHandle(self.get_open_for_model(node)) + conn.handle = LazyHandle(self._open2) if conn.name != new_name: orig_conn_name: str = conn.name or "" conn.name = new_name @@ -973,7 +984,7 @@ def _create_compute_connection( compute_name = _get_compute_name(node=node) or "" logger.debug( f"Creating DatabricksDBTConnection. name: {conn_name}, " - "thread: {self.get_thread_identifier()}, compute: `{compute_name}`" + f"thread: {self.get_thread_identifier()}, compute: `{compute_name}`" ) conn = DatabricksDBTConnection( type=Identifier(self.TYPE), @@ -987,8 +998,9 @@ def _create_compute_connection( creds = cast(DatabricksCredentials, self.profile.credentials) conn.http_path = _get_http_path(node=node, creds=creds) or "" conn.thread_identifier = cast(Tuple[int, int], self.get_thread_identifier()) + conn.max_idle_time = _get_max_idle_time(node=node, creds=creds) - conn.handle = LazyHandle(self.get_open_for_model(node)) + conn.handle = LazyHandle(self._open2) # Add this connection to the thread/compute connection pool. self._add_compute_connection(conn) # Remove the connection currently in use by this thread from the thread connection pool. @@ -1045,6 +1057,19 @@ def _get_if_exists_compute_connection( threads_map = self._get_compute_connections() return threads_map.get(compute_name) + def _cleanup_idle_connections(self) -> None: + assert ( + USE_LONG_SESSIONS + ), "This path, '_cleanup_idle_connections', should only be reachable with USE_LONG_SESSIONS" + + with self.lock: + for thread_conns in self.threads_compute_connections.values(): + for conn in thread_conns.values(): + if conn.acquire_release_count == 0 and conn._idle_too_long(): + logger.debug(f"closing idle connection: {conn._get_conn_info_str()}") + self.close(conn) + conn.handle = LazyHandle(self._open2) + def add_query( self, sql: str, @@ -1073,6 +1098,7 @@ def add_query( node_info=get_node_info(), ) ) + pre = time.time() cursor = cast(DatabricksSQLConnectionWrapper, connection.handle).cursor() @@ -1132,6 +1158,7 @@ def _execute_cursor( node_info=get_node_info(), ) ) + pre = time.time() handle: DatabricksSQLConnectionWrapper = connection.handle @@ -1255,6 +1282,81 @@ def exponential_backoff(attempt: int) -> int: retry_timeout=(timeout if timeout is not None else exponential_backoff), ) + @classmethod + def _open2(cls, connection: Connection) -> Connection: + # Once long session management is no longer under the USE_LONG_SESSIONS toggle + # this should be renamed and replace the _open class method. + assert ( + USE_LONG_SESSIONS + ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" + + if connection.state == ConnectionState.OPEN: + logger.debug("Connection is already open, skipping open.") + return connection + + creds: DatabricksCredentials = connection.credentials + timeout = creds.connect_timeout + + # gotta keep this so we don't prompt users many times + cls.credentials_provider = creds.authenticate(cls.credentials_provider) + + user_agent_entry = f"dbt-databricks/{__version__}" + + invocation_env = creds.get_invocation_env() + if invocation_env: + user_agent_entry = f"{user_agent_entry}; {invocation_env}" + + connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] + + http_headers: List[Tuple[str, str]] = list( + creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() + ) + + # If a model specifies a compute resource the http path + # may be different than the http_path property of creds. + http_path = cast(DatabricksDBTConnection, connection).http_path + + def connect() -> DatabricksSQLConnectionWrapper: + try: + # TODO: what is the error when a user specifies a catalog they don't have access to + conn: DatabricksSQLConnection = dbsql.connect( + server_hostname=creds.host, + http_path=http_path, + credentials_provider=cls.credentials_provider, + http_headers=http_headers if http_headers else None, + session_configuration=creds.session_properties, + catalog=creds.database, + # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. + _user_agent_entry=user_agent_entry, + **connection_parameters, + ) + return DatabricksSQLConnectionWrapper( + conn, + is_cluster=creds.cluster_id is not None, + creds=creds, + user_agent=user_agent_entry, + ) + except Error as exc: + _log_dbsql_errors(exc) + raise + + def exponential_backoff(attempt: int) -> int: + return attempt * attempt + + retryable_exceptions = [] + # this option is for backwards compatibility + if creds.retry_all: + retryable_exceptions = [Error] + + return cls.retry_connection( + connection, + connect=connect, + logger=logger, + retryable_exceptions=retryable_exceptions, + retry_limit=creds.connect_retries, + retry_timeout=(timeout if timeout is not None else exponential_backoff), + ) + @classmethod def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: _query_id = getattr(cursor, "hex_query_id", None) @@ -1408,3 +1510,30 @@ def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> ) return http_path + + +def _get_max_idle_time(node: Optional[ResultNode], creds: DatabricksCredentials) -> int: + """Get the http_path for the compute specified for the node. + If none is specified default will be used.""" + + max_idle_time = ( + DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle + ) + + if node: + compute_name = _get_compute_name(node) + if compute_name and creds.compute: + max_idle_time = creds.compute.get(compute_name, {}).get( + "connect_max_idle", max_idle_time + ) + + if not isinstance(max_idle_time, Number): + if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric(): + return int(max_idle_time.strip()) + else: + raise dbt.exceptions.DbtRuntimeError( + f"{max_idle_time} is not a valid value for connect_max_idle. " + "Must be a number of seconds." + ) + + return max_idle_time diff --git a/tests/functional/adapter/long_sessions/fixtures.py b/tests/functional/adapter/long_sessions/fixtures.py index 3fb9fc16d..f1332e877 100644 --- a/tests/functional/adapter/long_sessions/fixtures.py +++ b/tests/functional/adapter/long_sessions/fixtures.py @@ -14,3 +14,33 @@ select * from {{ ref('source') }} """ + +targetseq1 = """ +{{config(materialized='table', databricks_compute='alternate_warehouse')}} + +select * from {{ ref('source') }} +""" + +targetseq2 = """ +{{config(materialized='table')}} + +select * from {{ ref('targetseq1') }} +""" + +targetseq3 = """ +{{config(materialized='table')}} + +select * from {{ ref('targetseq2') }} +""" + +targetseq4 = """ +{{config(materialized='table')}} + +select * from {{ ref('targetseq3') }} +""" + +targetseq5 = """ +{{config(materialized='table', databricks_compute='alternate_warehouse')}} + +select * from {{ ref('targetseq4') }} +""" diff --git a/tests/functional/adapter/long_sessions/test_long_sessions.py b/tests/functional/adapter/long_sessions/test_long_sessions.py index 9ec2aae6a..d32b53970 100644 --- a/tests/functional/adapter/long_sessions/test_long_sessions.py +++ b/tests/functional/adapter/long_sessions/test_long_sessions.py @@ -70,6 +70,7 @@ def profiles_config_update(self, dbt_profile_target): outputs["default"]["compute"] = { "alternate_warehouse": {"http_path": dbt_profile_target["http_path"]}, } + return {"test": {"outputs": outputs, "target": "default"}} def test_long_sessions(self, project): @@ -78,3 +79,38 @@ def test_long_sessions(self, project): _, log = util.run_dbt_and_capture(["--debug", "run"]) open_count = log.count("Sending request: OpenSession") / 2 assert open_count == 3 + + +class TestLongSessionsIdleCleanup(TestLongSessionsMultipleCompute): + args_formatter = "" + + @pytest.fixture(scope="class") + def models(self): + m = { + "targetseq1.sql": fixtures.targetseq1, + "targetseq2.sql": fixtures.targetseq2, + "targetseq3.sql": fixtures.targetseq3, + "targetseq4.sql": fixtures.targetseq4, + "targetseq5.sql": fixtures.targetseq5, + } + return m + + @pytest.fixture(scope="class") + def profiles_config_update(self, dbt_profile_target): + outputs = {"default": dbt_profile_target} + outputs["default"]["connect_max_idle"] = 1 + outputs["default"]["compute"] = { + "alternate_warehouse": { + "http_path": dbt_profile_target["http_path"], + "connect_max_idle": 1, + }, + } + + return {"test": {"outputs": outputs, "target": "default"}} + + def test_long_sessions(self, project): + util.run_dbt(["--debug", "seed"]) + + _, log = util.run_dbt_and_capture(["--debug", "run"]) + idle_count = log.count("closing idle connection") / 2 + assert idle_count > 0 diff --git a/tests/unit/test_idle_config.py b/tests/unit/test_idle_config.py new file mode 100644 index 000000000..62a4da408 --- /dev/null +++ b/tests/unit/test_idle_config.py @@ -0,0 +1,238 @@ +import unittest +import dbt.exceptions +from dbt.contracts.graph import nodes, model_config +from dbt.adapters.databricks import connections + + +class TestDatabricksConnectionMaxIdleTime(unittest.TestCase): + """Test the various cases for determining a specified warehouse.""" + + errMsg = ( + "Compute resource foo does not exist or does not specify http_path, " "relation: a_relation" + ) + + def test_get_max_idle_default(self): + creds = connections.DatabricksCredentials() + + # No node and nothing specified in creds + time = connections._get_max_idle_time(None, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + node = nodes.ModelNode( + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + # node has no configuration so should get back default + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + # empty configuration should return default + node.config = model_config.ModelConfig() + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + # node with no extras in configuration should return default + node.config._extra = {} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + # node that specifies a compute with no corresponding definition should return default + node.config._extra["databricks_compute"] = "foo" + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + creds.compute = {} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + # if alternate compute doesn't specify a max time should return default + creds.compute = {"foo": {}} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + # with self.assertRaisesRegex( + # dbt.exceptions.DbtRuntimeError, + # self.errMsg, + # ): + # connections._get_http_path(node, creds) + + # creds.compute = {"foo": {"http_path": "alternate_path"}} + # path = connections._get_http_path(node, creds) + # self.assertEqual("alternate_path", path) + + def test_get_max_idle_creds(self): + creds_idle_time = 77 + creds = connections.DatabricksCredentials(connect_max_idle=creds_idle_time) + + # No node so value should come from creds + time = connections._get_max_idle_time(None, creds) + self.assertEqual(creds_idle_time, time) + + node = nodes.ModelNode( + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + # node has no configuration so should get value from creds + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + # empty configuration should get value from creds + node.config = model_config.ModelConfig() + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + # node with no extras in configuration should get value from creds + node.config._extra = {} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + # node that specifies a compute with no corresponding definition should get value from creds + node.config._extra["databricks_compute"] = "foo" + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + creds.compute = {} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + # if alternate compute doesn't specify a max time should get value from creds + creds.compute = {"foo": {}} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + def test_get_max_idle_compute(self): + creds_idle_time = 88 + compute_idle_time = 77 + creds = connections.DatabricksCredentials(connect_max_idle=creds_idle_time) + creds.compute = {"foo": {"connect_max_idle": compute_idle_time}} + + node = nodes.SnapshotNode( + config=None, + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + node.config = model_config.SnapshotConfig() + node.config._extra = {"databricks_compute": "foo"} + + time = connections._get_max_idle_time(node, creds) + self.assertEqual(compute_idle_time, time) + + def test_get_max_idle_invalid(self): + creds_idle_time = "foo" + compute_idle_time = "bar" + creds = connections.DatabricksCredentials(connect_max_idle=creds_idle_time) + creds.compute = {"alternate_compute": {"connect_max_idle": compute_idle_time}} + + node = nodes.SnapshotNode( + config=None, + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + node.config = model_config.SnapshotConfig() + + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + f"{creds_idle_time} is not a valid value for connect_max_idle. " + "Must be a number of seconds.", + ): + connections._get_max_idle_time(node, creds) + + node.config._extra["databricks_compute"] = "alternate_compute" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + f"{compute_idle_time} is not a valid value for connect_max_idle. " + "Must be a number of seconds.", + ): + connections._get_max_idle_time(node, creds) + + creds.compute["alternate_compute"]["connect_max_idle"] = "1.2.3" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "1.2.3 is not a valid value for connect_max_idle. " "Must be a number of seconds.", + ): + connections._get_max_idle_time(node, creds) + + creds.compute["alternate_compute"]["connect_max_idle"] = "1,002.3" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "1,002.3 is not a valid value for connect_max_idle. " "Must be a number of seconds.", + ): + connections._get_max_idle_time(node, creds) + + def test_get_max_idle_simple_string_conversion(self): + creds_idle_time = "12" + compute_idle_time = "34" + creds = connections.DatabricksCredentials(connect_max_idle=creds_idle_time) + creds.compute = {"alternate_compute": {"connect_max_idle": compute_idle_time}} + + node = nodes.SnapshotNode( + config=None, + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + node.config = model_config.SnapshotConfig() + + time = connections._get_max_idle_time(node, creds) + self.assertEqual(float(creds_idle_time), time) + + node.config._extra["databricks_compute"] = "alternate_compute" + time = connections._get_max_idle_time(node, creds) + self.assertEqual(float(compute_idle_time), time) + + creds.compute["alternate_compute"]["connect_max_idle"] = " 56 " + time = connections._get_max_idle_time(node, creds) + self.assertEqual(56, time)