Skip to content

Commit

Permalink
Add more test cases for invalid user funcs; check that batch time agg…
Browse files Browse the repository at this point in the history
…regator returns scalars
  • Loading branch information
timmens committed Dec 5, 2024
1 parent e089347 commit bd50a5d
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 17 deletions.
34 changes: 23 additions & 11 deletions src/optimagic/optimization/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ def _get_time(
)

time = fun_time + jac_time + fun_and_jac_time
batch_time = _batch_apply(
batch_aware_time = _apply_to_batch(
data=time,
batch_ids=self.batches,
func=cost_model.aggregate_batch_time,
)
return np.cumsum(batch_time)
return np.cumsum(batch_aware_time)

def _get_time_per_task(
self, task: EvalTask, cost_factor: float | None
Expand Down Expand Up @@ -383,7 +383,7 @@ def _task_as_categorical(task: list[EvalTask]) -> pd.Categorical:
)


def _batch_apply(
def _apply_to_batch(
data: NDArray[np.float64],
batch_ids: list[int],
func: Callable[[Iterable[float]], float],
Expand All @@ -392,10 +392,10 @@ def _batch_apply(
Args:
data: 1d array with data.
batch_ids: A list whose length is equal to the size of data. Values need to be
sorted and can be repeated.
batch_ids: A list with batch ids whose length is equal to the size of data.
Values need to be sorted and can be repeated.
func: A reduction function that takes an iterable of floats as input (e.g., a
numpy array or a list) and returns a scalar.
numpy.ndarray or list) and returns a scalar.
Returns:
The transformed data. Has the same length as data. For each batch, the result of
Expand All @@ -410,25 +410,37 @@ def _batch_apply(
for batch, (start, stop) in zip(
batch_ids, zip(batch_starts, batch_stops, strict=False), strict=False
):
batch_data = data[start:stop]

try:
batch_data = data[start:stop]
reduced = func(batch_data)
batch_results.append(reduced)
except Exception as e:
msg = (
f"Calling function {func.__name__} on batch {batch} of the History "
f"History raised an Exception. Please verify that {func.__name__} is "
"properly defined."
f"raised an Exception. Please verify that {func.__name__} is "
"well-defined and takes a list of floats as input and returns a scalar."
)
raise ValueError(msg) from e

try:
assert np.isscalar(reduced)
except AssertionError:
msg = (
f"Function {func.__name__} did not return a scalar for batch {batch}. "
f"Please verify that {func.__name__} returns a scalar when called on a "
"list of floats."
)
raise ValueError(msg) from None

batch_results.append(reduced)

out = np.zeros_like(data)
out[batch_starts] = batch_results
return out


def _get_batch_start(batch_ids: list[int]) -> list[int]:
"""Get start indices of batch.
"""Get start indices of batches.
This function assumes that batch_ids non-empty and sorted.
Expand Down
7 changes: 7 additions & 0 deletions src/optimagic/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ class CostModel:
label: str
aggregate_batch_time: Callable[[Iterable[float]], float]

def __post_init__(self) -> None:
if not callable(self.aggregate_batch_time):
raise ValueError(
"aggregate_batch_time must be a callable, got "
f"{self.aggregate_batch_time}"
)


evaluation_time = CostModel(
fun=None,
Expand Down
39 changes: 33 additions & 6 deletions tests/optimagic/optimization/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from optimagic.optimization.history import (
History,
HistoryEntry,
_batch_apply,
_apply_to_batch,
_calculate_monotone_sequence,
_get_batch_start,
_get_flat_param_names,
Expand Down Expand Up @@ -403,6 +403,13 @@ def test_get_time_wall_time(history):
assert_array_equal(got, exp)


def test_get_time_invalid_cost_model(history):
with pytest.raises(
ValueError, match="cost_model must be a CostModel or 'wall_time'."
):
history._get_time(cost_model="invalid")


def test_start_time_property(history):
assert history.start_time == [0, 2, 5, 7, 10, 12]

Expand Down Expand Up @@ -465,12 +472,18 @@ def test_get_flat_params_fast_path():
assert_array_equal(got, exp)


def test_get_flat_param_names():
def test_get_flat_param_names_pytree():
got = _get_flat_param_names(param={"a": 0, "b": [0, 1], "c": np.arange(2)})
exp = ["a", "b_0", "b_1", "c_0", "c_1"]
assert got == exp


def test_get_flat_param_names_fast_path():
got = _get_flat_param_names(param=np.arange(2))
exp = ["0", "1"]
assert got == exp


def test_calculate_monotone_sequence_maximize():
sequence = [0, 1, 0, 0, 2, 10, 0]
exp = [0, 1, 1, 1, 2, 10, 10]
Expand Down Expand Up @@ -509,17 +522,31 @@ def test_get_batch_start():
assert got == [0, 2, 5, 7]


def test_batch_apply_sum():
def test_apply_to_batch_sum():
data = np.array([0, 1, 2, 3, 4])
batch_ids = [0, 0, 1, 1, 2]
exp = np.array([1, 0, 5, 0, 4])
got = _batch_apply(data, batch_ids, sum)
got = _apply_to_batch(data, batch_ids, sum)
assert_array_equal(exp, got)


def test_batch_apply_max():
def test_apply_to_batch_max():
data = np.array([0, 1, 2, 3, 4])
batch_ids = [0, 0, 1, 1, 2]
exp = np.array([1, 0, 3, 0, 4])
got = _batch_apply(data, batch_ids, max)
got = _apply_to_batch(data, batch_ids, max)
assert_array_equal(exp, got)


def test_apply_to_batch_broken_func():
data = np.array([0, 1, 2, 3, 4])
batch_ids = [0, 0, 1, 1, 2]
with pytest.raises(ValueError, match="Calling function <lambda> on batch [0, 0]"):
_apply_to_batch(data, batch_ids, func=lambda _: 1 / 0)


def test_apply_to_batch_func_with_non_scalar_return():
data = np.array([0, 1, 2, 3, 4])
batch_ids = [0, 0, 1, 1, 2]
with pytest.raises(ValueError, match="Function <lambda> did not return a scalar"):
_apply_to_batch(data, batch_ids, func=lambda _list: _list)
14 changes: 14 additions & 0 deletions tests/optimagic/test_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from optimagic import timing


def test_invalid_aggregate_batch_time():
with pytest.raises(ValueError, match="aggregate_batch_time must be a callable"):
timing.CostModel(
fun=None,
jac=None,
fun_and_jac=None,
label="label",
aggregate_batch_time="Not callable",
)

0 comments on commit bd50a5d

Please sign in to comment.