diff --git a/ensemble_md/replica_exchange_EE.py b/ensemble_md/replica_exchange_EE.py index 46817c8..65606fd 100644 --- a/ensemble_md/replica_exchange_EE.py +++ b/ensemble_md/replica_exchange_EE.py @@ -833,7 +833,7 @@ def identify_swappable_pairs(self, states, state_ranges, dhdl_files, iteration=N """ n_sim = len(states) sim_idx = list(range(n_sim)) - states_for_swap = [] + states_for_swap = {} if self.proposal == 'forced_swap': potential_swappables = [] if iteration % 2 == 0: # Swap up for self.n_sim - 1 swaps @@ -842,24 +842,26 @@ def identify_swappable_pairs(self, states, state_ranges, dhdl_files, iteration=N else: # and then swap down for self.n_sim - 1 swaps and repeat for n in np.arange(1, n_sim-1, 2): potential_swappables.append([n, n+1]) - swap_index, swappables = [], [] + swappables = [] + swap_index = {} for swap in potential_swappables: index, state = self._deter_swap_index(swap, dhdl_files, self.add_swappables) if len(index) == 2: - swap_index.append(index) + swap_index[swap] = index swappables.append(swap) - states_for_swap.append(state) + states_for_swap[swap] = state elif self.proposal == 'forced_random': potential_swappables = [] for n in np.arange(0, n_sim-1,1): potential_swappables.append([n, n+1]) - swap_index, swappables = [], [] + swappables = [] + swap_index = {} for swap in potential_swappables: index, state = self._deter_swap_index(swap, dhdl_files, self.add_swappables) if len(index) == 2: - swap_index.append(index) + swap_index[swap] = index swappables.append(swap) - states_for_swap.append(state) + states_for_swap[swap] = state else: all_pairs = list(combinations(sim_idx, 2)) @@ -1036,9 +1038,10 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None): if swap_bool is True: swap_list.append(swap) if self.proposal == "forced_swap": - swap_index_accept.append(swap_index[i]) - states_modified[swap[0]] = states_for_swap[i][0] - states_modified[swap[1]] = states_for_swap[i][1] + swap_index_accept.append(swap_index[swap]) + states_i = states_for_swap[swap] + states_modified[swap[0]] = states_i[0] + states_modified[swap[1]] = states_i[1] else: swap_index_accept.append([-1, -1]) # Determine which