Skip to content

Commit

Permalink
Fixing unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Hollyqui committed Dec 19, 2024
1 parent c94a18b commit 48c73f4
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/prompting/test_weight_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_run_step_with_reward_events():
with (
patch("shared.uids.get_uids") as mock_get_uids,
patch("prompting.weight_setting.weight_setter.TaskRegistry") as MockTaskRegistry,
patch("prompting.weight_setting.weight_setter.mutable_globals") as mock_mutable_globals,
# patch("prompting.weight_setting.weight_setter.mutable_globals") as mock_mutable_globals,
patch("prompting.weight_setting.weight_setter.set_weights") as mock_set_weights,
patch("prompting.weight_setting.weight_setter.logger") as mock_logger,
):
Expand Down Expand Up @@ -75,7 +75,9 @@ def __init__(self, task, uids, rewards, weight):
mock_task_registry.get_task_config = MagicMock(return_value=mock_task_registry.task_configs[0])

# Set up the mock mutable_globals
mock_mutable_globals.reward_events = [

weight_setter = WeightSetter()
reward_events = [
[
WeightedRewardEvent(
task=mock_task_registry.task_configs[0], uids=mock_uids, rewards=[1.0, 2.0, 3.0, 4.0, 5.0], weight=1
Expand All @@ -87,8 +89,7 @@ def __init__(self, task, uids, rewards, weight):
),
],
]

weight_setter = WeightSetter()
weight_setter.reward_events = reward_events
output = asyncio.run(weight_setter.run_step())

print(output)
Expand Down

0 comments on commit 48c73f4

Please sign in to comment.