Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
ajfriedman22 committed Jan 9, 2025
1 parent 81f8482 commit db5f1e0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions ensemble_md/cli/run_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def main():
try:
if rank == 0:
for j in range(len(swap_list)):
if not os.path.exists(f'{REXEE.working_dir}/sim_{swap_list[j][0]}/iteration_{i-1}/confout_backup.gro') and not os.path.exists(f'{REXEE.working_dir}/sim_{swap_list[j][0]}/iteration_{i-1}/confout_backup.gro'):
if not os.path.exists(f'{REXEE.working_dir}/sim_{swap_list[j][0]}/iteration_{i-1}/confout_backup.gro') and not os.path.exists(f'{REXEE.working_dir}/sim_{swap_list[j][0]}/iteration_{i-1}/confout_backup.gro'): # noqa: E501
print('\nModifying the coordinates of the following output GRO files ...')
# gro_1 and gro_2 are the simlation outputs (that we want to back up) and the inputs to modify_coords # noqa: E501
gro_1 = f'{REXEE.working_dir}/sim_{swap_list[j][0]}/iteration_{i-1}/confout.gro'
Expand All @@ -319,7 +319,7 @@ def main():
os.rename(gro_2, gro_2_backup)

# Here we input gro_1_backup and gro_2_backup and modify_coords_fn will save the modified gro files as gro_1 and gro_2 # noqa: E501
REXEE.modify_coords_fn(gro_1_backup, gro_2_backup, swap_index[j]) # the order should not matter
REXEE.modify_coords_fn(gro_1_backup, gro_2_backup, swap_index[j]) # the order should not matter # noqa: E501
except Exception:
print('\n--------------------------------------------------------------------------\n')
print(f'\nAn error occurred on rank 0:\n{traceback.format_exc()}')
Expand Down
19 changes: 9 additions & 10 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def identify_swappable_pairs(self, states, state_ranges, dhdl_files, iteration=N
states_for_swap.append(state)
elif self.proposal == 'forced_random':
potential_swappables = []
for n in np.arange(0, n_sim-1,1):
for n in np.arange(0, n_sim-1, 1):
potential_swappables.append([n, n+1])
swappables, swap_index = [], []
for swap in potential_swappables:
Expand Down Expand Up @@ -978,18 +978,18 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None):
if len(swappables) == 1:
if n_ex > 1:
n_ex = 1 # n_ex is set back to 1 since there is only 1 swappable pair.

states_modified = states
swap_index_accept = []
for i in range(n_ex):
# Update the list of swappable pairs starting from the 2nd attempted swap for the exhaustive swap method.
if (self.proposal == 'exhaustive' or self.proposal == 'forced_swap' or self.proposal == 'forced_random') and i >= 1:
if (self.proposal == 'exhaustive' or self.proposal == 'forced_swap' or self.proposal == 'forced_random') and i >= 1: # noqa: E501
# Note that this should be done regardless of the acceptance/rejection of the previously drawn pairs.
# Also note that at this point, swap is still the last attempted swap.
swappables = [i for i in swappables if set(i).intersection(set(swap)) == set()] # noqa: F821
print(f'\nRemaining swappable pairs: {swappables}')

if len(swappables) == 0 and (self.proposal == 'exhaustive' or self.proposal == 'forced_swap' or self.proposal == 'forced_random'):
if len(swappables) == 0 and (self.proposal == 'exhaustive' or self.proposal == 'forced_swap' or self.proposal == 'forced_random'): # noqa: E501
# This should only happen when the method of exhaustive swaps is used.
if i == 0:
self.n_empty_swappable += 1
Expand All @@ -1005,18 +1005,17 @@ def get_swapping_pattern(self, dhdl_files, states, iteration=None):
swap = swappables[0]
else:
swap = ReplicaExchangeEE.propose_swap(swappables)

print(f'\nProposed swap: {swap}')
if swap == []: # the same as len(swappables) == 0, self.proposal must not be exhaustive if this line is reached. # noqa: E501
self.n_empty_swappable += 1
print('No swap is proposed because there is no swappable pair at all.')
break # no need to re-identify swappable pairs and draw new samples
else:
self.n_swap_attempts += 1
if self.verbose is True and self.proposal != 'exhaustive' and self.proposal != 'forced_swap' and self.proposal != 'forced_random':
if self.verbose is True and self.proposal != 'exhaustive' and self.proposal != 'forced_swap' and self.proposal != 'forced_random': # noqa: E501
print(f'A swap ({i + 1}/{n_ex}) is proposed between the configurations of Simulation {swap[0]} (state {states[swap[0]]}) and Simulation {swap[1]} (state {states[swap[1]]}) ...') # noqa: E501


if self.modify_coords_fn is not None:
swap_bool = True # always accept the move
else:
Expand Down Expand Up @@ -1206,7 +1205,7 @@ def _deter_swap_index(self, swap, dhdl_files, add_swappables):
swappable_global_states, swap_state = [], []
for swappable in add_swappables:
A, B = swappable
if (A >= (swap[0] * self.s) and A < ((swap[0] + 1) * self.s) and B >= (swap[1] * self.s) and B < ((swap[1] + 1) * self.s)) or (B >= (swap[0] * self.s) and B < ((swap[0] + 1) * self.s) and A >= (swap[1] * self.s) and A < ((swap[1] + 1) * self.s)):
if (A >= (swap[0] * self.s) and A < ((swap[0] + 1) * self.s) and B >= (swap[1] * self.s) and B < ((swap[1] + 1) * self.s)) or (B >= (swap[0] * self.s) and B < ((swap[0] + 1) * self.s) and A >= (swap[1] * self.s) and A < ((swap[1] + 1) * self.s)): # noqa: E501
swappable_global_states.append(A)
swappable_global_states.append(B)
convert_to_frames = int(self.template["nstxout"]/self.template["nstdhdl"])
Expand All @@ -1224,7 +1223,7 @@ def _deter_swap_index(self, swap, dhdl_files, add_swappables):
state_global_list.append(state_global)
# Select a random frame which is in the last 50% of the trajectory to have the swap occur
potential_swap_index = np.array(potential_swap_index)
potential_swap_index = potential_swap_index[potential_swap_index > (len(state_local)/(2*convert_to_frames))]
potential_swap_index = potential_swap_index[potential_swap_index > (len(state_local)/(2*convert_to_frames))] # noqa: E501
if len(potential_swap_index) != 0:
index = np.random.choice(potential_swap_index)
swap_state.append(state_global_list[np.where(potential_swap_index == index)[0][0]])
Expand Down Expand Up @@ -1644,7 +1643,7 @@ def default_coords_fn(self, molA_file_name, molB_file_name, swap_index):
molB_dir = molB_file_name.rsplit('/', 1)[0] + '/'

# Load trajectory trr for higher precison coordinates
molA = md.load_trr(f'{molA_dir}/traj.trr', top=molA_file_name).slice(swap_index[0]) # Load last frame of trr trajectory
molA = md.load_trr(f'{molA_dir}/traj.trr', top=molA_file_name).slice(swap_index[0]) # Load last frame of trr trajectory # noqa: E501
molB = md.load_trr(f'{molB_dir}/traj.trr', top=molB_file_name).slice(swap_index[1])

# Load the coordinate swapping map
Expand Down

0 comments on commit db5f1e0

Please sign in to comment.