Skip to content

Commit

Permalink
fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ajfriedman22 committed Dec 11, 2024
1 parent e86ce0e commit 1dfc70d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ensemble_md/cli/run_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def main():
states = copy.deepcopy(states_)
weights = copy.deepcopy(weights_)
counts = copy.deepcopy(counts_)
swap_pattern, swap_list = REXEE.get_swapping_pattern(dhdl_files, states_) # swap_list will only be used for modify_coords # noqa: E501
swap_pattern, swap_list = REXEE.get_swapping_pattern(dhdl_files, states_, i) # swap_list will only be used for modify_coords # noqa: E501

# 3-3. Perform weight correction/weight combination
if wl_delta != [None for i in range(REXEE.n_sim)]: # weight-updating
Expand Down
17 changes: 9 additions & 8 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def extract_final_log_info(self, log_files):
return wl_delta, weights, counts

@staticmethod
def identify_swappable_pairs(self, states, state_ranges, neighbor_exchange, add_swappables=None, iteration=None):
def identify_swappable_pairs(states, state_ranges, neighbor_exchange, forced_swap, add_swappables=None, iteration=None):
"""
Identifies swappable pairs. By definition, a pair of simulation is considered swappable only if
their last sampled states are in the alchemical ranges of both simulations. This is required
Expand Down Expand Up @@ -841,15 +841,16 @@ def identify_swappable_pairs(self, states, state_ranges, neighbor_exchange, add_
"""
n_sim = len(states)
sim_idx = list(range(n_sim))
if self.proposal == 'forced_swap':
if forced_swap is True:
if iteration % 2 == 0: # Swap up for self.n_sim - 1 swaps
swappables = []
for n in np.arange(0, self.n_sim-2, 2):
for n in np.arange(0, n_sim-2, 2):
swappables.append([n, n+1])
else: # and then swap down for self.n_sim - 1 swaps and repeat
swappables = []
for n in np.arange(1, self.n_sim-2, 2):
for n in np.arange(1, n_sim-2, 2):
swappables.append([n, n+1])
print(f'swappables: {swappables}')
else:
all_pairs = list(combinations(sim_idx, 2))

Expand Down Expand Up @@ -954,7 +955,7 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None):
swap_pattern = list(range(self.n_sim)) # Can be regarded as the indices of DHDL files/configurations
state_ranges = copy.deepcopy(self.state_ranges)
# states_copy = copy.deepcopy(states) # only for re-identifying swappable pairs given updated state_ranges --> was needed for the multiple exchange proposal scheme # noqa: E501
swappables = ReplicaExchangeEE.identify_swappable_pairs(states, state_ranges, self.proposal == 'neighboring', self.add_swappables, iteration) # noqa: E501
swappables = ReplicaExchangeEE.identify_swappable_pairs(states, state_ranges, self.proposal == 'neighboring', self.proposal == 'forced_swap', self.add_swappables, iteration) # noqa: E501

# Note that if there is only 1 swappable pair, then it will still be the only swappable pair
# after an attempted swap is accepted. Therefore, there is no need to perform multiple swaps or re-identify
Expand All @@ -963,7 +964,6 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None):
if n_ex > 1:
n_ex = 1 # n_ex is set back to 1 since there is only 1 swappable pair.

print(f"Swappable pairs: {swappables}")
for i in range(n_ex):
# Update the list of swappable pairs starting from the 2nd attempted swap for the exhaustive swap method.
if (self.proposal == 'exhaustive' or self.proposal == 'forced_swap') and i >= 1:
Expand All @@ -982,7 +982,7 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None):
if self.proposal == 'exhaustive':
n_ex_exhaustive += 1

if self.propoal == 'forced_swap':
if self.proposal == 'forced_swap':
swap = swappables[0]
else:
swap = ReplicaExchangeEE.propose_swap(swappables)
Expand All @@ -997,7 +997,7 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None):
print(f'A swap ({i + 1}/{n_ex}) is proposed between the configurations of Simulation {swap[0]} (state {states[swap[0]]}) and Simulation {swap[1]} (state {states[swap[1]]}) ...') # noqa: E501

if self.proposal == 'forced_swap':
index = self._deter_swap_index(swap, dhdl_files, shifts, self.add_swappables)
index = self._deter_swap_index(swap, dhdl_files, self.add_swappables)
swap_index.append(index)

if self.modify_coords_fn is not None:
Expand Down Expand Up @@ -1181,6 +1181,7 @@ def _deter_swap_index(self, swap, dhdl_files, add_swappables):
A, B = swappable
swappable_global_states.append(A)
swappable_global_states.append(B)
print(swappable_global_states)

swap_index = []
for i in range(1):
Expand Down

0 comments on commit 1dfc70d

Please sign in to comment.