From a472ea44a8a8c4f6cc5038d634e097d800f458fa Mon Sep 17 00:00:00 2001 From: femtotrader Date: Sun, 7 Jul 2024 17:53:28 +0200 Subject: [PATCH] SmootherFactory, SmoothedIndicator and SOBV --- talipp/indicators/SOBV.py | 58 +++++++++-------------------------- talipp/indicators/Smoother.py | 48 +++++++++++++++++++++++++++++ test/TalippTest.py | 6 ++-- test/test_SOBV.py | 12 +++----- 4 files changed, 71 insertions(+), 53 deletions(-) create mode 100644 talipp/indicators/Smoother.py diff --git a/talipp/indicators/SOBV.py b/talipp/indicators/SOBV.py index 8a7069d..2c013ee 100644 --- a/talipp/indicators/SOBV.py +++ b/talipp/indicators/SOBV.py @@ -1,44 +1,16 @@ -from typing import List, Any - -from talipp.indicator_util import has_valid_values -from talipp.indicators.Indicator import Indicator, InputModifierType from talipp.indicators.OBV import OBV -from talipp.input import SamplingPeriodType -from talipp.ohlcv import OHLCV - - -class SOBV(Indicator): - """Smoothed On Balance Volume. - - Input type: [OHLCV][talipp.ohlcv.OHLCV] - - Output type: `float` - - Args: - period: Moving average period. - input_values: List of input values. - input_indicator: Input indicator. - input_modifier: Input modifier. - input_sampling: Input sampling type. - """ - - def __init__(self, period: int, - input_values: List[OHLCV] = None, - input_indicator: Indicator = None, - input_modifier: InputModifierType = None, - input_sampling: SamplingPeriodType = None): - super().__init__(input_modifier=input_modifier, - input_sampling=input_sampling) - - self.period = period - - self.obv = OBV() - self.add_sub_indicator(self.obv) - - self.initialize(input_values, input_indicator) - - def _calculate_new_value(self) -> Any: - if not has_valid_values(self.obv, self.period): - return None - - return sum(self.obv[-self.period:]) / float(self.period) +from talipp.indicators.Smoother import SmootherFactory +from talipp.ma import MAType + + +"""Smoothed On Balance Volume. +Input type: [OHLCV][talipp.ohlcv.OHLCV] +Output type: `float` +Args: + period: Moving average period. + input_values: List of input values. + input_indicator: Input indicator. + input_modifier: Input modifier. + input_sampling: Input sampling type. +""" +SOBV = SmootherFactory.get_smoother(OBV, MAType.SMA) diff --git a/talipp/indicators/Smoother.py b/talipp/indicators/Smoother.py new file mode 100644 index 0000000..5ec8e86 --- /dev/null +++ b/talipp/indicators/Smoother.py @@ -0,0 +1,48 @@ +from typing import Any +from talipp.indicators.Indicator import Indicator +from talipp.ma import MAFactory, MAType + + +class SmoothedIndicator(Indicator): + def __init__(self, indicator_class, ma_type: MAType): + super().__init__() + + self.indicator_class = indicator_class + self.ma_type = ma_type + + def __call__( + self, + smoothing_period: int, + input_values=None, + input_indicator=None, + *args: Any, + **kwargs + ) -> Any: + self.ma = MAFactory.get_ma(self.ma_type, smoothing_period) + + self.internal_indicator = self.indicator_class(*args, **kwargs) + self.add_sub_indicator(self.internal_indicator) + + self.initialize(input_values, input_indicator) + + return self + + def _calculate_new_value(self) -> Any: + self.ma.add(self.internal_indicator.output_values) + return self.ma.output_values[-1] + + +class SmootherFactory: + """Smoother factory.""" + + @staticmethod + def get_smoother(indicator_class, ma_type: MAType = MAType.SMA): + """ + Return a smoother indicator + + Args: + indicator_class: indicator class + smoothing_period: Smoothing period. + ma_type: Moving average type. + """ + return SmoothedIndicator(indicator_class, ma_type) diff --git a/test/TalippTest.py b/test/TalippTest.py index 0e45be2..7b1ef48 100644 --- a/test/TalippTest.py +++ b/test/TalippTest.py @@ -31,7 +31,7 @@ def assertIndicatorUpdate(self, indicator: Indicator, iterations_no: int = 20): indicator.update(last_input_value) - self.assertEqual(last_indicator_value, indicator[-1]) + self.assertAlmostEqual(last_indicator_value, indicator[-1], places = 5) def assertIndicatorDelete(self, indicator: Indicator, iterations_no: int = 20): last_indicator_value = indicator[-1] @@ -48,13 +48,13 @@ def assertIndicatorDelete(self, indicator: Indicator, iterations_no: int = 20): indicator.remove() # verify that adding and then removing X input values returns the original output value - self.assertEqual(last_indicator_value, indicator[-1]) + self.assertAlmostEqual(last_indicator_value, indicator[-1], places = 5) # delete the original last input value and add it back and check the original last output value is returned indicator.remove() indicator.add(last_input_value) - self.assertEqual(last_indicator_value, indicator[-1]) + self.assertAlmostEqual(last_indicator_value, indicator[-1], places = 5) def assertIndicatorPurgeOldest(self, indicator: Indicator): # purge oldest 5 values diff --git a/test/test_SOBV.py b/test/test_SOBV.py index 01c2a54..21b317a 100644 --- a/test/test_SOBV.py +++ b/test/test_SOBV.py @@ -10,13 +10,11 @@ def setUp(self) -> None: self.input_values = list(TalippTest.OHLCV_TMPL) def test_init(self): - ind = SOBV(20, self.input_values) + ind = SOBV(20, input_values=self.input_values) - print(ind) - - self.assertAlmostEqual(ind[-3], 90.868499, places = 5) - self.assertAlmostEqual(ind[-2], 139.166499, places = 5) - self.assertAlmostEqual(ind[-1], 187.558499, places = 5) + self.assertAlmostEqual(ind[-3], 90.868499, places=5) + self.assertAlmostEqual(ind[-2], 139.166499, places=5) + self.assertAlmostEqual(ind[-1], 187.558499, places=5) def test_update(self): self.assertIndicatorUpdate(SOBV(20, self.input_values)) @@ -28,5 +26,5 @@ def test_purge_oldest(self): self.assertIndicatorPurgeOldest(SOBV(20, self.input_values)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()