Skip to content

Commit

Permalink
Merge branch 'main' into 1.9.latest
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Dec 3, 2024
2 parents d2a4e42 + 69aa772 commit 925d59f
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 192 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

- Replace array indexing with 'get' in split_part so as not to raise exception when indexing beyond bounds ([839](https://github.com/databricks/dbt-databricks/pull/839))
- Set queue enabled for Python notebook jobs ([856](https://github.com/databricks/dbt-databricks/pull/856))
- Ensure columns that are added get backticked ([859](https://github.com/databricks/dbt-databricks/pull/859))

### Under the Hood

Expand All @@ -34,6 +35,7 @@
- Fix behavior flag use in init of DatabricksAdapter (thanks @VersusFacit!) ([836](https://github.com/databricks/dbt-databricks/pull/836))
- Restrict pydantic to V1 per dbt Labs' request ([843](https://github.com/databricks/dbt-databricks/pull/843))
- Switching to Ruff for formatting and linting ([847](https://github.com/databricks/dbt-databricks/pull/847))
- Refactoring location of DLT polling code ([849](https://github.com/databricks/dbt-databricks/pull/849))
- Switching to Hatch and pyproject.toml for project config ([853](https://github.com/databricks/dbt-databricks/pull/853))

## dbt-databricks 1.8.7 (October 10, 2024)
Expand Down
55 changes: 55 additions & 0 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,60 @@ def run(self, job_id: str, enable_queueing: bool = True) -> str:
return response_json["run_id"]


class DltPipelineApi(PollableApi):
def __init__(self, session: Session, host: str, polling_interval: int):
super().__init__(session, host, "/api/2.0/pipelines", polling_interval, 60 * 60)

def poll_for_completion(self, pipeline_id: str) -> None:
self._poll_api(
url=f"/{pipeline_id}",
params={},
get_state_func=lambda response: response.json()["state"],
terminal_states={"IDLE", "FAILED", "DELETED"},
expected_end_state="IDLE",
unexpected_end_state_func=self._get_exception,
)

def _get_exception(self, response: Response) -> None:
response_json = response.json()
cause = response_json.get("cause")
if cause:
raise DbtRuntimeError(f"Pipeline {response_json.get('pipeline_id')} failed: {cause}")
else:
latest_update = response_json.get("latest_updates")[0]
last_error = self.get_update_error(response_json.get("pipeline_id"), latest_update)
raise DbtRuntimeError(
f"Pipeline {response_json.get('pipeline_id')} failed: {last_error}"
)

def get_update_error(self, pipeline_id: str, update_id: str) -> str:
response = self.session.get(f"/{pipeline_id}/events")
if response.status_code != 200:
raise DbtRuntimeError(
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
)

events = response.json().get("events", [])
update_events = [
e
for e in events
if e.get("event_type", "") == "update_progress"
and e.get("origin", {}).get("update_id") == update_id
]

error_events = [
e
for e in update_events
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
]

msg = ""
if error_events:
msg = error_events[0].get("message", "")

return msg


class DatabricksApiClient:
def __init__(
self,
Expand All @@ -481,6 +535,7 @@ def __init__(
self.job_runs = JobRunsApi(session, host, polling_interval, timeout)
self.workflows = WorkflowJobApi(session, host)
self.workflow_permissions = JobPermissionsApi(session, host)
self.dlt_pipelines = DltPipelineApi(session, host, polling_interval)

@staticmethod
def create(
Expand Down
16 changes: 15 additions & 1 deletion dbt/adapters/databricks/column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Any, ClassVar, Optional

from dbt.adapters.databricks.utils import quote
from dbt.adapters.spark.column import SparkColumn


Expand Down Expand Up @@ -28,3 +29,16 @@ def data_type(self) -> str:

def __repr__(self) -> str:
return "<DatabricksColumn {} ({})>".format(self.name, self.data_type)

@staticmethod
def get_name(column: dict[str, Any]) -> str:
name = column["name"]
return quote(name) if column.get("quote", False) else name

@staticmethod
def format_remove_column_list(columns: list["DatabricksColumn"]) -> str:
return ", ".join([quote(c.name) for c in columns])

@staticmethod
def format_add_column_list(columns: list["DatabricksColumn"]) -> str:
return ", ".join([f"{quote(c.name)} {c.data_type}" for c in columns])
157 changes: 6 additions & 151 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError
from dbt_common.utils import cast_to_str
from requests import Session

import databricks.sql as dbsql
from databricks.sql.client import Connection as DatabricksSQLConnection
Expand All @@ -35,7 +34,6 @@
)
from dbt.adapters.databricks.__version__ import version as __version__
from dbt.adapters.databricks.api_client import DatabricksApiClient
from dbt.adapters.databricks.auth import BearerAuth
from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider
from dbt.adapters.databricks.events.connection_events import (
ConnectionAcquire,
Expand All @@ -61,7 +59,6 @@
CursorCreate,
)
from dbt.adapters.databricks.events.other_events import QueryError
from dbt.adapters.databricks.events.pipeline_events import PipelineRefresh, PipelineRefreshError
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
from dbt.adapters.databricks.utils import redact_credentials
Expand Down Expand Up @@ -227,97 +224,6 @@ def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None:
bindings = [self._fix_binding(binding) for binding in bindings]
self._cursor.execute(sql, bindings)

def poll_refresh_pipeline(self, pipeline_id: str) -> None:
# interval in seconds
polling_interval = 10

# timeout in seconds
timeout = 60 * 60

stopped_states = ("COMPLETED", "FAILED", "CANCELED")
host: str = self._creds.host or ""
headers = (
self._cursor.connection.thrift_backend._auth_provider._header_factory # type: ignore
)
session = Session()
session.auth = BearerAuth(headers)
session.headers = {"User-Agent": self._user_agent}
pipeline = _get_pipeline_state(session, host, pipeline_id)
# get the most recently created update for the pipeline
latest_update = _find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

state = latest_update.get("state")
# we use update_id to retrieve the update in the polling loop
update_id = latest_update.get("update_id", "")
prev_state = state

logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))

start = time.time()
exceeded_timeout = False
while state not in stopped_states:
if time.time() - start > timeout:
exceeded_timeout = True
break

# should we do exponential backoff?
time.sleep(polling_interval)

pipeline = _get_pipeline_state(session, host, pipeline_id)
# get the update we are currently polling
update = _find_update(pipeline, update_id)
if not update:
raise DbtRuntimeError(
f"Error getting pipeline update info: {pipeline_id}, update: {update_id}"
)

state = update.get("state")
if state != prev_state:
logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))
prev_state = state

if state == "FAILED":
logger.error(
PipelineRefreshError(
pipeline_id,
update_id,
_get_update_error_msg(session, host, pipeline_id, update_id),
)
)

# another update may have been created due to retry_on_fail settings
# get the latest update and see if it is a new one
latest_update = _find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

latest_update_id = latest_update.get("update_id", "")
if latest_update_id != update_id:
update_id = latest_update_id
state = None

if exceeded_timeout:
raise DbtRuntimeError("timed out waiting for materialized view refresh")

if state == "FAILED":
msg = _get_update_error_msg(session, host, pipeline_id, update_id)
raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}")

if state == "CANCELED":
raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled")

return

@classmethod
def findUpdate(cls, updates: list, id: str) -> Optional[dict]:
matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None

@property
def hex_query_id(self) -> str:
"""Return the hex GUID for this query
Expand Down Expand Up @@ -475,12 +381,15 @@ class DatabricksConnectionManager(SparkConnectionManager):
credentials_provider: Optional[TCredentialProvider] = None
_user_agent = f"dbt-databricks/{__version__}"

def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
super().__init__(profile, mp_context)
creds = cast(DatabricksCredentials, self.profile.credentials)
self.api_client = DatabricksApiClient.create(creds, 15 * 60)

def cancel_open(self) -> list[str]:
cancelled = super().cancel_open()
creds = cast(DatabricksCredentials, self.profile.credentials)
api_client = DatabricksApiClient.create(creds, 15 * 60)
logger.info("Cancelling open python jobs")
PythonRunTracker.cancel_runs(api_client)
PythonRunTracker.cancel_runs(self.api_client)
return cancelled

def compare_dbr_version(self, major: int, minor: int) -> int:
Expand Down Expand Up @@ -1079,60 +988,6 @@ def exponential_backoff(attempt: int) -> int:
)


def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict:
pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}"

response = session.get(pipeline_url)
if response.status_code != 200:
raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}")

return response.json()


def _find_update(pipeline: dict, id: str = "") -> Optional[dict]:
updates = pipeline.get("latest_updates", [])
if not updates:
raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}")

if not id:
return updates[0]

matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None


def _get_update_error_msg(session: Session, host: str, pipeline_id: str, update_id: str) -> str:
events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events"
response = session.get(events_url)
if response.status_code != 200:
raise DbtRuntimeError(
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
)

events = response.json().get("events", [])
update_events = [
e
for e in events
if e.get("event_type", "") == "update_progress"
and e.get("origin", {}).get("update_id") == update_id
]

error_events = [
e
for e in update_events
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
]

msg = ""
if error_events:
msg = error_events[0].get("message", "")

return msg


def _get_compute_name(query_header_context: Any) -> Optional[str]:
# Get the name of the specified compute resource from the node's
# config.
Expand Down
14 changes: 3 additions & 11 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from dbt.adapters.databricks.connections import (
USE_LONG_SESSIONS,
DatabricksConnectionManager,
DatabricksDBTConnection,
DatabricksSQLConnectionWrapper,
ExtendedSessionConnectionManager,
)
from dbt.adapters.databricks.python_models.python_submissions import (
Expand Down Expand Up @@ -807,19 +805,13 @@ def get_from_relation(
"""Get the relation config from the relation."""

relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation)
connection = cast(DatabricksDBTConnection, adapter.connections.get_thread_connection())
wrapper: DatabricksSQLConnectionWrapper = connection.handle

# Ensure any current refreshes are completed before returning the relation config
tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"])
if tblproperties.pipeline_id:
# TODO fix this path so that it doesn't need a cursor
# It just calls APIs to poll the pipeline status
cursor = wrapper.cursor()
try:
cursor.poll_refresh_pipeline(tblproperties.pipeline_id)
finally:
cursor.close()
adapter.connections.api_client.dlt_pipelines.poll_for_completion(
tblproperties.pipeline_id
)
return relation_config


Expand Down
4 changes: 4 additions & 0 deletions dbt/adapters/databricks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ def handle_missing_objects(exec: Callable[[], T], default: T) -> T:
if check_not_found_error(errmsg):
return default
raise e


def quote(name: str) -> str:
return f"`{name}`"
17 changes: 17 additions & 0 deletions dbt/include/databricks/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,20 @@

{% do return(load_result('get_columns_comments_via_information_schema').table) %}
{% endmacro %}

{% macro databricks__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %}
{% if remove_columns %}
{% if not relation.is_delta %}
{{ exceptions.raise_compiler_error('Delta format required for dropping columns from tables') }}
{% endif %}
{%- call statement('alter_relation_remove_columns') -%}
ALTER TABLE {{ relation }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }})
{%- endcall -%}
{% endif %}

{% if add_columns %}
{%- call statement('alter_relation_add_columns') -%}
ALTER TABLE {{ relation }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }})
{%- endcall -%}
{% endif %}
{% endmacro %}
Loading

0 comments on commit 925d59f

Please sign in to comment.