Skip to content

Commit

Permalink
Merge pull request #55 from y0z/feature/simple-base-sampler
Browse files Browse the repository at this point in the history
Introduce SimpleBaseSampler
  • Loading branch information
toshihikoyanase committed Aug 23, 2024
2 parents 5cc1871 + 59fedf3 commit e55a73d
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 107 deletions.
11 changes: 11 additions & 0 deletions docs/source/optunahub.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. module:: optunahub

optunahub
=========

.. autosummary::
:toctree: generated/
:nosignatures:

optunahub.load_module
optunahub.load_local_module
9 changes: 5 additions & 4 deletions docs/source/reference.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
Reference
=========

.. automodule:: optunahub
:members:
:undoc-members:
:show-inheritance:
.. toctree::
:maxdepth: 1

optunahub
samplers
10 changes: 10 additions & 0 deletions docs/source/samplers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. module:: optunahub.samplers

optunahub.samplers
==================

.. autosummary::
:toctree: generated/
:nosignatures:

optunahub.samplers.SimpleBaseSampler
3 changes: 2 additions & 1 deletion optunahub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from optunahub import samplers
from optunahub.hub import load_local_module
from optunahub.hub import load_module
from optunahub.version import __version__


__all__ = ["load_module", "load_local_module", "__version__"]
__all__ = ["load_module", "load_local_module", "__version__", "samplers"]
53 changes: 2 additions & 51 deletions optunahub/hub.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

import importlib.util
import inspect
import logging
import os
import shutil
import sys
import types
from typing import Any
from urllib.parse import urlparse

from ga4mp import GtagMP # type: ignore
Expand All @@ -28,24 +26,6 @@
logging.getLogger("ga4mp.ga4mp").setLevel(logging.WARNING)


def _get_global_variable_from_outer_scopes(key: str, default: Any) -> Any:
"""Returns the value of the variable specified by the key defined on the stacks from the innermost caller to the outermost one.
If the value with the key is not found in the stacks, return the default value.
Args:
key:
The key to get.
default:
The default value.
"""

for s in inspect.stack():
outer_globals = s.frame.f_globals
if key in outer_globals:
return outer_globals[key]
return default


