From 0e7ba99a67b39c3ce90d59411840c3de91430c76 Mon Sep 17 00:00:00 2001 From: Craig Sanders Date: Wed, 25 Sep 2024 15:37:35 -0700 Subject: [PATCH] implement pairwise benchmarks Summary: Implements benchmarks compatible with the pairwise kernel Differential Revision: D63406907 --- aepsych/benchmark/example_problems.py | 87 +++++++++++++++++++++++-- aepsych/benchmark/problem.py | 26 +++++++- tests/benchmark/__init__.py | 6 ++ tests/{ => benchmark}/test_benchmark.py | 0 tests/benchmark/test_pairwise.py | 80 +++++++++++++++++++++++ 5 files changed, 191 insertions(+), 8 deletions(-) create mode 100644 tests/benchmark/__init__.py rename tests/{ => benchmark}/test_benchmark.py (100%) create mode 100644 tests/benchmark/test_pairwise.py diff --git a/aepsych/benchmark/example_problems.py b/aepsych/benchmark/example_problems.py index 9cc64435d..d31726ed8 100644 --- a/aepsych/benchmark/example_problems.py +++ b/aepsych/benchmark/example_problems.py @@ -1,15 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. import os + import numpy as np import torch -from aepsych.models import GPClassificationModel +from aepsych.benchmark.problem import LSEProblemWithEdgeLogging from aepsych.benchmark.test_functions import ( - modified_hartmann6, discrim_highdim, + modified_hartmann6, novel_discrimination_testfun, ) -from aepsych.benchmark.problem import LSEProblemWithEdgeLogging - +from aepsych.models import GPClassificationModel +from scipy.stats import norm """The DiscrimLowDim, DiscrimHighDim, ContrastSensitivity6d, and Hartmann6Binary classes are copied from bernoulli_lse github repository (https://github.com/facebookresearch/bernoulli_lse) @@ -84,7 +85,6 @@ class ContrastSensitivity6d(LSEProblemWithEdgeLogging): def __init__(self, thresholds=None): thresholds = 0.75 if thresholds is None else thresholds - super().__init__(thresholds=thresholds) # Load the data self.data = np.loadtxt( @@ -108,6 +108,83 @@ def __init__(self, thresholds=None): y, ) + super().__init__(thresholds=thresholds) + def f(self, X: torch.Tensor) -> torch.Tensor: # clamp f to 0 since we expect p(x) to be lower-bounded at 0.5 return torch.clamp(self.m.predict(torch.tensor(X))[0], min=0) + + +class PairwiseDiscrimLowdim(LSEProblemWithEdgeLogging): + name = "pairwise_discrim_lowdim" + bounds = torch.tensor([[-1, 1], [-1, 1], [-1, 1], [-1, 1]], dtype=torch.double).T + + def __init__(self, thresholds=None): + if thresholds is None: + jnds = np.arange(-4, 5) + thresholds = np.round(norm.cdf(jnds).tolist(), 3).tolist() + super().__init__(thresholds=thresholds) + + def f(self, x: torch.Tensor) -> torch.Tensor: + assert x.shape[-1] == 4 + f1 = novel_discrimination_testfun(x[..., :2]) + f2 = novel_discrimination_testfun(x[..., 2:]) + return (f1 - f2).to(torch.double) + + +class PairwiseDiscrimHighdim(LSEProblemWithEdgeLogging): + name = "pairwise_discrim_highdim" + bounds = torch.tensor( + [ + [-1, 1], + [-1, 1], + [0.5, 1.5], + [0.05, 0.15], + [0.05, 0.2], + [0, 0.9], + [0, 3.14 / 2], + [0.5, 2], + ] + * 2, + dtype=torch.double, + ).T + + def __init__(self, thresholds=None): + if thresholds is None: + jnds = np.arange(-4, 5) + thresholds = np.round(norm.cdf(jnds).tolist(), 3).tolist() + super().__init__(thresholds=thresholds) + + def f(self, x: torch.Tensor) -> torch.Tensor: + assert x.shape[-1] == 16 + f1 = discrim_highdim(x[..., :8]) + f2 = discrim_highdim(x[..., 8:]) + return torch.tensor(f1 - f2, dtype=torch.double) + + +class PairwiseHartmann6Binary(LSEProblemWithEdgeLogging): + name = "pairwise_hartmann6_binary" + bounds = torch.stack( + ( + torch.zeros(12, dtype=torch.double), + torch.ones(12, dtype=torch.double), + ) + ) + + def __init__(self, thresholds=None): + if thresholds is None: + jnds = np.arange(-4, 5) + thresholds = np.round(norm.cdf(jnds).tolist(), 3).tolist() + super().__init__(thresholds=thresholds) + + def f(self, X: torch.Tensor) -> torch.Tensor: + assert X.shape[-1] == 12 + + def latent_f(X1): + y = torch.tensor([modified_hartmann6(x) for x in X1], dtype=torch.double) + f = 3 * y - 2.0 + return f + + f1 = latent_f(X[..., :6]) + f2 = latent_f(X[..., 6:]) + return (f1 - f2).to(torch.double) diff --git a/aepsych/benchmark/problem.py b/aepsych/benchmark/problem.py index 299a56290..5510924f3 100644 --- a/aepsych/benchmark/problem.py +++ b/aepsych/benchmark/problem.py @@ -213,6 +213,10 @@ def __init__(self, thresholds: Union[float, List]): thresholds = [thresholds] if isinstance(thresholds, float) else thresholds self.thresholds = np.array(thresholds) + min_f = max(self.f_true.min(), -12) + max_f = min(self.f_true.max(), 12) + self.f_thresholds = np.arange(int(np.ceil(min_f)), int(np.floor(max_f))+1) + @property def metadata(self) -> Dict[str, Any]: """A dictionary of metadata passed to the Benchmark to be logged. Each key will become a column in the @@ -246,6 +250,11 @@ def true_below_threshold(self) -> np.ndarray: return ( self.p(self.eval_grid).reshape(1, -1) <= self.thresholds.reshape(-1, 1) ).astype(float) + + @cached_property + def true_f_below_threshold(self) -> np.ndarray: + return self.f_true.reshape(1, -1) < self.f_thresholds.reshape(-1, 1) + def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, float]: """Evaluate the model with respect to this problem. @@ -286,14 +295,25 @@ def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, floa # Predict p(below threshold) at test points brier_p_below_thresh = np.mean(2 * np.square(true_p_l - p_l), axis=1) - # Classification error - misclass_on_thresh = np.mean( + # Classification error in the probability space + misclass_on_p_thresh = np.mean( p_l * (1 - true_p_l) + (1 - p_l) * true_p_l, axis=1 ) + # classification error in units of latent space f(x) (or g(x1,x2)) + f, _ = model.predict(self.eval_grid) + pred_f = f.numpy().reshape(1, -1) < self.f_thresholds.reshape(-1, 1) + + misclass_on_f_thresh = np.mean( + pred_f * (1 - self.true_f_below_threshold) + (1 - pred_f) * self.true_f_below_threshold, axis=1 + ) + for i_threshold, threshold in enumerate(self.thresholds): metrics[f"brier_p_below_{threshold}"] = brier_p_below_thresh[i_threshold] - metrics[f"misclass_on_thresh_{threshold}"] = misclass_on_thresh[i_threshold] + metrics[f"misclass_on_thresh_{threshold}"] = misclass_on_p_thresh[i_threshold] + + for i_f_threshold, f_threshold in enumerate(self.f_thresholds): + metrics[f"misclass_on_f_thresh_{int(f_threshold)}"] = misclass_on_f_thresh[i_f_threshold] return metrics diff --git a/tests/benchmark/__init__.py b/tests/benchmark/__init__.py new file mode 100644 index 000000000..8b2df349c --- /dev/null +++ b/tests/benchmark/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/test_benchmark.py b/tests/benchmark/test_benchmark.py similarity index 100% rename from tests/test_benchmark.py rename to tests/benchmark/test_benchmark.py diff --git a/tests/benchmark/test_pairwise.py b/tests/benchmark/test_pairwise.py new file mode 100644 index 000000000..75d8541ff --- /dev/null +++ b/tests/benchmark/test_pairwise.py @@ -0,0 +1,80 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import logging +import unittest + +import aepsych.utils_logging as utils_logging +import numpy as np +import torch +from aepsych.benchmark import Benchmark, example_problems +from aepsych.utils import make_scaled_sobol +from scipy.stats import norm + +logger = utils_logging.getLogger(logging.ERROR) + + +class TestPairwise(unittest.TestCase): + def setUp(self) -> None: + self.problem_map = { + "PairwiseHartmann6Binary": "Hartmann6Binary", + "PairwiseDiscrimHighdim": "DiscrimHighDim", + "PairwiseDiscrimLowdim": "DiscrimLowDim", + } + + self.problems = {} + for pairwise_problem, single_problem in self.problem_map.items(): + self.problems[pairwise_problem] = getattr( + example_problems, pairwise_problem + )() + self.problems[single_problem] = getattr(example_problems, single_problem)() + + def test_pairwise_probability(self) -> None: + for pairwise_problem, single_problem in self.problem_map.items(): + pairwise_problem = self.problems[pairwise_problem] + single_problem = self.problems[single_problem] + + x1, x2 = make_scaled_sobol(single_problem.lb, single_problem.ub, 2) + pairwise_x = torch.concat([x1, x2]).unsqueeze(0) + + pairwise_p = pairwise_problem.p(pairwise_x) + f1 = single_problem.f(x1.unsqueeze(0)) + f2 = single_problem.f(x2.unsqueeze(0)) + single_p = norm.cdf(f1 - f2) + self.assertTrue(np.allclose(pairwise_p, single_p)) + + def pairwise_benchmark_smoketest(self) -> None: + """a smoke test to make sure the models and benchmark are set up correctly""" + config = { + "common": { + "stimuli_per_trial": 2, + "outcome_types": "binary", + "strategy_names": "[init_strat, opt_strat]", + }, + "init_strat": {"n_trials": 10, "generator": "SobolGenerator"}, + "opt_strat": { + "model": "GPClassificationModel", + "generator": "SobolGenerator", + "n_trials": 10, + "refit_every": 10, + }, + "GPClassificationModel": { + "inducing_size": 100, + "mean_covar_factory": "default_mean_covar_factory", + "inducing_point_method": "auto", + }, + } + + pairwise_problems = [self.problems[name] for name in self.problem_map.keys()] + + bench = Benchmark( + problems=pairwise_problems, + configs=config, + n_reps=1, + ) + bench.run_benchmarks() + + +if __name__ == "__main__": + unittest.main()