From bd50a5da398c44043093aa0de23879da47dba1d4 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 5 Dec 2024 11:14:09 +0100 Subject: [PATCH] Add more test cases for invalid user funcs; check that batch time aggregator returns scalars --- src/optimagic/optimization/history.py | 34 +++++++++++------ src/optimagic/timing.py | 7 ++++ tests/optimagic/optimization/test_history.py | 39 +++++++++++++++++--- tests/optimagic/test_timing.py | 14 +++++++ 4 files changed, 77 insertions(+), 17 deletions(-) create mode 100644 tests/optimagic/test_timing.py diff --git a/src/optimagic/optimization/history.py b/src/optimagic/optimization/history.py index 1ecc54f0e..4016d4152 100644 --- a/src/optimagic/optimization/history.py +++ b/src/optimagic/optimization/history.py @@ -219,12 +219,12 @@ def _get_time( ) time = fun_time + jac_time + fun_and_jac_time - batch_time = _batch_apply( + batch_aware_time = _apply_to_batch( data=time, batch_ids=self.batches, func=cost_model.aggregate_batch_time, ) - return np.cumsum(batch_time) + return np.cumsum(batch_aware_time) def _get_time_per_task( self, task: EvalTask, cost_factor: float | None @@ -383,7 +383,7 @@ def _task_as_categorical(task: list[EvalTask]) -> pd.Categorical: ) -def _batch_apply( +def _apply_to_batch( data: NDArray[np.float64], batch_ids: list[int], func: Callable[[Iterable[float]], float], @@ -392,10 +392,10 @@ def _batch_apply( Args: data: 1d array with data. - batch_ids: A list whose length is equal to the size of data. Values need to be - sorted and can be repeated. + batch_ids: A list with batch ids whose length is equal to the size of data. + Values need to be sorted and can be repeated. func: A reduction function that takes an iterable of floats as input (e.g., a - numpy array or a list) and returns a scalar. + numpy.ndarray or list) and returns a scalar. Returns: The transformed data. Has the same length as data. For each batch, the result of @@ -410,25 +410,37 @@ def _batch_apply( for batch, (start, stop) in zip( batch_ids, zip(batch_starts, batch_stops, strict=False), strict=False ): + batch_data = data[start:stop] + try: - batch_data = data[start:stop] reduced = func(batch_data) - batch_results.append(reduced) except Exception as e: msg = ( f"Calling function {func.__name__} on batch {batch} of the History " - f"History raised an Exception. Please verify that {func.__name__} is " - "properly defined." + f"raised an Exception. Please verify that {func.__name__} is " + "well-defined and takes a list of floats as input and returns a scalar." ) raise ValueError(msg) from e + try: + assert np.isscalar(reduced) + except AssertionError: + msg = ( + f"Function {func.__name__} did not return a scalar for batch {batch}. " + f"Please verify that {func.__name__} returns a scalar when called on a " + "list of floats." + ) + raise ValueError(msg) from None + + batch_results.append(reduced) + out = np.zeros_like(data) out[batch_starts] = batch_results return out def _get_batch_start(batch_ids: list[int]) -> list[int]: - """Get start indices of batch. + """Get start indices of batches. This function assumes that batch_ids non-empty and sorted. diff --git a/src/optimagic/timing.py b/src/optimagic/timing.py index db83a76d2..a9fbe7d88 100644 --- a/src/optimagic/timing.py +++ b/src/optimagic/timing.py @@ -10,6 +10,13 @@ class CostModel: label: str aggregate_batch_time: Callable[[Iterable[float]], float] + def __post_init__(self) -> None: + if not callable(self.aggregate_batch_time): + raise ValueError( + "aggregate_batch_time must be a callable, got " + f"{self.aggregate_batch_time}" + ) + evaluation_time = CostModel( fun=None, diff --git a/tests/optimagic/optimization/test_history.py b/tests/optimagic/optimization/test_history.py index ab92b88a2..bd7137922 100644 --- a/tests/optimagic/optimization/test_history.py +++ b/tests/optimagic/optimization/test_history.py @@ -10,7 +10,7 @@ from optimagic.optimization.history import ( History, HistoryEntry, - _batch_apply, + _apply_to_batch, _calculate_monotone_sequence, _get_batch_start, _get_flat_param_names, @@ -403,6 +403,13 @@ def test_get_time_wall_time(history): assert_array_equal(got, exp) +def test_get_time_invalid_cost_model(history): + with pytest.raises( + ValueError, match="cost_model must be a CostModel or 'wall_time'." + ): + history._get_time(cost_model="invalid") + + def test_start_time_property(history): assert history.start_time == [0, 2, 5, 7, 10, 12] @@ -465,12 +472,18 @@ def test_get_flat_params_fast_path(): assert_array_equal(got, exp) -def test_get_flat_param_names(): +def test_get_flat_param_names_pytree(): got = _get_flat_param_names(param={"a": 0, "b": [0, 1], "c": np.arange(2)}) exp = ["a", "b_0", "b_1", "c_0", "c_1"] assert got == exp +def test_get_flat_param_names_fast_path(): + got = _get_flat_param_names(param=np.arange(2)) + exp = ["0", "1"] + assert got == exp + + def test_calculate_monotone_sequence_maximize(): sequence = [0, 1, 0, 0, 2, 10, 0] exp = [0, 1, 1, 1, 2, 10, 10] @@ -509,17 +522,31 @@ def test_get_batch_start(): assert got == [0, 2, 5, 7] -def test_batch_apply_sum(): +def test_apply_to_batch_sum(): data = np.array([0, 1, 2, 3, 4]) batch_ids = [0, 0, 1, 1, 2] exp = np.array([1, 0, 5, 0, 4]) - got = _batch_apply(data, batch_ids, sum) + got = _apply_to_batch(data, batch_ids, sum) assert_array_equal(exp, got) -def test_batch_apply_max(): +def test_apply_to_batch_max(): data = np.array([0, 1, 2, 3, 4]) batch_ids = [0, 0, 1, 1, 2] exp = np.array([1, 0, 3, 0, 4]) - got = _batch_apply(data, batch_ids, max) + got = _apply_to_batch(data, batch_ids, max) assert_array_equal(exp, got) + + +def test_apply_to_batch_broken_func(): + data = np.array([0, 1, 2, 3, 4]) + batch_ids = [0, 0, 1, 1, 2] + with pytest.raises(ValueError, match="Calling function on batch [0, 0]"): + _apply_to_batch(data, batch_ids, func=lambda _: 1 / 0) + + +def test_apply_to_batch_func_with_non_scalar_return(): + data = np.array([0, 1, 2, 3, 4]) + batch_ids = [0, 0, 1, 1, 2] + with pytest.raises(ValueError, match="Function did not return a scalar"): + _apply_to_batch(data, batch_ids, func=lambda _list: _list) diff --git a/tests/optimagic/test_timing.py b/tests/optimagic/test_timing.py new file mode 100644 index 000000000..fd2edfc3c --- /dev/null +++ b/tests/optimagic/test_timing.py @@ -0,0 +1,14 @@ +import pytest + +from optimagic import timing + + +def test_invalid_aggregate_batch_time(): + with pytest.raises(ValueError, match="aggregate_batch_time must be a callable"): + timing.CostModel( + fun=None, + jac=None, + fun_and_jac=None, + label="label", + aggregate_batch_time="Not callable", + )