Skip to content

Commit

Permalink
Merge pull request #150 from nzw0301/ruff
Browse files Browse the repository at this point in the history
Use `__future__.annotations`
  • Loading branch information
nabenabe0928 authored Aug 20, 2024
2 parents f9a71d3 + af3e5fb commit 605390c
Show file tree
Hide file tree
Showing 35 changed files with 220 additions and 193 deletions.
15 changes: 7 additions & 8 deletions optuna_integration/_imports.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import importlib
import types
from types import TracebackType
from typing import Any
from typing import Optional
from typing import Tuple
from typing import Type


class _DeferredImportExceptionContextManager:
Expand All @@ -16,7 +15,7 @@ class _DeferredImportExceptionContextManager:
"""

def __init__(self) -> None:
self._deferred: Optional[Tuple[Exception, str]] = None
self._deferred: tuple[Exception, str] | None = None

def __enter__(self) -> "_DeferredImportExceptionContextManager":
"""Enter the context manager.
Expand All @@ -29,10 +28,10 @@ def __enter__(self) -> "_DeferredImportExceptionContextManager":

def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
traceback: Optional[TracebackType],
) -> Optional[bool]:
exc_type: type[Exception] | None,
exc_value: Exception | None,
traceback: TracebackType | None,
) -> bool | None:
"""Exit the context manager.
Args:
Expand Down
2 changes: 2 additions & 0 deletions optuna_integration/_lightgbm_tuner/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any
import warnings

Expand Down
2 changes: 2 additions & 0 deletions optuna_integration/allennlp/_dump_best_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json

import optuna
Expand Down
2 changes: 1 addition & 1 deletion optuna_integration/allennlp/_pruner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
import os
from typing import Any
from typing import Callable

from optuna import load_study
from optuna import pruners
Expand Down
2 changes: 2 additions & 0 deletions optuna_integration/allennlp/_variables.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import os
from typing import Any
Expand Down
84 changes: 41 additions & 43 deletions optuna_integration/botorch/botorch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Sequence
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Union
import warnings

import numpy
Expand Down Expand Up @@ -100,9 +97,9 @@ def _get_constraint_funcs(n_constraints: int) -> list[Callable[["torch.Tensor"],
def logei_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Log Expected Improvement (LogEI).
Expand Down Expand Up @@ -210,9 +207,9 @@ def logei_candidates_func(
def qei_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Quasi MC-based batch Expected Improvement (qEI).
Expand Down Expand Up @@ -315,9 +312,9 @@ def qei_candidates_func(
def qnei_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Quasi MC-based batch Noisy Expected Improvement (qNEI).
Expand Down Expand Up @@ -382,9 +379,9 @@ def qnei_candidates_func(
def qehvi_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Quasi MC-based batch Expected Hypervolume Improvement (qEHVI).
Expand Down Expand Up @@ -467,9 +464,9 @@ def qehvi_candidates_func(
def ehvi_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Expected Hypervolume Improvement (EHVI).
Expand Down Expand Up @@ -532,9 +529,9 @@ def ehvi_candidates_func(
def qnehvi_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Quasi MC-based batch Noisy Expected Hypervolume Improvement (qNEHVI).
Expand Down Expand Up @@ -616,9 +613,9 @@ def qnehvi_candidates_func(
def qparego_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Quasi MC-based extended ParEGO (qParEGO) for constrained multi-objective optimization.
Expand Down Expand Up @@ -688,9 +685,9 @@ def qparego_candidates_func(
def qkg_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Quasi MC-based batch Knowledge Gradient (qKG).
Expand Down Expand Up @@ -754,9 +751,9 @@ def qkg_candidates_func(
def qhvkg_candidates_func(
train_x: "torch.Tensor",
train_obj: "torch.Tensor",
train_con: Optional["torch.Tensor"],
train_con: "torch.Tensor" | None,
bounds: "torch.Tensor",
pending_x: Optional["torch.Tensor"],
pending_x: "torch.Tensor" | None,
) -> "torch.Tensor":
"""Quasi MC-based batch Hypervolume Knowledge Gradient (qHVKG).
Expand Down Expand Up @@ -842,9 +839,9 @@ def _get_default_candidates_func(
[
"torch.Tensor",
"torch.Tensor",
Optional["torch.Tensor"],
"torch.Tensor" | None,
"torch.Tensor",
Optional["torch.Tensor"],
"torch.Tensor" | None,
],
"torch.Tensor",
]:
Expand Down Expand Up @@ -936,24 +933,25 @@ class BoTorchSampler(BaseSampler):
def __init__(
self,
*,
candidates_func: Optional[
candidates_func: (
Callable[
[
"torch.Tensor",
"torch.Tensor",
Optional["torch.Tensor"],
"torch.Tensor" | None,
"torch.Tensor",
Optional["torch.Tensor"],
"torch.Tensor" | None,
],
"torch.Tensor",
]
] = None,
constraints_func: Optional[Callable[[FrozenTrial], Sequence[float]]] = None,
| None
) = None,
constraints_func: Callable[[FrozenTrial], Sequence[float]] | None = None,
n_startup_trials: int = 10,
consider_running_trials: bool = False,
independent_sampler: Optional[BaseSampler] = None,
seed: Optional[int] = None,
device: Optional["torch.device"] = None,
independent_sampler: BaseSampler | None = None,
seed: int | None = None,
device: "torch.device" | None = None,
):
_imports.check()

Expand All @@ -964,23 +962,23 @@ def __init__(
self._n_startup_trials = n_startup_trials
self._seed = seed

self._study_id: Optional[int] = None
self._study_id: int | None = None
self._search_space = IntersectionSearchSpace()
self._device = device or torch.device("cpu")

def infer_relative_search_space(
self,
study: Study,
trial: FrozenTrial,
) -> Dict[str, BaseDistribution]:
) -> dict[str, BaseDistribution]:
if self._study_id is None:
self._study_id = study._study_id
if self._study_id != study._study_id:
# Note that the check below is meaningless when `InMemoryStorage` is used
# because `InMemoryStorage.create_new_study` always returns the same study ID.
raise RuntimeError("BoTorchSampler cannot handle multiple studies.")

search_space: Dict[str, BaseDistribution] = {}
search_space: dict[str, BaseDistribution] = {}
for name, distribution in self._search_space.calculate(study).items():
if distribution.single():
# built-in `candidates_func` cannot handle distributions that contain just a
Expand All @@ -995,8 +993,8 @@ def sample_relative(
self,
study: Study,
trial: FrozenTrial,
search_space: Dict[str, BaseDistribution],
) -> Dict[str, Any]:
search_space: dict[str, BaseDistribution],
) -> dict[str, Any]:
assert isinstance(search_space, dict)

if len(search_space) == 0:
Expand All @@ -1015,12 +1013,12 @@ def sample_relative(

trans = _SearchSpaceTransform(search_space)
n_objectives = len(study.directions)
values: Union[numpy.ndarray, torch.Tensor] = numpy.empty(
values: numpy.ndarray | torch.Tensor = numpy.empty(
(n_trials, n_objectives), dtype=numpy.float64
)
params: Union[numpy.ndarray, torch.Tensor]
con: Optional[Union[numpy.ndarray, torch.Tensor]] = None
bounds: Union[numpy.ndarray, torch.Tensor] = trans.bounds
params: numpy.ndarray | torch.Tensor
con: numpy.ndarray | torch.Tensor | None = None
bounds: numpy.ndarray | torch.Tensor = trans.bounds
params = numpy.empty((n_trials, trans.bounds.shape[0]), dtype=numpy.float64)
for trial_idx, trial in enumerate(trials):
if trial.state == TrialState.COMPLETE:
Expand Down Expand Up @@ -1151,7 +1149,7 @@ def after_trial(
study: Study,
trial: FrozenTrial,
state: TrialState,
values: Optional[Sequence[float]],
values: Sequence[float] | None,
) -> None:
if self._constraints_func is not None:
_process_constraints_after_trial(self._constraints_func, study, trial, state)
Expand Down
4 changes: 2 additions & 2 deletions optuna_integration/chainermn/chainermn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from typing import Callable
from typing import overload
from typing import Sequence
import warnings

from optuna import TrialPruned
Expand Down
Loading

0 comments on commit 605390c

Please sign in to comment.