From 41092ba6fbc2a75da544a313985b1de8e993996f Mon Sep 17 00:00:00 2001 From: eric wang Date: Wed, 30 Oct 2024 23:21:15 -0700 Subject: [PATCH] update --- dbt/adapters/databricks/credentials.py | 3 +- tests/unit/python/test_python_submissions.py | 250 --------- tests/unit/test_adapter.py | 504 ++++++++++++++++++- 3 files changed, 485 insertions(+), 272 deletions(-) delete mode 100644 tests/unit/python/test_python_submissions.py diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 3fbd7a27..4346a403 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -1,4 +1,3 @@ -from http import client from collections.abc import Iterable import itertools import json @@ -8,7 +7,7 @@ from dataclasses import dataclass from dataclasses import field from typing import Any -from typing import Callable +from typing import Callable, Dict, List from typing import cast from typing import Optional from typing import Tuple diff --git a/tests/unit/python/test_python_submissions.py b/tests/unit/python/test_python_submissions.py deleted file mode 100644 index 7a230579..00000000 --- a/tests/unit/python/test_python_submissions.py +++ /dev/null @@ -1,250 +0,0 @@ -from mock import patch -from unittest.mock import Mock - -from dbt.adapters.databricks.credentials import DatabricksCredentials -from dbt.adapters.databricks.python_models.python_submissions import BaseDatabricksHelper -from dbt.adapters.databricks.python_models.python_submissions import WorkflowPythonJobHelper - - -# class TestDatabricksPythonSubmissions: -# def test_start_cluster_returns_on_receiving_running_state(self): -# session_mock = Mock() -# # Mock the start command -# post_mock = Mock() -# post_mock.status_code = 200 -# session_mock.post.return_value = post_mock -# # Mock the status command -# get_mock = Mock() -# get_mock.status_code = 200 -# get_mock.json.return_value = {"state": "RUNNING"} -# session_mock.get.return_value = get_mock - -# context = DBContext(Mock(), None, None, session_mock) -# context.start_cluster() - -# session_mock.get.assert_called_once() - - -class DatabricksTestHelper(BaseDatabricksHelper): - def __init__(self, parsed_model: dict, credentials: DatabricksCredentials): - self.parsed_model = parsed_model - self.credentials = credentials - self.job_grants = self.workflow_spec.get("grants", {}) - - -@patch("dbt.adapters.databricks.credentials.Config") -class TestAclUpdate: - def test_empty_acl_empty_config(self, _): - helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) - assert helper._update_with_acls({}) == {} - - def test_empty_acl_non_empty_config(self, _): - helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) - assert helper._update_with_acls({"a": "b"}) == {"a": "b"} - - def test_non_empty_acl_empty_config(self, _): - expected_access_control = { - "access_control_list": [ - {"user_name": "user2", "permission_level": "CAN_VIEW"}, - ] - } - helper = DatabricksTestHelper({"config": expected_access_control}, DatabricksCredentials()) - assert helper._update_with_acls({}) == expected_access_control - - def test_non_empty_acl_non_empty_config(self, _): - expected_access_control = { - "access_control_list": [ - {"user_name": "user2", "permission_level": "CAN_VIEW"}, - ] - } - helper = DatabricksTestHelper({"config": expected_access_control}, DatabricksCredentials()) - assert helper._update_with_acls({"a": "b"}) == { - "a": "b", - "access_control_list": expected_access_control["access_control_list"], - } - - -class TestJobGrants: - - @patch.object(BaseDatabricksHelper, "_build_job_owner") - def test_job_owner_user(self, mock_job_owner): - mock_job_owner.return_value = ("alighodsi@databricks.com", "user_name") - - helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) - helper.job_grants = {} - - assert helper._build_job_permissions() == [ - { - "permission_level": "IS_OWNER", - "user_name": "alighodsi@databricks.com", - } - ] - - @patch.object(BaseDatabricksHelper, "_build_job_owner") - def test_job_owner_service_principal(self, mock_job_owner): - mock_job_owner.return_value = ( - "9533b8cc-2d60-46dd-84f2-a39b3939e37a", - "service_principal_name", - ) - - helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) - helper.job_grants = {} - - assert helper._build_job_permissions() == [ - { - "permission_level": "IS_OWNER", - "service_principal_name": "9533b8cc-2d60-46dd-84f2-a39b3939e37a", - } - ] - - @patch.object(BaseDatabricksHelper, "_build_job_owner") - def test_job_grants(self, mock_job_owner): - mock_job_owner.return_value = ( - "9533b8cc-2d60-46dd-84f2-a39b3939e37a", - "service_principal_name", - ) - helper = DatabricksTestHelper( - { - "config": { - "workflow_job_config": { - "grants": { - "view": [ - {"user_name": "reynoldxin@databricks.com"}, - {"user_name": "alighodsi@databricks.com"}, - ], - "run": [{"group_name": "dbt-developers"}], - "manage": [{"group_name": "dbt-admins"}], - } - } - } - }, - DatabricksCredentials(), - ) - - actual = helper._build_job_permissions() - - expected_owner = { - "service_principal_name": "9533b8cc-2d60-46dd-84f2-a39b3939e37a", - "permission_level": "IS_OWNER", - } - expected_viewer_1 = { - "permission_level": "CAN_VIEW", - "user_name": "reynoldxin@databricks.com", - } - expected_viewer_2 = { - "permission_level": "CAN_VIEW", - "user_name": "alighodsi@databricks.com", - } - expected_runner = {"permission_level": "CAN_MANAGE_RUN", "group_name": "dbt-developers"} - expected_manager = {"permission_level": "CAN_MANAGE", "group_name": "dbt-admins"} - - assert expected_owner in actual - assert expected_viewer_1 in actual - assert expected_viewer_2 in actual - assert expected_runner in actual - assert expected_manager in actual - - -class TestWorkflowConfig: - def default_config(self): - return { - "alias": "test_model", - "database": "test_database", - "schema": "test_schema", - "config": { - "workflow_job_config": { - "email_notifications": "test@example.com", - "max_retries": 2, - "timeout_seconds": 500, - }, - "job_cluster_config": { - "spark_version": "15.3.x-scala2.12", - "node_type_id": "rd-fleet.2xlarge", - "autoscale": {"min_workers": 1, "max_workers": 2}, - }, - }, - } - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_default(self, mock_api_client): - job = WorkflowPythonJobHelper(self.default_config(), Mock()) - result = job._build_job_spec() - - assert result["name"] == "dbt__test_database-test_schema-test_model" - assert len(result["tasks"]) == 1 - - task = result["tasks"][0] - assert task["task_key"] == "inner_notebook" - assert task["new_cluster"]["spark_version"] == "15.3.x-scala2.12" - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_custom_name(self, mock_api_client): - config = self.default_config() - config["config"]["workflow_job_config"]["name"] = "custom_job_name" - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - assert result["name"] == "custom_job_name" - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_existing_cluster(self, mock_api_client): - config = self.default_config() - config["config"]["workflow_job_config"]["existing_cluster_id"] = "cluster-123" - del config["config"]["job_cluster_config"] - - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - task = result["tasks"][0] - assert task["existing_cluster_id"] == "cluster-123" - assert "new_cluster" not in task - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_serverless(self, mock_api_client): - config = self.default_config() - del config["config"]["job_cluster_config"] - - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - task = result["tasks"][0] - assert "existing_cluster_id" not in task - assert "new_cluster" not in task - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_with_additional_task_settings(self, mock_api_client): - config = self.default_config() - config["config"]["workflow_job_config"]["additional_task_settings"] = { - "task_key": "my_dbt_task" - } - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - task = result["tasks"][0] - assert task["task_key"] == "my_dbt_task" - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_with_post_hooks(self, mock_api_client): - config = self.default_config() - config["config"]["workflow_job_config"]["post_hook_tasks"] = [ - { - "depends_on": [{"task_key": "inner_notebook"}], - "task_key": "task_b", - "notebook_task": { - "notebook_path": "/Workspace/Shared/test_notebook", - "source": "WORKSPACE", - }, - "new_cluster": { - "spark_version": "14.3.x-scala2.12", - "node_type_id": "rd-fleet.2xlarge", - "autoscale": {"min_workers": 1, "max_workers": 2}, - }, - } - ] - - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - assert len(result["tasks"]) == 2 - assert result["tasks"][1]["task_key"] == "task_b" - assert result["tasks"][1]["new_cluster"]["spark_version"] == "14.3.x-scala2.12" diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 9428f3e2..abdea832 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -246,9 +246,10 @@ def connect( ): assert server_hostname == "yourorg.databricks.com" assert http_path == "sql/protocolv1/o/1234567890123456/1234-567890-test123" - if not (expected_no_token or expected_client_creds): - assert credentials_provider._token == "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + if not (expected_no_token or expected_client_creds): + k = credentials_provider()() + assert credentials_provider()().get("Authorization") == "Bearer dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" if expected_client_creds: assert kwargs.get("client_id") == "foo" assert kwargs.get("client_secret") == "bar" @@ -540,23 +541,486 @@ def test_parse_relation(self): "comment": None, } - def test_non_empty_acl_empty_config(self, _): - expected_access_control = { - "access_control_list": [ - {"user_name": "user2", "permission_level": "CAN_VIEW"}, - ] + def test_parse_relation_with_integer_owner(self): + self.maxDiff = None + rel_type = DatabricksRelation.get_relation_type.Table + + relation = DatabricksRelation.create( + schema="default_schema", identifier="mytable", type=rel_type + ) + assert relation.database is None + + # Mimics the output of Spark with a DESCRIBE TABLE EXTENDED + plain_rows = [ + ("col1", "decimal(22,0)", "comment"), + ("# Detailed Table Information", None, None), + ("Owner", 1234, None), + ] + + input_cols = [Row(keys=["col_name", "data_type", "comment"], values=r) for r in plain_rows] + + config = self._get_config() + _, rows = DatabricksAdapter(config, get_context("spawn")).parse_describe_extended( + relation, input_cols + ) + + assert rows[0].to_column_dict().get("table_owner") == "1234" + + def test_parse_relation_with_statistics(self): + self.maxDiff = None + rel_type = DatabricksRelation.get_relation_type.Table + + relation = DatabricksRelation.create( + schema="default_schema", identifier="mytable", type=rel_type + ) + assert relation.database is None + + # Mimics the output of Spark with a DESCRIBE TABLE EXTENDED + plain_rows = [ + ("col1", "decimal(22,0)", "comment"), + ("# Partition Information", "data_type", None), + (None, None, None), + ("# Detailed Table Information", None, None), + ("Database", None, None), + ("Owner", "root", None), + ("Created Time", "Wed Feb 04 18:15:00 UTC 1815", None), + ("Last Access", "Wed May 20 19:25:00 UTC 1925", None), + ("Comment", "Table model description", None), + ("Statistics", "1109049927 bytes, 14093476 rows", None), + ("Type", "MANAGED", None), + ("Provider", "delta", None), + ("Location", "/mnt/vo", None), + ( + "Serde Library", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", + None, + ), + ("InputFormat", "org.apache.hadoop.mapred.SequenceFileInputFormat", None), + ( + "OutputFormat", + "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat", + None, + ), + ("Partition Provider", "Catalog", None), + ] + + input_cols = [Row(keys=["col_name", "data_type", "comment"], values=r) for r in plain_rows] + + config = self._get_config() + metadata, rows = DatabricksAdapter(config, get_context("spawn")).parse_describe_extended( + relation, input_cols + ) + + assert metadata == { + None: None, + "# Detailed Table Information": None, + "Database": None, + "Owner": "root", + "Created Time": "Wed Feb 04 18:15:00 UTC 1815", + "Last Access": "Wed May 20 19:25:00 UTC 1925", + "Comment": "Table model description", + "Statistics": "1109049927 bytes, 14093476 rows", + "Type": "MANAGED", + "Provider": "delta", + "Location": "/mnt/vo", + "Serde Library": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", + "InputFormat": "org.apache.hadoop.mapred.SequenceFileInputFormat", + "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat", + "Partition Provider": "Catalog", + } + + assert len(rows) == 1 + assert rows[0].to_column_dict(omit_none=False) == { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "table_comment": "Table model description", + "column": "col1", + "column_index": 0, + "comment": "comment", + "dtype": "decimal(22,0)", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1109049927, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 14093476, + } + + def test_relation_with_database(self): + config = self._get_config() + adapter = DatabricksAdapter(config, get_context("spawn")) + r1 = adapter.Relation.create(schema="different", identifier="table") + assert r1.database is None + r2 = adapter.Relation.create(database="something", schema="different", identifier="table") + assert r2.database == "something" + + def test_parse_columns_from_information_with_table_type_and_delta_provider(self): + self.maxDiff = None + rel_type = DatabricksRelation.get_relation_type.Table + + # Mimics the output of Spark in the information column + information = ( + "Database: default_schema\n" + "Table: mytable\n" + "Owner: root\n" + "Created Time: Wed Feb 04 18:15:00 UTC 1815\n" + "Last Access: Wed May 20 19:25:00 UTC 1925\n" + "Created By: Spark 3.0.1\n" + "Type: MANAGED\n" + "Provider: delta\n" + "Statistics: 123456789 bytes\n" + "Location: /mnt/vo\n" + "Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe\n" + "InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat\n" + "OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat\n" + "Partition Provider: Catalog\n" + "Partition Columns: [`dt`]\n" + "Schema: root\n" + " |-- col1: decimal(22,0) (nullable = true)\n" + " |-- col2: string (nullable = true)\n" + " |-- dt: date (nullable = true)\n" + " |-- struct_col: struct (nullable = true)\n" + " | |-- struct_inner_col: string (nullable = true)\n" + ) + relation = DatabricksRelation.create( + schema="default_schema", identifier="mytable", type=rel_type + ) + + config = self._get_config() + columns = DatabricksAdapter(config, get_context("spawn")).parse_columns_from_information( + relation, information + ) + assert len(columns) == 4 + assert columns[0].to_column_dict(omit_none=False) == { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "table_comment": None, + "column": "col1", + "column_index": 0, + "dtype": "decimal(22,0)", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 123456789, + "comment": None, + } + + assert columns[3].to_column_dict(omit_none=False) == { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "table_comment": None, + "column": "struct_col", + "column_index": 3, + "dtype": "struct", + "comment": None, + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 123456789, + } + + def test_parse_columns_from_information_with_view_type(self): + self.maxDiff = None + rel_type = DatabricksRelation.get_relation_type.View + information = ( + "Database: default_schema\n" + "Table: myview\n" + "Owner: root\n" + "Created Time: Wed Feb 04 18:15:00 UTC 1815\n" + "Last Access: UNKNOWN\n" + "Created By: Spark 3.0.1\n" + "Type: VIEW\n" + "View Text: WITH base (\n" + " SELECT * FROM source_table\n" + ")\n" + "SELECT col1, col2, dt FROM base\n" + "View Original Text: WITH base (\n" + " SELECT * FROM source_table\n" + ")\n" + "SELECT col1, col2, dt FROM base\n" + "View Catalog and Namespace: spark_catalog.default\n" + "View Query Output Columns: [col1, col2, dt]\n" + "Table Properties: [view.query.out.col.1=col1, view.query.out.col.2=col2, " + "transient_lastDdlTime=1618324324, view.query.out.col.3=dt, " + "view.catalogAndNamespace.part.0=spark_catalog, " + "view.catalogAndNamespace.part.1=default]\n" + "Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe\n" + "InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat\n" + "OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat\n" + "Storage Properties: [serialization.format=1]\n" + "Schema: root\n" + " |-- col1: decimal(22,0) (nullable = true)\n" + " |-- col2: string (nullable = true)\n" + " |-- dt: date (nullable = true)\n" + " |-- struct_col: struct (nullable = true)\n" + " | |-- struct_inner_col: string (nullable = true)\n" + ) + relation = DatabricksRelation.create( + schema="default_schema", identifier="myview", type=rel_type + ) + + config = self._get_config() + columns = DatabricksAdapter(config, get_context("spawn")).parse_columns_from_information( + relation, information + ) + assert len(columns) == 4 + assert columns[1].to_column_dict(omit_none=False) == { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "table_comment": None, + "column": "col2", + "column_index": 1, + "comment": None, + "dtype": "string", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + } + + assert columns[3].to_column_dict(omit_none=False) == { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "table_comment": None, + "column": "struct_col", + "column_index": 3, + "comment": None, + "dtype": "struct", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + } + + def test_parse_columns_from_information_with_table_type_and_parquet_provider(self): + self.maxDiff = None + rel_type = DatabricksRelation.get_relation_type.Table + + information = ( + "Database: default_schema\n" + "Table: mytable\n" + "Owner: root\n" + "Created Time: Wed Feb 04 18:15:00 UTC 1815\n" + "Last Access: Wed May 20 19:25:00 UTC 1925\n" + "Created By: Spark 3.0.1\n" + "Type: MANAGED\n" + "Provider: parquet\n" + "Statistics: 1234567890 bytes, 12345678 rows\n" + "Location: /mnt/vo\n" + "Serde Library: org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe\n" + "InputFormat: org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat\n" + "OutputFormat: org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat\n" + "Schema: root\n" + " |-- col1: decimal(22,0) (nullable = true)\n" + " |-- col2: string (nullable = true)\n" + " |-- dt: date (nullable = true)\n" + " |-- struct_col: struct (nullable = true)\n" + " | |-- struct_inner_col: string (nullable = true)\n" + ) + relation = DatabricksRelation.create( + schema="default_schema", identifier="mytable", type=rel_type + ) + + config = self._get_config() + columns = DatabricksAdapter(config, get_context("spawn")).parse_columns_from_information( + relation, information + ) + assert len(columns) == 4 + assert columns[2].to_column_dict(omit_none=False) == { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "table_comment": None, + "column": "dt", + "column_index": 2, + "comment": None, + "dtype": "date", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1234567890, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 12345678, + } + + assert columns[3].to_column_dict(omit_none=False) == { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "table_comment": None, + "column": "struct_col", + "column_index": 3, + "comment": None, + "dtype": "struct", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1234567890, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 12345678, + } + + def test_describe_table_extended_2048_char_limit(self): + """GIVEN a list of table_names whos total character length exceeds 2048 characters + WHEN the environment variable DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS is "true" + THEN the identifier list is replaced with "*" + """ + + table_names = set([f"customers_{i}" for i in range(200)]) + + # By default, don't limit the number of characters + assert get_identifier_list_string(table_names) == "|".join(table_names) + + # If environment variable is set, then limit the number of characters + with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + # Long list of table names is capped + assert get_identifier_list_string(table_names) == "*" + + # Short list of table names is not capped + assert get_identifier_list_string(list(table_names)[:5]) == "|".join( + list(table_names)[:5] + ) + + def test_describe_table_extended_should_not_limit(self): + """GIVEN a list of table_names whos total character length exceeds 2048 characters + WHEN the environment variable DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS is not set + THEN the identifier list is not truncated + """ + + table_names = set([f"customers_{i}" for i in range(200)]) + + # By default, don't limit the number of characters + assert get_identifier_list_string(table_names) == "|".join(table_names) + + def test_describe_table_extended_should_limit(self): + """GIVEN a list of table_names whos total character length exceeds 2048 characters + WHEN the environment variable DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS is "true" + THEN the identifier list is replaced with "*" + """ + + table_names = set([f"customers_{i}" for i in range(200)]) + + # If environment variable is set, then limit the number of characters + with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + # Long list of table names is capped + assert get_identifier_list_string(table_names) == "*" + + def test_describe_table_extended_may_limit(self): + """GIVEN a list of table_names whos total character length does not 2048 characters + WHEN the environment variable DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS is "true" + THEN the identifier list is not truncated + """ + + table_names = set([f"customers_{i}" for i in range(200)]) + + # If environment variable is set, then we may limit the number of characters + with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + # But a short list of table names is not capped + assert get_identifier_list_string(list(table_names)[:5]) == "|".join( + list(table_names)[:5] + ) + + +class TestCheckNotFound: + def test_prefix(self): + assert check_not_found_error("Runtime error \n Database 'dbt' not found") + + def test_no_prefix_or_suffix(self): + assert check_not_found_error("Database not found") + + def test_quotes(self): + assert check_not_found_error("Database '`dbt`' not found") + + def test_suffix(self): + assert check_not_found_error("Database not found and \n foo") + + def test_error_condition(self): + assert check_not_found_error("[SCHEMA_NOT_FOUND]") + + def test_unexpected_error(self): + assert not check_not_found_error("[DATABASE_NOT_FOUND]") + assert not check_not_found_error("Schema foo not found") + assert not check_not_found_error("Database 'foo' not there") + + +class TestGetPersistDocColumns(DatabricksAdapterBase): + @pytest.fixture + def adapter(self, setUp) -> DatabricksAdapter: + return DatabricksAdapter(self._get_config(), get_context("spawn")) + + def create_column(self, name, comment) -> DatabricksColumn: + return DatabricksColumn( + column=name, + dtype="string", + comment=comment, + ) + + def test_get_persist_doc_columns_empty(self, adapter): + assert adapter.get_persist_doc_columns([], {}) == {} + + def test_get_persist_doc_columns_no_match(self, adapter): + existing = [self.create_column("col1", "comment1")] + column_dict = {"col2": {"name": "col2", "description": "comment2"}} + assert adapter.get_persist_doc_columns(existing, column_dict) == {} + + def test_get_persist_doc_columns_full_match(self, adapter): + existing = [self.create_column("col1", "comment1")] + column_dict = {"col1": {"name": "col1", "description": "comment1"}} + assert adapter.get_persist_doc_columns(existing, column_dict) == {} + + def test_get_persist_doc_columns_partial_match(self, adapter): + existing = [self.create_column("col1", "comment1")] + column_dict = {"col1": {"name": "col1", "description": "comment2"}} + assert adapter.get_persist_doc_columns(existing, column_dict) == column_dict + + def test_get_persist_doc_columns_mixed(self, adapter): + existing = [ + self.create_column("col1", "comment1"), + self.create_column("col2", "comment2"), + ] + column_dict = { + "col1": {"name": "col1", "description": "comment2"}, + "col2": {"name": "col2", "description": "comment2"}, } - helper = DatabricksTestHelper({"config": expected_access_control}, DatabricksCredentials()) - assert helper._update_with_acls({}) == expected_access_control - - def test_non_empty_acl_non_empty_config(self, _): - expected_access_control = { - "access_control_list": [ - {"user_name": "user2", "permission_level": "CAN_VIEW"}, - ] + expected = { + "col1": {"name": "col1", "description": "comment2"}, } - helper = DatabricksTestHelper({"config": expected_access_control}, DatabricksCredentials()) - assert helper._update_with_acls({"a": "b"}) == { - "a": "b", - "access_control_list": expected_access_control["access_control_list"], - } \ No newline at end of file + assert adapter.get_persist_doc_columns(existing, column_dict) == expected \ No newline at end of file