Skip to content

Commit

Permalink
Backport: Get ensembles by id in dark storage
Browse files Browse the repository at this point in the history
Backport of e33136f
  • Loading branch information
yngve-sk authored Sep 6, 2024
1 parent 4d108af commit c15ab0b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 28 deletions.
28 changes: 14 additions & 14 deletions src/ert/gui/tools/plot/plot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(self) -> None:
self._all_ensembles: Optional[List[EnsembleObject]] = None
self._timeout = 120

def _get_ensemble(self, name: str) -> Optional[EnsembleObject]:
def _get_ensemble_by_id(self, id: str) -> Optional[EnsembleObject]:
for ensemble in self.get_all_ensembles():
if ensemble.name == name:
if ensemble.id == id:
return ensemble
return None

Expand Down Expand Up @@ -149,15 +149,15 @@ def all_data_type_keys(self) -> List[PlotApiKeyDefinition]:

return list(all_keys.values())

def data_for_key(self, ensemble_name: str, key: str) -> pd.DataFrame:
def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame:
"""Returns a pandas DataFrame with the datapoints for a given key for a given
ensemble. The row index is the realization number, and the columns are an index
over the indexes/dates"""

if key.startswith("LOG10_"):
key = key[6:]

ensemble = self._get_ensemble(ensemble_name)
ensemble = self._get_ensemble_by_id(ensemble_id)
if not ensemble:
return pd.DataFrame()

Expand All @@ -182,15 +182,15 @@ def data_for_key(self, ensemble_name: str, key: str) -> pd.DataFrame:
except ValueError:
return df

def observations_for_key(self, ensemble_names: List[str], key: str) -> pd.DataFrame:
def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFrame:
"""Returns a pandas DataFrame with the datapoints for a given observation key
for a given ensembles. The row index is the realization number, and the column index
is a multi-index with (obs_key, index/date, obs_index), where index/date is
used to relate the observation to the data point it relates to, and obs_index
is the index for the observation itself"""
all_observations = pd.DataFrame()
for ensemble_name in ensemble_names:
ensemble = self._get_ensemble(ensemble_name)
for ensemble_id in ensemble_ids:
ensemble = self._get_ensemble_by_id(ensemble_id)
if not ensemble:
continue

Expand All @@ -206,7 +206,7 @@ def observations_for_key(self, ensemble_names: List[str], key: str) -> pd.DataFr
obs = response.json()[0]
except (KeyError, IndexError, JSONDecodeError) as e:
raise httpx.RequestError(
f"Observation schema might have changed key={key}, ensemble_name={ensemble_name}, e={e}"
f"Observation schema might have changed key={key}, ensemble_name={ensemble.name}, e={e}"
) from e
try:
int(obs["x_axis"][0])
Expand All @@ -226,19 +226,19 @@ def observations_for_key(self, ensemble_names: List[str], key: str) -> pd.DataFr

return all_observations.T

def history_data(self, key: str, ensembles: Optional[List[str]]) -> pd.DataFrame:
def history_data(self, key: str, ensemble_ids: Optional[List[str]]) -> pd.DataFrame:
"""Returns a pandas DataFrame with the data points for the history for a
given data key, if any. The row index is the index/date and the column
index is the key."""
if ensembles:
for ensemble in ensembles:
if ensemble_ids:
for ensemble_id in ensemble_ids:
if ":" in key:
head, tail = key.split(":", 2)
history_key = f"{head}H:{tail}"
else:
history_key = f"{key}H"

df = self.data_for_key(ensemble, history_key)
df = self.data_for_key(ensemble_id, history_key)

if not df.empty:
df = df.T
Expand All @@ -253,9 +253,9 @@ def history_data(self, key: str, ensembles: Optional[List[str]]) -> pd.DataFrame
return pd.DataFrame()

def std_dev_for_parameter(
self, key: str, ensemble_name: str, z: int
self, key: str, ensemble_id: str, z: int
) -> npt.NDArray[np.float32]:
ensemble = self._get_ensemble(ensemble_name)
ensemble = self._get_ensemble_by_id(ensemble_id)
if not ensemble:
return np.array([])

