Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement pairwise benchmarks #390

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 82 additions & 5 deletions aepsych/benchmark/example_problems.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
import os

import numpy as np
import torch
from aepsych.models import GPClassificationModel
from aepsych.benchmark.problem import LSEProblemWithEdgeLogging
from aepsych.benchmark.test_functions import (
modified_hartmann6,
discrim_highdim,
modified_hartmann6,
novel_discrimination_testfun,
)
from aepsych.benchmark.problem import LSEProblemWithEdgeLogging

from aepsych.models import GPClassificationModel
from scipy.stats import norm

"""The DiscrimLowDim, DiscrimHighDim, ContrastSensitivity6d, and Hartmann6Binary classes
are copied from bernoulli_lse github repository (https://github.com/facebookresearch/bernoulli_lse)
Expand Down Expand Up @@ -84,7 +85,6 @@ class ContrastSensitivity6d(LSEProblemWithEdgeLogging):

def __init__(self, thresholds=None):
thresholds = 0.75 if thresholds is None else thresholds
super().__init__(thresholds=thresholds)

# Load the data
self.data = np.loadtxt(
Expand All @@ -108,6 +108,83 @@ def __init__(self, thresholds=None):
y,
)

super().__init__(thresholds=thresholds)

def f(self, X: torch.Tensor) -> torch.Tensor:
# clamp f to 0 since we expect p(x) to be lower-bounded at 0.5
return torch.clamp(self.m.predict(torch.tensor(X))[0], min=0)


class PairwiseDiscrimLowdim(LSEProblemWithEdgeLogging):
name = "pairwise_discrim_lowdim"
bounds = torch.tensor([[-1, 1], [-1, 1], [-1, 1], [-1, 1]], dtype=torch.double).T

def __init__(self, thresholds=None):
if thresholds is None:
jnds = np.arange(-4, 5)
thresholds = np.round(norm.cdf(jnds).tolist(), 3).tolist()
super().__init__(thresholds=thresholds)

def f(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == 4
f1 = novel_discrimination_testfun(x[..., :2])
f2 = novel_discrimination_testfun(x[..., 2:])
return (f1 - f2).to(torch.double)


class PairwiseDiscrimHighdim(LSEProblemWithEdgeLogging):
name = "pairwise_discrim_highdim"
bounds = torch.tensor(
[
[-1, 1],
[-1, 1],
[0.5, 1.5],
[0.05, 0.15],
[0.05, 0.2],
[0, 0.9],
[0, 3.14 / 2],
[0.5, 2],
]
* 2,
dtype=torch.double,
).T

def __init__(self, thresholds=None):
if thresholds is None:
jnds = np.arange(-4, 5)
thresholds = np.round(norm.cdf(jnds).tolist(), 3).tolist()
super().__init__(thresholds=thresholds)

def f(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == 16
f1 = discrim_highdim(x[..., :8])
f2 = discrim_highdim(x[..., 8:])
return torch.tensor(f1 - f2, dtype=torch.double)


class PairwiseHartmann6Binary(LSEProblemWithEdgeLogging):
name = "pairwise_hartmann6_binary"
bounds = torch.stack(
(
torch.zeros(12, dtype=torch.double),
torch.ones(12, dtype=torch.double),
)
)

def __init__(self, thresholds=None):
if thresholds is None:
jnds = np.arange(-4, 5)
thresholds = np.round(norm.cdf(jnds).tolist(), 3).tolist()
super().__init__(thresholds=thresholds)

def f(self, X: torch.Tensor) -> torch.Tensor:
assert X.shape[-1] == 12

def latent_f(X1):
y = torch.tensor([modified_hartmann6(x) for x in X1], dtype=torch.double)
f = 3 * y - 2.0
return f

f1 = latent_f(X[..., :6])
f2 = latent_f(X[..., 6:])
return (f1 - f2).to(torch.double)
26 changes: 23 additions & 3 deletions aepsych/benchmark/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def __init__(self, thresholds: Union[float, List]):
thresholds = [thresholds] if isinstance(thresholds, float) else thresholds
self.thresholds = np.array(thresholds)

min_f = max(self.f_true.min(), -12)
max_f = min(self.f_true.max(), 12)
self.f_thresholds = np.arange(int(np.ceil(min_f)), int(np.floor(max_f))+1)

@property
def metadata(self) -> Dict[str, Any]:
"""A dictionary of metadata passed to the Benchmark to be logged. Each key will become a column in the
Expand Down Expand Up @@ -246,6 +250,11 @@ def true_below_threshold(self) -> np.ndarray:
return (
self.p(self.eval_grid).reshape(1, -1) <= self.thresholds.reshape(-1, 1)
).astype(float)

