diff --git a/CHANGELOG.md b/CHANGELOG.md index 4088dbc90..6d1fa6d35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## dbt-databricks 1.7.x (TBD) +### Under the Hood + +- Compatibility with dbt-spark 1.7.0b2 ([467](https://github.com/databricks/dbt-databricks/pull/467)) + ## dbt-databricks 1.6.6 (October 9, 2023) ### Fixes diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 3889c64f5..c6a8c2053 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -100,7 +100,8 @@ def emit(self, record: logging.LogRecord) -> None: @dataclass class DatabricksCredentials(Credentials): - database: Optional[str] # type: ignore[assignment] + database: Optional[str] = None # type: ignore[assignment] + schema: Optional[str] = None # type: ignore[assignment] host: Optional[str] = None http_path: Optional[str] = None token: Optional[str] = None @@ -130,7 +131,7 @@ def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]: return data def __post_init__(self) -> None: - if "." in self.schema: + if "." in (self.schema or ""): raise dbt.exceptions.DbtValidationError( f"The schema should not contain '.': {self.schema}\n" "If you are trying to set a catalog, please use `catalog` instead.\n" diff --git a/dev-requirements.txt b/dev-requirements.txt index c0c658c2c..fc895f314 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -23,8 +23,8 @@ tox>=3.2.0 types-requests types-mock -dbt-core==1.6.0 -dbt-tests-adapter==1.6.0 +dbt-core==1.7.0b2 +dbt-tests-adapter==1.7.0b2 # git+https://github.com/dbt-labs/dbt-spark.git@1.5.latest#egg=dbt-spark # git+https://github.com/dbt-labs/dbt-core.git@1.5.latest#egg=dbt-core&subdirectory=core # git+https://github.com/dbt-labs/dbt-core.git@1.5.latest#egg=dbt-tests-adapter&subdirectory=tests/adapter diff --git a/requirements.txt b/requirements.txt index ca7034f96..8d1007a3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ databricks-sql-connector>=2.9.3, <3.0.0 -dbt-spark==1.6.0 +dbt-spark==1.7.0b2 databricks-sdk==0.9.0 keyring>=23.13.0 diff --git a/setup.py b/setup.py index 8c65a7600..29c35303c 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def _get_plugin_version(): packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, install_requires=[ - "dbt-spark==1.6.0", + "dbt-spark==1.7.0b2", "databricks-sql-connector>=2.9.3, <3.0.0", "databricks-sdk>=0.9.0", "keyring>=23.13.0", diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 7772b6598..a03b81bff 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,9 +1,11 @@ +from typing import Any, Dict, Optional import unittest from unittest import mock from agate import Row import dbt.flags as flags import dbt.exceptions +from dbt.config import RuntimeConfig from dbt.adapters.databricks import __version__ from dbt.adapters.databricks import DatabricksAdapter, DatabricksRelation @@ -33,97 +35,34 @@ def setUp(self): "config-version": 2, } - def _get_target_databricks_sql_connector(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "databricks", - "schema": "analytics", - "host": "yourorg.databricks.com", - "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", - "token": "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", - "session_properties": {"spark.sql.ansi.enabled": "true"}, - } - }, - "target": "test", - }, - ) - - def _get_target_databricks_sql_connector_no_token(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "databricks", - "schema": "analytics", - "host": "yourorg.databricks.com", - "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", - "session_properties": {"spark.sql.ansi.enabled": "true"}, - } - }, - "target": "test", + self.profile_cfg = { + "outputs": { + "test": { + "type": "databricks", + "catalog": "main", + "schema": "analytics", + "host": "yourorg.databricks.com", + "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", + } }, - ) + "target": "test", + } - def _get_target_databricks_sql_connector_client_creds(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "databricks", - "schema": "analytics", - "host": "yourorg.databricks.com", - "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", - "client_id": "foo", - "client_secret": "bar", - "session_properties": {"spark.sql.ansi.enabled": "true"}, - } - }, - "target": "test", - }, - ) + def _get_config( + self, + token: Optional[str] = "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", + session_properties: Optional[Dict[str, str]] = {"spark.sql.ansi.enabled": "true"}, + **kwargs: Any, + ) -> RuntimeConfig: + if token: + self.profile_cfg["outputs"]["test"]["token"] = token + if session_properties: + self.profile_cfg["outputs"]["test"]["session_properties"] = session_properties - def _get_target_databricks_sql_connector_catalog(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "databricks", - "schema": "analytics", - "catalog": "main", - "host": "yourorg.databricks.com", - "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", - "token": "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", - "session_properties": {"spark.sql.ansi.enabled": "true"}, - } - }, - "target": "test", - }, - ) + for key, val in kwargs.items(): + self.profile_cfg["outputs"]["test"][key] = val - def _get_target_databricks_sql_connector_http_header(self, project, http_header): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "databricks", - "schema": "analytics", - "host": "yourorg.databricks.com", - "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", - "token": "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", - "session_properties": {"spark.sql.ansi.enabled": "true"}, - "connection_parameters": {"http_headers": http_header}, - } - }, - "target": "test", - }, - ) + return config_from_parts_or_dicts(self.project_cfg, self.profile_cfg) def test_two_catalog_settings(self): with self.assertRaisesRegex( @@ -131,25 +70,11 @@ def test_two_catalog_settings(self): "Got duplicate keys: \\(`databricks.catalog` in session_properties\\)" ' all map to "database"', ): - config_from_parts_or_dicts( - self.project_cfg, - { - "outputs": { - "test": { - "type": "databricks", - "schema": "analytics", - "catalog": "main", - "host": "yourorg.databricks.com", - "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", - "token": "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", - "session_properties": { - CATALOG_KEY_IN_SESSION_PROPERTIES: "catalog", - "spark.sql.ansi.enabled": "true", - }, - } - }, - "target": "test", - }, + self._get_config( + session_properties={ + CATALOG_KEY_IN_SESSION_PROPERTIES: "catalog", + "spark.sql.ansi.enabled": "true", + } ) def test_database_and_catalog_settings(self): @@ -157,46 +82,14 @@ def test_database_and_catalog_settings(self): dbt.exceptions.DbtProfileError, 'Got duplicate keys: \\(catalog\\) all map to "database"', ): - config_from_parts_or_dicts( - self.project_cfg, - { - "outputs": { - "test": { - "type": "databricks", - "schema": "analytics", - "catalog": "main", - "database": "database", - "host": "yourorg.databricks.com", - "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", - "token": "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", - "session_properties": {"spark.sql.ansi.enabled": "true"}, - } - }, - "target": "test", - }, - ) + self._get_config(catalog="main", database="database") def test_reserved_connection_parameters(self): with self.assertRaisesRegex( dbt.exceptions.DbtProfileError, "The connection parameter `server_hostname` is reserved.", ): - config_from_parts_or_dicts( - self.project_cfg, - { - "outputs": { - "test": { - "type": "databricks", - "schema": "analytics", - "host": "yourorg.databricks.com", - "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", - "token": "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", - "connection_parameters": {"server_hostname": "theirorg.databricks.com"}, - } - }, - "target": "test", - }, - ) + self._get_config(connection_parameters={"server_hostname": "theirorg.databricks.com"}) def test_invalid_http_headers(self): def test_http_headers(http_header): @@ -204,7 +97,7 @@ def test_http_headers(http_header): dbt.exceptions.DbtProfileError, "The connection parameter `http_headers` should be dict of strings.", ): - self._get_target_databricks_sql_connector_http_header(self.project_cfg, http_header) + self._get_config(connection_parameters={"http_headers": http_header}) test_http_headers("a") test_http_headers(["a", "b"]) @@ -215,14 +108,14 @@ def test_invalid_custom_user_agent(self): dbt.exceptions.DbtValidationError, "Invalid invocation environment", ): - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() adapter = DatabricksAdapter(config) with mock.patch.dict("os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "(Some-thing)"}): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load def test_custom_user_agent(self): - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() adapter = DatabricksAdapter(config) with mock.patch( @@ -282,11 +175,9 @@ def _test_environment_http_headers( self, http_headers_str, expected_http_headers, user_http_headers=None ): if user_http_headers: - config = self._get_target_databricks_sql_connector_http_header( - self.project_cfg, user_http_headers - ) + config = self._get_config(connection_parameters={"http_headers": user_http_headers}) else: - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() adapter = DatabricksAdapter(config) @@ -303,7 +194,7 @@ def _test_environment_http_headers( @unittest.skip("not ready") def test_oauth_settings(self): - config = self._get_target_databricks_sql_connector_no_token(self.project_cfg) + config = self._get_config(token=None) adapter = DatabricksAdapter(config) @@ -316,7 +207,7 @@ def test_oauth_settings(self): @unittest.skip("not ready") def test_client_creds_settings(self): - config = self._get_target_databricks_sql_connector_client_creds(self.project_cfg) + config = self._get_config(client_id="foo", client_secret="bar") adapter = DatabricksAdapter(config) @@ -330,7 +221,7 @@ def test_client_creds_settings(self): def _connect_func( self, *, - expected_catalog=None, + expected_catalog="main", expected_invocation_env=None, expected_http_headers=None, expected_no_token=None, @@ -378,7 +269,7 @@ def test_databricks_sql_connector_connection(self): self._test_databricks_sql_connector_connection(self._connect_func()) def _test_databricks_sql_connector_connection(self, connect): - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() adapter = DatabricksAdapter(config) with mock.patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): @@ -398,7 +289,6 @@ def _test_databricks_sql_connector_connection(self, connect): connection.credentials.session_properties["spark.sql.ansi.enabled"], "true", ) - self.assertIsNone(connection.credentials.database) def test_databricks_sql_connector_catalog_connection(self): self._test_databricks_sql_connector_catalog_connection( @@ -406,7 +296,7 @@ def test_databricks_sql_connector_catalog_connection(self): ) def _test_databricks_sql_connector_catalog_connection(self, connect): - config = self._get_target_databricks_sql_connector_catalog(self.project_cfg) + config = self._get_config() adapter = DatabricksAdapter(config) with mock.patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): @@ -433,9 +323,7 @@ def test_databricks_sql_connector_http_header_connection(self): ) def _test_databricks_sql_connector_http_header_connection(self, http_headers, connect): - config = self._get_target_databricks_sql_connector_http_header( - self.project_cfg, http_headers - ) + config = self._get_config(connection_parameters={"http_headers": http_headers}) adapter = DatabricksAdapter(config) with mock.patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): @@ -450,7 +338,6 @@ def _test_databricks_sql_connector_http_header_connection(self, http_headers, co ) self.assertEqual(connection.credentials.token, "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX") self.assertEqual(connection.credentials.schema, "analytics") - self.assertIsNone(connection.credentials.database) def test_simple_catalog_relation(self): self.maxDiff = None @@ -505,7 +392,7 @@ def test_parse_relation(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() metadata, rows = DatabricksAdapter(config).parse_describe_extended(relation, input_cols) self.assertDictEqual( @@ -616,7 +503,7 @@ def test_parse_relation_with_integer_owner(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() _, rows = DatabricksAdapter(config).parse_describe_extended(relation, input_cols) self.assertEqual(rows[0].to_column_dict().get("table_owner"), "1234") @@ -655,7 +542,7 @@ def test_parse_relation_with_statistics(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() metadata, rows = DatabricksAdapter(config).parse_describe_extended(relation, input_cols) self.assertEqual( @@ -705,7 +592,7 @@ def test_parse_relation_with_statistics(self): ) def test_relation_with_database(self): - config = self._get_target_databricks_sql_connector_catalog(self.project_cfg) + config = self._get_config() adapter = DatabricksAdapter(config) r1 = adapter.Relation.create(schema="different", identifier="table") assert r1.database is None @@ -744,7 +631,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) schema="default_schema", identifier="mytable", type=rel_type ) - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() columns = DatabricksAdapter(config).parse_columns_from_information(relation, information) self.assertEqual(len(columns), 4) self.assertEqual( @@ -829,7 +716,7 @@ def test_parse_columns_from_information_with_view_type(self): schema="default_schema", identifier="myview", type=rel_type ) - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() columns = DatabricksAdapter(config).parse_columns_from_information(relation, information) self.assertEqual(len(columns), 4) self.assertEqual( @@ -895,7 +782,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel schema="default_schema", identifier="mytable", type=rel_type ) - config = self._get_target_databricks_sql_connector(self.project_cfg) + config = self._get_config() columns = DatabricksAdapter(config).parse_columns_from_information(relation, information) self.assertEqual(len(columns), 4) self.assertEqual( @@ -961,7 +848,6 @@ def test_describe_table_extended_2048_char_limit(self): # If environment variable is set, then limit the number of characters with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): - # Long list of table names is capped self.assertEqual(get_identifier_list_string(table_names), "*") @@ -991,7 +877,6 @@ def test_describe_table_extended_should_limit(self): # If environment variable is set, then limit the number of characters with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): - # Long list of table names is capped self.assertEqual(get_identifier_list_string(table_names), "*") @@ -1005,7 +890,6 @@ def test_describe_table_extended_may_limit(self): # If environment variable is set, then we may limit the number of characters with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): - # But a short list of table names is not capped self.assertEqual( get_identifier_list_string(list(table_names)[:5]), "|".join(list(table_names)[:5])