Skip to content

Commit

Permalink
SacessOptimizer fixes (#1476)
Browse files Browse the repository at this point in the history
* Make SacessOptimizer conform more to the original saCeSS - default settings were shifted by one worker.

* More informative debugging output 

* Add `RefSet.__repr__`

* Test with `SacessOptions`

* Fix x trace in history (`history.update` does not copy x by itself); fval trace was correct
  • Loading branch information
dweindl authored Oct 17, 2024
1 parent 2eee081 commit dba0e59
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def _maybe_update_global_best(self, x, fx):
self.fx_best = fx
self.x_best_has_changed = True
self.history.update(
self.x_best,
self.x_best.copy(),
(0,),
pypesto.C.MODE_FUN,
{pypesto.C.FVAL: self.fx_best},
Expand Down
8 changes: 8 additions & 0 deletions pypesto/optimize/ess/refset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def __init__(
self.n_stuck = np.zeros(shape=[dim])
self.attributes: dict[Any, np.array] = {}

def __repr__(self):
fx = (
f", fx=[{np.min(self.fx)} ... {np.max(self.fx)}]"
if self.fx is not None and len(self.fx) >= 2
else ""
)
return f"RefSet(dim={self.dim}{fx})"

def sort(self):
"""Sort RefSet by quality."""
order = np.argsort(self.fx)
Expand Down
24 changes: 11 additions & 13 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def minimize(
start_time = time.time()
logger.debug(
f"Running {self.__class__.__name__} with {self.num_workers} "
f"workers: {self.ess_init_args}"
f"workers: {self.ess_init_args} and {self.options}."
)
ess_init_args = self.ess_init_args or get_default_ess_options(
num_workers=self.num_workers, dim=problem.dim
Expand Down Expand Up @@ -563,7 +563,8 @@ def run(
self._manager._logger = self._logger

self._logger.debug(
f"#{self._worker_idx} starting " f"({self._ess_kwargs})."
f"#{self._worker_idx} starting "
f"({self._ess_kwargs}, {self._options})."
)

evaluator = create_function_evaluator(
Expand Down Expand Up @@ -694,6 +695,13 @@ def _maybe_adapt(self, problem: Problem):
f"Updated settings on worker {self._worker_idx} to "
f"{self._ess_kwargs}"
)
else:
self._logger.debug(
f"Worker {self._worker_idx} not adapting. "
f"Received: {self._n_received_solutions} <= {self._options.adaptation_sent_coeff * self._n_sent_solutions + self._options.adaptation_sent_offset}, "
f"Sent: {self._n_sent_solutions}, "
f"neval: {self._neval} <= {problem.dim * self._options.adaptation_min_evals}."
)

def maybe_update_best(self, x: np.array, fx: float):
"""Maybe update the best known solution and send it to the manager."""
Expand Down Expand Up @@ -840,13 +848,6 @@ def dim_refset(x):
return max(min_dimrefset, ceil((1 + sqrt(4 * dim * x)) / 2))

settings = [
# settings for first worker
{
"dim_refset": dim_refset(10),
"balance": 0.5,
"local_n2": 10,
},
# for the remaining workers, cycle through these settings
# 1
{
"dim_refset": dim_refset(1),
Expand Down Expand Up @@ -998,10 +999,7 @@ def dim_refset(x):
elif local_optimizer is not False:
cur_settings["local_optimizer"] = local_optimizer

return [
settings[0],
*(itertools.islice(itertools.cycle(settings[1:]), num_workers - 1)),
]
return list(itertools.islice(itertools.cycle(settings), num_workers))


class SacessFidesFactory:
Expand Down
18 changes: 16 additions & 2 deletions test/optimize/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import pypesto.optimize as optimize
from pypesto.optimize.ess import (
ESSOptimizer,
FunctionEvaluatorMP,
RefSet,
SacessFidesFactory,
SacessOptimizer,
SacessOptions,
get_default_ess_options,
)
from pypesto.optimize.util import (
Expand Down Expand Up @@ -490,6 +493,11 @@ def test_ess(problem, local_optimizer, ess_type, request):
sacess_loglevel=logging.DEBUG,
ess_loglevel=logging.WARNING,
ess_init_args=ess_init_args,
options=SacessOptions(
adaptation_min_evals=500,
adaptation_sent_offset=10,
adaptation_sent_coeff=5,
),
)
else:
raise ValueError(f"Unsupported ESS type {ess_type}.")
Expand Down Expand Up @@ -522,8 +530,6 @@ def test_ess_multiprocess(problem, request):

from fides.constants import Options as FidesOptions

from pypesto.optimize.ess import ESSOptimizer, FunctionEvaluatorMP, RefSet

# augment objective with parameter prior to check it's copyable
# https://github.com/ICB-DCM/pyPESTO/issues/1465
# https://github.com/ICB-DCM/pyPESTO/pull/1467
Expand Down Expand Up @@ -563,6 +569,14 @@ def test_ess_multiprocess(problem, request):
print("ESS result: ", res.summary())


def test_ess_refset_repr():
assert RefSet(10, None).__repr__() == "RefSet(dim=10)"
assert (
RefSet(10, None, x=np.zeros(10), fx=np.arange(10)).__repr__()
== "RefSet(dim=10, fx=[0 ... 9])"
)


def test_scipy_integrated_grad():
integrated = True
obj = rosen_for_sensi(max_sensi_order=2, integrated=integrated)["obj"]
Expand Down

0 comments on commit dba0e59

Please sign in to comment.