From 54e5cd1f71cbf5702050252874f49b2aac764809 Mon Sep 17 00:00:00 2001 From: Lyubov Yamshchikova Date: Wed, 5 Jun 2024 13:17:36 +0300 Subject: [PATCH] Fix action name --- .../optimisers/adaptive/mab_agents/mab_agent.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/golem/core/optimisers/adaptive/mab_agents/mab_agent.py b/golem/core/optimisers/adaptive/mab_agents/mab_agent.py index 47382157..59e3279b 100644 --- a/golem/core/optimisers/adaptive/mab_agents/mab_agent.py +++ b/golem/core/optimisers/adaptive/mab_agents/mab_agent.py @@ -2,7 +2,8 @@ import _pickle as pickle import random import re -from typing import Union, Sequence, Optional +from functools import partial +from typing import Union, Sequence, Optional, Callable from mabwiser.mab import MAB, LearningPolicy from scipy.special import softmax @@ -26,7 +27,7 @@ def __init__(self, self.actions = list(actions) self._indices = list(range(len(actions))) # str because parent operator for mutation is stored as string for custom mutations serialisation - self._arm_by_action = dict(map(lambda x, y: (x.__name__, y), actions, self._indices)) + self._arm_by_action = dict(map(lambda x, y: (self._get_callable_name(x), y), actions, self._indices)) self._agent = MAB(arms=self._indices, learning_policy=LearningPolicy.EpsilonGreedy(epsilon=0.4), n_jobs=n_jobs) @@ -35,6 +36,16 @@ def __init__(self, self._initial_fit() self._path_to_save = path_to_save + @staticmethod + def _get_callable_name(action: Callable): + if isinstance(action, partial): + return action.func.__name__ + else: + try: + return action.__name__ + except AttributeError: + return str(action) + def _initial_fit(self): n = len(self.actions) uniform_rewards = [1. / n] * n