diff --git a/python/tempo/interpol.py b/python/tempo/interpol.py index 952298d4..510dbf2e 100644 --- a/python/tempo/interpol.py +++ b/python/tempo/interpol.py @@ -3,6 +3,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import col, expr, last, lead, lit, when from pyspark.sql.window import Window +from tempo.utils import calculate_time_horizon from tempo.resample import checkAllowableFreq, freq_dict # Interpolation fill options @@ -290,6 +291,9 @@ def interpolate( parsed_freq = checkAllowableFreq(freq) freq = f"{parsed_freq[0]} {freq_dict[parsed_freq[1]]}" + # Throw warning for user to validate that the expected number of output rows is valid. + calculate_time_horizon(tsdf.df, ts_col, freq, partition_cols) + # Only select required columns for interpolation input_cols: List[str] = [*partition_cols, ts_col, *target_cols] sampled_input: DataFrame = tsdf.df.select(*input_cols) diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index 3178d7fd..eafcc2e7 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -14,7 +14,7 @@ import tempo.io as tio import tempo.resample as rs from tempo.interpol import Interpolation -from tempo.utils import ENV_BOOLEAN, PLATFORM +from tempo.utils import ENV_BOOLEAN, PLATFORM, calculate_time_horizon logger = logging.getLogger(__name__) @@ -682,6 +682,11 @@ def resample(self, freq, func=None, metricCols = None, prefix=None, fill = None) :return: TSDF object with sample data using aggregate function """ rs.validateFuncExists(func) + + # Throw warning for user to validate that the expected number of output rows is valid. + if fill is True: + calculate_time_horizon(self.df, self.ts_col, freq, self.partitionCols) + enriched_df:DataFrame = rs.aggregate(self, freq, func, metricCols, prefix, fill) return (_ResampledTSDF(enriched_df, ts_col = self.ts_col, partition_cols = self.partitionCols, freq = freq, func = func)) diff --git a/python/tempo/utils.py b/python/tempo/utils.py index a967e561..7655697b 100644 --- a/python/tempo/utils.py +++ b/python/tempo/utils.py @@ -1,11 +1,14 @@ +from typing import List import logging import os - +import warnings from IPython import get_ipython from IPython.core.display import HTML from IPython.display import display as ipydisplay from pandas import DataFrame as pandasDataFrame from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import expr, max, min, sum, percentile_approx +from tempo.resample import checkAllowableFreq, freq_dict logger = logging.getLogger(__name__) PLATFORM = "DATABRICKS" if "DB_HOME" in os.environ.keys() else "NON_DATABRICKS" @@ -17,6 +20,15 @@ where the code is running from. """ + +class ResampleWarning(Warning): + """ + This class is a warning that is raised when the interpolate or resample with fill methods are called. + """ + + pass + + def __is_capable_of_html_rendering(): """ This method returns a boolean value signifying whether the environment is a notebook environment @@ -34,6 +46,77 @@ def __is_capable_of_html_rendering(): return False +def calculate_time_horizon( + df: DataFrame, ts_col: str, freq: str, partition_cols: List[str] +): + # Convert Frequency using resample dictionary + parsed_freq = checkAllowableFreq(freq) + freq = f"{parsed_freq[0]} {freq_dict[parsed_freq[1]]}" + + # Get max and min timestamp per partition + partitioned_df: DataFrame = df.groupBy(*partition_cols).agg( + max(ts_col).alias("max_ts"), + min(ts_col).alias("min_ts"), + ) + + # Generate upscale metrics + normalized_time_df: DataFrame = ( + partitioned_df.withColumn("min_epoch_ms", expr("unix_millis(min_ts)")) + .withColumn("max_epoch_ms", expr("unix_millis(max_ts)")) + .withColumn( + "interval_ms", + expr( + f"unix_millis(cast('1970-01-01 00:00:00.000+0000' as TIMESTAMP) + INTERVAL {freq})" + ), + ) + .withColumn( + "rounded_min_epoch", expr("min_epoch_ms - (min_epoch_ms % interval_ms)") + ) + .withColumn( + "rounded_max_epoch", expr("max_epoch_ms - (max_epoch_ms % interval_ms)") + ) + .withColumn("diff_ms", expr("rounded_max_epoch - rounded_min_epoch")) + .withColumn("num_values", expr("(diff_ms/interval_ms) +1")) + ) + + ( + min_ts, + max_ts, + min_value_partition, + max_value_partition, + p25_value_partition, + p50_value_partition, + p75_value_partition, + total_values, + ) = normalized_time_df.select( + min("min_ts"), + max("max_ts"), + min("num_values"), + max("num_values"), + percentile_approx("num_values", 0.25), + percentile_approx("num_values", 0.5), + percentile_approx("num_values", 0.75), + sum("num_values"), + ).first() + + warnings.simplefilter("always", ResampleWarning) + warnings.warn( + f""" + Resample Metrics Warning: + Earliest Timestamp: {min_ts} + Latest Timestamp: {max_ts} + No. of Unique Partitions: {normalized_time_df.count()} + Resampled Min No. Values in Single a Partition: {min_value_partition} + Resampled Max No. Values in Single a Partition: {max_value_partition} + Resampled P25 No. Values in Single a Partition: {p25_value_partition} + Resampled P50 No. Values in Single a Partition: {p50_value_partition} + Resampled P75 No. Values in Single a Partition: {p75_value_partition} + Resampled Total No. Values Across All Partitions: {total_values} + """, + ResampleWarning, + ) + + def display_html(df): """ Display method capable of displaying the dataframe in a formatted HTML structured output @@ -51,29 +134,38 @@ def display_unavailable(df): """ This method is called when display method is not available in the environment. """ - logger.error("'display' method not available in this environment. Use 'show' method instead.") + logger.error( + "'display' method not available in this environment. Use 'show' method instead." + ) ENV_BOOLEAN = __is_capable_of_html_rendering() -if (PLATFORM == "DATABRICKS") and (type(get_ipython()) != type(None)) and ('display' in get_ipython().user_ns.keys()): - method = get_ipython().user_ns['display'] +if ( + (PLATFORM == "DATABRICKS") + and (type(get_ipython()) != type(None)) + and ("display" in get_ipython().user_ns.keys()) +): + method = get_ipython().user_ns["display"] # Under 'display' key in user_ns the original databricks display method is present # to know more refer: /databricks/python_shell/scripts/db_ipykernel_launcher.py def display_improvised(obj): - if type(obj).__name__ == 'TSDF': + if type(obj).__name__ == "TSDF": method(obj.df) else: method(obj) + display = display_improvised elif ENV_BOOLEAN: + def display_html_improvised(obj): - if type(obj).__name__ == 'TSDF': + if type(obj).__name__ == "TSDF": display_html(obj.df) else: display_html(obj) + display = display_html_improvised else: diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py index b75c9a6e..7ac16ef2 100644 --- a/python/tests/tsdf_tests.py +++ b/python/tests/tsdf_tests.py @@ -619,17 +619,6 @@ def test_group_stats(self): self.assertDataFramesEqual(featured_df, dfExpected) -class UtilsTest(SparkTest): - - def test_display(self): - """Test of the display utility""" - if PLATFORM == 'DATABRICKS': - self.assertEqual(id(display),id(display_improvised)) - elif ENV_BOOLEAN: - self.assertEqual(id(display),id(display_html_improvised)) - else: - self.assertEqual(id(display),id(display_unavailable)) - class ResampleTest(SparkTest): def test_resample(self): diff --git a/python/tests/utils_tests.py b/python/tests/utils_tests.py new file mode 100644 index 00000000..183fd12e --- /dev/null +++ b/python/tests/utils_tests.py @@ -0,0 +1,91 @@ +from pyspark.sql.types import * +from tests.tsdf_tests import SparkTest +from tempo.utils import calculate_time_horizon +from chispa.dataframe_comparer import * +from tempo.tsdf import TSDF +from tempo.interpol import Interpolation +from tempo.utils import * +import unittest + + +class UtilsTest(SparkTest): + def buildTestingDataFrame(self): + schema = StructType( + [ + StructField("partition_a", StringType()), + StructField("partition_b", StringType()), + StructField("event_ts", StringType()), + StructField("value_a", FloatType()), + StructField("value_b", FloatType()), + ] + ) + + simple_data = [ + ["A", "A-1", "2020-01-01 00:00:10", 0.0, None], + ["A", "A-1", "2020-01-01 00:01:10", 2.0, 2.0], + ["A", "A-1", "2020-01-01 00:01:32", None, None], + ["A", "A-1", "2020-01-01 00:02:03", None, None], + ["A", "A-1", "2020-01-01 00:03:32", None, 7.0], + ["A", "A-1", "2020-01-01 00:04:12", 8.0, 8.0], + ["A", "A-1", "2020-01-01 00:05:31", 11.0, None], + ["A", "A-2", "2020-01-01 00:00:10", 0.0, None], + ["A", "A-2", "2020-01-01 00:01:10", 2.0, 2.0], + ["A", "A-2", "2020-01-01 00:01:32", None, None], + ["A", "A-2", "2020-01-01 00:02:03", None, None], + ["A", "A-2", "2020-01-01 00:04:12", 8.0, 8.0], + ["A", "A-2", "2020-01-01 00:05:31", 11.0, None], + ["B", "A-2", "2020-01-01 00:01:10", 2.0, 2.0], + ["B", "A-2", "2020-01-01 00:01:32", None, None], + ["B", "A-2", "2020-01-01 00:02:03", None, None], + ["B", "A-2", "2020-01-01 00:03:32", None, 7.0], + ["B", "A-2", "2020-01-01 00:04:12", 8.0, 8.0], + ] + + # construct dataframes + self.simple_input_df = self.buildTestDF(schema, simple_data) + + self.simple_input_tsdf = TSDF( + self.simple_input_df, + partition_cols=["partition_a", "partition_b"], + ts_col="event_ts", + ) + + +class UtilsTest(UtilsTest): + def test_display(self): + """Test of the display utility""" + if PLATFORM == "DATABRICKS": + self.assertEqual(id(display), id(display_improvised)) + elif ENV_BOOLEAN: + self.assertEqual(id(display), id(display_html_improvised)) + else: + self.assertEqual(id(display), id(display_unavailable)) + + def test_calculate_time_horizon(self): + """Test calculate time horizon warning and number of expected output rows""" + self.buildTestingDataFrame() + with warnings.catch_warnings(record=True) as w: + calculate_time_horizon( + self.simple_input_tsdf.df, + self.simple_input_tsdf.ts_col, + "30 seconds", + ["partition_a", "partition_b"], + ) + warning_message = """ + Resample Metrics Warning: + Earliest Timestamp: 2020-01-01 00:00:10 + Latest Timestamp: 2020-01-01 00:05:31 + No. of Unique Partitions: 3 + Resampled Min No. Values in Single a Partition: 7.0 + Resampled Max No. Values in Single a Partition: 12.0 + Resampled P25 No. Values in Single a Partition: 7.0 + Resampled P50 No. Values in Single a Partition: 12.0 + Resampled P75 No. Values in Single a Partition: 12.0 + Resampled Total No. Values Across All Partitions: 31.0 + """ + assert warning_message.strip() == str(w[-1].message).strip() + + +## MAIN +if __name__ == "__main__": + unittest.main()