From 5afc839d8f3490c4f22673a0d5b484e96bc01cf4 Mon Sep 17 00:00:00 2001 From: z3z1ma Date: Sat, 4 Jan 2025 23:34:14 -0700 Subject: [PATCH] feat: add working middleware to dbt proxy to sync alter table modify column comment stmts to manifest --- src/dbt_osmosis/sql/proxy.py | 140 ++++++++++++++++------------------- 1 file changed, 64 insertions(+), 76 deletions(-) diff --git a/src/dbt_osmosis/sql/proxy.py b/src/dbt_osmosis/sql/proxy.py index 2937732..b152d75 100644 --- a/src/dbt_osmosis/sql/proxy.py +++ b/src/dbt_osmosis/sql/proxy.py @@ -1,22 +1,25 @@ +# pyright: reportMissingTypeStubs=false, reportAny=false, reportImplicitOverride=false, reportUnknownMemberType=false, reportUnusedImport=false, reportUnknownParameterType=false """Proxy server experiment that any MySQL client (including BI tools) can connect to.""" -# pyright: reportMissingTypeStubs=false, reportAny=false, reportImplicitOverride=false, reportUnknownMemberType=false, reportUnusedImport=false, reportUnknownParameterType=false import asyncio import functools import re import typing as t from collections import defaultdict from collections.abc import Iterator +from itertools import chain from dbt.adapters.contracts.connection import AdapterResponse from mysql_mimic import MysqlServer, Session from mysql_mimic.errors import MysqlError +from mysql_mimic.results import AllowedResult from mysql_mimic.schema import ( Column, InfoSchema, dict_depth, # pyright: ignore[reportUnknownVariableType, reportPrivateLocalImportUsage] info_schema_tables, ) +from mysql_mimic.session import Query from sqlglot import exp import dbt_osmosis.core.logger as logger @@ -29,77 +32,34 @@ execute_sql_code, ) -# TODO: this doesn't capture comment body consistently -ALTER_MODIFY_COL_PATTERN = re.compile( - r""" - ^\s* # start, allow leading whitespace - ALTER\s+TABLE # "ALTER TABLE" (case-insensitive via flags=) - \s+ - (?: # optional schema part: - (?:"(?P[^"]+)" # double-quoted schema - |(?P\w+) # or unquoted schema - ) - \s*\. - )? - (?:"(?P[^"]+)" # table in double-quotes - |(?P\w+) # or unquoted table - ) - \s+ - MODIFY\s+COLUMN\s+ # "MODIFY COLUMN" (case-insensitive via flags=) - (?:"(?P[^"]+)" # column in double-quotes - |(?P\w+) # or unquoted column - ) - .*? # lazily consume anything until we might see COMMENT - (?: # optional comment group - COMMENT\s+ # must have "COMMENT" then space(s) - (["']) # capture the quote symbol in group 1 - (?P - (?:.|[^"'])* # any escaped char or anything that isn't ' or " - ) - \1 # match the same quote symbol (group 1) - )? - \s*;?\s*$ # optional whitespace, optional semicolon - """, - flags=re.IGNORECASE | re.DOTALL | re.VERBOSE, +ALTER_TABLE_MODIFY_COLUMN_COMMENT = re.compile( + r"(?i)(?:/\*.*?\*/\s*)?ALTER TABLE\s+(?:(?P[^\s\.]+)\.)?(?P
[^\s\.]+)\s+MODIFY COLUMN\s+(?P[^\s]+)\s+.*?COMMENT\s+'(?P[^']*)';?" ) +ALTER_TABLE_COMMENT = re.compile( + r"(?i)(?:/\*.*?\*/\s*)?ALTER TABLE\s+(?:(?P[^\s\.]+)\.)?(?P
[^\s\.]+)\s+COMMENT\s*=\s*'(?P[^']*)';" +) -def parse_alter_modify_column(sql: str) -> dict[str, str] | None: - """ - Attempt to parse a statement like: - ALTER TABLE schema.table MODIFY COLUMN col TYPE ... COMMENT 'some text'; - - Returns None if the pattern does not match, otherwise a dict with: - { - "schema": ... or None, - "table": ..., - "column": ..., - "comment": ... or None - } - """ - match = ALTER_MODIFY_COL_PATTERN.match(sql) - if not match: - return None - - # Because we have both quoted and unquoted named groups, pick whichever matched: - schema = match.group("schema") or match.group("schema_unquoted") - table = match.group("table") or match.group("table_unquoted") - column = match.group("column") or match.group("column_unquoted") - comment = match.group("comment") # can be None if COMMENT was not present - return {"schema": schema, "table": table, "column": column, "comment": comment} +def _regex_parse_to_complete_dict(sql: str, pattern: re.Pattern[str]) -> dict[str, str] | None: + """Parse a SQL statement using a regex pattern and return a dict with the matched groups ensuring all are present""" + if match := pattern.match(sql): + result = match.groupdict() + if all(result.values()): + return result class QueryException(MysqlError): def __init__(self, response: AdapterResponse) -> None: - self.response: AdapterResponse = response super().__init__(response._message) # pyright: ignore[reportPrivateUsage] + self.response: AdapterResponse = response class DbtSession(Session): def __init__(self, project: DbtProjectContext, *args: t.Any, **kwargs: t.Any) -> None: - self.project: DbtProjectContext = project super().__init__(*args, **kwargs) + self.project: DbtProjectContext = project + self.middlewares.append(self._alter_table_comment_middleware) def _parse(self, sql: str) -> list[exp.Expression]: if _has_jinja(sql): @@ -107,33 +67,61 @@ def _parse(self, sql: str) -> list[exp.Expression]: sql = node.compiled_code or node.raw_code return [e for e in self.dialect().parse(sql) if e] - async def query(self, expression: exp.Expression, sql: str, attrs: dict[str, t.Any]): + async def _alter_table_comment_middleware(self, q: Query) -> AllowedResult: + """Intercept ALTER TABLE ... MODIFY COLUMN ... COMMENT statements + + This middleware will update the column description in the dbt project manifest. Eventually + it could use our Yaml context class to actually write the changes to disk. + """ + if isinstance(q.expression, exp.Command): + lower_sql = q.sql.lower() + likely_alter_column_comment = all( + k in lower_sql for k in ("alter", "table", "modify", "column", "comment") + ) + if doc_update_req := ( + likely_alter_column_comment + and _regex_parse_to_complete_dict(q.sql, ALTER_TABLE_MODIFY_COLUMN_COMMENT) + ): + ref = (doc_update_req["schema"], doc_update_req["table"]) + for node in chain( + self.project.manifest.sources.values(), self.project.manifest.nodes.values() + ): + if ref == (node.schema, node.name): + for column in node.columns.values(): + if column.name == doc_update_req["column"]: + column.description = doc_update_req["comment"] + break + likely_alter_table_comment = all(k in lower_sql for k in ("alter", "table", "comment")) + if doc_update_req := ( + likely_alter_table_comment + and _regex_parse_to_complete_dict(q.sql, ALTER_TABLE_COMMENT) + ): + ref = (doc_update_req["schema"], doc_update_req["table"]) + for node in chain( + self.project.manifest.sources.values(), self.project.manifest.nodes.values() + ): + if ref == (node.schema, node.name): + node.description = doc_update_req["comment"] + return [], [] + return await q.next() + + async def query( + self, expression: exp.Expression, sql: str, attrs: dict[str, t.Any] + ) -> AllowedResult: logger.info("Query: %s", sql) - if isinstance(expression, exp.Command): - cmd = f"{expression.this} {expression.expression}" - doc_update = "alter" in sql.lower() and parse_alter_modify_column(cmd) - if doc_update: - logger.info("Will update doc: %s", doc_update) - else: - logger.info("Ignoring command %s", sql) - return (), [] # pyright: ignore[reportUnknownVariableType] resp, table = await asyncio.to_thread( execute_sql_code, self.project, expression.sql(dialect=self.project.adapter.type()) ) if resp.code: raise QueryException(resp) - logger.info(table) - return [ - t.cast(tuple[t.Any], row.values()) for row in t.cast(tuple[t.Any], table.rows.values()) - ], t.cast(tuple[str], table.column_names) + rows = t.cast(tuple[t.Any], table.rows.values()) + return [row.values() for row in rows], t.cast(tuple[str], table.column_names) async def schema(self): schema: defaultdict[str, dict[str, dict[str, tuple[str, str | None]]]] = defaultdict(dict) - for source in self.project.manifest.sources.values(): - schema[source.schema][source.name] = { - c.name: (c.data_type or "UNKOWN", c.description) for c in source.columns.values() - } - for node in self.project.manifest.nodes.values(): + for node in chain( + self.project.manifest.sources.values(), self.project.manifest.nodes.values() + ): schema[node.schema][node.name] = { c.name: (c.data_type or "UNKOWN", c.description) for c in node.columns.values() }