Skip to content

Commit

Permalink
Fix forced_random l state assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
ajfriedman22 committed Jan 9, 2025
1 parent 163d4f4 commit 81f8482
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def identify_swappable_pairs(self, states, state_ranges, dhdl_files, iteration=N
"""
n_sim = len(states)
sim_idx = list(range(n_sim))
states_for_swap = {}
states_for_swap = []
if self.proposal == 'forced_swap':
potential_swappables = []
if iteration % 2 == 0: # Swap up for self.n_sim - 1 swaps
Expand All @@ -842,26 +842,24 @@ def identify_swappable_pairs(self, states, state_ranges, dhdl_files, iteration=N
else: # and then swap down for self.n_sim - 1 swaps and repeat
for n in np.arange(1, n_sim-1, 2):
potential_swappables.append([n, n+1])
swappables = []
swap_index = {}
swappables, swap_index = [], []
for swap in potential_swappables:
index, state = self._deter_swap_index(swap, dhdl_files, self.add_swappables)
if len(index) == 2:
swap_index[swap] = index
swap_index.append(index)
swappables.append(swap)
states_for_swap[swap] = state
states_for_swap.append(state)
elif self.proposal == 'forced_random':
potential_swappables = []
for n in np.arange(0, n_sim-1,1):
potential_swappables.append([n, n+1])
swappables = []
swap_index = {}
swappables, swap_index = [], []
for swap in potential_swappables:
index, state = self._deter_swap_index(swap, dhdl_files, self.add_swappables)
if len(index) == 2:
swap_index[swap] = index
swap_index.append(index)
swappables.append(swap)
states_for_swap[swap] = state
states_for_swap.append(state)
else:
all_pairs = list(combinations(sim_idx, 2))

Expand Down Expand Up @@ -972,6 +970,7 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None):
state_ranges = copy.deepcopy(self.state_ranges)
# states_copy = copy.deepcopy(states) # only for re-identifying swappable pairs given updated state_ranges --> was needed for the multiple exchange proposal scheme # noqa: E501
swappables, swap_index, states_for_swap = self.identify_swappable_pairs(states, state_ranges, dhdl_files, iteration) # noqa: E501
all_swappables = swappables.copy()

# Note that if there is only 1 swappable pair, then it will still be the only swappable pair
# after an attempted swap is accepted. Therefore, there is no need to perform multiple swaps or re-identify
Expand Down Expand Up @@ -1037,11 +1036,13 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None):

if swap_bool is True:
swap_list.append(swap)
if self.proposal == "forced_swap":
swap_index_accept.append(swap_index[swap])
states_i = states_for_swap[swap]
states_modified[swap[0]] = states_i[0]
states_modified[swap[1]] = states_i[1]
if self.proposal == "forced_swap" or self.proposal == "forced_random":
for p, p_swap in enumerate(all_swappables):
if p_swap == swap:
break
swap_index_accept.append(swap_index[p])
states_modified[swap[0]] = states_for_swap[p][0]
states_modified[swap[1]] = states_for_swap[p][1]
else:
swap_index_accept.append([-1, -1])
# Determine which
Expand Down

0 comments on commit 81f8482

Please sign in to comment.