Skip to content

Commit

Permalink
chore(data-warehouse): Put s3table queries into cte so joins work (RE…
Browse files Browse the repository at this point in the history
…VERT) (#16452)

Revert "chore(data-warehouse): Put s3table queries into cte so joins work (#16413)"

This reverts commit cfa1048.
  • Loading branch information
mariusandra authored Jul 9, 2023
1 parent 32b0552 commit d8fff21
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 86 deletions.
4 changes: 2 additions & 2 deletions posthog/hogql/database/test/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from posthog.hogql.database.s3_table import S3Table


def create_aapl_stock_s3_table(name="aapl_stock") -> S3Table:
def create_aapl_stock_s3_table() -> S3Table:
return S3Table(
name=name,
name="aapl_stock",
url="https://s3.eu-west-3.amazonaws.com/datasets-documentation/aapl_stock.csv",
format="CSVWithNames",
fields={
Expand Down
3 changes: 1 addition & 2 deletions posthog/hogql/database/test/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def test_database_with_warehouse_tables(self, patch_execute):
"select * from whatever",
team=self.team,
)

self.assertEqual(
response.clickhouse,
f"WITH whatever AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_3)s, %(hogql_val_4)s, %(hogql_val_1)s, %(hogql_val_2)s)) SELECT whatever.id FROM whatever LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=True",
f"SELECT whatever.id FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_3)s, %(hogql_val_4)s, %(hogql_val_1)s, %(hogql_val_2)s) AS whatever LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=True",
)
73 changes: 2 additions & 71 deletions posthog/hogql/database/test/test_s3_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class TestS3Table(BaseTest):
def _init_database(self):
self.database = create_hogql_database(self.team.pk)
self.database.aapl_stock = create_aapl_stock_s3_table()
self.database.aapl_stock_2 = create_aapl_stock_s3_table(name="aapl_stock_2")
self.context = HogQLContext(team_id=self.team.pk, enable_select_queries=True, database=self.database)

def _select(self, query: str, dialect: str = "clickhouse") -> str:
Expand All @@ -23,10 +22,9 @@ def test_s3_table_select(self):
self.assertEqual(hogql, "SELECT Date, Open, High, Low, Close, Volume, OpenInt FROM aapl_stock LIMIT 10")

clickhouse = self._select(query="SELECT * FROM aapl_stock LIMIT 10", dialect="clickhouse")

self.assertEqual(
clickhouse,
"WITH aapl_stock AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_1)s)) SELECT aapl_stock.Date, aapl_stock.Open, aapl_stock.High, aapl_stock.Low, aapl_stock.Close, aapl_stock.Volume, aapl_stock.OpenInt FROM aapl_stock LIMIT 10",
"SELECT aapl_stock.Date, aapl_stock.Open, aapl_stock.High, aapl_stock.Low, aapl_stock.Close, aapl_stock.Volume, aapl_stock.OpenInt FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_1)s) AS aapl_stock LIMIT 10",
)

def test_s3_table_select_with_alias(self):
Expand All @@ -36,74 +34,7 @@ def test_s3_table_select_with_alias(self):
self.assertEqual(hogql, "SELECT High, Low FROM aapl_stock AS a LIMIT 10")

clickhouse = self._select(query="SELECT High, Low FROM aapl_stock AS a LIMIT 10", dialect="clickhouse")

self.assertEqual(
clickhouse,
"WITH a AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_1)s)) SELECT a.High, a.Low FROM aapl_stock AS a LIMIT 10",
)

def test_s3_table_select_join(self):
self._init_database()

hogql = self._select(
query="SELECT aapl_stock.High, aapl_stock.Low FROM aapl_stock JOIN aapl_stock_2 ON aapl_stock.High = aapl_stock_2.High LIMIT 10",
dialect="hogql",
)
self.assertEqual(
hogql,
"SELECT aapl_stock.High, aapl_stock.Low FROM aapl_stock JOIN aapl_stock_2 ON equals(aapl_stock.High, aapl_stock_2.High) LIMIT 10",
)

clickhouse = self._select(
query="SELECT aapl_stock.High, aapl_stock.Low FROM aapl_stock JOIN aapl_stock_2 ON aapl_stock.High = aapl_stock_2.High LIMIT 10",
dialect="clickhouse",
)

self.assertEqual(
clickhouse,
"WITH aapl_stock AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_1)s)), aapl_stock_2 AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_3)s, %(hogql_val_4)s)) SELECT aapl_stock.High, aapl_stock.Low FROM aapl_stock JOIN aapl_stock_2 ON equals(aapl_stock.High, aapl_stock_2.High) LIMIT 10",
)

def test_s3_table_select_join_with_alias(self):
self._init_database()

hogql = self._select(
query="SELECT a.High, a.Low FROM aapl_stock AS a JOIN aapl_stock AS b ON a.High = b.High LIMIT 10",
dialect="hogql",
)
self.assertEqual(
hogql, "SELECT a.High, a.Low FROM aapl_stock AS a JOIN aapl_stock AS b ON equals(a.High, b.High) LIMIT 10"
)

clickhouse = self._select(
query="SELECT a.High, a.Low FROM aapl_stock AS a JOIN aapl_stock AS b ON a.High = b.High LIMIT 10",
dialect="clickhouse",
)

self.assertEqual(
clickhouse,
"WITH a AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_1)s)), b AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_3)s, %(hogql_val_4)s)) SELECT a.High, a.Low FROM aapl_stock AS a JOIN aapl_stock AS b ON equals(a.High, b.High) LIMIT 10",
)

def test_s3_table_select_and_non_s3_join(self):
self._init_database()

hogql = self._select(
query="SELECT aapl_stock.High, aapl_stock.Low FROM aapl_stock JOIN events ON aapl_stock.High = events.event LIMIT 10",
dialect="hogql",
)
self.assertEqual(
hogql,
"SELECT aapl_stock.High, aapl_stock.Low FROM aapl_stock JOIN events ON equals(aapl_stock.High, events.event) LIMIT 10",
)

clickhouse = self._select(
query="SELECT aapl_stock.High, aapl_stock.Low FROM aapl_stock JOIN events ON aapl_stock.High = events.event LIMIT 10",
dialect="clickhouse",
)

self.maxDiff = None
self.assertEqual(
clickhouse,
f"WITH aapl_stock AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_1)s)) SELECT aapl_stock.High, aapl_stock.Low FROM aapl_stock JOIN events ON equals(aapl_stock.High, events.event) WHERE equals(events.team_id, {self.team.pk}) LIMIT 10",
"SELECT a.High, a.Low FROM s3Cluster('posthog', %(hogql_val_0)s, %(hogql_val_1)s) AS a LIMIT 10",
)
12 changes: 1 addition & 11 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import Table, FunctionCallTable
from posthog.hogql.database.database import create_hogql_database
from posthog.hogql.database.s3_table import S3Table
from posthog.hogql.errors import HogQLException
from posthog.hogql.escape_sql import (
escape_clickhouse_identifier,
Expand Down Expand Up @@ -96,7 +95,6 @@ def print_prepared_ast(
class JoinExprResponse:
printed_sql: str
where: Optional[ast.Expr] = None
ctes: Optional[List[str]] = None


class _Printer(Visitor):
Expand Down Expand Up @@ -158,7 +156,6 @@ def visit_select_query(self, node: ast.SelectQuery):
where = node.where

joined_tables = []
ctes = []
next_join = node.select_from
while isinstance(next_join, ast.JoinExpr):
if next_join.type is None:
Expand All @@ -167,7 +164,6 @@ def visit_select_query(self, node: ast.SelectQuery):

visited_join = self.visit_join_expr(next_join)
joined_tables.append(visited_join.printed_sql)
ctes.extend(visited_join.ctes or [])

# This is an expression we must add to the SELECT's WHERE clause to limit results, like the team ID guard.
extra_where = visited_join.where
Expand Down Expand Up @@ -231,8 +227,6 @@ def visit_select_query(self, node: ast.SelectQuery):

response = " ".join([clause for clause in clauses if clause])

response = f"WITH {', '.join(ctes)} {response}" if ctes else response

# If we are printing a SELECT subquery (not the first AST node we are visiting), wrap it in parentheses.
if not part_of_select_union and not is_top_level_query:
response = f"({response})"
Expand All @@ -244,7 +238,6 @@ def visit_join_expr(self, node: ast.JoinExpr) -> JoinExprResponse:
extra_where: Optional[ast.Expr] = None

join_strings = []
ctes = []

if node.join_type is not None:
join_strings.append(node.join_type)
Expand All @@ -264,9 +257,6 @@ def visit_join_expr(self, node: ast.JoinExpr) -> JoinExprResponse:

if self.dialect == "clickhouse":
sql = table_type.table.to_printed_clickhouse(self.context)
if isinstance(table_type.table, S3Table):
ctes.append(f"{node.alias} AS (SELECT * FROM {sql})")
sql = table_type.table.to_printed_hogql()
else:
sql = table_type.table.to_printed_hogql()
join_strings.append(sql)
Expand Down Expand Up @@ -305,7 +295,7 @@ def visit_join_expr(self, node: ast.JoinExpr) -> JoinExprResponse:
if node.constraint is not None:
join_strings.append(f"ON {self.visit(node.constraint)}")

return JoinExprResponse(printed_sql=" ".join(join_strings), where=extra_where, ctes=ctes if ctes else None)
return JoinExprResponse(printed_sql=" ".join(join_strings), where=extra_where)

def visit_join_constraint(self, node: ast.JoinConstraint):
return self.visit(node.expr)
Expand Down

0 comments on commit d8fff21

Please sign in to comment.