Skip to content

Commit

Permalink
fix(experiments): Hackily apply timestamp to ASOF LEFT JOIN (#26852)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbachhuber authored Dec 12, 2024
1 parent d3773c1 commit 0d72070
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from posthog.warehouse.models.credential import DataWarehouseCredential
from posthog.warehouse.models.join import DataWarehouseJoin
from posthog.warehouse.models.table import DataWarehouseTable
from posthog.hogql.query import execute_hogql_query

TEST_BUCKET = "test_storage_bucket-posthog.hogql.datawarehouse.trendquery" + XDIST_SUFFIX

Expand Down Expand Up @@ -693,6 +694,87 @@ def test_query_runner_with_data_warehouse_series_avg_amount(self):
[0.0, 50.0, 125.0, 125.0, 125.0, 205.0, 205.0, 205.0, 205.0, 205.0],
)

def test_query_runner_with_data_warehouse_series_expected_query(self):
table_name = self.create_data_warehouse_table_with_payments()

feature_flag = self.create_feature_flag()
experiment = self.create_experiment(
feature_flag=feature_flag,
start_date=datetime(2023, 1, 1),
end_date=datetime(2023, 1, 10),
)

feature_flag_property = f"$feature/{feature_flag.key}"

count_query = TrendsQuery(
series=[
DataWarehouseNode(
id=table_name,
distinct_id_field="dw_distinct_id",
id_field="id",
table_name=table_name,
timestamp_field="dw_timestamp",
math="total",
)
]
)
exposure_query = TrendsQuery(series=[EventsNode(event="$feature_flag_called")])

experiment_query = ExperimentTrendsQuery(
experiment_id=experiment.id,
kind="ExperimentTrendsQuery",
count_query=count_query,
exposure_query=exposure_query,
)

experiment.metrics = [{"type": "primary", "query": experiment_query.model_dump()}]
experiment.save()

# Populate exposure events
for variant, count in [("control", 7), ("test", 9)]:
for i in range(count):
_create_event(
team=self.team,
event="$feature_flag_called",
distinct_id=f"user_{variant}_{i}",
properties={feature_flag_property: variant},
timestamp=datetime(2023, 1, i + 1),
)

flush_persons_and_events()

query_runner = ExperimentTrendsQueryRunner(
query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team
)
with freeze_time("2023-01-07"):
# Build and execute the query to get the ClickHouse SQL
queries = query_runner.count_query_runner.to_queries()
response = execute_hogql_query(
query_type="TrendsQuery",
query=queries[0],
team=query_runner.count_query_runner.team,
modifiers=query_runner.count_query_runner.modifiers,
limit_context=query_runner.count_query_runner.limit_context,
)

# Assert the expected join condition in the clickhouse SQL
expected_join_condition = f"and(equals(events.team_id, {query_runner.count_query_runner.team.id}), equals(event, %(hogql_val_7)s), greaterOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_8)s, 6, %(hogql_val_9)s))), lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_10)s, 6, %(hogql_val_11)s))))) AS e__events ON"
self.assertIn(expected_join_condition, str(response.clickhouse))

result = query_runner.calculate()

trend_result = cast(ExperimentTrendsQueryResponse, result)

self.assertEqual(len(result.variants), 2)

control_result = next(variant for variant in trend_result.variants if variant.key == "control")
test_result = next(variant for variant in trend_result.variants if variant.key == "test")

self.assertEqual(control_result.count, 1)
self.assertEqual(test_result.count, 3)
self.assertEqual(control_result.absolute_exposure, 7)
self.assertEqual(test_result.absolute_exposure, 9)

def test_query_runner_with_invalid_data_warehouse_table_name(self):
# parquet file isn't created, so we'll get an error
table_name = "invalid_table_name"
Expand Down
26 changes: 21 additions & 5 deletions posthog/warehouse/models/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def _join_function_for_experiments(
if not timestamp_key:
raise ResolutionError("experiments_timestamp_key is not set for this join")

whereExpr: list[ast.Expr] = [
ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Field(chain=["event"]),
right=ast.Constant(value="$feature_flag_called"),
)
]
# :HACK: We need to pull the timestamp gt/lt values from node.where.exprs[0] because
# we can't reference the parent data warehouse table in the where clause.
if node.where and hasattr(node.where, "exprs"):
for expr in node.where.exprs:
if isinstance(expr, ast.CompareOperation):
if expr.op == ast.CompareOperationOp.GtEq or expr.op == ast.CompareOperationOp.LtEq:
if isinstance(expr.left, ast.Alias) and expr.left.expr.to_hogql() == timestamp_key:
whereExpr.append(
ast.CompareOperation(
op=expr.op, left=ast.Field(chain=["timestamp"]), right=expr.right
)
)

return ast.JoinExpr(
table=ast.SelectQuery(
select=[
Expand All @@ -128,11 +148,7 @@ def _join_function_for_experiments(
}.items()
],
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
where=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Field(chain=["event"]),
right=ast.Constant(value="$feature_flag_called"),
),
where=ast.And(exprs=whereExpr),
),
# ASOF JOIN finds the most recent matching event that occurred at or before each data warehouse timestamp.
#
Expand Down

0 comments on commit 0d72070

Please sign in to comment.