diff --git a/python/tempo/resample.py b/python/tempo/resample.py index 919d6f01..dca1c0dc 100644 --- a/python/tempo/resample.py +++ b/python/tempo/resample.py @@ -4,7 +4,8 @@ from pyspark.sql.window import Window # define global frequency options - +MUSEC = 'microsec' +MS = 'ms' SEC = 'sec' MIN = 'min' HR = 'hr' @@ -17,9 +18,9 @@ average = "mean" ceiling = "ceil" -freq_dict = {'sec' : 'seconds', 'min' : 'minutes', 'hr' : 'hours', 'day' : 'days', 'hour' : 'hours'} +freq_dict = {'microsec' : 'microseconds','ms' : 'milliseconds','sec' : 'seconds', 'min' : 'minutes', 'hr' : 'hours', 'day' : 'days', 'hour' : 'hours'} -allowableFreqs = [SEC, MIN, HR, DAY] +allowableFreqs = [MUSEC, MS, SEC, MIN, HR, DAY] allowableFuncs = [floor, min, max, average, ceiling] def __appendAggKey(tsdf, freq = None): @@ -41,7 +42,7 @@ def aggregate(tsdf, freq, func, metricCols = None, prefix = None, fill = None): :param tsdf: input TSDF object :param func: aggregate function :param metricCols: columns used for aggregates - :param prefix the metric columns with the aggregate named function + :param prefix: the metric columns with the aggregate named function :param fill: upsample based on the time increment for 0s in numeric columns :return: TSDF object with newly aggregated timestamp as ts_col with aggregated values """ @@ -118,13 +119,22 @@ def aggregate(tsdf, freq, func, metricCols = None, prefix = None, fill = None): def checkAllowableFreq(freq): + """ + Parses frequency and checks against allowable frequencies + :param freq: frequncy at which to upsample/downsample, declared in resample function + :return: list of parsed frequency value and time suffix + """ if freq not in allowableFreqs: try: periods = freq.lower().split(" ")[0].strip() units = freq.lower().split(" ")[1].strip() except: - raise ValueError("Allowable grouping frequencies are sec (second), min (minute), hr (hour), day. Reformat your frequency as ") - if units.startswith(SEC): + raise ValueError("Allowable grouping frequencies are microsecond (musec), millisecond (ms), sec (second), min (minute), hr (hour), day. Reformat your frequency as ") + if units.startswith(MUSEC): + return (periods, MUSEC) + elif units.startswith(MS) | units.startswith("millis"): + return (periods, MS) + elif units.startswith(SEC): return (periods, SEC) elif units.startswith(MIN): return (periods, MIN) diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py index 8ae2ce86..b75c9a6e 100644 --- a/python/tests/tsdf_tests.py +++ b/python/tests/tsdf_tests.py @@ -714,6 +714,51 @@ def test_resample(self): # test bars summary self.assertDataFramesEqual(bars, barsExpected) + def test_resample_millis(self): + """Test of resampling for millisecond windows""" + schema = StructType([StructField("symbol", StringType()), + StructField("date", StringType()), + StructField("event_ts", StringType()), + StructField("trade_pr", FloatType()), + StructField("trade_pr_2", FloatType())]) + + expectedSchema = StructType([StructField("symbol", StringType()), + StructField("event_ts", StringType()), + StructField("floor_trade_pr", FloatType()), + StructField("floor_date", StringType()), + StructField("floor_trade_pr_2", FloatType())]) + + expectedSchemaMS = StructType([StructField("symbol", StringType()), + StructField("event_ts", StringType(), True), + StructField("date", DoubleType()), + StructField("trade_pr", DoubleType()), + StructField("trade_pr_2", DoubleType())]) + + + data = [["S1", "SAME_DT", "2020-08-01 00:00:10.12345", 349.21, 10.0], + ["S1", "SAME_DT", "2020-08-01 00:00:10.123", 340.21, 9.0], + ["S1", "SAME_DT", "2020-08-01 00:00:10.124", 353.32, 8.0]] + + expected_data_ms = [ + ["S1", "2020-08-01 00:00:10.123", None, 344.71, 9.5], + ["S1", "2020-08-01 00:00:10.124", None, 353.32, 8.0] + ] + + # construct dataframes + df = self.buildTestDF(schema, data) + dfExpected = self.buildTestDF(expectedSchemaMS, expected_data_ms) + + # convert to TSDF + tsdf_left = TSDF(df, partition_cols=["symbol"]) + + # 30 minute aggregation + resample_ms = tsdf_left.resample(freq="ms", func="mean").df.withColumn("trade_pr", F.round(F.col('trade_pr'), 2)) + + int_df = TSDF(tsdf_left.df.withColumn("event_ts", F.col("event_ts").cast("timestamp")), partition_cols = ['symbol']) + interpolated = int_df.interpolate(freq='ms', func='floor', method='ffill') + self.assertDataFramesEqual(resample_ms, dfExpected) + + def test_upsample(self): """Test of range stats for 20 minute rolling window""" schema = StructType([StructField("symbol", StringType()),