Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long sessions #517

Merged
Merged
157 changes: 143 additions & 14 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Union,
Hashable,
)
from numbers import Number

from agate import Table

Expand Down Expand Up @@ -110,7 +111,12 @@ def emit(self, record: logging.LogRecord) -> None:
CLIENT_ID = "dbt-databricks"
SCOPES = ["all-apis", "offline_access"]

USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "FALSE").upper() == "TRUE"
# toggle for session managements that minimizes the number of sessions opened/closed
USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"

# Number of idle seconds before a connection is automatically closed. Only applicable if
# USE_LONG_SESSIONS is true.
DEFAULT_MAX_IDLE_TIME = 600


@dataclass
Expand All @@ -133,6 +139,7 @@ class DatabricksCredentials(Credentials):
connect_retries: int = 1
connect_timeout: Optional[int] = None
retry_all: bool = False
connect_max_idle: Optional[int] = None

_credentials_provider: Optional[Dict[str, Any]] = None
_lock = threading.Lock() # to avoid concurrent auth
Expand Down Expand Up @@ -728,30 +735,37 @@ class DatabricksDBTConnection(Connection):
compute_name: str = ""
http_path: str = ""
thread_identifier: Tuple[int, int] = (0, 0)
max_idle_time: int = DEFAULT_MAX_IDLE_TIME

def _acquire(self, node: Optional[ResultNode]) -> None:
"""Indicate that this connection is in use."""
logger.debug(f"DatabricksDBTConnection._acquire: {self._get_conn_info_str()}")
self._log_usage(node)
# Make sure this won't be cleaned up as being idle for too long.
self.last_used_time = None
self.acquire_release_count += 1

def _release(self) -> None:
"""Indicate that this connection is not in use."""
logger.debug(f"DatabricksDBTConnection._release: {self._get_conn_info_str()}")
self.last_used_time = time.time()
# Need to check for > 0 because in some situations the dbt code will make an extra
# release call on a connection.
if self.acquire_release_count > 0:
self.acquire_release_count -= 1

if self.acquire_release_count == 0:
benc-db marked this conversation as resolved.
Show resolved Hide resolved
self.last_used_time = time.time()

def _get_idle_time(self) -> float:
return 0 if self.last_used_time is None else time.time() - self.last_used_time

def _idle_too_long(self) -> bool:
return self.max_idle_time > 0 and self._get_idle_time() > self.max_idle_time

def _get_conn_info_str(self) -> str:
"""Generate a string describing this connection."""
return (
f"name: {self.name}, thread: {self.thread_identifier}, "
f"compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count},"
f" last_used: {self.last_used_time}"
f" idle time: {self._get_idle_time()}s"
)

def _log_usage(self, node: Optional[ResultNode]) -> None:
Expand Down Expand Up @@ -908,6 +922,8 @@ def _get_compute_connection(
USE_LONG_SESSIONS
), "This path, '_get_compute_connection', should only be reachable with USE_LONG_SESSIONS"

self._cleanup_idle_connections()

conn_name: str = "master" if name is None else name

# Get a connection for this thread
Expand All @@ -934,17 +950,12 @@ def _update_compute_connection(
"reachable with USE_LONG_SESSIONS"
)

compute_name = _get_compute_name(node=node) or ""
if (
conn.name == new_name
and conn.state == ConnectionState.OPEN
and conn.compute_name == compute_name
):
if conn.name == new_name and conn.state == ConnectionState.OPEN:
# Found a connection and nothing to do, so just return it
return conn

if conn.state != ConnectionState.OPEN:
conn.handle = LazyHandle(self.get_open_for_model(node))
conn.handle = LazyHandle(self._open2)
if conn.name != new_name:
orig_conn_name: str = conn.name or ""
conn.name = new_name
Expand Down Expand Up @@ -973,7 +984,7 @@ def _create_compute_connection(
compute_name = _get_compute_name(node=node) or ""
logger.debug(
f"Creating DatabricksDBTConnection. name: {conn_name}, "
"thread: {self.get_thread_identifier()}, compute: `{compute_name}`"
f"thread: {self.get_thread_identifier()}, compute: `{compute_name}`"
)
conn = DatabricksDBTConnection(
type=Identifier(self.TYPE),
Expand All @@ -987,8 +998,9 @@ def _create_compute_connection(
creds = cast(DatabricksCredentials, self.profile.credentials)
conn.http_path = _get_http_path(node=node, creds=creds) or ""
conn.thread_identifier = cast(Tuple[int, int], self.get_thread_identifier())
conn.max_idle_time = _get_max_idle_time(node=node, creds=creds)

conn.handle = LazyHandle(self.get_open_for_model(node))
conn.handle = LazyHandle(self._open2)
# Add this connection to the thread/compute connection pool.
self._add_compute_connection(conn)
# Remove the connection currently in use by this thread from the thread connection pool.
Expand Down Expand Up @@ -1045,6 +1057,19 @@ def _get_if_exists_compute_connection(
threads_map = self._get_compute_connections()
return threads_map.get(compute_name)

def _cleanup_idle_connections(self) -> None:
assert (
USE_LONG_SESSIONS
), "This path, '_cleanup_idle_connections', should only be reachable with USE_LONG_SESSIONS"

with self.lock:
for thread_conns in self.threads_compute_connections.values():
for conn in thread_conns.values():
if conn.acquire_release_count == 0 and conn._idle_too_long():
logger.debug(f"closing idle connection: {conn._get_conn_info_str()}")
self.close(conn)
conn.handle = LazyHandle(self._open2)

def add_query(
self,
sql: str,
Expand Down Expand Up @@ -1073,6 +1098,7 @@ def add_query(
node_info=get_node_info(),
)
)

pre = time.time()

cursor = cast(DatabricksSQLConnectionWrapper, connection.handle).cursor()
Expand Down Expand Up @@ -1132,6 +1158,7 @@ def _execute_cursor(
node_info=get_node_info(),
)
)

pre = time.time()

handle: DatabricksSQLConnectionWrapper = connection.handle
Expand Down Expand Up @@ -1255,6 +1282,81 @@ def exponential_backoff(attempt: int) -> int:
retry_timeout=(timeout if timeout is not None else exponential_backoff),
)

@classmethod
def _open2(cls, connection: Connection) -> Connection:
rcypher-databricks marked this conversation as resolved.
Show resolved Hide resolved
# Once long session management is no longer under the USE_LONG_SESSIONS toggle
# this should be renamed and replace the _open class method.
assert (
USE_LONG_SESSIONS
), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS"

if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection

creds: DatabricksCredentials = connection.credentials
timeout = creds.connect_timeout

# gotta keep this so we don't prompt users many times
cls.credentials_provider = creds.authenticate(cls.credentials_provider)

user_agent_entry = f"dbt-databricks/{__version__}"

invocation_env = creds.get_invocation_env()
if invocation_env:
user_agent_entry = f"{user_agent_entry}; {invocation_env}"

connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr]

http_headers: List[Tuple[str, str]] = list(
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

# If a model specifies a compute resource the http path
# may be different than the http_path property of creds.
http_path = cast(DatabricksDBTConnection, connection).http_path

def connect() -> DatabricksSQLConnectionWrapper:
try:
# TODO: what is the error when a user specifies a catalog they don't have access to
conn: DatabricksSQLConnection = dbsql.connect(
server_hostname=creds.host,
http_path=http_path,
credentials_provider=cls.credentials_provider,
http_headers=http_headers if http_headers else None,
session_configuration=creds.session_properties,
catalog=creds.database,
# schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL.
_user_agent_entry=user_agent_entry,
**connection_parameters,
)
return DatabricksSQLConnectionWrapper(
conn,
is_cluster=creds.cluster_id is not None,
creds=creds,
user_agent=user_agent_entry,
)
except Error as exc:
_log_dbsql_errors(exc)
raise

def exponential_backoff(attempt: int) -> int:
return attempt * attempt

retryable_exceptions = []
# this option is for backwards compatibility
if creds.retry_all:
retryable_exceptions = [Error]

return cls.retry_connection(
connection,
connect=connect,
logger=logger,
retryable_exceptions=retryable_exceptions,
retry_limit=creds.connect_retries,
retry_timeout=(timeout if timeout is not None else exponential_backoff),
)

@classmethod
def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse:
_query_id = getattr(cursor, "hex_query_id", None)
Expand Down Expand Up @@ -1408,3 +1510,30 @@ def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) ->
)

return http_path


def _get_max_idle_time(node: Optional[ResultNode], creds: DatabricksCredentials) -> int:
"""Get the http_path for the compute specified for the node.
If none is specified default will be used."""

max_idle_time = (
DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle
)

if node:
compute_name = _get_compute_name(node)
if compute_name and creds.compute:
max_idle_time = creds.compute.get(compute_name, {}).get(
"connect_max_idle", max_idle_time
)

if not isinstance(max_idle_time, Number):
if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric():
return int(max_idle_time.strip())
else:
raise dbt.exceptions.DbtRuntimeError(
f"{max_idle_time} is not a valid value for connect_max_idle. "
"Must be a number of seconds."
)

return max_idle_time
30 changes: 30 additions & 0 deletions tests/functional/adapter/long_sessions/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,33 @@

select * from {{ ref('source') }}
"""

targetseq1 = """
{{config(materialized='table', databricks_compute='alternate_warehouse')}}

select * from {{ ref('source') }}
"""

targetseq2 = """
{{config(materialized='table')}}

select * from {{ ref('targetseq1') }}
"""

targetseq3 = """
{{config(materialized='table')}}

select * from {{ ref('targetseq2') }}
"""

targetseq4 = """
{{config(materialized='table')}}

select * from {{ ref('targetseq3') }}
"""

targetseq5 = """
{{config(materialized='table', databricks_compute='alternate_warehouse')}}

select * from {{ ref('targetseq4') }}
"""
36 changes: 36 additions & 0 deletions tests/functional/adapter/long_sessions/test_long_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def profiles_config_update(self, dbt_profile_target):
outputs["default"]["compute"] = {
"alternate_warehouse": {"http_path": dbt_profile_target["http_path"]},
}

return {"test": {"outputs": outputs, "target": "default"}}

def test_long_sessions(self, project):
Expand All @@ -78,3 +79,38 @@ def test_long_sessions(self, project):
_, log = util.run_dbt_and_capture(["--debug", "run"])
open_count = log.count("Sending request: OpenSession") / 2
assert open_count == 3


class TestLongSessionsIdleCleanup(TestLongSessionsMultipleCompute):
args_formatter = ""

@pytest.fixture(scope="class")
def models(self):
m = {
"targetseq1.sql": fixtures.targetseq1,
"targetseq2.sql": fixtures.targetseq2,
"targetseq3.sql": fixtures.targetseq3,
"targetseq4.sql": fixtures.targetseq4,
"targetseq5.sql": fixtures.targetseq5,
}
return m

@pytest.fixture(scope="class")
def profiles_config_update(self, dbt_profile_target):
outputs = {"default": dbt_profile_target}
outputs["default"]["connect_max_idle"] = 1
outputs["default"]["compute"] = {
"alternate_warehouse": {
"http_path": dbt_profile_target["http_path"],
"connect_max_idle": 1,
},
}

return {"test": {"outputs": outputs, "target": "default"}}

def test_long_sessions(self, project):
util.run_dbt(["--debug", "seed"])

_, log = util.run_dbt_and_capture(["--debug", "run"])
idle_count = log.count("closing idle connection") / 2
assert idle_count > 0
Loading
Loading