Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(experiments): Hackily apply timestamp to ASOF LEFT JOIN #26852

Merged
merged 5 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from posthog.warehouse.models.credential import DataWarehouseCredential
from posthog.warehouse.models.join import DataWarehouseJoin
from posthog.warehouse.models.table import DataWarehouseTable
from posthog.hogql.query import execute_hogql_query

TEST_BUCKET = "test_storage_bucket-posthog.hogql.datawarehouse.trendquery" + XDIST_SUFFIX

Expand Down Expand Up @@ -693,6 +694,87 @@ def test_query_runner_with_data_warehouse_series_avg_amount(self):
[0.0, 50.0, 125.0, 125.0, 125.0, 205.0, 205.0, 205.0, 205.0, 205.0],
)

def test_query_runner_with_data_warehouse_series_expected_query(self):
table_name = self.create_data_warehouse_table_with_payments()

feature_flag = self.create_feature_flag()
experiment = self.create_experiment(
feature_flag=feature_flag,
start_date=datetime(2023, 1, 1),
end_date=datetime(2023, 1, 10),
)

feature_flag_property = f"$feature/{feature_flag.key}"

count_query = TrendsQuery(
series=[
DataWarehouseNode(
id=table_name,
distinct_id_field="dw_distinct_id",
id_field="id",
table_name=table_name,
timestamp_field="dw_timestamp",
math="total",
)
]
)
exposure_query = TrendsQuery(series=[EventsNode(event="$feature_flag_called")])

experiment_query = ExperimentTrendsQuery(
experiment_id=experiment.id,
kind="ExperimentTrendsQuery",
count_query=count_query,
exposure_query=exposure_query,
)

experiment.metrics = [{"type": "primary", "query": experiment_query.model_dump()}]
experiment.save()

# Populate exposure events
for variant, count in [("control", 7), ("test", 9)]:
for i in range(count):
_create_event(
team=self.team,
event="$feature_flag_called",
distinct_id=f"user_{variant}_{i}",
properties={feature_flag_property: variant},
timestamp=datetime(2023, 1, i + 1),
)

flush_persons_and_events()

query_runner = ExperimentTrendsQueryRunner(
query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team
)
with freeze_time("2023-01-07"):
# Build and execute the query to get the ClickHouse SQL
queries = query_runner.count_query_runner.to_queries()
response = execute_hogql_query(
query_type="TrendsQuery",
query=queries[0],
team=query_runner.count_query_runner.team,
modifiers=query_runner.count_query_runner.modifiers,
limit_context=query_runner.count_query_runner.limit_context,
)

# Assert the expected join condition in the clickhouse SQL
expected_join_condition = f"and(equals(events.team_id, {query_runner.count_query_runner.team.id}), equals(event, %(hogql_val_7)s), greaterOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_8)s, 6, %(hogql_val_9)s))), lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_10)s, 6, %(hogql_val_11)s))))) AS e__events ON"
self.assertIn(expected_join_condition, str(response.clickhouse))

result = query_runner.calculate()

trend_result = cast(ExperimentTrendsQueryResponse, result)

self.assertEqual(len(result.variants), 2)

control_result = next(variant for variant in trend_result.variants if variant.key == "control")
test_result = next(variant for variant in trend_result.variants if variant.key == "test")

self.assertEqual(control_result.count, 1)
self.assertEqual(test_result.count, 3)
self.assertEqual(control_result.absolute_exposure, 7)
self.assertEqual(test_result.absolute_exposure, 9)

def test_query_runner_with_invalid_data_warehouse_table_name(self):
# parquet file isn't created, so we'll get an error
table_name = "invalid_table_name"
Expand Down
26 changes: 21 additions & 5 deletions posthog/warehouse/models/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def _join_function_for_experiments(
if not timestamp_key:
raise ResolutionError("experiments_timestamp_key is not set for this join")

whereExpr: list[ast.Expr] = [
ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Field(chain=["event"]),
right=ast.Constant(value="$feature_flag_called"),
)
]
# :HACK: We need to pull the timestamp gt/lt values from node.where.exprs[0] because
# we can't reference the parent data warehouse table in the where clause.
if node.where and hasattr(node.where, "exprs"):
for expr in node.where.exprs:
if isinstance(expr, ast.CompareOperation):
if expr.op == ast.CompareOperationOp.GtEq or expr.op == ast.CompareOperationOp.LtEq:
if isinstance(expr.left, ast.Alias) and expr.left.expr.to_hogql() == timestamp_key:
whereExpr.append(
ast.CompareOperation(
op=expr.op, left=ast.Field(chain=["timestamp"]), right=expr.right
)
)

return ast.JoinExpr(
table=ast.SelectQuery(
select=[
Expand All @@ -128,11 +148,7 @@ def _join_function_for_experiments(
}.items()
],
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
where=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Field(chain=["event"]),
right=ast.Constant(value="$feature_flag_called"),
),
where=ast.And(exprs=whereExpr),
),
# ASOF JOIN finds the most recent matching event that occurred at or before each data warehouse timestamp.
#
Expand Down
Loading