def _report_stats(
package: str,
ref: str | None,
Expand Down Expand Up @@ -89,9 +69,9 @@ def load_module(
repo_owner: str = "optuna",
repo_name: str = "optunahub-registry",
registry_root: str = "package",
ref: str | None = None,
ref: str = "main",
base_url: str = "https://api.github.com",
force_reload: bool | None = None,
force_reload: bool = False,
auth: Auth.Auth | None = None,
) -> types.ModuleType:
"""Import a package from the OptunaHub registry.
Expand All @@ -111,29 +91,19 @@ def load_module(
The default is "package".
ref:
The Git reference (branch, tag, or commit SHA) for the package.
This setting will be inherited to the inner `load`-like function.
If :obj:`None`, the setting is inherited from the outer `load`-like function.
For the outermost call, the default is "main".
base_url:
The base URL for the GitHub API.
force_reload:
If :obj:`True`, the package will be downloaded from the repository.
If :obj:`False`, the package cached in the local directory will be
loaded if available.
If :obj:`None`, the setting is inherited from the outer `load`-like function.
For the outermost call, the default is `False`.
auth:
`The authentication object <https://pygithub.readthedocs.io/en/latest/examples/Authentication.html>`__ for the GitHub API.
It is required to access private/internal repositories.
Returns:
The module object of the package.
"""
ref = ref or _get_global_variable_from_outer_scopes("OPTUNAHUB_REF", "main")
force_reload = force_reload or _get_global_variable_from_outer_scopes(
"OPTUNAHUB_FORCE_RELOAD", False
)

dir_path = f"{registry_root}/{package}" if registry_root else package
hostname = urlparse(base_url).hostname
if hostname is None:
Expand Down Expand Up @@ -174,8 +144,6 @@ def load_module(
module = load_local_module(
package=package,
registry_root=local_registry_root,
ref=ref,
force_reload=force_reload,
)

# Statistics are collected only for the official registry.
Expand All @@ -194,8 +162,6 @@ def load_local_module(
package: str,
*,
registry_root: str = os.sep,
ref: str | None = None,
force_reload: bool | None = None,
) -> types.ModuleType:
"""Import a package from the local registry.
The imported package name is set to ``optunahub_registry.package.<package>``.
Expand All @@ -207,24 +173,11 @@ def load_local_module(
The root directory of the registry.
The default is the root directory of the file system,
e.g., "/" for UNIX-like systems.
ref:
This setting will be inherited to the inner `load`-like function.
If :obj:`None`, the setting is inherited from the outer `load`-like function.
For the outermost call, the default is "main".
force_reload:
This setting will be inherited to the inner `load`-like function.
If :obj:`None`, the setting is inherited from the outer `load`-like function.
For the outermost call, the default is :obj:`False`.
Returns:
The module object of the package.
"""

ref = ref or _get_global_variable_from_outer_scopes("OPTUNAHUB_REF", "main")
force_reload = force_reload or _get_global_variable_from_outer_scopes(
"OPTUNAHUB_FORCE_RELOAD", False
)

module_path = os.path.join(registry_root, package)
module_name = f"optunahub_registry.package.{package.replace('/', '.')}"
spec = importlib.util.spec_from_file_location(
Expand All @@ -235,8 +188,6 @@ def load_local_module(
module = importlib.util.module_from_spec(spec)
if module is None:
raise ImportError(f"Module {module_name} not found in {module_path}")
setattr(module, "OPTUNAHUB_REF", ref)
setattr(module, "OPTUNAHUB_FORCE_RELOAD", force_reload)
sys.modules[module_name] = module
spec.loader.exec_module(module)

Expand Down
4 changes: 4 additions & 0 deletions optunahub/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from optunahub.samplers._simple_base import SimpleBaseSampler


__all__ = ["SimpleBaseSampler"]
94 changes: 94 additions & 0 deletions optunahub/samplers/_simple_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations

import abc
from typing import Any

from optuna import Study
from optuna.distributions import BaseDistribution
from optuna.samplers import BaseSampler
from optuna.samplers import RandomSampler
from optuna.search_space import IntersectionSearchSpace
from optuna.trial import FrozenTrial


class SimpleBaseSampler(BaseSampler, abc.ABC):
"""A simple base class to implement user-defined samplers."""

def __init__(
self, search_space: dict[str, BaseDistribution] | None = None, seed: int | None = None
) -> None:
self.search_space = search_space
self._seed = seed
self._init_defaults()

def infer_relative_search_space(
self,
study: Study,
trial: FrozenTrial,
) -> dict[str, BaseDistribution]:
# This method is optional.
# If you want to optimize the function with the eager search space,
# please implement this method.
if self.search_space is not None:
return self.search_space
return self._default_infer_relative_search_space(study, trial)

@abc.abstractmethod
def sample_relative(
self,
study: Study,
trial: FrozenTrial,
search_space: dict[str, BaseDistribution],
) -> dict[str, Any]:
# This method is required.
# This method is called at the beginning of each trial in Optuna to sample parameters.
raise NotImplementedError

def sample_independent(
self,
study: Study,
trial: FrozenTrial,
param_name: str,
param_distribution: BaseDistribution,
) -> Any:
# This method is optional.
# By default, parameter values are sampled by ``optuna.samplers.RandomSampler``.
return self._default_sample_independent(study, trial, param_name, param_distribution)

def reseed_rng(self) -> None:
self._default_reseed_rng()

def _init_defaults(self) -> None:
self._intersection_search_space = IntersectionSearchSpace()
self._random_sampler = RandomSampler(seed=self._seed)

def _default_infer_relative_search_space(
self, study: Study, trial: FrozenTrial
) -> dict[str, BaseDistribution]:
search_space: dict[str, BaseDistribution] = {}
for name, distribution in self._intersection_search_space.calculate(study).items():
if distribution.single():
# Single value objects are not sampled with the `sample_relative` method,
# but with the `sample_independent` method.
continue
search_space[name] = distribution
return search_space

def _default_sample_independent(
self,
study: Study,
trial: FrozenTrial,
param_name: str,
param_distribution: BaseDistribution,
) -> Any:
# Following parameters are randomly sampled here.
# 1. A parameter in the initial population/first generation.
# 2. A parameter to mutate.
# 3. A parameter excluded from the intersection search space.

return self._random_sampler.sample_independent(
study, trial, param_name, param_distribution
)

def _default_reseed_rng(self) -> None:
self._random_sampler.reseed_rng()
14 changes: 2 additions & 12 deletions tests/package_for_test_hub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
from optuna.samplers import RandomSampler
from .sampler import TestSampler

import optunahub

from . import implementation


ref = optunahub.hub._get_global_variable_from_outer_scopes("OPTUNAHUB_REF", "main")
force_reload = optunahub.hub._get_global_variable_from_outer_scopes(
"OPTUNAHUB_FORCE_RELOAD", False
)


__all__ = ["RandomSampler", "implementation", "ref", "force_reload"]
__all__ = ["TestSampler"]
7 changes: 0 additions & 7 deletions tests/package_for_test_hub/implementation.py

This file was deleted.

27 changes: 27 additions & 0 deletions tests/package_for_test_hub/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any
from typing import Dict
from typing import Optional

import numpy as np
from optuna import Study
from optuna.distributions import BaseDistribution
from optuna.trial import FrozenTrial

import optunahub


class TestSampler(optunahub.samplers.SimpleBaseSampler):
def __init__(self, search_space: Optional[Dict[str, BaseDistribution]] = None) -> None:
super().__init__(search_space)
self._rng = np.random.RandomState()

def sample_relative(
self,
study: Study,
trial: FrozenTrial,
search_space: Dict[str, BaseDistribution],
) -> Dict[str, Any]:
params = {}
for n, d in search_space.items():
params[n] = self._rng.uniform(d.low, d.high)
return params
33 changes: 1 addition & 32 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import sys

import optuna
import pytest

import optunahub

Expand Down Expand Up @@ -35,35 +33,6 @@ def objective(trial: optuna.Trial) -> float:
assert m.__name__ == "optunahub_registry.package.package_for_test_hub"

# Confirm no error occurs by running optimization
sampler = m.RandomSampler()
sampler = m.TestSampler()
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=10)


@pytest.mark.parametrize(
("ref", "force_reload", "expected_ref", "expected_force_reload"),
[
(None, None, "main", False),
("main", False, "main", False),
("test", True, "test", True),
],
)
def test_load_settings_propagation(
ref: str,
force_reload: bool,
expected_ref: str,
expected_force_reload: bool,
) -> None:
m = optunahub.load_local_module(
"package_for_test_hub",
registry_root=os.path.dirname(__file__),
ref=ref,
force_reload=force_reload,
)
assert m.ref == expected_ref
assert m.force_reload == expected_force_reload
assert m.implementation.ref == expected_ref
assert m.implementation.force_reload == expected_force_reload

del sys.modules["optunahub_registry.package.package_for_test_hub"]
del sys.modules["optunahub_registry.package.package_for_test_hub.implementation"]

0 comments on commit e55a73d

Please sign in to comment.