diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 509686d7..0b523574 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -59,6 +59,7 @@ CursorCreate, ) from dbt.adapters.databricks.events.other_events import QueryError +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -86,9 +87,6 @@ DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") -# toggle for session managements that minimizes the number of sessions opened/closed -USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" - # Number of idle seconds before a connection is automatically closed. Only applicable if # USE_LONG_SESSIONS is true. # Updated when idle times of 180s were causing errors @@ -475,6 +473,8 @@ def add_query( auto_begin: bool = True, bindings: Optional[Any] = None, abridge_sql_log: bool = False, + retryable_exceptions: tuple[type[Exception], ...] = tuple(), + retry_limit: int = 1, *, close_cursor: bool = False, ) -> tuple[Connection, Any]: @@ -707,7 +707,7 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe class ExtendedSessionConnectionManager(DatabricksConnectionManager): def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None: assert ( - USE_LONG_SESSIONS + GlobalState.get_use_long_sessions() ), "This connection manager should only be used when USE_LONG_SESSIONS is enabled" super().__init__(profile, mp_context) self.threads_compute_connections: dict[ @@ -910,7 +910,7 @@ def open(cls, connection: Connection) -> Connection: # Once long session management is no longer under the USE_LONG_SESSIONS toggle # this should be renamed and replace the _open class method. assert ( - USE_LONG_SESSIONS + GlobalState.get_use_long_sessions() ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" databricks_connection = cast(DatabricksDBTConnection, connection) @@ -1013,7 +1013,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O # If there is no node we return the http_path for the default compute. if not query_header_context: - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug(f"Thread {thread_id}: using default compute resource.") return creds.http_path @@ -1021,7 +1021,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O # If none is specified return the http_path for the default compute. compute_name = _get_compute_name(query_header_context) if not compute_name: - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.") return creds.http_path @@ -1037,7 +1037,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O f"does not specify http_path, relation: {relation_name}" ) - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug( f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'." ) diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 7a318cad..387d0e76 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -19,10 +19,10 @@ CredentialSaveError, CredentialShardEvent, ) +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" -DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS" @@ -150,7 +150,7 @@ def validate_creds(self) -> None: @classmethod def get_invocation_env(cls) -> Optional[str]: - invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV) + invocation_env = GlobalState.get_invocation_env() if invocation_env: # Thrift doesn't allow nested () so we need to ensure # that the passed user agent is valid. @@ -160,9 +160,7 @@ def get_invocation_env(cls) -> Optional[str]: @classmethod def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: - http_session_headers_str: Optional[str] = os.environ.get( - DBT_DATABRICKS_HTTP_SESSION_HEADERS - ) + http_session_headers_str = GlobalState.get_http_session_headers() http_session_headers_dict: dict[str, str] = ( { diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py new file mode 100644 index 00000000..52cd8a15 --- /dev/null +++ b/dbt/adapters/databricks/global_state.py @@ -0,0 +1,58 @@ +import os +from typing import ClassVar, Optional + + +class GlobalState: + """Global state is a bad idea, but since we don't control instantiation, better to have it in a + single place than scattered throughout the codebase. + """ + + __use_long_sessions: ClassVar[Optional[bool]] = None + + @classmethod + def get_use_long_sessions(cls) -> bool: + if cls.__use_long_sessions is None: + cls.__use_long_sessions = ( + os.environ.get("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" + ) + return cls.__use_long_sessions + + __invocation_env: ClassVar[Optional[str]] = None + __invocation_env_set: ClassVar[bool] = False + + @classmethod + def get_invocation_env(cls) -> Optional[str]: + if not cls.__invocation_env_set: + cls.__invocation_env = os.environ.get("DBT_DATABRICKS_INVOCATION_ENV") + cls.__invocation_env_set = True + return cls.__invocation_env + + __session_headers: ClassVar[Optional[str]] = None + __session_headers_set: ClassVar[bool] = False + + @classmethod + def get_http_session_headers(cls) -> Optional[str]: + if not cls.__session_headers_set: + cls.__session_headers = os.environ.get("DBT_DATABRICKS_HTTP_SESSION_HEADERS") + cls.__session_headers_set = True + return cls.__session_headers + + __describe_char_bypass: ClassVar[Optional[bool]] = None + + @classmethod + def get_char_limit_bypass(cls) -> bool: + if cls.__describe_char_bypass is None: + cls.__describe_char_bypass = ( + os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE" + ) + return cls.__describe_char_bypass + + __connector_log_level: ClassVar[Optional[str]] = None + + @classmethod + def get_connector_log_level(cls) -> str: + if cls.__connector_log_level is None: + cls.__connector_log_level = os.environ.get( + "DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN" + ).upper() + return cls.__connector_log_level diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index dce432c9..15c333e2 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -32,10 +32,10 @@ ) from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.connections import ( - USE_LONG_SESSIONS, DatabricksConnectionManager, ExtendedSessionConnectionManager, ) +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, @@ -142,7 +142,7 @@ def get_identifier_list_string(table_names: set[str]) -> str: """ _identifier = "|".join(table_names) - bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false") + bypass_2048_char_limit = GlobalState.get_char_limit_bypass() if bypass_2048_char_limit == "true": _identifier = _identifier if len(_identifier) < 2048 else "*" return _identifier @@ -154,7 +154,7 @@ class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation Column = DatabricksColumn - if USE_LONG_SESSIONS: + if GlobalState.get_use_long_sessions(): ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager else: ConnectionManager = DatabricksConnectionManager diff --git a/dbt/adapters/databricks/logging.py b/dbt/adapters/databricks/logging.py index d0f1d42b..81e7449e 100644 --- a/dbt/adapters/databricks/logging.py +++ b/dbt/adapters/databricks/logging.py @@ -1,7 +1,7 @@ -import os from logging import Handler, LogRecord, getLogger from typing import Union +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.events.logging import AdapterLogger logger = AdapterLogger("Databricks") @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None: dbt_adapter_logger = AdapterLogger("databricks-sql-connector") pysql_logger = getLogger("databricks.sql") -pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper() +pysql_logger_level = GlobalState.get_connector_log_level() pysql_logger.setLevel(pysql_logger_level) pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 78ae12cb..4ac564be 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -11,8 +11,6 @@ from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.credentials import ( CATALOG_KEY_IN_SESSION_PROPERTIES, - DBT_DATABRICKS_HTTP_SESSION_HEADERS, - DBT_DATABRICKS_INVOCATION_ENV, ) from dbt.adapters.databricks.impl import get_identifier_list_string from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType @@ -114,7 +112,7 @@ def test_invalid_custom_user_agent(self): with pytest.raises(DbtValidationError) as excinfo: config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch.dict("os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "(Some-thing)"}): + with patch.dict("os.environ", **{"DBT_DATABRICKS_INVOCATION_ENV": "(Some-thing)"}): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -129,7 +127,7 @@ def test_custom_user_agent(self): new=self._connect_func(expected_invocation_env="databricks-workflows"), ): with patch.dict( - "os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "databricks-workflows"} + "os.environ", **{"DBT_DATABRICKS_INVOCATION_ENV": "databricks-workflows"} ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -192,7 +190,7 @@ def _test_environment_http_headers( ): with patch.dict( "os.environ", - **{DBT_DATABRICKS_HTTP_SESSION_HEADERS: http_headers_str}, + **{"DBT_DATABRICKS_HTTP_SESSION_HEADERS": http_headers_str}, ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load