Skip to content

Commit

Permalink
Fix indexing bug with infeasible experiments for IDAKLUSolver (#4541)
Browse files Browse the repository at this point in the history
* fix interp indexing

Co-Authored-By: Pip Liggins <[email protected]>

* simplify indexing

---------

Co-authored-by: Pip Liggins <[email protected]>
  • Loading branch information
MarcBerliner and pipliggins authored Oct 23, 2024
1 parent 164f71e commit 9560875
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
- Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416))

## Bug Fixes

- Fixed bug in post-processing solutions with infeasible experiments using the (`IDAKLUSolver`). ([#4541](https://github.com/pybamm-team/PyBaMM/pull/4541))
- Disabled IREE on MacOS due to compatibility issues and added the CasADI
path to the environment to resolve issues on MacOS and Linux. Windows
users may still experience issues with interpolation. ([#4528](https://github.com/pybamm-team/PyBaMM/pull/4528))
Expand Down
9 changes: 9 additions & 0 deletions src/pybamm/solvers/c_solvers/idaklu/observe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class TimeSeriesInterpolator {
) {
for (size_t i = 0; i < ts_data_np.size(); i++) {
const auto& t_data = ts_data_np[i].unchecked<1>();
// Continue if there is no data
if (t_data.size() == 0) {
continue;
}

const realtype t_data_final = t_data(t_data.size() - 1);
realtype t_interp_next = t_interp(i_interp);
// Continue if the next interpolation point is beyond the final data point
Expand Down Expand Up @@ -227,6 +232,10 @@ class TimeSeriesProcessor {
int i_entries = 0;
for (size_t i = 0; i < ts.size(); i++) {
const auto& t = ts[i].unchecked<1>();
// Continue if there is no data
if (t.size() == 0) {
continue;
}
const auto& y = ys[i].unchecked<2>();
const auto input = inputs[i].data();
const auto func = *funcs[i];
Expand Down
31 changes: 15 additions & 16 deletions src/pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,22 @@ def _setup_cpp_inputs(self, t, full_range):
ys = self.all_ys
yps = self.all_yps
inputs = self.all_inputs_casadi
# Find the indices of the time points to observe
if full_range:
idxs = range(len(ts))
else:
idxs = _find_ts_indices(ts, t)

if isinstance(idxs, list):
# Extract the time points and inputs
ts = [ts[idx] for idx in idxs]
ys = [ys[idx] for idx in idxs]
if self.hermite_interpolation:
yps = [yps[idx] for idx in idxs]
inputs = [self.all_inputs_casadi[idx] for idx in idxs]
# Remove all empty ts
idxs = np.where([ti.size > 0 for ti in ts])[0]

# Find the indices of the time points to observe
if not full_range:
ts_nonempty = [ts[idx] for idx in idxs]
idxs_subset = _find_ts_indices(ts_nonempty, t)
idxs = idxs[idxs_subset]

# Extract the time points and inputs
ts = [ts[idx] for idx in idxs]
ys = [ys[idx] for idx in idxs]
if self.hermite_interpolation:
yps = [yps[idx] for idx in idxs]
inputs = [self.all_inputs_casadi[idx] for idx in idxs]

is_f_contiguous = _is_f_contiguous(ys)

Expand Down Expand Up @@ -977,8 +980,4 @@ def _find_ts_indices(ts, t):
if (t[-1] > ts[-1][-1]) and (len(indices) == 0 or indices[-1] != len(ts) - 1):
indices.append(len(ts) - 1)

if len(indices) == len(ts):
# All indices are included
return range(len(ts))

return indices
12 changes: 8 additions & 4 deletions src/pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,11 +580,15 @@ def _update_variable(self, variable):
# Iterate through all models, some may be in the list several times and
# therefore only get set up once
vars_casadi = []
for i, (model, ys, inputs, var_pybamm) in enumerate(
zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm)
for i, (model, ts, ys, inputs, var_pybamm) in enumerate(
zip(self.all_models, self.all_ts, self.all_ys, self.all_inputs, vars_pybamm)
):
if ys.size == 0 and var_pybamm.has_symbol_of_classes(
pybamm.expression_tree.state_vector.StateVector
if (
ys.size == 0
and var_pybamm.has_symbol_of_classes(
pybamm.expression_tree.state_vector.StateVector
)
and not ts.size == 0
):
raise KeyError(
f"Cannot process variable '{variable}' as it was not part of the "
Expand Down

0 comments on commit 9560875

Please sign in to comment.