Skip to content

Commit

Permalink
ABFE workflow will not read dHdl file if no TI estimator is choosen (#…
Browse files Browse the repository at this point in the history
…231)

If not explicitly set, AMBER will not generate the dHdl file. (I don't know if the u_nk will be written by default but I got a complain that workflow cannot be used to process their amber output, which doesn't have dhdl data). Currently, the ABFE workflow will attempt to read both dHdl and u_nk. This PR will make it such that if no TI estimator is chosen, the workflow will not attempt to read the dHdl.

I have also removed the ignore_warnings in the read file stage as maintaining it and having it to be clever enough to solve all the dumb user issues seems to be too much.

Co-authored-by: Zhiyi Wu <[email protected]>
  • Loading branch information
xiki-tempula and xiki-tempula authored Sep 21, 2022
1 parent eb698c3 commit cb965ad
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 41 deletions.
4 changes: 4 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Changes
- AutoMBAR accepts the `method` argument (PR #114).
- Refactor the subsampling module to unify the behaviour of
equilibrium_detection and statistical_inefficiency (PR #218).
- Remove the ignore_warnings flag in ABFE workflow. (PR #231)


Enhancements
- Add u_nk2series and dhdl2series to convert u_nk and dHdl to series (PR
Expand All @@ -41,6 +43,8 @@ Fixes
at which the simulations were run (issue #225, PR #232)
- changed how int/float are read from AMBER files (issue #229, PR #235)
- substitute the any_none() function with a check "if None in" in the AMBER parser (issue #236, PR #237)
- For ABFE workflow, dHdl will only be read when a TI estimator is chosen.
Similarly, u_nk will only be read when FEP estimators are chosen. (PR #231)

07/22/2022 xiki-tempula, IAlibay, dotsdl, orbeckst, ptmerz

Expand Down
28 changes: 28 additions & 0 deletions src/alchemlyb/tests/test_workflow_ABFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,34 @@ def test_run_none(self, workflow):
workflow.run(uncorr=None, estimators=None, overlap=None, breakdown=None,
forwrev=None)

def test_run_single_estimator(self, workflow, monkeypatch):
monkeypatch.setattr(workflow, 'u_nk_list', [])
monkeypatch.setattr(workflow, 'dHdl_list', [])
workflow.run(uncorr=None, estimators='MBAR', overlap=None, breakdown=True,
forwrev=None)

def test_run_invalid_estimator(self, workflow):
with pytest.raises(ValueError,
match=r'Estimator aaa is not supported.'):
workflow.run(uncorr=None, estimators='aaa', overlap=None, breakdown=None,
forwrev=None)

@pytest.mark.parametrize('read_u_nk', [True, False])
@pytest.mark.parametrize('read_dHdl', [True, False])
def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl):
monkeypatch.setattr(workflow, 'u_nk_list', [])
monkeypatch.setattr(workflow, 'dHdl_list', [])
workflow.read(read_u_nk, read_dHdl)
if read_u_nk:
assert len(workflow.u_nk_list) == 5
else:
assert len(workflow.u_nk_list) == 0

if read_dHdl:
assert len(workflow.dHdl_list) == 5
else:
assert len(workflow.dHdl_list) == 0

def test_read_invalid_u_nk(self, workflow, monkeypatch):
def extract_u_nk(self, T):
raise IOError('Error read u_nk.')
Expand Down
118 changes: 77 additions & 41 deletions src/alchemlyb/workflows/abfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ class ABFE(WorkflowBase):
outdirectory : str
Directory in which the output files produced by this script will be
stored. Default: os.path.curdir.
ignore_warnings : bool
Turn all errors into warnings.
Attributes
----------
Expand All @@ -60,11 +58,9 @@ class ABFE(WorkflowBase):
'''
def __init__(self, T, units='kT', software='GROMACS', dir=os.path.curdir,
prefix='dhdl', suffix='xvg',
outdirectory=os.path.curdir,
ignore_warnings=False):
outdirectory=os.path.curdir):

super().__init__(units, software, T, outdirectory)
self.ignore_warnings = ignore_warnings
self.logger = logging.getLogger('alchemlyb.workflows.ABFE')
self.logger.info('Initialise Alchemlyb ABFE Workflow')
self.logger.info(f'Alchemlyb Version: f{__version__}')
Expand Down Expand Up @@ -93,61 +89,83 @@ def __init__(self, T, units='kT', software='GROMACS', dir=os.path.curdir,
else:
raise NotImplementedError(f'{software} parser not found.')

def read(self):
def read(self, read_u_nk=True, read_dHdl=True):
'''Read the u_nk and dHdL data from the
:attr:`~alchemlyb.workflows.ABFE.file_list`
Parameters
----------
read_u_nk : bool
Whether to read the u_nk.
read_dHdl : bool
Whether to read the dHdl.
Attributes
----------
u_nk_list : list
A list of :class:`pandas.DataFrame` of u_nk.
dHdl_list : list
A list of :class:`pandas.DataFrame` of dHdl.
'''
self.u_nk_sample_list = None
self.dHdl_sample_list = None

u_nk_list = []
dHdl_list = []
for file in self.file_list:
try:
u_nk = self._extract_u_nk(file, T=self.T)
self.logger.info(
f'Reading {len(u_nk)} lines of u_nk from {file}')
u_nk_list.append(u_nk)
except Exception as exc:
msg = f'Error reading u_nk from {file}.'
if self.ignore_warnings:
self.logger.exception(msg + f'\n{exc}\n' +
'This exception is being ignored because ignore_warnings=True.')
else:
if read_u_nk:
try:
u_nk = self._extract_u_nk(file, T=self.T)
self.logger.info(
f'Reading {len(u_nk)} lines of u_nk from {file}')
u_nk_list.append(u_nk)
except Exception as exc:
msg = f'Error reading u_nk from {file}.'
self.logger.error(msg)
raise OSError(msg) from exc

try:
dhdl = self._extract_dHdl(file, T=self.T)
self.logger.info(
f'Reading {len(dhdl)} lines of dhdl from {file}')
dHdl_list.append(dhdl)
except Exception as exc:
msg = f'Error reading dHdl from {file}.'
if self.ignore_warnings:
self.logger.exception(msg + f'\n{exc}\n' +
'This exception is being ignored because ignore_warnings=True.')
else:
if read_dHdl:
try:
dhdl = self._extract_dHdl(file, T=self.T)
self.logger.info(
f'Reading {len(dhdl)} lines of dhdl from {file}')
dHdl_list.append(dhdl)
except Exception as exc:
msg = f'Error reading dHdl from {file}.'
self.logger.error(msg)
raise OSError(msg) from exc

# Sort the files according to the state
self.logger.info('Sort files according to the u_nk.')
column_names = u_nk_list[0].columns.values.tolist()
index_list = sorted(range(len(self.file_list)),
key=lambda x:column_names.index(
u_nk_list[x].reset_index('time').index.values[0]))
if read_u_nk:
self.logger.info('Sort files according to the u_nk.')
column_names = u_nk_list[0].columns.values.tolist()
index_list = sorted(range(len(self.file_list)),
key=lambda x: column_names.index(
u_nk_list[x].reset_index(
'time').index.values[0]))
elif read_dHdl:
self.logger.info('Sort files according to the dHdl.')
index_list = sorted(range(len(self.file_list)),
key=lambda x:
dHdl_list[x].reset_index(
'time').index.values[0])
else:
self.u_nk_list = []
self.dHdl_list = []
return

self.file_list = [self.file_list[i] for i in index_list]
self.logger.info("Sorted file list: \n%s", '\n'.join(self.file_list))
self.u_nk_list = [u_nk_list[i] for i in index_list]
self.dHdl_list = [dHdl_list[i] for i in index_list]
self.u_nk_sample_list = None
self.dHdl_sample_list = None
if read_u_nk:
self.u_nk_list = [u_nk_list[i] for i in index_list]
else:
self.u_nk_list = []

if read_dHdl:
self.dHdl_list = [dHdl_list[i] for i in index_list]
else:
self.dHdl_list = []


def run(self, skiptime=0, uncorr='dhdl', threshold=50,
estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf',
Expand Down Expand Up @@ -196,7 +214,24 @@ def run(self, skiptime=0, uncorr='dhdl', threshold=50,
:func:`~alchemlyb.convergence.forward_backward_convergence` for
further explanation.
'''
self.read()
use_FEP = False
use_TI = False

if estimators is not None:
if isinstance(estimators, str):
estimators = [estimators, ]
for estimator in estimators:
if estimator in FEP_ESTIMATORS:
use_FEP = True
elif estimator in TI_ESTIMATORS:
use_TI = True
else:
msg = f"Estimator {estimator} is not supported. Choose one from " \
f"{FEP_ESTIMATORS + TI_ESTIMATORS}."
self.logger.error(msg)
raise ValueError(msg)

self.read(use_FEP, use_TI)

if uncorr is not None:
self.preprocess(skiptime=skiptime, uncorr=uncorr,
Expand All @@ -205,13 +240,14 @@ def run(self, skiptime=0, uncorr='dhdl', threshold=50,
self.estimate(estimators)
self.generate_result()

if overlap is not None:
if overlap is not None and use_FEP:
ax = self.plot_overlap_matrix(overlap)
plt.close(ax.figure)

if breakdown:
ax = self.plot_ti_dhdl()
plt.close(ax.figure)
if use_TI:
ax = self.plot_ti_dhdl()
plt.close(ax.figure)
fig = self.plot_dF_state()
plt.close(fig)
fig = self.plot_dF_state(dF_state='dF_state_long.pdf',
Expand Down

0 comments on commit cb965ad

Please sign in to comment.