Expand Down
8 changes: 4 additions & 4 deletions src/ert/gui/tools/plot/plot_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None:
for ensemble in selected_ensembles:
try:
ensemble_to_data_map[ensemble] = self._api.data_for_key(
ensemble.name, key
ensemble.id, key
)
except (RequestError, TimeoutError) as e:
logger.exception(e)
Expand All @@ -198,7 +198,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None:
if key_def.observations and selected_ensembles:
try:
observations = self._api.observations_for_key(
[ensembles.name for ensembles in selected_ensembles], key
[ensembles.id for ensembles in selected_ensembles], key
)
except (RequestError, TimeoutError) as e:
logger.exception(e)
Expand All @@ -218,7 +218,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None:
for ensemble in selected_ensembles:
try:
std_dev_images[ensemble.name] = self._api.std_dev_for_parameter(
key, ensemble.name, layer
key, ensemble.id, layer
)
except (RequestError, TimeoutError) as e:
logger.exception(e)
Expand All @@ -237,7 +237,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None:
try:
plot_context.history_data = self._api.history_data(
key,
[e.name for e in plot_context.ensembles()],
[e.id for e in plot_context.ensembles()],
)

except (RequestError, TimeoutError) as e:
Expand Down
31 changes: 21 additions & 10 deletions tests/unit_tests/gui/tools/plot/test_plot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,13 @@ def test_can_load_data_and_observations(api):
"observations": True,
},
}

ensemble_name = "default_0"
ensemble = next(x for x in api.get_all_ensembles() if x.name == "default_0")
for key, value in keys.items():
observations = value["observations"]
if observations:
obs_data = api.observations_for_key([ensemble_name], key)
obs_data = api.observations_for_key([ensemble.id], key)
assert not obs_data.empty
data = api.data_for_key(ensemble_name, key)
data = api.data_for_key(ensemble.id, key)
assert not data.empty


Expand All @@ -99,21 +98,32 @@ def test_all_data_type_keys(api):


def test_load_history_data(api):
df = api.history_data(ensembles=["default_0"], key="FOPR")
ens_id = next(ens.id for ens in api.get_all_ensembles() if ens.name == "default_0")
df = api.history_data(ensemble_ids=[ens_id], key="FOPR")
assert_frame_equal(
df, pd.DataFrame({1: [0.2, 0.2, 1.2], 3: [1.0, 1.1, 1.2], 4: [1.0, 1.1, 1.3]})
)


def test_load_history_data_searches_until_history_found(api):
df = api.history_data(ensembles=["no-history", "default_0"], key="FOPR")
ensemble_ids = [
ens.id
for ens in api.get_all_ensembles()
if ens.name in ["no-history", "default_0"]
]
df = api.history_data(ensemble_ids=ensemble_ids, key="FOPR")
assert_frame_equal(
df, pd.DataFrame({1: [0.2, 0.2, 1.2], 3: [1.0, 1.1, 1.2], 4: [1.0, 1.1, 1.3]})
)


def test_load_history_data_returns_empty_frame_if_no_history(api):
df = api.history_data(ensembles=["no-history", "still-no-history"], key="FOPR")
ensemble_ids = [
ens.id
for ens in api.get_all_ensembles()
if ens.name in ["no-history", "still-no-history"]
]
df = api.history_data(ensemble_ids=ensemble_ids, key="FOPR")
assert_frame_equal(df, pd.DataFrame())


Expand All @@ -129,9 +139,10 @@ def test_plot_api_request_errors_all_data_type_keys(api, mocker):


def test_plot_api_request_errors(api):
ensemble_name = "default_0"
ensemble = next(x for x in api.get_all_ensembles() if x.name == "default_0")

with pytest.raises(httpx.RequestError):
api.observations_for_key([ensemble_name], "should_not_be_there")
api.observations_for_key([ensemble.id], "should_not_be_there")

with pytest.raises(httpx.RequestError):
api.data_for_key(ensemble_name, "should_not_be_there")
api.data_for_key(ensemble.id, "should_not_be_there")

0 comments on commit c15ab0b

Please sign in to comment.