Skip to content

Commit

Permalink
feat: add working middleware to dbt proxy to sync alter table modify …
Browse files Browse the repository at this point in the history
…column comment stmts to manifest
  • Loading branch information
z3z1ma committed Jan 5, 2025
1 parent f6051bf commit 5afc839
Showing 1 changed file with 64 additions and 76 deletions.
140 changes: 64 additions & 76 deletions src/dbt_osmosis/sql/proxy.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# pyright: reportMissingTypeStubs=false, reportAny=false, reportImplicitOverride=false, reportUnknownMemberType=false, reportUnusedImport=false, reportUnknownParameterType=false
"""Proxy server experiment that any MySQL client (including BI tools) can connect to."""

# pyright: reportMissingTypeStubs=false, reportAny=false, reportImplicitOverride=false, reportUnknownMemberType=false, reportUnusedImport=false, reportUnknownParameterType=false
import asyncio
import functools
import re
import typing as t
from collections import defaultdict
from collections.abc import Iterator
from itertools import chain

from dbt.adapters.contracts.connection import AdapterResponse
from mysql_mimic import MysqlServer, Session
from mysql_mimic.errors import MysqlError
from mysql_mimic.results import AllowedResult
from mysql_mimic.schema import (
Column,
InfoSchema,
dict_depth, # pyright: ignore[reportUnknownVariableType, reportPrivateLocalImportUsage]
info_schema_tables,
)
from mysql_mimic.session import Query
from sqlglot import exp

import dbt_osmosis.core.logger as logger
Expand All @@ -29,111 +32,96 @@
execute_sql_code,
)

