Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom period_start for Input Sampler #1

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions talipp/indicators/Indicator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABCMeta, abstractmethod
from collections.abc import MutableSequence, Sequence
from datetime import datetime
from typing import List, Any, Callable, Union, Type
from warnings import warn

Expand All @@ -26,12 +27,12 @@ class Indicator(Sequence):
def __init__(self,
input_modifier: InputModifierType = None,
output_value_type: Type = float,
input_sampling: SamplingPeriodType = None):
input_sampling: SamplingPeriodType = None, period_start: datetime = None):
self.input_modifier = input_modifier
self.output_value_type = output_value_type
self.input_sampler: Sampler = None
if input_sampling is not None:
self.input_sampler = Sampler(input_sampling)
self.input_sampler = Sampler(input_sampling, period_start)

self.input_values: ListAny = []
self.output_values: ListAny = []
Expand Down
28 changes: 16 additions & 12 deletions talipp/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class SamplingPeriodType(Enum):

DAY_1 = (TimeUnitType.DAY, 1)
"""1 day"""


class Sampler:
"""Implementation of timeframe auto-sampling.

Expand Down Expand Up @@ -120,8 +120,9 @@ class Sampler:
TimeUnitType.DAY: 3600 * 24,
}

def __init__(self, period_type: SamplingPeriodType):
def __init__(self, period_type: SamplingPeriodType, period_start: datetime = None):
self._period_type: SamplingPeriodType = period_type
self._period_start: datetime = period_start

def is_same_period(self, first: OHLCV, second: OHLCV) -> bool:
"""Evaluate whether two [OHLCV][talipp.ohlcv.OHLCV] objects belong to the same period.
Expand All @@ -145,15 +146,18 @@ def _normalize(self, dt: datetime):
period_type = self._period_type.value[0]
period_length = self._period_type.value[1]

if period_type == TimeUnitType.SEC:
period_start = datetime(dt.year, dt.month, dt.day, dt.hour, dt.minute)
elif period_type == TimeUnitType.MIN:
period_start = datetime(dt.year, dt.month, dt.day, dt.hour)
elif period_type == TimeUnitType.HOUR:
period_start = datetime(dt.year, dt.month, dt.day)
elif period_type == TimeUnitType.DAY:
period_start = datetime(dt.year, dt.month, 1)
period_start = period_start.replace(tzinfo=dt.tzinfo)
if self._period_start is None:
if period_type == TimeUnitType.SEC:
period_start = datetime(dt.year, dt.month, dt.day, dt.hour, dt.minute)
elif period_type == TimeUnitType.MIN:
period_start = datetime(dt.year, dt.month, dt.day, dt.hour)
elif period_type == TimeUnitType.HOUR:
period_start = datetime(dt.year, dt.month, dt.day)
elif period_type == TimeUnitType.DAY:
period_start = datetime(dt.year, dt.month, 1)
period_start = period_start.replace(tzinfo=dt.tzinfo)
else:
period_start = self._period_start

delta = dt - period_start
num_periods = delta.total_seconds() // (period_length * Sampler.CONVERSION_TO_SEC[period_type])
Expand Down
8 changes: 8 additions & 0 deletions test/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def test_sample_normalize_5min(self):
self.assertFalse(sampler.is_same_period(self.get_ohlcv("01/01/2024 12:34:59"),
self.get_ohlcv("01/01/2024 12:35:00")))

self.assertFalse(sampler.is_same_period(self.get_ohlcv("01/01/2024 09:14:00"),
self.get_ohlcv("01/01/2024 09:16:00")))

period_adjusted_sampler = Sampler(SamplingPeriodType.MIN_5, datetime.strptime("01/01/2024 09:14:00", "%d/%m/%Y %H:%M:%S"))

self.assertTrue(period_adjusted_sampler.is_same_period(self.get_ohlcv("01/01/2024 09:14:00"),
self.get_ohlcv("01/01/2024 09:16:00")))

def test_sample_normalize_1hour(self):
sampler = Sampler(SamplingPeriodType.HOUR_1)

Expand Down