Skip to content

Commit

Permalink
Merge pull request #147 from nabenabe0928/code-fix/followup-botorch-u…
Browse files Browse the repository at this point in the history
…pdate

Refactor botorch
  • Loading branch information
HideakiImamura authored Aug 19, 2024
2 parents 16fd88b + 0e815a3 commit f9a71d3
Showing 1 changed file with 27 additions and 29 deletions.
56 changes: 27 additions & 29 deletions optuna_integration/botorch/botorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any
from typing import Callable
from typing import Dict
Expand Down Expand Up @@ -81,6 +83,19 @@ def _get_sobol_qmc_normal_sampler(num_samples: int) -> SobolQMCNormalSampler:
)


def _validate_botorch_version_for_constrained_opt(func_name: str) -> None:
if version.parse(botorch.version.version) < version.parse("0.9.0"):
raise ImportError(
f"{func_name} requires botorch>=0.9.0 for constrained problems, but got "
f"botorch={botorch.version.version}.\n"
"Please run ``pip install botorch --upgrade``."
)


def _get_constraint_funcs(n_constraints: int) -> list[Callable[["torch.Tensor"], "torch.Tensor"]]:
return [lambda Z: Z[..., -n_constraints + i] for i in range(n_constraints)]


@experimental_func("3.3.0")
def logei_candidates_func(
train_x: "torch.Tensor",
Expand Down Expand Up @@ -235,11 +250,7 @@ def qei_candidates_func(
if train_obj.size(-1) != 1:
raise ValueError("Objective may only contain single values with qEI.")
if train_con is not None:
if version.parse(botorch.version.version) < version.parse("0.9.0"):
raise ImportError(
"qei_candidates_func requires botorch >=0.9.0. for constrained problems."
"Please upgrade botorch"
)
_validate_botorch_version_for_constrained_opt("qei_candidates_func")
train_y = torch.cat([train_obj, train_con], dim=-1)

is_feas = (train_con <= 0).all(dim=-1)
Expand All @@ -257,7 +268,7 @@ def qei_candidates_func(
n_constraints = train_con.size(1)
additonal_qei_kwargs = {
"objective": GenericMCObjective(lambda Z, X: Z[..., 0]),
"constraints": [lambda Z: Z[..., -n_constraints + i] for i in range(n_constraints)],
"constraints": _get_constraint_funcs(n_constraints),
}
else:
train_y = train_obj
Expand Down Expand Up @@ -320,17 +331,13 @@ def qnei_candidates_func(
if train_obj.size(-1) != 1:
raise ValueError("Objective may only contain single values with qNEI.")
if train_con is not None:
if version.parse(botorch.version.version) < version.parse("0.9.0"):
raise ImportError(
"qnei_candidates_func requires botorch >=0.9.0. for constrained problems."
"Please upgrade botorch"
)
_validate_botorch_version_for_constrained_opt("qnei_candidates_func")
train_y = torch.cat([train_obj, train_con], dim=-1)

n_constraints = train_con.size(1)
additional_qnei_kwargs = {
"objective": GenericMCObjective(lambda Z, X: Z[..., 0]),
"constraints": [lambda Z: Z[..., -n_constraints + i] for i in range(n_constraints)],
"constraints": _get_constraint_funcs(n_constraints),
}
else:
train_y = train_obj
Expand Down Expand Up @@ -400,7 +407,7 @@ def qehvi_candidates_func(
n_constraints = train_con.size(1)
additional_qehvi_kwargs = {
"objective": IdentityMCMultiOutputObjective(outcomes=list(range(n_objectives))),
"constraints": [lambda Z: Z[..., -n_constraints + i] for i in range(n_constraints)],
"constraints": _get_constraint_funcs(n_constraints),
}
else:
train_y = train_obj
Expand Down Expand Up @@ -548,9 +555,7 @@ def qnehvi_candidates_func(
n_constraints = train_con.size(1)
additional_qnehvi_kwargs = {
"objective": IdentityMCMultiOutputObjective(outcomes=list(range(n_objectives))),
"constraints": [
(lambda Z, i=i: Z[..., -n_constraints + i]) for i in range(n_constraints)
],
"constraints": _get_constraint_funcs(n_constraints),
}
else:
train_y = train_obj
Expand Down Expand Up @@ -631,23 +636,18 @@ def qparego_candidates_func(
scalarization = get_chebyshev_scalarization(weights=weights, Y=train_obj)

if train_con is not None:
if version.parse(botorch.version.version) < version.parse("0.9.0"):
raise ImportError(
"qparego_candidates_func requires botorch >=0.9.0. for constrained problems."
"Please upgrade botorch"
)

_validate_botorch_version_for_constrained_opt("qparego_candidates_func")
train_y = torch.cat([train_obj, train_con], dim=-1)
n_constraints = train_con.size(1)
objective = GenericMCObjective(lambda Z, X: scalarization(Z[..., :n_objectives]))
additional_kwargs = {
"constraints": [lambda Z: Z[..., -n_constraints + i] for i in range(n_constraints)],
additional_qei_kwargs = {
"constraints": _get_constraint_funcs(n_constraints),
}
else:
train_y = train_obj

objective = GenericMCObjective(scalarization)
additional_kwargs = {}
additional_qei_kwargs = {}

train_x = normalize(train_x, bounds=bounds)
if pending_x is not None:
Expand All @@ -663,7 +663,7 @@ def qparego_candidates_func(
sampler=_get_sobol_qmc_normal_sampler(256),
objective=objective,
X_pending=pending_x,
**additional_kwargs,
**additional_qei_kwargs,
)

standard_bounds = torch.zeros_like(bounds)
Expand Down Expand Up @@ -711,9 +711,7 @@ def qkg_candidates_func(
n_constraints = train_con.size(1)
objective = ConstrainedMCObjective(
objective=lambda Z, X: Z[..., 0],
constraints=[
(lambda Z, i=i: Z[..., -n_constraints + i]) for i in range(n_constraints)
],
constraints=_get_constraint_funcs(n_constraints),
)
else:
train_y = train_obj
Expand Down

0 comments on commit f9a71d3

Please sign in to comment.