@cached_property
def true_f_below_threshold(self) -> np.ndarray:
return self.f_true.reshape(1, -1) < self.f_thresholds.reshape(-1, 1)


def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, float]:
"""Evaluate the model with respect to this problem.
Expand Down Expand Up @@ -286,14 +295,25 @@ def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, floa

# Predict p(below threshold) at test points
brier_p_below_thresh = np.mean(2 * np.square(true_p_l - p_l), axis=1)
# Classification error
misclass_on_thresh = np.mean(
# Classification error in the probability space
misclass_on_p_thresh = np.mean(
p_l * (1 - true_p_l) + (1 - p_l) * true_p_l, axis=1
)

# classification error in units of latent space f(x) (or g(x1,x2))
f, _ = model.predict(self.eval_grid)
pred_f = f.numpy().reshape(1, -1) < self.f_thresholds.reshape(-1, 1)

misclass_on_f_thresh = np.mean(
pred_f * (1 - self.true_f_below_threshold) + (1 - pred_f) * self.true_f_below_threshold, axis=1
)

for i_threshold, threshold in enumerate(self.thresholds):
metrics[f"brier_p_below_{threshold}"] = brier_p_below_thresh[i_threshold]
metrics[f"misclass_on_thresh_{threshold}"] = misclass_on_thresh[i_threshold]
metrics[f"misclass_on_thresh_{threshold}"] = misclass_on_p_thresh[i_threshold]

for i_f_threshold, f_threshold in enumerate(self.f_thresholds):
metrics[f"misclass_on_f_thresh_{int(f_threshold)}"] = misclass_on_f_thresh[i_f_threshold]
return metrics


Expand Down
6 changes: 6 additions & 0 deletions tests/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
File renamed without changes.
80 changes: 80 additions & 0 deletions tests/benchmark/test_pairwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

import logging
import unittest

import aepsych.utils_logging as utils_logging
import numpy as np
import torch
from aepsych.benchmark import Benchmark, example_problems
from aepsych.utils import make_scaled_sobol
from scipy.stats import norm

logger = utils_logging.getLogger(logging.ERROR)


class TestPairwise(unittest.TestCase):
def setUp(self) -> None:
self.problem_map = {
"PairwiseHartmann6Binary": "Hartmann6Binary",
"PairwiseDiscrimHighdim": "DiscrimHighDim",
"PairwiseDiscrimLowdim": "DiscrimLowDim",
}

self.problems = {}
for pairwise_problem, single_problem in self.problem_map.items():
self.problems[pairwise_problem] = getattr(
example_problems, pairwise_problem
)()
self.problems[single_problem] = getattr(example_problems, single_problem)()

def test_pairwise_probability(self) -> None:
for pairwise_problem, single_problem in self.problem_map.items():
pairwise_problem = self.problems[pairwise_problem]
single_problem = self.problems[single_problem]

x1, x2 = make_scaled_sobol(single_problem.lb, single_problem.ub, 2)
pairwise_x = torch.concat([x1, x2]).unsqueeze(0)

pairwise_p = pairwise_problem.p(pairwise_x)
f1 = single_problem.f(x1.unsqueeze(0))
f2 = single_problem.f(x2.unsqueeze(0))
single_p = norm.cdf(f1 - f2)
self.assertTrue(np.allclose(pairwise_p, single_p))

def pairwise_benchmark_smoketest(self) -> None:
"""a smoke test to make sure the models and benchmark are set up correctly"""
config = {
"common": {
"stimuli_per_trial": 2,
"outcome_types": "binary",
"strategy_names": "[init_strat, opt_strat]",
},
"init_strat": {"n_trials": 10, "generator": "SobolGenerator"},
"opt_strat": {
"model": "GPClassificationModel",
"generator": "SobolGenerator",
"n_trials": 10,
"refit_every": 10,
},
"GPClassificationModel": {
"inducing_size": 100,
"mean_covar_factory": "default_mean_covar_factory",
"inducing_point_method": "auto",
},
}

pairwise_problems = [self.problems[name] for name in self.problem_map.keys()]

bench = Benchmark(
problems=pairwise_problems,
configs=config,
n_reps=1,
)
bench.run_benchmarks()


if __name__ == "__main__":
unittest.main()
Loading