diff --git a/.changes/unreleased/Features-20240730-135911.yaml b/.changes/unreleased/Features-20240730-135911.yaml new file mode 100644 index 000000000..52868c2ee --- /dev/null +++ b/.changes/unreleased/Features-20240730-135911.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add support for cancelling queries on keyboard interrupt +time: 2024-07-30T13:59:11.585452-07:00 +custom: + Author: d-cole MichelleArk colin-rogers-dbt + Issue: "917" diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 0d57d22c3..cdd9d17dc 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -1,17 +1,17 @@ +from collections import defaultdict from concurrent.futures import TimeoutError import json import re from contextlib import contextmanager from dataclasses import dataclass, field - -from dbt_common.invocation import get_invocation_id - -from dbt_common.events.contextvars import get_node_info +import uuid from mashumaro.helper import pass_through from functools import lru_cache from requests.exceptions import ConnectionError -from typing import Optional, Any, Dict, Tuple, TYPE_CHECKING + +from multiprocessing.context import SpawnContext +from typing import Optional, Any, Dict, Tuple, Hashable, List, TYPE_CHECKING import google.auth import google.auth.exceptions @@ -24,19 +24,25 @@ service_account as GoogleServiceAccountCredentials, ) -from dbt.adapters.bigquery import gcloud -from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse, Credentials +from dbt_common.events.contextvars import get_node_info +from dbt_common.events.functions import fire_event from dbt_common.exceptions import ( DbtRuntimeError, DbtConfigError, + DbtDatabaseError, +) +from dbt_common.invocation import get_invocation_id +from dbt.adapters.bigquery import gcloud +from dbt.adapters.contracts.connection import ( + ConnectionState, + AdapterResponse, + Credentials, + AdapterRequiredConfig, ) - -from dbt_common.exceptions import DbtDatabaseError from dbt.adapters.exceptions.connection import FailedToConnectError from dbt.adapters.base import BaseConnectionManager from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.events.types import SQLQuery -from dbt_common.events.functions import fire_event from dbt.adapters.bigquery import __version__ as dbt_version from dbt.adapters.bigquery.utility import is_base64, base64_to_string @@ -231,6 +237,10 @@ class BigQueryConnectionManager(BaseConnectionManager): DEFAULT_INITIAL_DELAY = 1.0 # Seconds DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds + def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): + super().__init__(profile, mp_context) + self.jobs_by_thread: Dict[Hashable, List[str]] = defaultdict(list) + @classmethod def handle_error(cls, error, message): error_msg = "\n".join([item["message"] for item in error.errors]) @@ -284,11 +294,31 @@ def exception_handler(self, sql): exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip() raise DbtRuntimeError(exc_message) - def cancel_open(self) -> None: - pass + def cancel_open(self): + names = [] + this_connection = self.get_if_exists() + with self.lock: + for thread_id, connection in self.thread_connections.items(): + if connection is this_connection: + continue + if connection.handle is not None and connection.state == ConnectionState.OPEN: + client = connection.handle + for job_id in self.jobs_by_thread.get(thread_id, []): + + def fn(): + return client.cancel_job(job_id) + + self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn) + + self.close(connection) + + if connection.name is not None: + names.append(connection.name) + return names @classmethod def close(cls, connection): + connection.handle.close() connection.state = ConnectionState.CLOSED return connection @@ -452,6 +482,18 @@ def get_labels_from_query_comment(cls): return {} + def generate_job_id(self) -> str: + # Generating a fresh job_id for every _query_and_results call to avoid job_id reuse. + # Generating a job id instead of persisting a BigQuery-generated one after client.query is called. + # Using BigQuery's job_id can lead to a race condition if a job has been started and a termination + # is sent before the job_id was stored, leading to a failure to cancel the job. + # By predetermining job_ids (uuid4), we can persist the job_id before the job has been kicked off. + # Doing this, the race condition only leads to attempting to cancel a job that doesn't exist. + job_id = str(uuid.uuid4()) + thread_id = self.get_thread_identifier() + self.jobs_by_thread[thread_id].append(job_id) + return job_id + def raw_execute( self, sql, @@ -488,10 +530,13 @@ def raw_execute( job_execution_timeout = self.get_job_execution_timeout_seconds(conn) def fn(): + job_id = self.generate_job_id() + return self._query_and_results( client, sql, job_params, + job_id, job_creation_timeout=job_creation_timeout, job_execution_timeout=job_execution_timeout, limit=limit, @@ -731,6 +776,7 @@ def _query_and_results( client, sql, job_params, + job_id, job_creation_timeout=None, job_execution_timeout=None, limit: Optional[int] = None, @@ -738,7 +784,9 @@ def _query_and_results( """Query the client and wait for results.""" # Cannot reuse job_config if destination is set and ddl is used job_config = google.cloud.bigquery.QueryJobConfig(**job_params) - query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout) + query_job = client.query( + query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout + ) if ( query_job.location is not None and query_job.job_id is not None diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index a1aaf17eb..0b49c0373 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -164,7 +164,7 @@ def date_function(cls) -> str: @classmethod def is_cancelable(cls) -> bool: - return False + return True def drop_relation(self, relation: BigQueryRelation) -> None: is_cached = self._schema_is_cached(relation.database, relation.schema) @@ -693,8 +693,11 @@ def load_dataframe( load_config.skip_leading_rows = 1 load_config.schema = bq_schema load_config.field_delimiter = field_delimiter + job_id = self.connections.generate_job_id() with open(agate_table.original_abspath, "rb") as f: # type: ignore - job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config) + job = client.load_table_from_file( + f, table_ref, rewind=True, job_config=load_config, job_id=job_id + ) timeout = self.connections.get_job_execution_timeout_seconds(conn) or 300 with self.connections.exception_handler("LOAD TABLE"): diff --git a/tests/functional/test_cancel.py b/tests/functional/test_cancel.py new file mode 100644 index 000000000..50306a6ae --- /dev/null +++ b/tests/functional/test_cancel.py @@ -0,0 +1,127 @@ +import time + +import os +import signal +import subprocess + +import pytest + +from dbt.tests.util import get_connection + +_SEED_CSV = """ +id, name, astrological_sign, moral_alignment +1, Alice, Aries, Lawful Good +2, Bob, Taurus, Neutral Good +3, Thaddeus, Gemini, Chaotic Neutral +4, Zebulon, Cancer, Lawful Evil +5, Yorick, Leo, True Neutral +6, Xavier, Virgo, Chaotic Evil +7, Wanda, Libra, Lawful Neutral +""" + +_LONG_RUNNING_MODEL_SQL = """ + {{ config(materialized='table') }} + with array_1 as ( + select generated_ids from UNNEST(GENERATE_ARRAY(1, 200000)) AS generated_ids + ), + array_2 as ( + select generated_ids from UNNEST(GENERATE_ARRAY(2, 200000)) AS generated_ids + ) + + SELECT array_1.generated_ids + FROM array_1 + LEFT JOIN array_1 as jnd on 1=1 + LEFT JOIN array_2 as jnd2 on 1=1 + LEFT JOIN array_1 as jnd3 on jnd3.generated_ids >= jnd2.generated_ids +""" + + +def _get_info_schema_jobs_query(project_id, dataset_id, table_id): + """ + Running this query requires roles/bigquery.resourceViewer on the project, + see: https://cloud.google.com/bigquery/docs/information-schema-jobs#required_role + :param project_id: + :param dataset_id: + :param table_id: + :return: a single job id that matches the model we tried to create and was cancelled + """ + return f""" + SELECT job_id + FROM `region-us`.`INFORMATION_SCHEMA.JOBS_BY_PROJECT` + WHERE creation_time > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 5 HOUR) + AND statement_type = 'CREATE_TABLE_AS_SELECT' + AND state = 'DONE' + AND job_id IS NOT NULL + AND project_id = '{project_id}' + AND error_result.reason = 'stopped' + AND error_result.message = 'Job execution was cancelled: User requested cancellation' + AND destination_table.table_id = '{table_id}' + AND destination_table.dataset_id = '{dataset_id}' + """ + + +def _run_dbt_in_subprocess(project, dbt_command): + os.chdir(project.project_root) + run_dbt_process = subprocess.Popen( + [ + "dbt", + dbt_command, + "--profiles-dir", + project.profiles_dir, + "--project-dir", + project.project_root, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + ) + std_out_log = "" + while True: + std_out_line = run_dbt_process.stdout.readline().decode("utf-8") + std_out_log += std_out_line + if std_out_line != "": + print(std_out_line) + if "1 of 1 START" in std_out_line: + time.sleep(1) + run_dbt_process.send_signal(signal.SIGINT) + + if run_dbt_process.poll(): + break + + return std_out_log + + +def _get_job_id(project, table_name): + # Because we run this in a subprocess we have to actually call Bigquery and look up the job id + with get_connection(project.adapter): + job_id = project.run_sql( + _get_info_schema_jobs_query(project.database, project.test_schema, table_name) + ) + + return job_id + + +class TestBigqueryCancelsQueriesOnKeyboardInterrupt: + @pytest.fixture(scope="class", autouse=True) + def models(self): + return { + "model.sql": _LONG_RUNNING_MODEL_SQL, + } + + @pytest.fixture(scope="class", autouse=True) + def seeds(self): + return { + "seed.csv": _SEED_CSV, + } + + def test_bigquery_cancels_queries_for_model_on_keyboard_interrupt(self, project): + std_out_log = _run_dbt_in_subprocess(project, "run") + + assert "CANCEL query model.test.model" in std_out_log + assert len(_get_job_id(project, "model")) == 1 + + def test_bigquery_cancels_queries_for_seed_on_keyboard_interrupt(self, project): + std_out_log = _run_dbt_in_subprocess(project, "seed") + + assert "CANCEL query seed.test.seed" in std_out_log + # we can't assert the job id since we can't kill the seed process fast enough to cancel it diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index 19d9dbd08..a922525fd 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -32,6 +32,7 @@ inject_adapter, TestAdapterConversions, load_internal_manifest_macros, + mock_connection, ) @@ -368,23 +369,22 @@ def test_acquire_connection_maximum_bytes_billed(self, mock_open_connection): def test_cancel_open_connections_empty(self): adapter = self.get_adapter("oauth") - self.assertEqual(adapter.cancel_open_connections(), None) + self.assertEqual(len(list(adapter.cancel_open_connections())), 0) def test_cancel_open_connections_master(self): adapter = self.get_adapter("oauth") - adapter.connections.thread_connections[0] = object() - self.assertEqual(adapter.cancel_open_connections(), None) + key = adapter.connections.get_thread_identifier() + adapter.connections.thread_connections[key] = mock_connection("master") + self.assertEqual(len(list(adapter.cancel_open_connections())), 0) def test_cancel_open_connections_single(self): adapter = self.get_adapter("oauth") - adapter.connections.thread_connections.update( - { - 0: object(), - 1: object(), - } - ) - # actually does nothing - self.assertEqual(adapter.cancel_open_connections(), None) + master = mock_connection("master") + model = mock_connection("model") + key = adapter.connections.get_thread_identifier() + + adapter.connections.thread_connections.update({key: master, 1: model}) + self.assertEqual(len(list(adapter.cancel_open_connections())), 1) @patch("dbt.adapters.bigquery.impl.google.auth.default") @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 9dc8fe219..d09cb1635 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -104,18 +104,18 @@ def test_drop_dataset(self): @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") def test_query_and_results(self, mock_bq): - self.mock_client.query = Mock(return_value=Mock(state="DONE")) self.connections._query_and_results( self.mock_client, "sql", {"job_param_1": "blah"}, + job_id=1, job_creation_timeout=15, - job_execution_timeout=3, + job_execution_timeout=100, ) mock_bq.QueryJobConfig.assert_called_once() self.mock_client.query.assert_called_once_with( - query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15 + query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15 ) def test_copy_bq_table_appends(self):