diff --git a/src/ert/gui/tools/plot/plot_api.py b/src/ert/gui/tools/plot/plot_api.py index e920121f8bf..c2695aa3868 100644 --- a/src/ert/gui/tools/plot/plot_api.py +++ b/src/ert/gui/tools/plot/plot_api.py @@ -42,9 +42,9 @@ def __init__(self) -> None: self._all_ensembles: Optional[List[EnsembleObject]] = None self._timeout = 120 - def _get_ensemble(self, name: str) -> Optional[EnsembleObject]: + def _get_ensemble_by_id(self, id: str) -> Optional[EnsembleObject]: for ensemble in self.get_all_ensembles(): - if ensemble.name == name: + if ensemble.id == id: return ensemble return None @@ -149,7 +149,7 @@ def all_data_type_keys(self) -> List[PlotApiKeyDefinition]: return list(all_keys.values()) - def data_for_key(self, ensemble_name: str, key: str) -> pd.DataFrame: + def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame: """Returns a pandas DataFrame with the datapoints for a given key for a given ensemble. The row index is the realization number, and the columns are an index over the indexes/dates""" @@ -157,7 +157,7 @@ def data_for_key(self, ensemble_name: str, key: str) -> pd.DataFrame: if key.startswith("LOG10_"): key = key[6:] - ensemble = self._get_ensemble(ensemble_name) + ensemble = self._get_ensemble_by_id(ensemble_id) if not ensemble: return pd.DataFrame() @@ -182,15 +182,15 @@ def data_for_key(self, ensemble_name: str, key: str) -> pd.DataFrame: except ValueError: return df - def observations_for_key(self, ensemble_names: List[str], key: str) -> pd.DataFrame: + def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFrame: """Returns a pandas DataFrame with the datapoints for a given observation key for a given ensembles. The row index is the realization number, and the column index is a multi-index with (obs_key, index/date, obs_index), where index/date is used to relate the observation to the data point it relates to, and obs_index is the index for the observation itself""" all_observations = pd.DataFrame() - for ensemble_name in ensemble_names: - ensemble = self._get_ensemble(ensemble_name) + for ensemble_id in ensemble_ids: + ensemble = self._get_ensemble_by_id(ensemble_id) if not ensemble: continue @@ -206,7 +206,7 @@ def observations_for_key(self, ensemble_names: List[str], key: str) -> pd.DataFr obs = response.json()[0] except (KeyError, IndexError, JSONDecodeError) as e: raise httpx.RequestError( - f"Observation schema might have changed key={key}, ensemble_name={ensemble_name}, e={e}" + f"Observation schema might have changed key={key}, ensemble_name={ensemble.name}, e={e}" ) from e try: int(obs["x_axis"][0]) @@ -226,19 +226,19 @@ def observations_for_key(self, ensemble_names: List[str], key: str) -> pd.DataFr return all_observations.T - def history_data(self, key: str, ensembles: Optional[List[str]]) -> pd.DataFrame: + def history_data(self, key: str, ensemble_ids: Optional[List[str]]) -> pd.DataFrame: """Returns a pandas DataFrame with the data points for the history for a given data key, if any. The row index is the index/date and the column index is the key.""" - if ensembles: - for ensemble in ensembles: + if ensemble_ids: + for ensemble_id in ensemble_ids: if ":" in key: head, tail = key.split(":", 2) history_key = f"{head}H:{tail}" else: history_key = f"{key}H" - df = self.data_for_key(ensemble, history_key) + df = self.data_for_key(ensemble_id, history_key) if not df.empty: df = df.T @@ -253,9 +253,9 @@ def history_data(self, key: str, ensembles: Optional[List[str]]) -> pd.DataFrame return pd.DataFrame() def std_dev_for_parameter( - self, key: str, ensemble_name: str, z: int + self, key: str, ensemble_id: str, z: int ) -> npt.NDArray[np.float32]: - ensemble = self._get_ensemble(ensemble_name) + ensemble = self._get_ensemble_by_id(ensemble_id) if not ensemble: return np.array([]) diff --git a/src/ert/gui/tools/plot/plot_window.py b/src/ert/gui/tools/plot/plot_window.py index 4537957f09e..50b1fd27959 100644 --- a/src/ert/gui/tools/plot/plot_window.py +++ b/src/ert/gui/tools/plot/plot_window.py @@ -188,7 +188,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None: for ensemble in selected_ensembles: try: ensemble_to_data_map[ensemble] = self._api.data_for_key( - ensemble.name, key + ensemble.id, key ) except (RequestError, TimeoutError) as e: logger.exception(e) @@ -198,7 +198,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None: if key_def.observations and selected_ensembles: try: observations = self._api.observations_for_key( - [ensembles.name for ensembles in selected_ensembles], key + [ensembles.id for ensembles in selected_ensembles], key ) except (RequestError, TimeoutError) as e: logger.exception(e) @@ -218,7 +218,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None: for ensemble in selected_ensembles: try: std_dev_images[ensemble.name] = self._api.std_dev_for_parameter( - key, ensemble.name, layer + key, ensemble.id, layer ) except (RequestError, TimeoutError) as e: logger.exception(e) @@ -237,7 +237,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None: try: plot_context.history_data = self._api.history_data( key, - [e.name for e in plot_context.ensembles()], + [e.id for e in plot_context.ensembles()], ) except (RequestError, TimeoutError) as e: diff --git a/tests/unit_tests/gui/tools/plot/test_plot_api.py b/tests/unit_tests/gui/tools/plot/test_plot_api.py index 864d07ecfbc..85f8b2b290e 100644 --- a/tests/unit_tests/gui/tools/plot/test_plot_api.py +++ b/tests/unit_tests/gui/tools/plot/test_plot_api.py @@ -74,14 +74,13 @@ def test_can_load_data_and_observations(api): "observations": True, }, } - - ensemble_name = "default_0" + ensemble = next(x for x in api.get_all_ensembles() if x.name == "default_0") for key, value in keys.items(): observations = value["observations"] if observations: - obs_data = api.observations_for_key([ensemble_name], key) + obs_data = api.observations_for_key([ensemble.id], key) assert not obs_data.empty - data = api.data_for_key(ensemble_name, key) + data = api.data_for_key(ensemble.id, key) assert not data.empty @@ -99,21 +98,32 @@ def test_all_data_type_keys(api): def test_load_history_data(api): - df = api.history_data(ensembles=["default_0"], key="FOPR") + ens_id = next(ens.id for ens in api.get_all_ensembles() if ens.name == "default_0") + df = api.history_data(ensemble_ids=[ens_id], key="FOPR") assert_frame_equal( df, pd.DataFrame({1: [0.2, 0.2, 1.2], 3: [1.0, 1.1, 1.2], 4: [1.0, 1.1, 1.3]}) ) def test_load_history_data_searches_until_history_found(api): - df = api.history_data(ensembles=["no-history", "default_0"], key="FOPR") + ensemble_ids = [ + ens.id + for ens in api.get_all_ensembles() + if ens.name in ["no-history", "default_0"] + ] + df = api.history_data(ensemble_ids=ensemble_ids, key="FOPR") assert_frame_equal( df, pd.DataFrame({1: [0.2, 0.2, 1.2], 3: [1.0, 1.1, 1.2], 4: [1.0, 1.1, 1.3]}) ) def test_load_history_data_returns_empty_frame_if_no_history(api): - df = api.history_data(ensembles=["no-history", "still-no-history"], key="FOPR") + ensemble_ids = [ + ens.id + for ens in api.get_all_ensembles() + if ens.name in ["no-history", "still-no-history"] + ] + df = api.history_data(ensemble_ids=ensemble_ids, key="FOPR") assert_frame_equal(df, pd.DataFrame()) @@ -129,9 +139,10 @@ def test_plot_api_request_errors_all_data_type_keys(api, mocker): def test_plot_api_request_errors(api): - ensemble_name = "default_0" + ensemble = next(x for x in api.get_all_ensembles() if x.name == "default_0") + with pytest.raises(httpx.RequestError): - api.observations_for_key([ensemble_name], "should_not_be_there") + api.observations_for_key([ensemble.id], "should_not_be_there") with pytest.raises(httpx.RequestError): - api.data_for_key(ensemble_name, "should_not_be_there") + api.data_for_key(ensemble.id, "should_not_be_there")