Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement a janitorclass to automatically format class attributes #1650

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions splink/input_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlglot.errors import ParseError

from .default_from_jsonschema import default_value_from_schema
from .misc import JanitorClass


def sqlglot_tree_signature(tree):
Expand Down Expand Up @@ -223,15 +224,9 @@ def _get_dialect_quotes(dialect):
sqlglot_dialect = sqlglot.Dialect[dialect.lower()]
except KeyError:
return start, end
return _get_sqlglot_dialect_quotes(sqlglot_dialect)
return _get_sqlglot_dialect_quotes(JanitorClass(sqlglot_dialect))


def _get_sqlglot_dialect_quotes(dialect: sqlglot.Dialect):
try:
# For sqlglot >= 16.0.0
start = dialect.IDENTIFIER_START
end = dialect.IDENTIFIER_END
except AttributeError:
start = dialect.identifier_start
end = dialect.identifier_end
return start, end
# From sqlglot >= 16.0.0 all static variables are uppercase
return dialect.IDENTIFIER_START, dialect.IDENTIFIER_END
8 changes: 8 additions & 0 deletions splink/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def default(self, obj):
return obj.__str__()


class JanitorClass:
def __init__(self, RootClass, *args, **kwargs):
self.wrapped = RootClass(*args, **kwargs)

def __getattr__(self, name):
return getattr(self.wrapped, name.upper())


def calculate_cartesian(df_rows, link_type):
"""
Calculates the cartesian product for the input df(s).
Expand Down
Loading