Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace RandomState with Generator #173

Merged
merged 16 commits into from
Oct 21, 2024
Prev Previous commit
Next Next commit
RNG instead of state
leonlan committed Oct 21, 2024
commit 540868fb89fa4eb9182cfca850b9ad11ae56309a
12 changes: 6 additions & 6 deletions alns/accept/tests/test_simulated_annealing.py
Original file line number Diff line number Diff line change
@@ -118,14 +118,14 @@ def test_linear_random_solutions():
"""
simulated_annealing = SimulatedAnnealing(2, 1, 1, "linear")

state = rnd.default_rng(1)
rng = rnd.default_rng(1)

# Using the above seed, the first two random numbers are 0.51 and 0.95,
# respectively. The acceptance probability is 0.61 first, so the first
# should be accepted (0.61 > 0.51). Thereafter, it drops to 0.37, so the
# second should not (0.37 < 0.95).
assert_(simulated_annealing(state, Zero(), Zero(), One()))
assert_(not simulated_annealing(state, Zero(), Zero(), One()))
assert_(simulated_annealing(rng, Zero(), Zero(), One()))
assert_(not simulated_annealing(rng, Zero(), Zero(), One()))


def test_exponential_random_solutions():
@@ -136,10 +136,10 @@ def test_exponential_random_solutions():
"""
simulated_annealing = SimulatedAnnealing(2, 1, 0.5, "exponential")

state = rnd.default_rng(1)
rng = rnd.default_rng(1)

assert_(simulated_annealing(state, Zero(), Zero(), One()))
assert_(not simulated_annealing(state, Zero(), Zero(), One()))
assert_(simulated_annealing(rng, Zero(), Zero(), One()))
assert_(not simulated_annealing(rng, Zero(), Zero(), One()))


@mark.parametrize(
10 changes: 5 additions & 5 deletions alns/select/tests/test_alpha_ucb.py
Original file line number Diff line number Diff line change
@@ -58,27 +58,27 @@ def test_raises_invalid_arguments(
def test_call_with_only_one_operator_pair():
# Only one operator pair, so the algorithm should select (0, 0).
select = AlphaUCB([2, 1, 1, 0], 0.5, 1, 1)
state = rnd.default_rng()
rng = rnd.default_rng()

selected = select(state, Zero(), Zero())
selected = select(rng, Zero(), Zero())
assert_equal(selected, (0, 0))


def test_update_with_two_operator_pairs():
select = AlphaUCB([2, 1, 1, 0], 0.5, 2, 1)
state = rnd.default_rng()
rng = rnd.default_rng()

# Avg. reward for (0, 0) after this is 2, for (1, 0) is still 1 (default).
select.update(Zero(), 0, 0, outcome=Outcome.BEST)

# So now (0, 0) is selected again.
selected = select(state, Zero(), Zero())
selected = select(rng, Zero(), Zero())
assert_equal(selected, (0, 0))

# One more update. Avg. reward goes to 1, and number of times to 2.
select.update(Zero(), 0, 0, outcome=Outcome.REJECT)

# The Q value of (0, 0) is now approx 1.432, and that of (1, 0) is now
# approx 1.74. So (1, 0) is selected.
selected = select(state, Zero(), Zero())
selected = select(rng, Zero(), Zero())
assert_equal(selected, (1, 0))
24 changes: 12 additions & 12 deletions alns/select/tests/test_mab_selector.py
Original file line number Diff line number Diff line change
@@ -73,42 +73,42 @@ def test_call_with_only_one_operator_pair():
select = MABSelector(
[2, 1, 1, 0], 1, 1, LearningPolicy.EpsilonGreedy(0.15)
)
state = rnd.default_rng()
rng = rnd.default_rng()

for _ in range(10):
selected = select(state, Zero(), Zero())
selected = select(rng, Zero(), Zero())
assert_equal(selected, (0, 0))


def test_mab_epsilon_greedy():
state = rnd.default_rng()
rng = rnd.default_rng()

# epsilon=0 is equivalent to greedy selection
select = MABSelector([2, 1, 1, 0], 2, 1, LearningPolicy.EpsilonGreedy(0.0))

select.update(Zero(), 0, 0, outcome=Outcome.BETTER)
selected = select(state, Zero(), Zero())
selected = select(rng, Zero(), Zero())
for _ in range(10):
selected = select(state, Zero(), Zero())
selected = select(rng, Zero(), Zero())
assert_equal(selected, (0, 0))

select.update(Zero(), 1, 0, outcome=Outcome.BEST)
for _ in range(10):
selected = select(state, Zero(), Zero())
selected = select(rng, Zero(), Zero())
assert_equal(selected, (1, 0))


@mark.parametrize("alpha", [0.25, 0.5])
def test_mab_ucb1(alpha):
state = rnd.default_rng()
rng = rnd.default_rng()
select = MABSelector([2, 1, 1, 0], 2, 1, LearningPolicy.UCB1(alpha))

select.update(Zero(), 0, 0, outcome=Outcome.BEST)
mab_select = select(state, Zero(), Zero())
mab_select = select(rng, Zero(), Zero())
assert_equal(mab_select, (0, 0))

select.update(Zero(), 0, 0, outcome=Outcome.REJECT)
mab_select = select(state, Zero(), Zero())
mab_select = select(rng, Zero(), Zero())
assert_equal(mab_select, (0, 0))


@@ -125,7 +125,7 @@ def test_contextual_mab_requires_context():


def text_contextual_mab_uses_context():
state = rnd.default_rng()
rng = rnd.default_rng()
select = MABSelector(
[2, 1, 1, 0],
2,
@@ -142,8 +142,8 @@ def text_contextual_mab_uses_context():
select.update(ZeroWithOneContext(), 1, 0, outcome=Outcome.REJECT)
select.update(ZeroWithOneContext(), 0, 0, outcome=Outcome.BEST)

mab_select = select(state, ZeroWithZeroContext(), ZeroWithZeroContext())
mab_select = select(rng, ZeroWithZeroContext(), ZeroWithZeroContext())
assert_equal(mab_select, (1, 0))

mab_select = select(state, ZeroWithZeroContext(), ZeroWithZeroContext())
mab_select = select(rng, ZeroWithZeroContext(), ZeroWithZeroContext())
assert_equal(mab_select, (0, 0))
8 changes: 4 additions & 4 deletions alns/tests/test_result.py
Original file line number Diff line number Diff line change
@@ -36,15 +36,15 @@ def get_statistics():
statistics.collect_objective(objective)

# We should make sure these results are reproducible.
state = rnd.default_rng(1)
rng = rnd.default_rng(1)

operators = ["test1", "test2", "test3"]

for _ in range(100):
operator = state.choice(operators)
operator = rng.choice(operators)

statistics.collect_destroy_operator("d_" + operator, state.integers(4))
statistics.collect_repair_operator("r_" + operator, state.integers(4))
statistics.collect_destroy_operator("d_" + operator, rng.integers(4))
statistics.collect_repair_operator("r_" + operator, rng.integers(4))

return statistics