Skip to content

Commit

Permalink
Refactor reading env vars (#888)
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db authored Dec 19, 2024
1 parent 477b745 commit 5f6412d
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 29 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## dbt-databricks 1.9.2 (TBD)

### Under the Hood

- Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888))

## dbt-databricks 1.9.1 (December 16, 2024)

### Features
Expand Down
16 changes: 8 additions & 8 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
CursorCreate,
)
from dbt.adapters.databricks.events.other_events import QueryError
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
from dbt.adapters.databricks.utils import redact_credentials
Expand Down Expand Up @@ -86,9 +87,6 @@
DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)")


# toggle for session managements that minimizes the number of sessions opened/closed
USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"

# Number of idle seconds before a connection is automatically closed. Only applicable if
# USE_LONG_SESSIONS is true.
# Updated when idle times of 180s were causing errors
Expand Down Expand Up @@ -475,6 +473,8 @@ def add_query(
auto_begin: bool = True,
bindings: Optional[Any] = None,
abridge_sql_log: bool = False,
retryable_exceptions: tuple[type[Exception], ...] = tuple(),
retry_limit: int = 1,
*,
close_cursor: bool = False,
) -> tuple[Connection, Any]:
Expand Down Expand Up @@ -707,7 +707,7 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe
class ExtendedSessionConnectionManager(DatabricksConnectionManager):
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None:
assert (
USE_LONG_SESSIONS
GlobalState.get_use_long_sessions()
), "This connection manager should only be used when USE_LONG_SESSIONS is enabled"
super().__init__(profile, mp_context)
self.threads_compute_connections: dict[
Expand Down Expand Up @@ -910,7 +910,7 @@ def open(cls, connection: Connection) -> Connection:
# Once long session management is no longer under the USE_LONG_SESSIONS toggle
# this should be renamed and replace the _open class method.
assert (
USE_LONG_SESSIONS
GlobalState.get_use_long_sessions()
), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS"

databricks_connection = cast(DatabricksDBTConnection, connection)
Expand Down Expand Up @@ -1013,15 +1013,15 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O

# If there is no node we return the http_path for the default compute.
if not query_header_context:
if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(f"Thread {thread_id}: using default compute resource.")
return creds.http_path

# Get the name of the compute resource specified in the node's config.
# If none is specified return the http_path for the default compute.
compute_name = _get_compute_name(query_header_context)
if not compute_name:
if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.")
return creds.http_path

Expand All @@ -1037,7 +1037,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O
f"does not specify http_path, relation: {relation_name}"
)

if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(
f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'."
)
Expand Down
8 changes: 3 additions & 5 deletions dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
CredentialSaveError,
CredentialShardEvent,
)
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.logging import logger

CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"
DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV"
DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$")
EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)")
DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS"
Expand Down Expand Up @@ -150,7 +150,7 @@ def validate_creds(self) -> None:

@classmethod
def get_invocation_env(cls) -> Optional[str]:
invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV)
invocation_env = GlobalState.get_invocation_env()
if invocation_env:
# Thrift doesn't allow nested () so we need to ensure
# that the passed user agent is valid.
Expand All @@ -160,9 +160,7 @@ def get_invocation_env(cls) -> Optional[str]:

@classmethod
def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]:
http_session_headers_str: Optional[str] = os.environ.get(
DBT_DATABRICKS_HTTP_SESSION_HEADERS
)
http_session_headers_str = GlobalState.get_http_session_headers()

