-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adaptive threshold acceptance (#156)
* Add AdaptiveThreshold acceptance criterion * Add tests
- Loading branch information
1 parent
bdcd64b
commit d2b3871
Showing
2 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import collections | ||
from statistics import mean | ||
from typing import Deque, List | ||
|
||
|
||
class AdaptiveThreshold: | ||
""" | ||
The Adaptive Threshold (AT) criterion accepts solutions | ||
if the candidate solution has a value lower than an | ||
adaptive threshold. The adaptive threshold is computed as: | ||
''adaptive_threshold = best_solution + | ||
eta_parameter * (average_solution - best_solution)'' | ||
where | ||
``best_solution`` is the best solution received so far, | ||
``average_solution`` is the average of the last | ||
``gamma_parameter`` solutions received, and | ||
``eta_parameter`` is a parameter between 0 and 1, | ||
the greater the value of | ||
``eta_parameter``, the more likely it is that a solution | ||
will be accepted. | ||
Each time a new solution is received, | ||
the threshold is updated. The average solution | ||
and best solution are taken by the last "gamma_parameter" | ||
solutions received. If the number of solutions received | ||
is less than"gamma_parameter" then the threshold | ||
is updated with the average of all the solutions | ||
received so far. | ||
The implementation is based on the description of AT in [1]. | ||
Parameters | ||
---------- | ||
eta: float | ||
Used to update/tune the threshold, | ||
the greater the value of ``eta_parameter``, | ||
the more likely it is that a solution will be accepted. | ||
gamma: int | ||
Used to update the threshold, the number of solutions | ||
received to compute the average & best solution. | ||
References | ||
---------- | ||
.. [1] Vinícius R. Máximo, Mariá C.V. Nascimento 2021. | ||
"A hybrid adaptive iterated local search with | ||
diversification control to the capacitated | ||
vehicle routing problem." | ||
*European Journal of Operational Research* | ||
294 (3): 1108 - 1119. | ||
""" | ||
|
||
def __init__(self, eta: float, gamma: int): | ||
if not (0 <= eta <= 1): | ||
raise ValueError("eta must be in [0, 1].") | ||
|
||
if gamma <= 0: | ||
raise ValueError("gamma must be positive.") | ||
|
||
self._eta = eta | ||
self._gamma = gamma | ||
self._history: Deque[float] = collections.deque(maxlen=gamma) | ||
|
||
@property | ||
def eta(self) -> float: | ||
return self._eta | ||
|
||
@property | ||
def gamma(self) -> int: | ||
return self._gamma | ||
|
||
@property | ||
def history(self) -> List[float]: | ||
return list(self._history) | ||
|
||
def __call__(self, rnd, best, current, candidate) -> bool: | ||
self._history.append(candidate.objective()) | ||
best_solution = min(self._history) | ||
avg_solution = mean(self._history) | ||
threshold = best_solution + self._eta * (avg_solution - best_solution) | ||
return candidate.objective() <= threshold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import numpy.random as rnd | ||
from numpy.testing import assert_, assert_equal, assert_raises | ||
from pytest import mark | ||
|
||
from alns.accept.AdaptiveThreshold import AdaptiveThreshold | ||
from alns.tests.states import One, Two, VarObj, Zero | ||
|
||
|
||
@mark.parametrize( | ||
"eta, gamma", | ||
[ | ||
(-1, 3), # eta cannot be < 0 | ||
(2, 3), # eta cannot be > 1 | ||
(0.5, -2), # gamma cannot be < 0 | ||
], | ||
) | ||
def test_raise_invalid_parameters(eta, gamma): | ||
with assert_raises(ValueError): | ||
AdaptiveThreshold(eta=eta, gamma=gamma) | ||
|
||
|
||
@mark.parametrize("eta, gamma", [(1, 3), (0.4, 4)]) | ||
def test_no_raise_valid_parameters(eta, gamma): | ||
AdaptiveThreshold(eta=eta, gamma=gamma) | ||
|
||
|
||
@mark.parametrize("eta", [0, 0.01, 0.5, 0.99, 1]) | ||
def test_eta(eta): | ||
adaptive_threshold = AdaptiveThreshold(eta, 3) | ||
assert_equal(adaptive_threshold.eta, eta) | ||
|
||
|
||
@mark.parametrize("gamma", range(1, 10)) | ||
def test_gamma(gamma): | ||
adaptive_threshold = AdaptiveThreshold(0.5, gamma) | ||
assert_equal(adaptive_threshold.gamma, gamma) | ||
|
||
|
||
def test_accepts_below_threshold(): | ||
adaptive_threshold = AdaptiveThreshold(eta=0.5, gamma=4) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), One()) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), Zero()) | ||
result = adaptive_threshold(rnd.RandomState(), One(), One(), Zero()) | ||
|
||
# The threshold is set at 0 + 0.5 * (0.5 - 0) = 0.25 | ||
assert_(result) | ||
|
||
|
||
def test_rejects_above_threshold(): | ||
adaptive_threshold = AdaptiveThreshold(eta=0.5, gamma=4) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), Two()) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), Zero()) | ||
result = adaptive_threshold(rnd.RandomState(), One(), One(), One()) | ||
|
||
# The threshold is set at 0 + 0.5 * (1 - 0) = 0.5 | ||
assert_(not result) | ||
|
||
|
||
def test_accepts_equal_threshold(): | ||
adaptive_threshold = AdaptiveThreshold(eta=0.5, gamma=4) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7100)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7200)) | ||
result = adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7120)) | ||
|
||
# The threshold is set at 7100 + 0.5 * (7140 - 7100) = 7120 | ||
assert_(result) | ||
|
||
|
||
def test_accepts_over_gamma_candidates(): | ||
adaptive_threshold = AdaptiveThreshold(eta=0.2, gamma=3) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7100)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7200)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7200)) | ||
result = adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7000)) | ||
|
||
# The threshold is set at 7000 + 0.2 * (7133.33 - 7000) = 7013.33 | ||
assert_(result) | ||
|
||
|
||
def test_rejects_over_gamma_candidates(): | ||
adaptive_threshold = AdaptiveThreshold(eta=0.2, gamma=3) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7100)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7200)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7200)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7000)) | ||
result = adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7100)) | ||
|
||
# The threshold is set at 7000 + 0.2 * (7100 - 7000) = 7020 | ||
assert_(not result) | ||
|
||
|
||
def test_evaluate_consecutive_solutions(): | ||
""" | ||
Test if AT correctly accepts and rejects consecutive solutions. | ||
""" | ||
adaptive_threshold = AdaptiveThreshold(eta=0.5, gamma=4) | ||
|
||
result = adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7100)) | ||
# The threshold is set at 7100, hence the solution is accepted | ||
assert_(result) | ||
|
||
result = adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7200)) | ||
# The threshold is set at 7125, hence the solution is accepted | ||
assert_(not result) | ||
|
||
result = adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7120)) | ||
# The threshold is set at 7120, hence the solution is accepted | ||
assert_(result) | ||
|
||
|
||
def test_history(): | ||
""" | ||
Test if AT correctly stores the history of the thresholds correctly. | ||
""" | ||
adaptive_threshold = AdaptiveThreshold(eta=0.5, gamma=4) | ||
|
||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7100)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7200)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7120)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7100)) | ||
adaptive_threshold(rnd.RandomState(), One(), One(), VarObj(7200)) | ||
assert_equal(adaptive_threshold.history, [7200, 7120, 7100, 7200]) |