-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #55 from y0z/feature/simple-base-sampler
Introduce SimpleBaseSampler
- Loading branch information
Showing
11 changed files
with
158 additions
and
107 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
.. module:: optunahub | ||
|
||
optunahub | ||
========= | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:nosignatures: | ||
|
||
optunahub.load_module | ||
optunahub.load_local_module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
Reference | ||
========= | ||
|
||
.. automodule:: optunahub | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
.. toctree:: | ||
:maxdepth: 1 | ||
|
||
optunahub | ||
samplers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
.. module:: optunahub.samplers | ||
|
||
optunahub.samplers | ||
================== | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:nosignatures: | ||
|
||
optunahub.samplers.SimpleBaseSampler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from optunahub import samplers | ||
from optunahub.hub import load_local_module | ||
from optunahub.hub import load_module | ||
from optunahub.version import __version__ | ||
|
||
|
||
__all__ = ["load_module", "load_local_module", "__version__"] | ||
__all__ = ["load_module", "load_local_module", "__version__", "samplers"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from optunahub.samplers._simple_base import SimpleBaseSampler | ||
|
||
|
||
__all__ = ["SimpleBaseSampler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from __future__ import annotations | ||
|
||
import abc | ||
from typing import Any | ||
|
||
from optuna import Study | ||
from optuna.distributions import BaseDistribution | ||
from optuna.samplers import BaseSampler | ||
from optuna.samplers import RandomSampler | ||
from optuna.search_space import IntersectionSearchSpace | ||
from optuna.trial import FrozenTrial | ||
|
||
|
||
class SimpleBaseSampler(BaseSampler, abc.ABC): | ||
"""A simple base class to implement user-defined samplers.""" | ||
|
||
def __init__( | ||
self, search_space: dict[str, BaseDistribution] | None = None, seed: int | None = None | ||
) -> None: | ||
self.search_space = search_space | ||
self._seed = seed | ||
self._init_defaults() | ||
|
||
def infer_relative_search_space( | ||
self, | ||
study: Study, | ||
trial: FrozenTrial, | ||
) -> dict[str, BaseDistribution]: | ||
# This method is optional. | ||
# If you want to optimize the function with the eager search space, | ||
# please implement this method. | ||
if self.search_space is not None: | ||
return self.search_space | ||
return self._default_infer_relative_search_space(study, trial) | ||
|
||
@abc.abstractmethod | ||
def sample_relative( | ||
self, | ||
study: Study, | ||
trial: FrozenTrial, | ||
search_space: dict[str, BaseDistribution], | ||
) -> dict[str, Any]: | ||
# This method is required. | ||
# This method is called at the beginning of each trial in Optuna to sample parameters. | ||
raise NotImplementedError | ||
|
||
def sample_independent( | ||
self, | ||
study: Study, | ||
trial: FrozenTrial, | ||
param_name: str, | ||
param_distribution: BaseDistribution, | ||
) -> Any: | ||
# This method is optional. | ||
# By default, parameter values are sampled by ``optuna.samplers.RandomSampler``. | ||
return self._default_sample_independent(study, trial, param_name, param_distribution) | ||
|
||
def reseed_rng(self) -> None: | ||
self._default_reseed_rng() | ||
|
||
def _init_defaults(self) -> None: | ||
self._intersection_search_space = IntersectionSearchSpace() | ||
self._random_sampler = RandomSampler(seed=self._seed) | ||
|
||
def _default_infer_relative_search_space( | ||
self, study: Study, trial: FrozenTrial | ||
) -> dict[str, BaseDistribution]: | ||
search_space: dict[str, BaseDistribution] = {} | ||
for name, distribution in self._intersection_search_space.calculate(study).items(): | ||
if distribution.single(): | ||
# Single value objects are not sampled with the `sample_relative` method, | ||
# but with the `sample_independent` method. | ||
continue | ||
search_space[name] = distribution | ||
return search_space | ||
|
||
def _default_sample_independent( | ||
self, | ||
study: Study, | ||
trial: FrozenTrial, | ||
param_name: str, | ||
param_distribution: BaseDistribution, | ||
) -> Any: | ||
# Following parameters are randomly sampled here. | ||
# 1. A parameter in the initial population/first generation. | ||
# 2. A parameter to mutate. | ||
# 3. A parameter excluded from the intersection search space. | ||
|
||
return self._random_sampler.sample_independent( | ||
study, trial, param_name, param_distribution | ||
) | ||
|
||
def _default_reseed_rng(self) -> None: | ||
self._random_sampler.reseed_rng() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,4 @@ | ||
from optuna.samplers import RandomSampler | ||
from .sampler import TestSampler | ||
|
||
import optunahub | ||
|
||
from . import implementation | ||
|
||
|
||
ref = optunahub.hub._get_global_variable_from_outer_scopes("OPTUNAHUB_REF", "main") | ||
force_reload = optunahub.hub._get_global_variable_from_outer_scopes( | ||
"OPTUNAHUB_FORCE_RELOAD", False | ||
) | ||
|
||
|
||
__all__ = ["RandomSampler", "implementation", "ref", "force_reload"] | ||
__all__ = ["TestSampler"] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import Any | ||
from typing import Dict | ||
from typing import Optional | ||
|
||
import numpy as np | ||
from optuna import Study | ||
from optuna.distributions import BaseDistribution | ||
from optuna.trial import FrozenTrial | ||
|
||
import optunahub | ||
|
||
|
||
class TestSampler(optunahub.samplers.SimpleBaseSampler): | ||
def __init__(self, search_space: Optional[Dict[str, BaseDistribution]] = None) -> None: | ||
super().__init__(search_space) | ||
self._rng = np.random.RandomState() | ||
|
||
def sample_relative( | ||
self, | ||
study: Study, | ||
trial: FrozenTrial, | ||
search_space: Dict[str, BaseDistribution], | ||
) -> Dict[str, Any]: | ||
params = {} | ||
for n, d in search_space.items(): | ||
params[n] = self._rng.uniform(d.low, d.high) | ||
return params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters