From 8873740c5a1d876b5433dc40f6430ca851a84e35 Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 22 Aug 2024 14:20:14 +0900 Subject: [PATCH 1/6] Introduce SimpleBaseSampler --- optunahub/__init__.py | 3 +- optunahub/hub.py | 53 +---------- optunahub/samplers/__init__.py | 4 + optunahub/samplers/_simple_base.py | 92 ++++++++++++++++++++ tests/package_for_test_hub/__init__.py | 12 +-- tests/package_for_test_hub/implementation.py | 7 -- tests/test_hub.py | 31 ------- 7 files changed, 101 insertions(+), 101 deletions(-) create mode 100644 optunahub/samplers/__init__.py create mode 100644 optunahub/samplers/_simple_base.py delete mode 100644 tests/package_for_test_hub/implementation.py diff --git a/optunahub/__init__.py b/optunahub/__init__.py index 6b2c34a..07bd2ef 100644 --- a/optunahub/__init__.py +++ b/optunahub/__init__.py @@ -1,6 +1,7 @@ +from optunahub import samplers from optunahub.hub import load_local_module from optunahub.hub import load_module from optunahub.version import __version__ -__all__ = ["load_module", "load_local_module", "__version__"] +__all__ = ["load_module", "load_local_module", "__version__", "samplers"] diff --git a/optunahub/hub.py b/optunahub/hub.py index b3f475c..df58762 100644 --- a/optunahub/hub.py +++ b/optunahub/hub.py @@ -1,13 +1,11 @@ from __future__ import annotations import importlib.util -import inspect import logging import os import shutil import sys import types -from typing import Any from urllib.parse import urlparse from ga4mp import GtagMP # type: ignore @@ -28,24 +26,6 @@ logging.getLogger("ga4mp.ga4mp").setLevel(logging.WARNING) -def _get_global_variable_from_outer_scopes(key: str, default: Any) -> Any: - """Returns the value of the variable specified by the key defined on the stacks from the innermost caller to the outermost one. - If the value with the key is not found in the stacks, return the default value. - - Args: - key: - The key to get. - default: - The default value. - """ - - for s in inspect.stack(): - outer_globals = s.frame.f_globals - if key in outer_globals: - return outer_globals[key] - return default - - def _report_stats( package: str, ref: str | None, @@ -89,9 +69,9 @@ def load_module( repo_owner: str = "optuna", repo_name: str = "optunahub-registry", registry_root: str = "package", - ref: str | None = None, + ref: str = "main", base_url: str = "https://api.github.com", - force_reload: bool | None = None, + force_reload: bool = False, auth: Auth.Auth | None = None, ) -> types.ModuleType: """Import a package from the OptunaHub registry. @@ -111,17 +91,12 @@ def load_module( The default is "package". ref: The Git reference (branch, tag, or commit SHA) for the package. - This setting will be inherited to the inner `load`-like function. - If :obj:`None`, the setting is inherited from the outer `load`-like function. - For the outermost call, the default is "main". base_url: The base URL for the GitHub API. force_reload: If :obj:`True`, the package will be downloaded from the repository. If :obj:`False`, the package cached in the local directory will be loaded if available. - If :obj:`None`, the setting is inherited from the outer `load`-like function. - For the outermost call, the default is `False`. auth: `The authentication object `__ for the GitHub API. It is required to access private/internal repositories. @@ -129,11 +104,6 @@ def load_module( Returns: The module object of the package. """ - ref = ref or _get_global_variable_from_outer_scopes("OPTUNAHUB_REF", "main") - force_reload = force_reload or _get_global_variable_from_outer_scopes( - "OPTUNAHUB_FORCE_RELOAD", False - ) - dir_path = f"{registry_root}/{package}" if registry_root else package hostname = urlparse(base_url).hostname if hostname is None: @@ -174,8 +144,6 @@ def load_module( module = load_local_module( package=package, registry_root=local_registry_root, - ref=ref, - force_reload=force_reload, ) # Statistics are collected only for the official registry. @@ -194,8 +162,6 @@ def load_local_module( package: str, *, registry_root: str = os.sep, - ref: str | None = None, - force_reload: bool | None = None, ) -> types.ModuleType: """Import a package from the local registry. The imported package name is set to ``optunahub_registry.package.``. @@ -207,24 +173,11 @@ def load_local_module( The root directory of the registry. The default is the root directory of the file system, e.g., "/" for UNIX-like systems. - ref: - This setting will be inherited to the inner `load`-like function. - If :obj:`None`, the setting is inherited from the outer `load`-like function. - For the outermost call, the default is "main". - force_reload: - This setting will be inherited to the inner `load`-like function. - If :obj:`None`, the setting is inherited from the outer `load`-like function. - For the outermost call, the default is :obj:`False`. Returns: The module object of the package. """ - ref = ref or _get_global_variable_from_outer_scopes("OPTUNAHUB_REF", "main") - force_reload = force_reload or _get_global_variable_from_outer_scopes( - "OPTUNAHUB_FORCE_RELOAD", False - ) - module_path = os.path.join(registry_root, package) module_name = f"optunahub_registry.package.{package.replace('/', '.')}" spec = importlib.util.spec_from_file_location( @@ -235,8 +188,6 @@ def load_local_module( module = importlib.util.module_from_spec(spec) if module is None: raise ImportError(f"Module {module_name} not found in {module_path}") - setattr(module, "OPTUNAHUB_REF", ref) - setattr(module, "OPTUNAHUB_FORCE_RELOAD", force_reload) sys.modules[module_name] = module spec.loader.exec_module(module) diff --git a/optunahub/samplers/__init__.py b/optunahub/samplers/__init__.py new file mode 100644 index 0000000..b33082d --- /dev/null +++ b/optunahub/samplers/__init__.py @@ -0,0 +1,4 @@ +from optunahub.samplers._simple_base import SimpleBaseSampler + + +__all__ = ["SimpleBaseSampler"] diff --git a/optunahub/samplers/_simple_base.py b/optunahub/samplers/_simple_base.py new file mode 100644 index 0000000..0e4f9cf --- /dev/null +++ b/optunahub/samplers/_simple_base.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import abc +from typing import Any + +from optuna import Study +from optuna.distributions import BaseDistribution +from optuna.samplers import BaseSampler +from optuna.samplers import RandomSampler +from optuna.search_space import IntersectionSearchSpace +from optuna.trial import FrozenTrial + + +class SimpleBaseSampler(BaseSampler, abc.ABC): + def __init__( + self, search_space: dict[str, BaseDistribution] | None = None, seed: int | None = None + ) -> None: + self.search_space = search_space + self._seed = seed + self._init_defaults() + + def infer_relative_search_space( + self, + study: Study, + trial: FrozenTrial, + ) -> dict[str, BaseDistribution]: + # This method is optional. + # If you want to optimize the function with the eager search space, + # please implement this method. + if self.search_space is not None: + return self.search_space + return self._default_infer_relative_search_space(study, trial) + + @abc.abstractmethod + def sample_relative( + self, + study: Study, + trial: FrozenTrial, + search_space: dict[str, BaseDistribution], + ) -> dict[str, Any]: + # This method is required. + # This method is called at the beginning of each trial in Optuna to sample parameters. + raise NotImplementedError + + def sample_independent( + self, + study: Study, + trial: FrozenTrial, + param_name: str, + param_distribution: BaseDistribution, + ) -> Any: + # This method is optional. + # By default, parameter values are sampled by ``optuna.samplers.RandomSampler``. + return self._default_sample_independent(study, trial, param_name, param_distribution) + + def reseed_rng(self) -> None: + self._default_reseed_rng() + + def _init_defaults(self) -> None: + self._intersection_search_space = IntersectionSearchSpace() + self._random_sampler = RandomSampler(seed=self._seed) + + def _default_infer_relative_search_space( + self, study: Study, trial: FrozenTrial + ) -> dict[str, BaseDistribution]: + search_space: dict[str, BaseDistribution] = {} + for name, distribution in self._intersection_search_space.calculate(study).items(): + if distribution.single(): + # Single value objects are not sampled with the `sample_relative` method, + # but with the `sample_independent` method. + continue + search_space[name] = distribution + return search_space + + def _default_sample_independent( + self, + study: Study, + trial: FrozenTrial, + param_name: str, + param_distribution: BaseDistribution, + ) -> Any: + # Following parameters are randomly sampled here. + # 1. A parameter in the initial population/first generation. + # 2. A parameter to mutate. + # 3. A parameter excluded from the intersection search space. + + return self._random_sampler.sample_independent( + study, trial, param_name, param_distribution + ) + + def _default_reseed_rng(self) -> None: + self._random_sampler.reseed_rng() diff --git a/tests/package_for_test_hub/__init__.py b/tests/package_for_test_hub/__init__.py index a9cd5f8..e6aa2bf 100644 --- a/tests/package_for_test_hub/__init__.py +++ b/tests/package_for_test_hub/__init__.py @@ -1,14 +1,4 @@ from optuna.samplers import RandomSampler -import optunahub -from . import implementation - - -ref = optunahub.hub._get_global_variable_from_outer_scopes("OPTUNAHUB_REF", "main") -force_reload = optunahub.hub._get_global_variable_from_outer_scopes( - "OPTUNAHUB_FORCE_RELOAD", False -) - - -__all__ = ["RandomSampler", "implementation", "ref", "force_reload"] +__all__ = ["RandomSampler"] diff --git a/tests/package_for_test_hub/implementation.py b/tests/package_for_test_hub/implementation.py deleted file mode 100644 index 1add580..0000000 --- a/tests/package_for_test_hub/implementation.py +++ /dev/null @@ -1,7 +0,0 @@ -import optunahub - - -ref = optunahub.hub._get_global_variable_from_outer_scopes("OPTUNAHUB_REF", "main") -force_reload = optunahub.hub._get_global_variable_from_outer_scopes( - "OPTUNAHUB_FORCE_RELOAD", False -) diff --git a/tests/test_hub.py b/tests/test_hub.py index 63c066e..7829bf3 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -1,8 +1,6 @@ import os -import sys import optuna -import pytest import optunahub @@ -38,32 +36,3 @@ def objective(trial: optuna.Trial) -> float: sampler = m.RandomSampler() study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=10) - - -@pytest.mark.parametrize( - ("ref", "force_reload", "expected_ref", "expected_force_reload"), - [ - (None, None, "main", False), - ("main", False, "main", False), - ("test", True, "test", True), - ], -) -def test_load_settings_propagation( - ref: str, - force_reload: bool, - expected_ref: str, - expected_force_reload: bool, -) -> None: - m = optunahub.load_local_module( - "package_for_test_hub", - registry_root=os.path.dirname(__file__), - ref=ref, - force_reload=force_reload, - ) - assert m.ref == expected_ref - assert m.force_reload == expected_force_reload - assert m.implementation.ref == expected_ref - assert m.implementation.force_reload == expected_force_reload - - del sys.modules["optunahub_registry.package.package_for_test_hub"] - del sys.modules["optunahub_registry.package.package_for_test_hub.implementation"] From c1b168b054542ada540f078ddba409c5f35007d9 Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 22 Aug 2024 15:17:00 +0900 Subject: [PATCH 2/6] Add test for SimpleBaseSampler --- tests/package_for_test_hub/__init__.py | 4 ++-- tests/package_for_test_hub/sampler.py | 25 +++++++++++++++++++++++++ tests/test_hub.py | 2 +- 3 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 tests/package_for_test_hub/sampler.py diff --git a/tests/package_for_test_hub/__init__.py b/tests/package_for_test_hub/__init__.py index e6aa2bf..a189448 100644 --- a/tests/package_for_test_hub/__init__.py +++ b/tests/package_for_test_hub/__init__.py @@ -1,4 +1,4 @@ -from optuna.samplers import RandomSampler +from .sampler import TestSampler -__all__ = ["RandomSampler"] +__all__ = ["TestSampler"] diff --git a/tests/package_for_test_hub/sampler.py b/tests/package_for_test_hub/sampler.py new file mode 100644 index 0000000..ab40c22 --- /dev/null +++ b/tests/package_for_test_hub/sampler.py @@ -0,0 +1,25 @@ +from typing import Any + +import numpy as np +from optuna import Study +from optuna.distributions import BaseDistribution +from optuna.trial import FrozenTrial + +import optunahub + + +class TestSampler(optunahub.samplers.SimpleBaseSampler): + def __init__(self, search_space: dict[str, BaseDistribution] | None = None) -> None: + super().__init__(search_space) + self._rng = np.random.RandomState() + + def sample_relative( + self, + study: Study, + trial: FrozenTrial, + search_space: dict[str, BaseDistribution], + ) -> dict[str, Any]: + params = {} + for n, d in search_space.items(): + params[n] = self._rng.uniform(d.low, d.high) + return params diff --git a/tests/test_hub.py b/tests/test_hub.py index 7829bf3..ab1941f 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -33,6 +33,6 @@ def objective(trial: optuna.Trial) -> float: assert m.__name__ == "optunahub_registry.package.package_for_test_hub" # Confirm no error occurs by running optimization - sampler = m.RandomSampler() + sampler = m.TestSampler() study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=10) From ee8deb4f85501ccb6f709119da22f66a9320b143 Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 22 Aug 2024 15:20:12 +0900 Subject: [PATCH 3/6] Support Python 3.9 --- tests/package_for_test_hub/sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/package_for_test_hub/sampler.py b/tests/package_for_test_hub/sampler.py index ab40c22..b071753 100644 --- a/tests/package_for_test_hub/sampler.py +++ b/tests/package_for_test_hub/sampler.py @@ -1,4 +1,5 @@ from typing import Any +from typing import Optional import numpy as np from optuna import Study @@ -9,7 +10,7 @@ class TestSampler(optunahub.samplers.SimpleBaseSampler): - def __init__(self, search_space: dict[str, BaseDistribution] | None = None) -> None: + def __init__(self, search_space: Optional[dict[str, BaseDistribution]] = None) -> None: super().__init__(search_space) self._rng = np.random.RandomState() From 98f57b06479b0f190c8f6e2d8b2035ccb97a74a1 Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 22 Aug 2024 15:22:45 +0900 Subject: [PATCH 4/6] Support Python 3.8 --- tests/package_for_test_hub/sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/package_for_test_hub/sampler.py b/tests/package_for_test_hub/sampler.py index b071753..1f6474c 100644 --- a/tests/package_for_test_hub/sampler.py +++ b/tests/package_for_test_hub/sampler.py @@ -1,4 +1,5 @@ from typing import Any +from typing import Dict from typing import Optional import numpy as np @@ -10,7 +11,7 @@ class TestSampler(optunahub.samplers.SimpleBaseSampler): - def __init__(self, search_space: Optional[dict[str, BaseDistribution]] = None) -> None: + def __init__(self, search_space: Optional[Dict[str, BaseDistribution]] = None) -> None: super().__init__(search_space) self._rng = np.random.RandomState() From 6b180664558382389e561816cb921a7c0fffadc8 Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 22 Aug 2024 15:24:06 +0900 Subject: [PATCH 5/6] Support Python 3.8 --- tests/package_for_test_hub/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/package_for_test_hub/sampler.py b/tests/package_for_test_hub/sampler.py index 1f6474c..fe2ba2d 100644 --- a/tests/package_for_test_hub/sampler.py +++ b/tests/package_for_test_hub/sampler.py @@ -19,8 +19,8 @@ def sample_relative( self, study: Study, trial: FrozenTrial, - search_space: dict[str, BaseDistribution], - ) -> dict[str, Any]: + search_space: Dict[str, BaseDistribution], + ) -> Dict[str, Any]: params = {} for n, d in search_space.items(): params[n] = self._rng.uniform(d.low, d.high) From 59fedf3f67278ce1db1da82ecd45ec6beadf24b7 Mon Sep 17 00:00:00 2001 From: y0z Date: Thu, 22 Aug 2024 18:17:45 +0900 Subject: [PATCH 6/6] Add doc for SimpleBaseSampler. --- docs/source/optunahub.rst | 11 +++++++++++ docs/source/reference.rst | 9 +++++---- docs/source/samplers.rst | 10 ++++++++++ optunahub/samplers/_simple_base.py | 2 ++ 4 files changed, 28 insertions(+), 4 deletions(-) create mode 100644 docs/source/optunahub.rst create mode 100644 docs/source/samplers.rst diff --git a/docs/source/optunahub.rst b/docs/source/optunahub.rst new file mode 100644 index 0000000..2a58920 --- /dev/null +++ b/docs/source/optunahub.rst @@ -0,0 +1,11 @@ +.. module:: optunahub + +optunahub +========= + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + optunahub.load_module + optunahub.load_local_module \ No newline at end of file diff --git a/docs/source/reference.rst b/docs/source/reference.rst index ec8e82e..08a6ff9 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -1,7 +1,8 @@ Reference ========= -.. automodule:: optunahub - :members: - :undoc-members: - :show-inheritance: +.. toctree:: + :maxdepth: 1 + + optunahub + samplers \ No newline at end of file diff --git a/docs/source/samplers.rst b/docs/source/samplers.rst new file mode 100644 index 0000000..44ab4c4 --- /dev/null +++ b/docs/source/samplers.rst @@ -0,0 +1,10 @@ +.. module:: optunahub.samplers + +optunahub.samplers +================== + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + optunahub.samplers.SimpleBaseSampler \ No newline at end of file diff --git a/optunahub/samplers/_simple_base.py b/optunahub/samplers/_simple_base.py index 0e4f9cf..480d15f 100644 --- a/optunahub/samplers/_simple_base.py +++ b/optunahub/samplers/_simple_base.py @@ -12,6 +12,8 @@ class SimpleBaseSampler(BaseSampler, abc.ABC): + """A simple base class to implement user-defined samplers.""" + def __init__( self, search_space: dict[str, BaseDistribution] | None = None, seed: int | None = None ) -> None: