Skip to content

Commit

Permalink
Make all generators know the dim (facebookresearch#513)
Browse files Browse the repository at this point in the history
Summary:

Some of the generators know the dimensionality of the search space but others don't. Unified the API such that every generator has a dim attribute. This allows other classes working with generators to be able to always access the dimensionality (even if the bounds are not available).

Differential Revision: D67997208
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Jan 10, 2025
1 parent 3945448 commit 4ef66c6
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 1 deletion.
1 change: 1 addition & 0 deletions aepsych/generators/acqf_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.stimuli_per_trial = stimuli_per_trial
self.lb = lb
self.ub = ub
self.dim = len(lb)

def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFunction:
"""Instantiate the acquisition function with the model and any extra arguments.
Expand Down
1 change: 1 addition & 0 deletions aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class AEPsychGenerator(abc.ABC, Generic[AEPsychModelType]):

acqf: AcquisitionFunction
acqf_kwargs: Dict[str, Any]
dim: int

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions aepsych/generators/epsilon_greedy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
self.epsilon = epsilon
self.lb = lb
self.ub = ub
self.dim = len(lb)

@classmethod
def from_config(cls, config: Config) -> "EpsilonGreedyGenerator":
Expand Down
2 changes: 1 addition & 1 deletion aepsych/generators/monotonic_rejection_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
self.acqf_kwargs = acqf_kwargs
self.model_gen_options = model_gen_options
self.explore_features = explore_features
self.lb, self.ub, _ = _process_bounds(lb, ub, None)
self.lb, self.ub, self.dim = _process_bounds(lb, ub, None)
self.bounds = torch.stack((self.lb, self.ub))

def _instantiate_acquisition_fn(
Expand Down
3 changes: 3 additions & 0 deletions aepsych/generators/monotonic_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
num_ts_points: int,
target_value: float,
objective: MCAcquisitionObjective,
dim: int,
explore_features: Optional[List[Type[int]]] = None,
) -> None:
"""Initialize MonotonicMCAcquisition
Expand All @@ -41,6 +42,7 @@ def __init__(
target_value (float): target value that is being looked for
objective (MCAcquisitionObjective): Objective transform of the GP output
before evaluating the acquisition. Defaults to identity transform.
dim (int): Dimensionality of the model.
explore_features (List[Type[int]], optional): List of features that will be selected randomly and then
fixed for acquisition fn optimization. Defaults to None.
"""
Expand All @@ -50,6 +52,7 @@ def __init__(
self.target_value = target_value
self.objective = objective()
self.explore_features = explore_features
self.dim = dim

def gen(
self,
Expand Down
1 change: 1 addition & 0 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
self.stimuli_per_trial = stimuli_per_trial
self.lb = lb
self.ub = ub
self.dim = len(lb)

def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFunction:
"""
Expand Down

0 comments on commit 4ef66c6

Please sign in to comment.