From b5d457d8992badf80590760aaa7538c8c9edefbe Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sun, 31 Dec 2023 13:10:41 +0530 Subject: [PATCH 1/2] refactor: load eq_param_index of BundleSolvers correctly --- neurodiffeq/solvers_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/neurodiffeq/solvers_utils.py b/neurodiffeq/solvers_utils.py index 36e7ba9..ca97712 100644 --- a/neurodiffeq/solvers_utils.py +++ b/neurodiffeq/solvers_utils.py @@ -530,7 +530,9 @@ def load(cls, t_min=t_min, t_max=t_max, theta_min=tuple(load_dict['solver'].r_min[1:]), - theta_max=tuple(load_dict['solver'].r_max[1:])) + theta_max=tuple(load_dict['solver'].r_max[1:]), + eq_param_index=load_dict['solver'].eq_param_index + ) if best_nets != None: solver.best_nets = best_nets From b90ab428a3cf24455e4947ce0c973a7b191886ed Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sun, 31 Dec 2023 13:17:26 +0530 Subject: [PATCH 2/2] fix: use self.eq_param_index inside _diff_eqs_wrapper --- neurodiffeq/solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neurodiffeq/solvers.py b/neurodiffeq/solvers.py index 160552d..e950655 100644 --- a/neurodiffeq/solvers.py +++ b/neurodiffeq/solvers.py @@ -1357,7 +1357,7 @@ def __init__(self, ode_system, conditions, t_min, t_max, def _diff_eqs_wrapper(*variables): funcs_and_coords = variables[:N_FUNCTIONS + N_COORDS] - eq_params = tuple(variables[idx] for idx in eq_param_index) + eq_params = tuple(variables[idx] for idx in self.eq_param_index) return ode_system(*funcs_and_coords, *eq_params) super(BundleSolver1D, self).__init__(