Skip to content

Commit

Permalink
fix: Move early return in main workflow activity
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Jun 10, 2024
1 parent 1fb4d16 commit 94712b6
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 22 deletions.
12 changes: 8 additions & 4 deletions posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
is_backfill=inputs.is_backfill,
)

first_record_batch, records_iterator = peek_first_and_rewind(records_iterator)
if first_record_batch is None:
return 0

bigquery_table = None
inserted_at = None

Expand All @@ -269,8 +273,6 @@ async def flush_to_bigquery(bigquery_table, table_schema):
rows_exported.add(jsonl_file.records_since_last_reset)
bytes_exported.add(jsonl_file.bytes_since_last_reset)

first_record, records_iterator = peek_first_and_rewind(records_iterator)

if inputs.use_json_type is True:
json_type = "JSON"
json_columns = ["properties", "set", "set_once", "person_properties"]
Expand All @@ -295,8 +297,10 @@ async def flush_to_bigquery(bigquery_table, table_schema):
]

else:
column_names = [column for column in first_record.schema.names if column != "_inserted_at"]
record_schema = first_record.select(column_names).schema
column_names = [
column for column in first_record_batch.schema.names if column != "_inserted_at"
]
record_schema = first_record_batch.select(column_names).schema
schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns)

bigquery_table = await create_table_in_bigquery(
Expand Down
9 changes: 5 additions & 4 deletions posthog/temporal/batch_exports/postgres_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs) -> Records
extra_query_parameters=query_parameters,
is_backfill=inputs.is_backfill,
)
first_record_batch, record_iterator = peek_first_and_rewind(record_iterator)
if first_record_batch is None:
return 0

if inputs.batch_export_schema is None:
table_fields = [
Expand All @@ -298,10 +301,8 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs) -> Records
]

else:
first_record, record_iterator = peek_first_and_rewind(record_iterator)

column_names = [column for column in first_record.schema.names if column != "_inserted_at"]
record_schema = first_record.select(column_names).schema
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
table_fields = get_postgres_fields_from_record_schema(
record_schema, known_json_columns=["properties", "set", "set_once", "person_properties"]
)
Expand Down
9 changes: 5 additions & 4 deletions posthog/temporal/batch_exports/redshift_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
extra_query_parameters=query_parameters,
is_backfill=inputs.is_backfill,
)
first_record_batch, record_iterator = peek_first_and_rewind(record_iterator)
if first_record_batch is None:
return 0

known_super_columns = ["properties", "set", "set_once", "person_properties"]

Expand All @@ -347,10 +350,8 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
("timestamp", "TIMESTAMP WITH TIME ZONE"),
]
else:
first_record, record_iterator = peek_first_and_rewind(record_iterator)

column_names = [column for column in first_record.schema.names if column != "_inserted_at"]
record_schema = first_record.select(column_names).schema
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
table_fields = get_redshift_fields_from_record_schema(
record_schema, known_super_columns=known_super_columns
)
Expand Down
6 changes: 5 additions & 1 deletion posthog/temporal/batch_exports/s3_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,11 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted:
is_backfill=inputs.is_backfill,
)

first_record_batch, record_iterator = peek_first_and_rewind(record_iterator)

if first_record_batch is None:
return 0

async with s3_upload as s3_upload:

async def flush_to_s3(
Expand All @@ -487,7 +492,6 @@ async def flush_to_s3(

heartbeater.details = (str(last_inserted_at), s3_upload.to_state())

first_record_batch, record_iterator = peek_first_and_rewind(record_iterator)
first_record_batch = cast_record_batch_json_columns(first_record_batch)
column_names = first_record_batch.column_names
column_names.pop(column_names.index("_inserted_at"))
Expand Down
10 changes: 6 additions & 4 deletions posthog/temporal/batch_exports/snowflake_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ async def flush_to_snowflake(
extra_query_parameters=query_parameters,
is_backfill=inputs.is_backfill,
)
first_record_batch, record_iterator = peek_first_and_rewind(record_iterator)

if first_record_batch is None:
return 0

known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"]
if inputs.batch_export_schema is None:
Expand All @@ -491,10 +495,8 @@ async def flush_to_snowflake(
]

else:
first_record, record_iterator = peek_first_and_rewind(record_iterator)

column_names = [column for column in first_record.schema.names if column != "_inserted_at"]
record_schema = first_record.select(column_names).schema
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
table_fields = get_snowflake_fields_from_record_schema(
record_schema,
known_variant_columns=known_variant_columns,
Expand Down
15 changes: 12 additions & 3 deletions posthog/temporal/batch_exports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def peek_first_and_rewind(
gen: collections.abc.Generator[T, None, None],
) -> tuple[T, collections.abc.Generator[T, None, None]]:
) -> tuple[T | None, collections.abc.Generator[T, None, None]]:
"""Peek into the first element in a generator and rewind the advance.
The generator is advanced and cannot be reversed, so we create a new one that first
Expand All @@ -19,10 +19,19 @@ def peek_first_and_rewind(
Returns:
A tuple with the first element of the generator and the generator itself.
"""
first = next(gen)
try:
first = next(gen)
except StopIteration:
first = None

def rewind_gen() -> collections.abc.Generator[T, None, None]:
"""Yield the item we popped to rewind the generator."""
"""Yield the item we popped to rewind the generator.
Return early if the generator is empty.
"""
if first is None:
return

yield first
yield from gen

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,67 @@ async def test_postgres_export_workflow(
)


@pytest.mark.parametrize("interval", ["hour", "day"], indirect=True)
@pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True)
@pytest.mark.parametrize("batch_export_schema", TEST_SCHEMAS)
async def test_postgres_export_workflow_without_events(
clickhouse_client,
postgres_config,
postgres_connection,
postgres_batch_export,
interval,
exclude_events,
ateam,
table_name,
batch_export_schema,
):
"""Test Postgres Export Workflow end-to-end by using a local PG database.
The workflow should update the batch export run status to completed and produce the expected
records to the local development PostgreSQL database.
"""
data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00")

workflow_id = str(uuid.uuid4())
inputs = PostgresBatchExportInputs(
team_id=ateam.pk,
batch_export_id=str(postgres_batch_export.id),
data_interval_end=data_interval_end.isoformat(),
interval=interval,
batch_export_schema=batch_export_schema,
**postgres_batch_export.destination.config,
)

async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[PostgresBatchExportWorkflow],
activities=[
start_batch_export_run,
insert_into_postgres_activity,
finish_batch_export_run,
],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with override_settings(BATCH_EXPORT_POSTGRES_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2):
await activity_environment.client.execute_workflow(
PostgresBatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)

runs = await afetch_batch_export_runs(batch_export_id=postgres_batch_export.id)
assert len(runs) == 1

run = runs[0]
assert run.status == "Completed"
assert run.records_completed == 0


async def test_postgres_export_workflow_handles_insert_activity_errors(ateam, postgres_batch_export, interval):
"""Test that Postgres Export Workflow can gracefully handle errors when inserting Postgres data."""
data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,75 @@ async def test_s3_export_workflow_with_minio_bucket(
)


@pytest.mark.parametrize("interval", ["hour"], indirect=True)
@pytest.mark.parametrize("compression", [None], indirect=True)
@pytest.mark.parametrize("exclude_events", [None], indirect=True)
@pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS)
async def test_s3_export_workflow_with_minio_bucket_without_events(
clickhouse_client,
minio_client,
ateam,
s3_batch_export,
bucket_name,
interval,
compression,
exclude_events,
s3_key_prefix,
batch_export_schema,
):
"""Test S3BatchExport Workflow end-to-end by using a local MinIO bucket instead of S3.
The workflow should update the batch export run status to completed and produce the expected
records to the MinIO bucket.
We use a BatchExport model to provide accurate inputs to the Workflow and because the Workflow
will require its prescense in the database when running. This model is indirectly parametrized
by several fixtures. Refer to them for more information.
"""
data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00")

workflow_id = str(uuid4())
inputs = S3BatchExportInputs(
team_id=ateam.pk,
batch_export_id=str(s3_batch_export.id),
data_interval_end=data_interval_end.isoformat(),
interval=interval,
batch_export_schema=batch_export_schema,
**s3_batch_export.destination.config,
)

async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[S3BatchExportWorkflow],
activities=[
start_batch_export_run,
insert_into_s3_activity,
finish_batch_export_run,
],
workflow_runner=UnsandboxedWorkflowRunner(),
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(minutes=10),
)

runs = await afetch_batch_export_runs(batch_export_id=s3_batch_export.id)
assert len(runs) == 1

run = runs[0]
assert run.status == "Completed"
assert run.records_completed == 0

objects = await minio_client.list_objects_v2(Bucket=bucket_name, Prefix=s3_key_prefix)
assert len(objects.get("Contents", [])) == 0


@pytest_asyncio.fixture
async def s3_client(bucket_name, s3_key_prefix):
"""Manage an S3 client to interact with an S3 bucket.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def query_request_handler(request: PreparedRequest):
"rowset": rowset,
"total": 1,
"returned": 1,
"queryId": "query-id",
"queryId": str(uuid4()),
"queryResultFormat": "json",
},
}
Expand Down Expand Up @@ -463,7 +463,7 @@ async def test_snowflake_export_workflow_exports_events(


@pytest.mark.parametrize("interval", ["hour", "day"], indirect=True)
async def test_snowflake_export_workflow_without_events(ateam, snowflake_batch_export, interval):
async def test_snowflake_export_workflow_without_events(ateam, snowflake_batch_export, interval, truncate_events):
workflow_id = str(uuid4())
inputs = SnowflakeBatchExportInputs(
team_id=ateam.pk,
Expand Down

0 comments on commit 94712b6

Please sign in to comment.