# TODO: this doesn't capture comment body consistently
ALTER_MODIFY_COL_PATTERN = re.compile(
r"""
^\s* # start, allow leading whitespace
ALTER\s+TABLE # "ALTER TABLE" (case-insensitive via flags=)
\s+
(?: # optional schema part:
(?:"(?P<schema>[^"]+)" # double-quoted schema
|(?P<schema_unquoted>\w+) # or unquoted schema
)
\s*\.
)?
(?:"(?P<table>[^"]+)" # table in double-quotes
|(?P<table_unquoted>\w+) # or unquoted table
)
\s+
MODIFY\s+COLUMN\s+ # "MODIFY COLUMN" (case-insensitive via flags=)
(?:"(?P<column>[^"]+)" # column in double-quotes
|(?P<column_unquoted>\w+) # or unquoted column
)
.*? # lazily consume anything until we might see COMMENT
(?: # optional comment group
COMMENT\s+ # must have "COMMENT" then space(s)
(["']) # capture the quote symbol in group 1
(?P<comment>
(?:.|[^"'])* # any escaped char or anything that isn't ' or "
)
\1 # match the same quote symbol (group 1)
)?
\s*;?\s*$ # optional whitespace, optional semicolon
""",
flags=re.IGNORECASE | re.DOTALL | re.VERBOSE,
ALTER_TABLE_MODIFY_COLUMN_COMMENT = re.compile(
r"(?i)(?:/\*.*?\*/\s*)?ALTER TABLE\s+(?:(?P<schema>[^\s\.]+)\.)?(?P<table>[^\s\.]+)\s+MODIFY COLUMN\s+(?P<column>[^\s]+)\s+.*?COMMENT\s+'(?P<comment>[^']*)';?"
)

ALTER_TABLE_COMMENT = re.compile(
r"(?i)(?:/\*.*?\*/\s*)?ALTER TABLE\s+(?:(?P<schema>[^\s\.]+)\.)?(?P<table>[^\s\.]+)\s+COMMENT\s*=\s*'(?P<comment>[^']*)';"
)

def parse_alter_modify_column(sql: str) -> dict[str, str] | None:
"""
Attempt to parse a statement like:
ALTER TABLE schema.table MODIFY COLUMN col TYPE ... COMMENT 'some text';
Returns None if the pattern does not match, otherwise a dict with:
{
"schema": ... or None,
"table": ...,
"column": ...,
"comment": ... or None
}
"""
match = ALTER_MODIFY_COL_PATTERN.match(sql)
if not match:
return None

# Because we have both quoted and unquoted named groups, pick whichever matched:
schema = match.group("schema") or match.group("schema_unquoted")
table = match.group("table") or match.group("table_unquoted")
column = match.group("column") or match.group("column_unquoted")
comment = match.group("comment") # can be None if COMMENT was not present

return {"schema": schema, "table": table, "column": column, "comment": comment}
def _regex_parse_to_complete_dict(sql: str, pattern: re.Pattern[str]) -> dict[str, str] | None:
"""Parse a SQL statement using a regex pattern and return a dict with the matched groups ensuring all are present"""
if match := pattern.match(sql):
result = match.groupdict()
if all(result.values()):
return result


class QueryException(MysqlError):
def __init__(self, response: AdapterResponse) -> None:
self.response: AdapterResponse = response
super().__init__(response._message) # pyright: ignore[reportPrivateUsage]
self.response: AdapterResponse = response


class DbtSession(Session):
def __init__(self, project: DbtProjectContext, *args: t.Any, **kwargs: t.Any) -> None:
self.project: DbtProjectContext = project
super().__init__(*args, **kwargs)
self.project: DbtProjectContext = project
self.middlewares.append(self._alter_table_comment_middleware)

def _parse(self, sql: str) -> list[exp.Expression]:
if _has_jinja(sql):
node = compile_sql_code(self.project, sql)
sql = node.compiled_code or node.raw_code
return [e for e in self.dialect().parse(sql) if e]

async def query(self, expression: exp.Expression, sql: str, attrs: dict[str, t.Any]):
async def _alter_table_comment_middleware(self, q: Query) -> AllowedResult:
"""Intercept ALTER TABLE ... MODIFY COLUMN ... COMMENT statements
This middleware will update the column description in the dbt project manifest. Eventually
it could use our Yaml context class to actually write the changes to disk.
"""
if isinstance(q.expression, exp.Command):
lower_sql = q.sql.lower()
likely_alter_column_comment = all(
k in lower_sql for k in ("alter", "table", "modify", "column", "comment")
)
if doc_update_req := (
likely_alter_column_comment
and _regex_parse_to_complete_dict(q.sql, ALTER_TABLE_MODIFY_COLUMN_COMMENT)
):
ref = (doc_update_req["schema"], doc_update_req["table"])
for node in chain(
self.project.manifest.sources.values(), self.project.manifest.nodes.values()
):
if ref == (node.schema, node.name):
for column in node.columns.values():
if column.name == doc_update_req["column"]:
column.description = doc_update_req["comment"]
break
likely_alter_table_comment = all(k in lower_sql for k in ("alter", "table", "comment"))
if doc_update_req := (
likely_alter_table_comment
and _regex_parse_to_complete_dict(q.sql, ALTER_TABLE_COMMENT)
):
ref = (doc_update_req["schema"], doc_update_req["table"])
for node in chain(
self.project.manifest.sources.values(), self.project.manifest.nodes.values()
):
if ref == (node.schema, node.name):
node.description = doc_update_req["comment"]
return [], []
return await q.next()

async def query(
self, expression: exp.Expression, sql: str, attrs: dict[str, t.Any]
) -> AllowedResult:
logger.info("Query: %s", sql)
if isinstance(expression, exp.Command):
cmd = f"{expression.this} {expression.expression}"
doc_update = "alter" in sql.lower() and parse_alter_modify_column(cmd)
if doc_update:
logger.info("Will update doc: %s", doc_update)
else:
logger.info("Ignoring command %s", sql)
return (), [] # pyright: ignore[reportUnknownVariableType]
resp, table = await asyncio.to_thread(
execute_sql_code, self.project, expression.sql(dialect=self.project.adapter.type())
)
if resp.code:
raise QueryException(resp)
logger.info(table)
return [
t.cast(tuple[t.Any], row.values()) for row in t.cast(tuple[t.Any], table.rows.values())
], t.cast(tuple[str], table.column_names)
rows = t.cast(tuple[t.Any], table.rows.values())
return [row.values() for row in rows], t.cast(tuple[str], table.column_names)

async def schema(self):
schema: defaultdict[str, dict[str, dict[str, tuple[str, str | None]]]] = defaultdict(dict)
for source in self.project.manifest.sources.values():
schema[source.schema][source.name] = {
c.name: (c.data_type or "UNKOWN", c.description) for c in source.columns.values()
}
for node in self.project.manifest.nodes.values():
for node in chain(
self.project.manifest.sources.values(), self.project.manifest.nodes.values()
):
schema[node.schema][node.name] = {
c.name: (c.data_type or "UNKOWN", c.description) for c in node.columns.values()
}
Expand Down

0 comments on commit 5afc839

Please sign in to comment.