Skip to content

Commit

Permalink
Add tolerance parameter to asofJoin (#304)
Browse files Browse the repository at this point in the history
* start

* check in some changes

* new tests

* test

* fix black

* remove tolerance window spec

* tolerance unit

* format

* run black
  • Loading branch information
nina-hu authored Apr 17, 2023
1 parent 7ddcdfa commit 8270aa6
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 1 deletion.
28 changes: 28 additions & 0 deletions python/tempo/tsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ def asofJoin(
skipNulls=True,
sql_join_opt=False,
suppress_null_warning=False,
tolerance=None,
):
"""
Performs an as-of join between two time-series. If a tsPartitionVal is
Expand All @@ -718,6 +719,7 @@ def asofJoin(
:param skipNulls - whether to skip nulls when joining in values
:param sql_join_opt - if set to True, will use standard Spark SQL join if it is estimated to be efficient
:param suppress_null_warning - when tsPartitionVal is specified, will collect min of each column and raise warnings about null values, set to True to avoid
:param tolerance - only join values within this tolerance range, expressed in number of seconds as an int
"""

# first block of logic checks whether a standard range join will suffice
Expand Down Expand Up @@ -865,6 +867,32 @@ def asofJoin(

asofDF = TSDF(df, asofDF.ts_col, combined_df.partitionCols)

if tolerance is not None:
df = asofDF.df
left_ts_col = left_tsdf.ts_col
right_ts_col = right_tsdf.ts_col
tolerance_condition = (
df[left_ts_col].cast("double") - df[right_ts_col].cast("double")
> tolerance
)

for right_col in right_columns:
# First set right non-timestamp columns to null for rows outside of tolerance band
if right_col != right_ts_col:
df = df.withColumn(
right_col,
f.when(tolerance_condition, f.lit(None)).otherwise(
df[right_col]
),
)

# Finally, set right timestamp column to null for rows outside of tolerance band
df = df.withColumn(
right_ts_col,
f.when(tolerance_condition, f.lit(None)).otherwise(df[right_ts_col]),
)
asofDF.df = df

return asofDF

def __baseWindow(self, sort_col=None, reverse=False):
Expand Down
21 changes: 21 additions & 0 deletions python/tests/as_of_join_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,27 @@ def test_asof_join_nanos(self):
# compare
self.assertDataFrameEquality(joined_df, dfExpected)

def test_asof_join_tolerance(self):
"""As of join with tolerance band"""

# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")

tolerance_test_values = [0, 7, 10]
for tolerance in tolerance_test_values:
# perform join
joined_df = tsdf_left.asofJoin(
tsdf_right,
left_prefix="left",
right_prefix="right",
tolerance=tolerance,
).df

# compare
expected_tolerance = self.get_data_as_sdf(f"expected_tolerance_{tolerance}")
self.assertDataFrameEquality(joined_df, expected_tolerance)


# MAIN
if __name__ == "__main__":
Expand Down
54 changes: 53 additions & 1 deletion python/tests/unit_test_data/as_of_join_tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,58 @@
["S1", "2022-01-01 10:00:01.123456789", 364.31, "2022-01-01 10:00:01.10000001", 365.31, 359.21]
]
}
},
"test_asof_join_tolerance": {
"left": {
"$ref": "#/__SharedData/shared_left"
},
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:01:05", 348.10, 353.13],
["S1", "2020-09-01 00:02:01", 358.93, 365.12],
["S1", "2020-09-01 00:15:01", 359.21, 365.31]
]
},
"expected_tolerance_0": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, null, null, null],
["S1", "2020-08-01 00:01:12", 351.32, null, null, null],
["S1", "2020-09-01 00:02:10", 361.1, null, null, null],
["S1", "2020-09-01 00:19:12", 362.1, null, null, null]
]
},
"expected_tolerance_7": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, null, null, null],
["S1", "2020-08-01 00:01:12", 351.32, "2020-08-01 00:01:05", 348.10, 353.13],
["S1", "2020-09-01 00:02:10", 361.1, null, null, null],
["S1", "2020-09-01 00:19:12", 362.1, null, null, null]
]
},
"expected_tolerance_10": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:01:12", 351.32, "2020-08-01 00:01:05", 348.10, 353.13],
["S1", "2020-09-01 00:02:10", 361.1, "2020-09-01 00:02:01", 358.93, 365.12],
["S1", "2020-09-01 00:19:12", 362.1, null, null, null]
]
}
}
}
}
}

0 comments on commit 8270aa6

Please sign in to comment.