diff --git a/python/tests/ts/test_dataframe.py b/python/tests/ts/test_dataframe.py index 4269e2ac..7eb175e0 100644 --- a/python/tests/ts/test_dataframe.py +++ b/python/tests/ts/test_dataframe.py @@ -120,6 +120,46 @@ def fn_2(rows): "with key 7" ) + def test_addColumnsForCycleTimeNotFirst(self): + import pyspark.sql.types as pyspark_types + price = self.flintContext.read.pandas(test_utils.make_pdf([ + [7, 1000, 0.5], + [3, 1000, 1.0], + [3, 1050, 1.5], + [7, 1050, 2.0], + [3, 1100, 2.5], + [7, 1100, 3.0], + [3, 1150, 3.5], + [7, 1150, 4.0], + [3, 1200, 4.5], + [7, 1200, 5.0], + [3, 1250, 5.5], + [7, 1250, 6.0], + ], ["id", "time", "price"])) + expected_pdf = test_utils.make_pdf([ + [1000, 7, 0.5, 1.0], + [1000, 3, 1.0, 2.0], + [1050, 3, 1.5, 3.0], + [1050, 7, 2.0, 4.0], + [1100, 3, 2.5, 5.0], + [1100, 7, 3.0, 6.0], + [1150, 3, 3.5, 7.0], + [1150, 7, 4.0, 8.0], + [1200, 3, 4.5, 9.0], + [1200, 7, 5.0, 10.0], + [1250, 3, 5.5, 11.0], + [1250, 7, 6.0, 12.0], + ], ["time", "id", "price", "adjustedPrice"]) + + def fn_1(rows): + size = len(rows) + return {row: row.price*size for row in rows} + + new_pdf = price.addColumnsForCycle( + {"adjustedPrice": (pyspark_types.DoubleType(), fn_1)} + ).toPandas() + test_utils.assert_same(new_pdf, expected_pdf) + def test_merge(self): price = self.price() price1 = price.filter(price.time > 1100) @@ -171,6 +211,79 @@ def test_leftJoin(self): ).toPandas() test_utils.assert_same(new_pdf, expected_pdf) + def test_leftJoinTimeNotFirst(self): + # Note that in price we have time as the second column + price = self.flintContext.read.pandas(test_utils.make_pdf([ + [7, 1000, 0.5], + [3, 1000, 1.0], + [3, 1050, 1.5], + [7, 1050, 2.0], + [3, 1100, 2.5], + [7, 1100, 3.0], + [3, 1150, 3.5], + [7, 1150, 4.0], + [3, 1200, 4.5], + [7, 1200, 5.0], + [3, 1250, 5.5], + [7, 1250, 6.0], + ], ["id", "time", "price"])) + # Time is also the second column of vol + vol = self.flintContext.read.pandas(test_utils.make_pdf([ + [7, 1000, 100], + [3, 1000, 200], + [3, 1050, 300], + [7, 1050, 400], + [3, 1100, 500], + [7, 1100, 600], + [3, 1150, 700], + [7, 1150, 800], + [3, 1200, 900], + [7, 1200, 1000], + [3, 1250, 1100], + [7, 1250, 1200], + ], ["id", "time", "volume"])) + # We expect to get the result with time as the first column + expected_pdf = test_utils.make_pdf([ + (1000, 7, 0.5, 100,), + (1000, 3, 1.0, 200,), + (1050, 3, 1.5, 300,), + (1050, 7, 2.0, 400,), + (1100, 3, 2.5, 500,), + (1100, 7, 3.0, 600,), + (1150, 3, 3.5, 700,), + (1150, 7, 4.0, 800,), + (1200, 3, 4.5, 900,), + (1200, 7, 5.0, 1000,), + (1250, 3, 5.5, 1100,), + (1250, 7, 6.0, 1200,) + ], ["time", "id", "price", "volume"]) + + new_pdf = price.leftJoin(vol, key=["id"]).toPandas() + test_utils.assert_same(new_pdf, expected_pdf) + test_utils.assert_same( + new_pdf, price.leftJoin(vol, key="id").toPandas() + ) + + expected_pdf = test_utils.make_pdf([ + (1000, 7, 0.5, 100), + (1000, 3, 1.0, 200), + (1050, 3, 1.5, None), + (1050, 7, 2.0, None), + (1100, 3, 2.5, 500), + (1100, 7, 3.0, 600), + (1150, 3, 3.5, 700), + (1150, 7, 4.0, 800), + (1200, 3, 4.5, 900), + (1200, 7, 5.0, 1000), + (1250, 3, 5.5, 1100), + (1250, 7, 6.0, 1200), + ], ["time", "id", "price", "volume"]) + + new_pdf = price.leftJoin( + vol.filter(vol.time != 1050), key="id" + ).toPandas() + test_utils.assert_same(new_pdf, expected_pdf) + def test_futureLeftJoin(self): import pyspark.sql.types as pyspark_types price = self.price() diff --git a/python/ts/flint/dataframe.py b/python/ts/flint/dataframe.py index 51eb9e6b..9da15b0c 100644 --- a/python/ts/flint/dataframe.py +++ b/python/ts/flint/dataframe.py @@ -85,11 +85,9 @@ class TimeSeriesDataFrame(pyspark.sql.DataFrame): DEFAULT_UNIT = "ns" '''The units of the timestamps present in :attr:`DEFAULT_TIME_COLUMN`. - Acceptable values are: ``'s'``, ``'ms'``, ``'us'``, ``'ns'``. - ''' - def __init__(self, df, sql_ctx, *, time_column=DEFAULT_TIME_COLUMN, is_sorted=True, unit=DEFAULT_UNIT, tsrdd_part_info=None): + def __init__(self, df, sql_ctx, *, is_sorted=True, tsrdd_part_info=None): ''' :type df: pyspark.sql.DataFrame :type sql_ctx: pyspark.sql.SqlContext @@ -102,7 +100,7 @@ def __init__(self, df, sql_ctx, *, time_column=DEFAULT_TIME_COLUMN, is_sorted=Tr :param tsrdd_part_info: Partition info :type tsrdd_part_info: Option[com.twosigma.flint.timeseries.PartitionInfo] ''' - self._time_column = time_column + self._time_column = self.DEFAULT_TIME_COLUMN self._is_sorted = is_sorted self._tsrdd_part_info = tsrdd_part_info @@ -112,7 +110,7 @@ def __init__(self, df, sql_ctx, *, time_column=DEFAULT_TIME_COLUMN, is_sorted=Tr super().__init__(self._jdf, sql_ctx) self._jpkg = java.Packages(self._sc) - self._junit = utils.junit(self._sc, unit) if isinstance(unit,str) else unit + self._junit = utils.junit(self._sc, self.DEFAULT_UNIT) if tsrdd_part_info: if not is_sorted: @@ -169,9 +167,7 @@ def _new_method(self, *args, **kwargs): if self._jpkg.OrderPreservingOperation.isDerivedFrom(self._jdf, df._jdf): tsdf_args = { "df": df, - "sql_ctx": df.sql_ctx, - "time_column": self._time_column, - "unit": self._junit + "sql_ctx": df.sql_ctx } tsdf_args['is_sorted'] = self._is_sorted and self._jpkg.OrderPreservingOperation.isOrderPreserving(self._jdf, df._jdf) @@ -226,17 +222,14 @@ def _from_df(df, *, time_column, is_sorted, unit): return TimeSeriesDataFrame(df, df.sql_ctx, time_column=time_column, - is_sorted=is_sorted, - unit=unit) + is_sorted=is_sorted) @staticmethod - def _from_pandas(df, schema, sql_ctx, *, time_column, is_sorted, unit): + def _from_pandas(df, schema, sql_ctx, *, is_sorted): df = sql_ctx.createDataFrame(df, schema) return TimeSeriesDataFrame(df, sql_ctx, - time_column=time_column, - is_sorted=is_sorted, - unit=unit) + is_sorted=is_sorted) def _timedelta_ns(self, varname, timedelta, *, default=None): """Transforms pandas.Timedelta to a ns string with appropriate checks @@ -308,10 +301,10 @@ def addColumnsForCycle(self, columns, *, key=None): :returns: a new dataframe with the columns added :rtype: :class:`TimeSeriesDataFrame` """ - # Need to make a new StructType to prevent from modifying the original schema object - schema = pyspark_types.StructType.fromJson(self.schema.jsonValue()) tsdf = self.groupByCycle(key) - # Don't pickle the whole schema, just the names for the lambda + # Last element of tsdf.schema describes the 'rows' returned + # which does differ from self.schema if the first column is not 'time' + schema = tsdf.schema[len(tsdf.schema)-1].dataType.elementType schema_names = list(schema.names) def flatmap_fn(): @@ -335,8 +328,6 @@ def _(orig_row): return TimeSeriesDataFrame(df, df.sql_ctx, - time_column=self._time_column, - unit=self._junit, tsrdd_part_info=tsdf._tsrdd_part_info) def merge(self, other): diff --git a/python/ts/flint/readwriter.py b/python/ts/flint/readwriter.py index 384af21f..d372b1e3 100644 --- a/python/ts/flint/readwriter.py +++ b/python/ts/flint/readwriter.py @@ -15,7 +15,6 @@ # from pyspark.sql import DataFrame -from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from . import java from . import utils @@ -57,9 +56,7 @@ def pandas(self, df, schema=None, *, return TimeSeriesDataFrame._from_pandas( df, schema, self._flintContext._sqlContext, - time_column=time_column, - is_sorted=is_sorted, - unit=unit) + is_sorted=is_sorted) def _df_between(self, df, begin, end, time_column, unit): """Filter a Python dataframe to contain data between begin (inclusive) and end (exclusive)