Skip to content

Commit

Permalink
add additional test functions and psychophysics task and dataset from…
Browse files Browse the repository at this point in the history
… Letham et al. 2022

Summary: Additional high-dimensional test functions and real psychophysics task are added to problem.py for benchmarking performance of acquistions functions or GP models. The code and dataset are obtained from https://github.com/facebookresearch/bernoulli_lse/blob/main/problems.py.

Differential Revision: D57885175
  • Loading branch information
wenx-guo authored and facebook-github-bot committed May 29, 2024
1 parent 23ac907 commit 6f45705
Show file tree
Hide file tree
Showing 2 changed files with 1,130 additions and 2 deletions.
130 changes: 128 additions & 2 deletions aepsych/benchmark/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from functools import cached_property
from typing import Any, Dict, Union

import aepsych
import numpy as np
import torch
from scipy.stats import bernoulli, norm, pearsonr
from aepsych.strategy import SequentialStrategy, Strategy
from aepsych.utils import make_scaled_sobol
from scipy.stats import bernoulli, norm, pearsonr
from aepsych.benchmark.test_functions import (
modified_hartmann6,
discrim_highdim,
novel_discrimination_testfun,
)
from aepsych.models import GPClassificationModel



class Problem:
Expand Down Expand Up @@ -281,3 +288,122 @@ def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, floa
)

return metrics

"""
The LSEProblemWithEdgeLogging, DiscrimLowDim, DiscrimHighDim, ContrastSensitivity6d, and Hartmann6Binary classes
are copied from bernoulli_lse repository (https://github.com/facebookresearch/bernoulli_lse) by Letham et al. 2022.
"""
class LSEProblemWithEdgeLogging(LSEProblem):
eps = 0.05

def evaluate(self, strat):
metrics = super().evaluate(strat)

# add number of edge samples to the log

# get the trials selected by the final strat only
n_opt_trials = strat.strat_list[-1].n_trials

lb, ub = strat.lb, strat.ub
r = ub - lb
lb2 = lb + self.eps * r
ub2 = ub - self.eps * r

near_edge = (
np.logical_or(
(strat.x[-n_opt_trials:, :] <= lb2), (strat.x[-n_opt_trials:, :] >= ub2)
)
.any(axis=-1)
.double()
)

metrics["prop_edge_sampling_mean"] = near_edge.mean().item()
metrics["prop_edge_sampling_err"] = (
2 * near_edge.std() / np.sqrt(len(near_edge))
).item()
return metrics


class DiscrimLowDim(LSEProblemWithEdgeLogging):
name = "discrim_lowdim"
bounds = torch.tensor([[-1, 1], [-1, 1]], dtype=torch.double).T
threshold = 0.75

def f(self, x: torch.Tensor) -> torch.Tensor:
return torch.tensor(novel_discrimination_testfun(x), dtype=torch.double)


class DiscrimHighDim(LSEProblemWithEdgeLogging):
name = "discrim_highdim"
threshold = 0.75
bounds = torch.tensor(
[
[-1, 1],
[-1, 1],
[0.5, 1.5],
[0.05, 0.15],
[0.05, 0.2],
[0, 0.9],
[0, 3.14 / 2],
[0.5, 2],
],
dtype=torch.double,
).T

def f(self, x: torch.Tensor) -> torch.Tensor:
return torch.tensor(discrim_highdim(x), dtype=torch.double)


class Hartmann6Binary(LSEProblemWithEdgeLogging):
name = "hartmann6_binary"
threshold = 0.5
bounds = torch.stack(
(
torch.zeros(6, dtype=torch.double),
torch.ones(6, dtype=torch.double),
)
)

def f(self, X: torch.Tensor) -> torch.Tensor:
y = torch.tensor([modified_hartmann6(x) for x in X], dtype=torch.double)


class ContrastSensitivity6d(LSEProblemWithEdgeLogging):
"""
Uses a surrogate model fit to real data from a constrast sensitivity study.
"""

name = "contrast_sensitivity_6d"
threshold = 0.75
bounds = torch.tensor(
[[-1.5, 0], [-1.5, 0], [0, 20], [0.5, 7], [1, 10], [0, 10]],
dtype=torch.double,
).T

def __init__(self):

# Load the data
self.data = np.loadtxt(
os.path.join("..", "..", "dataset", "csf_dataset.csv"),
delimiter=",",
skiprows=1,
)
y = torch.LongTensor(self.data[:, 0])
x = torch.Tensor(self.data[:, 1:])

# Fit a model, with a large number of inducing points
self.m = GPClassificationModel(
lb=self.bounds[0],
ub=self.bounds[1],
inducing_size=100,
inducing_point_method="kmeans++",
)

self.m.fit(
x,
y,
)

def f(self, X: torch.Tensor) -> torch.Tensor:
# clamp f to 0 since we expect p(x) to be lower-bounded at 0.5
return torch.clamp(self.m.predict(torch.tensor(X))[0], min=0)
Loading

0 comments on commit 6f45705

Please sign in to comment.