http_session_headers_dict: dict[str, str] = (
{
Expand Down
58 changes: 58 additions & 0 deletions dbt/adapters/databricks/global_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
from typing import ClassVar, Optional


class GlobalState:
"""Global state is a bad idea, but since we don't control instantiation, better to have it in a
single place than scattered throughout the codebase.
"""

__use_long_sessions: ClassVar[Optional[bool]] = None

@classmethod
def get_use_long_sessions(cls) -> bool:
if cls.__use_long_sessions is None:
cls.__use_long_sessions = (
os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"
)
return cls.__use_long_sessions

__invocation_env: ClassVar[Optional[str]] = None
__invocation_env_set: ClassVar[bool] = False

@classmethod
def get_invocation_env(cls) -> Optional[str]:
if not cls.__invocation_env_set:
cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV")
cls.__invocation_env_set = True
return cls.__invocation_env

__session_headers: ClassVar[Optional[str]] = None
__session_headers_set: ClassVar[bool] = False

@classmethod
def get_http_session_headers(cls) -> Optional[str]:
if not cls.__session_headers_set:
cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS")
cls.__session_headers_set = True
return cls.__session_headers

__describe_char_bypass: ClassVar[Optional[bool]] = None

@classmethod
def get_char_limit_bypass(cls) -> bool:
if cls.__describe_char_bypass is None:
cls.__describe_char_bypass = (
os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE"
)
return cls.__describe_char_bypass

__connector_log_level: ClassVar[Optional[str]] = None

@classmethod
def get_connector_log_level(cls) -> str:
if cls.__connector_log_level is None:
cls.__connector_log_level = os.getenv(
"DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN"
).upper()
return cls.__connector_log_level
6 changes: 3 additions & 3 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.connections import (
USE_LONG_SESSIONS,
DatabricksConnectionManager,
ExtendedSessionConnectionManager,
)
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.python_models.python_submissions import (
AllPurposeClusterPythonJobHelper,
JobClusterPythonJobHelper,
Expand Down Expand Up @@ -142,7 +142,7 @@ def get_identifier_list_string(table_names: set[str]) -> str:
"""

_identifier = "|".join(table_names)
bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false")
bypass_2048_char_limit = GlobalState.get_char_limit_bypass()
if bypass_2048_char_limit == "true":
_identifier = _identifier if len(_identifier) < 2048 else "*"
return _identifier
Expand All @@ -154,7 +154,7 @@ class DatabricksAdapter(SparkAdapter):
Relation = DatabricksRelation
Column = DatabricksColumn

if USE_LONG_SESSIONS:
if GlobalState.get_use_long_sessions():
ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager
else:
ConnectionManager = DatabricksConnectionManager
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/databricks/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from logging import Handler, LogRecord, getLogger
from typing import Union

from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("Databricks")
Expand All @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None:
dbt_adapter_logger = AdapterLogger("databricks-sql-connector")

pysql_logger = getLogger("databricks.sql")
pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper()
pysql_logger_level = GlobalState.get_connector_log_level()
pysql_logger.setLevel(pysql_logger_level)

pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level)
Expand Down
33 changes: 22 additions & 11 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.credentials import (
CATALOG_KEY_IN_SESSION_PROPERTIES,
DBT_DATABRICKS_HTTP_SESSION_HEADERS,
DBT_DATABRICKS_INVOCATION_ENV,
)
from dbt.adapters.databricks.impl import get_identifier_list_string
from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType
Expand Down Expand Up @@ -114,7 +112,10 @@ def test_invalid_custom_user_agent(self):
with pytest.raises(DbtValidationError) as excinfo:
config = self._get_config()
adapter = DatabricksAdapter(config, get_context("spawn"))
with patch.dict("os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "(Some-thing)"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_invocation_env",
return_value="(Some-thing)",
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load

Expand All @@ -128,8 +129,9 @@ def test_custom_user_agent(self):
"dbt.adapters.databricks.connections.dbsql.connect",
new=self._connect_func(expected_invocation_env="databricks-workflows"),
):
with patch.dict(
"os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "databricks-workflows"}
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_invocation_env",
return_value="databricks-workflows",
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
Expand Down Expand Up @@ -190,9 +192,9 @@ def _test_environment_http_headers(
"dbt.adapters.databricks.connections.dbsql.connect",
new=self._connect_func(expected_http_headers=expected_http_headers),
):
with patch.dict(
"os.environ",
**{DBT_DATABRICKS_HTTP_SESSION_HEADERS: http_headers_str},
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_http_session_headers",
return_value=http_headers_str,
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
Expand Down Expand Up @@ -912,7 +914,10 @@ def test_describe_table_extended_2048_char_limit(self):
assert get_identifier_list_string(table_names) == "|".join(table_names)

# If environment variable is set, then limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# Long list of table names is capped
assert get_identifier_list_string(table_names) == "*"

Expand Down Expand Up @@ -941,7 +946,10 @@ def test_describe_table_extended_should_limit(self):
table_names = set([f"customers_{i}" for i in range(200)])

# If environment variable is set, then limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# Long list of table names is capped
assert get_identifier_list_string(table_names) == "*"

Expand All @@ -954,7 +962,10 @@ def test_describe_table_extended_may_limit(self):
table_names = set([f"customers_{i}" for i in range(200)])

# If environment variable is set, then we may limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# But a short list of table names is not capped
assert get_identifier_list_string(list(table_names)[:5]) == "|".join(
list(table_names)[:5]
Expand Down

0 comments on commit 5f6412d

Please sign in to comment.