Skip to content

Commit

Permalink
Fixed bug in simulation_from_thermo where provided states were not be…
Browse files Browse the repository at this point in the history
…ing applied to Simulation objects
  • Loading branch information
timbernat committed May 17, 2024
1 parent 9cd8194 commit 863d2cc
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions polymerist/openmmtools/preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
LOGGER = logging.getLogger(__name__)

from typing import Optional
from typing import Optional, Union
from pathlib import Path

from openmm import State, System, XmlSerializer
Expand All @@ -14,37 +14,34 @@
from .serialization import SimulationPaths
from .thermo import EnsembleFactory
from .parameters import ThermoParameters, SimulationParameters
from .forcegroups import impose_unique_force_groups


def label_forces(system : System) -> None:
'''Designates each Force in a System with a unique force group and assigns helpful names by Force type'''
for i, force in enumerate(system.getForces()):
force.setForceGroup(i)

# TODO : add labelling (depends partially on Interchange's NonbondedForce separation)

def simulation_from_thermo(topology : Topology, system : System, thermo_params : ThermoParameters, time_step : Quantity, state : Optional[Path]=None) -> Simulation:
def simulation_from_thermo(topology : Topology, system : System, thermo_params : ThermoParameters, time_step : Quantity, state : Optional[Union[str, Path, State]]=None) -> Simulation:
'''Prepare an OpenMM simulation from a serialized thermodynamics parameter set'''
ens_fac = EnsembleFactory.from_thermo_params(thermo_params)
if (forces := ens_fac.forces()): # check if any extra forces are present
for force in forces:
if (extra_forces := ens_fac.forces()): # check if any extra forces are present
for force in extra_forces:
system.addForce(force) # add forces to System BEFORE creating Simulation to avoid having to reinitialize the Conext to preserve changes
LOGGER.info(f'Added {force.getName()} Force to System')
label_forces(system) # ensure all system forces (including any ensemble-specific ones) are labelled
impose_unique_force_groups(system)

if state is not None:
try:
with state.open('r') as statefile:
saved_state = XmlSerializer.deserialize(statefile.read())
except ValueError: # catch when a state file exists but is invalid (or indeed empty)
state = None
if isinstance(state, str):
state_path = state
elif isinstance(state, Path):
state_path = str(state)
else:
state_path = None

simulation = Simulation(
topology=topology,
system=system,
integrator=ens_fac.integrator(time_step),
state=state
state=state_path
)
if isinstance(state, State):
simulation.context.setState(state)
simulation.context.reinitialize(preserveState=True) # TOSELF : unclear whether this is necessary, redundant, or in fact harmful

return simulation

Expand Down

0 comments on commit 863d2cc

Please sign in to comment.