From 63992dd02e4468cee75939604acf0396eba94046 Mon Sep 17 00:00:00 2001 From: maffettone Date: Fri, 12 Apr 2024 18:56:10 -0700 Subject: [PATCH] fix: return report dict --- pdf_agents/scientific_value.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pdf_agents/scientific_value.py b/pdf_agents/scientific_value.py index d73da03..64e669c 100644 --- a/pdf_agents/scientific_value.py +++ b/pdf_agents/scientific_value.py @@ -72,7 +72,7 @@ def __init__( bounds: torch.Tensor, device: torch.device = None, num_restarts: int = 10, - raw_samples: int = 20, + raw_samples: int = 128, observable_distance_function: Optional[Callable] = None, ucb_beta=1.0, **kwargs @@ -122,7 +122,7 @@ def tell(self, x, y): def report(self): value = self._value_function(np.array(self.independent_cache), np.array(self.observable_cache)) - dict(latest_data=self.tell_cache[-1], cache_len=len(self.independent_cache), latest_value=value[-1]) + return dict(latest_data=self.tell_cache[-1], cache_len=len(self.independent_cache), latest_value=value[-1]) def ask(self, batch_size: int = 1): value = self._value_function(np.array(self.independent_cache), np.array(self.observable_cache))