Skip to content

Commit

Permalink
EP-2625 Use a transform to add a frame clause - required for redshift…
Browse files Browse the repository at this point in the history
… conversion of generated sql
  • Loading branch information
BartBaddeley committed Jul 13, 2023
1 parent 2911daf commit 1ec00a6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
5 changes: 3 additions & 2 deletions splink/redshift/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ..misc import ensure_is_list
from ..splink_dataframe import SplinkDataFrame
from ..unique_id_concat import _composite_unique_id_from_nodes_sql
from ..sql_transform import sqlglot_transform_sql
from .redshift_helpers.redshift_transforms import add_frame_clause

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -130,8 +132,7 @@ def _execute_sql_against_backend(self, sql, templated_name, physical_name):
# In the case of a table already existing in the database,
# execute sql is only reached if the user has explicitly turned off the cache
self._delete_table_from_database(physical_name)
sql = sql.replace('partition by group_name order by value_count desc',
'partition by group_name order by value_count desc rows between unbounded preceding and current row')
sql = sqlglot_transform_sql(sql, add_frame_clause)

if self.output_sql_to_file:
os.makedirs(self.output_sql_directory, exist_ok=True)
Expand Down
17 changes: 7 additions & 10 deletions splink/redshift/redshift_helpers/redshift_transforms.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import sqlglot
from sqlglot import expressions as exp
import re


def cast_concat_as_varchar(node):
if isinstance(node, exp.Column):
if isinstance(node.parent, exp.Cast):
return node
def add_frame_clause(node):
sql = node.sql()
sql = re.sub(r'(partition by \w+ order by \w+ desc)',
r'\1 rows between unbounded preceding and current row',
sql, flags=re.IGNORECASE)
return sqlglot.parse_one(sql)

if node.find_ancestor(exp.DPipe):
sql = f"cast({node.sql()} as varchar)"
return sqlglot.parse_one(sql)

return node

0 comments on commit 1ec00a6

Please sign in to comment.