Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor reading env vars #888

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## dbt-databricks 1.9.2 (TBD)

### Under the Hood

- Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888))

## dbt-databricks 1.9.1 (December 16, 2024)

### Features
Expand Down
16 changes: 8 additions & 8 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
CursorCreate,
)
from dbt.adapters.databricks.events.other_events import QueryError
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
from dbt.adapters.databricks.utils import redact_credentials
Expand Down Expand Up @@ -86,9 +87,6 @@
DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)")


# toggle for session managements that minimizes the number of sessions opened/closed
USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"

# Number of idle seconds before a connection is automatically closed. Only applicable if
# USE_LONG_SESSIONS is true.
# Updated when idle times of 180s were causing errors
Expand Down Expand Up @@ -475,6 +473,8 @@ def add_query(
auto_begin: bool = True,
bindings: Optional[Any] = None,
abridge_sql_log: bool = False,
retryable_exceptions: tuple[type[Exception], ...] = tuple(),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is new interface in latest upstream. Hopefully this doesn't break anything, otherwise we (and other adapters) will be in for a world of hurt.

retry_limit: int = 1,
*,
close_cursor: bool = False,
) -> tuple[Connection, Any]:
Expand Down Expand Up @@ -707,7 +707,7 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe
class ExtendedSessionConnectionManager(DatabricksConnectionManager):
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None:
assert (
USE_LONG_SESSIONS
GlobalState.get_use_long_sessions()
), "This connection manager should only be used when USE_LONG_SESSIONS is enabled"
super().__init__(profile, mp_context)
self.threads_compute_connections: dict[
Expand Down Expand Up @@ -910,7 +910,7 @@ def open(cls, connection: Connection) -> Connection:
# Once long session management is no longer under the USE_LONG_SESSIONS toggle
# this should be renamed and replace the _open class method.
assert (
USE_LONG_SESSIONS
GlobalState.get_use_long_sessions()
), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS"

databricks_connection = cast(DatabricksDBTConnection, connection)
Expand Down Expand Up @@ -1013,15 +1013,15 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O

# If there is no node we return the http_path for the default compute.
if not query_header_context:
if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(f"Thread {thread_id}: using default compute resource.")
return creds.http_path

# Get the name of the compute resource specified in the node's config.
# If none is specified return the http_path for the default compute.
compute_name = _get_compute_name(query_header_context)
if not compute_name:
if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.")
return creds.http_path

Expand All @@ -1037,7 +1037,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O
f"does not specify http_path, relation: {relation_name}"
)

if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(
f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'."
)
Expand Down
8 changes: 3 additions & 5 deletions dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
CredentialSaveError,
CredentialShardEvent,
)
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.logging import logger

CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"
DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV"
DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$")
EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)")
DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS"
Expand Down Expand Up @@ -150,7 +150,7 @@ def validate_creds(self) -> None:

@classmethod
def get_invocation_env(cls) -> Optional[str]:
invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV)
invocation_env = GlobalState.get_invocation_env()
if invocation_env:
# Thrift doesn't allow nested () so we need to ensure
# that the passed user agent is valid.
Expand All @@ -160,9 +160,7 @@ def get_invocation_env(cls) -> Optional[str]:

@classmethod
def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]:
http_session_headers_str: Optional[str] = os.environ.get(
DBT_DATABRICKS_HTTP_SESSION_HEADERS
)
http_session_headers_str = GlobalState.get_http_session_headers()

http_session_headers_dict: dict[str, str] = (
{
Expand Down
58 changes: 58 additions & 0 deletions dbt/adapters/databricks/global_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
from typing import ClassVar, Optional


class GlobalState:
"""Global state is a bad idea, but since we don't control instantiation, better to have it in a
single place than scattered throughout the codebase.
"""

__use_long_sessions: ClassVar[Optional[bool]] = None

@classmethod
def get_use_long_sessions(cls) -> bool:
if cls.__use_long_sessions is None:
cls.__use_long_sessions = (
os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"
)
return cls.__use_long_sessions

__invocation_env: ClassVar[Optional[str]] = None
__invocation_env_set: ClassVar[bool] = False

@classmethod
def get_invocation_env(cls) -> Optional[str]:
if not cls.__invocation_env_set:
cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV")
cls.__invocation_env_set = True
return cls.__invocation_env

__session_headers: ClassVar[Optional[str]] = None
__session_headers_set: ClassVar[bool] = False

@classmethod
def get_http_session_headers(cls) -> Optional[str]:
if not cls.__session_headers_set:
cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS")
cls.__session_headers_set = True
return cls.__session_headers

__describe_char_bypass: ClassVar[Optional[bool]] = None

@classmethod
def get_char_limit_bypass(cls) -> bool:
if cls.__describe_char_bypass is None:
cls.__describe_char_bypass = (
os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE"
)
return cls.__describe_char_bypass

__connector_log_level: ClassVar[Optional[str]] = None

@classmethod
def get_connector_log_level(cls) -> str:
if cls.__connector_log_level is None:
cls.__connector_log_level = os.getenv(
"DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN"
).upper()
return cls.__connector_log_level
6 changes: 3 additions & 3 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.connections import (
USE_LONG_SESSIONS,
DatabricksConnectionManager,
ExtendedSessionConnectionManager,
)
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.python_models.python_submissions import (
AllPurposeClusterPythonJobHelper,
JobClusterPythonJobHelper,
Expand Down Expand Up @@ -142,7 +142,7 @@ def get_identifier_list_string(table_names: set[str]) -> str:
"""

_identifier = "|".join(table_names)
bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false")
bypass_2048_char_limit = GlobalState.get_char_limit_bypass()
if bypass_2048_char_limit == "true":
_identifier = _identifier if len(_identifier) < 2048 else "*"
return _identifier
Expand All @@ -154,7 +154,7 @@ class DatabricksAdapter(SparkAdapter):
Relation = DatabricksRelation
Column = DatabricksColumn

if USE_LONG_SESSIONS:
if GlobalState.get_use_long_sessions():
ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager
else:
ConnectionManager = DatabricksConnectionManager
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/databricks/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from logging import Handler, LogRecord, getLogger
from typing import Union

from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("Databricks")
Expand All @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None:
dbt_adapter_logger = AdapterLogger("databricks-sql-connector")

pysql_logger = getLogger("databricks.sql")
pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper()
pysql_logger_level = GlobalState.get_connector_log_level()
pysql_logger.setLevel(pysql_logger_level)

pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level)
Expand Down
33 changes: 22 additions & 11 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.credentials import (
CATALOG_KEY_IN_SESSION_PROPERTIES,
DBT_DATABRICKS_HTTP_SESSION_HEADERS,
DBT_DATABRICKS_INVOCATION_ENV,
)
from dbt.adapters.databricks.impl import get_identifier_list_string
from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType
Expand Down Expand Up @@ -114,7 +112,10 @@ def test_invalid_custom_user_agent(self):
with pytest.raises(DbtValidationError) as excinfo:
config = self._get_config()
adapter = DatabricksAdapter(config, get_context("spawn"))
with patch.dict("os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "(Some-thing)"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_invocation_env",
return_value="(Some-thing)",
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load

Expand All @@ -128,8 +129,9 @@ def test_custom_user_agent(self):
"dbt.adapters.databricks.connections.dbsql.connect",
new=self._connect_func(expected_invocation_env="databricks-workflows"),
):
with patch.dict(
"os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "databricks-workflows"}
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_invocation_env",
return_value="databricks-workflows",
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
Expand Down Expand Up @@ -190,9 +192,9 @@ def _test_environment_http_headers(
"dbt.adapters.databricks.connections.dbsql.connect",
new=self._connect_func(expected_http_headers=expected_http_headers),
):
with patch.dict(
"os.environ",
**{DBT_DATABRICKS_HTTP_SESSION_HEADERS: http_headers_str},
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_http_session_headers",
return_value=http_headers_str,
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
Expand Down Expand Up @@ -912,7 +914,10 @@ def test_describe_table_extended_2048_char_limit(self):
assert get_identifier_list_string(table_names) == "|".join(table_names)

# If environment variable is set, then limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# Long list of table names is capped
assert get_identifier_list_string(table_names) == "*"

Expand Down Expand Up @@ -941,7 +946,10 @@ def test_describe_table_extended_should_limit(self):
table_names = set([f"customers_{i}" for i in range(200)])

# If environment variable is set, then limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# Long list of table names is capped
assert get_identifier_list_string(table_names) == "*"

Expand All @@ -954,7 +962,10 @@ def test_describe_table_extended_may_limit(self):
table_names = set([f"customers_{i}" for i in range(200)])

# If environment variable is set, then we may limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# But a short list of table names is not capped
assert get_identifier_list_string(list(table_names)[:5]) == "|".join(
list(table_names)[:5]
Expand Down
Loading