diff --git a/qupulse/utils/__init__.py b/qupulse/utils/__init__.py index 43de28c1..326072f4 100644 --- a/qupulse/utils/__init__.py +++ b/qupulse/utils/__init__.py @@ -1,11 +1,12 @@ """This package contains utility functions and classes as well as custom sympy extensions(hacks).""" -from typing import Union, Iterable, Any, Tuple, Mapping, Iterator, TypeVar, Sequence, AbstractSet +from typing import Union, Iterable, Any, Tuple, Mapping, Iterator, TypeVar, Sequence, AbstractSet, Optional, Callable import itertools import re import numbers from collections import OrderedDict from frozendict import frozendict +from qupulse.expressions import ExpressionScalar, ExpressionLike import numpy @@ -25,7 +26,7 @@ __all__ = ["checked_int_cast", "is_integer", "isclose", "pairwise", "replace_multiple", "cached_property", - "forced_hash"] + "forced_hash", "to_next_multiple"] def checked_int_cast(x: Union[float, int, numpy.ndarray], epsilon: float=1e-6) -> int: @@ -122,3 +123,30 @@ def forced_hash(obj) -> int: return hash(tuple(map(forced_hash, obj))) raise + + +def to_next_multiple(sample_rate: ExpressionLike, quantum: int, + min_quanta: Optional[int] = None) -> Callable[[ExpressionLike],ExpressionScalar]: + """Construct a helper function to expand a duration to one corresponding to + valid sample multiples according to the arguments given. + Useful e.g. for PulseTemplate.pad_to's 'to_new_duration'-argument. + + Args: + sample_rate: sample rate with respect to which the duration is evaluated. + quantum: number of samples to whose next integer multiple the duration shall be rounded up to. + min_quanta: number of multiples of quantum not to fall short of. + Returns: + A function that takes a duration (ExpressionLike) as input, and returns + a duration rounded up to the next valid samples count in given sample rate. + The function returns 0 if duration==0, <0 is not checked if min_quanta is None. + + """ + sample_rate = ExpressionScalar(sample_rate) + #is it more efficient to omit the Max call if not necessary? + if min_quanta is None: + #double negative for ceil division. + return lambda duration: -(-(duration*sample_rate)//quantum) * (quantum/sample_rate) + else: + #still return 0 if duration==0 + return lambda duration: ExpressionScalar(f'{quantum}/({sample_rate})*Max({min_quanta},-(-{duration}*{sample_rate}//{quantum}))*Max(0, sign({duration}))') + \ No newline at end of file diff --git a/tests/utils/utils_tests.py b/tests/utils/utils_tests.py index 6ec75092..83e1a26a 100644 --- a/tests/utils/utils_tests.py +++ b/tests/utils/utils_tests.py @@ -2,7 +2,7 @@ from unittest import mock from collections import OrderedDict -from qupulse.utils import checked_int_cast, replace_multiple, _fallback_pairwise +from qupulse.utils import checked_int_cast, replace_multiple, _fallback_pairwise, to_next_multiple class PairWiseTest(unittest.TestCase): @@ -102,3 +102,42 @@ def test_replace_multiple_overlap(self): replacements = OrderedDict(reversed(replacement_list)) result = replace_multiple('asdf', replacements) self.assertEqual(result, '2') + + +class ToNextMultipleTests(unittest.TestCase): + def test_to_next_multiple(self): + from qupulse.utils.types import TimeType + from qupulse.expressions import ExpressionScalar + + duration = TimeType.from_float(47.1415926535) + evaluated = to_next_multiple(sample_rate=TimeType.from_float(2.4),quantum=16)(duration) + expected = ExpressionScalar('160/3') + self.assertEqual(evaluated, expected) + + duration = TimeType.from_float(3.1415926535) + evaluated = to_next_multiple(sample_rate=TimeType.from_float(2.4),quantum=16,min_quanta=13)(duration) + expected = ExpressionScalar('260/3') + self.assertEqual(evaluated, expected) + + duration = 6185240.0000001 + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration) + expected = 6185248 + self.assertEqual(evaluated, expected) + + duration = 0. + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration) + expected = 0. + self.assertEqual(evaluated, expected) + + duration = ExpressionScalar('abc') + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration).evaluate_in_scope(dict(abc=0.)) + expected = 0. + self.assertEqual(evaluated, expected) + + duration = ExpressionScalar('q') + evaluated = to_next_multiple(sample_rate=ExpressionScalar('w'),quantum=16,min_quanta=1)(duration).evaluate_in_scope( + dict(q=3.14159,w=1.0)) + expected = 16. + self.assertEqual(evaluated, expected) + + \ No newline at end of file