From 5a785b276417b9f168c4e5e153d564e74ab7b5a3 Mon Sep 17 00:00:00 2001 From: MaaikeG Date: Mon, 24 Jan 2022 11:51:44 +0100 Subject: [PATCH] TRAM fix artificial transition counts by negative state indices (#194) --- deeptime/markov/msm/tram/_tram_dataset.py | 33 +++++++++++++++++++---- tests/markov/msm/test_tram_datatset.py | 4 +-- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/deeptime/markov/msm/tram/_tram_dataset.py b/deeptime/markov/msm/tram/_tram_dataset.py index 20c623326..a8c1746b2 100644 --- a/deeptime/markov/msm/tram/_tram_dataset.py +++ b/deeptime/markov/msm/tram/_tram_dataset.py @@ -19,6 +19,18 @@ def _determine_n_therm_states(dtrajs, ttrajs): return _determine_n_states(ttrajs) +def _split_at_negative_state_indices(trajectory_fragment, negative_state_indices): + split_fragments = np.split(trajectory_fragment, negative_state_indices) + sub_fragments = [] + # now get rid of the negative state indices. + for frag in split_fragments: + frag = frag[frag >= 0] + # Only add to the list if there are any samples left in the fragments + if len(frag) > 0: + sub_fragments.append(frag) + return sub_fragments + + def transition_counts_from_count_models(n_therm_states, n_markov_states, count_models): transition_counts = np.zeros((n_therm_states, n_markov_states, n_markov_states), dtype=np.int32) @@ -454,13 +466,24 @@ def _find_trajectory_fragments(self): # get a mapping from trajectory segments to thermodynamic states fragment_indices = self._find_trajectory_fragment_mapping() - fragments = [] + fragments = [[] for _ in range(self.n_therm_states)] # for each them. state k, gather all trajectory fragments that were sampled at that state. for k in range(self.n_therm_states): - # take the fragments based on the list of indices. Exclude all values that are less than zero. They don't - # belong in the connected set. - fragments.append([self.dtrajs[traj_idx][start:stop][self.dtrajs[traj_idx][start:stop] >= 0] - for (traj_idx, start, stop) in fragment_indices[k]]) + # Select the fragments using the list of indices. + for (traj_idx, start, stop) in fragment_indices[k]: + fragment = self.dtrajs[traj_idx][start:stop] + + # Whenever state values are negative, those samples do not belong in the connected set and need to be + # excluded. We split trajectories where negative state indices occur. + # Example: [0, 0, 2, -1, 2, 1, 0], we want to exclude the sample with state index -1. + # Simply filtering out negative state indices would lead to [0, 0, 2, 2, 1, 0] which gives a transition + # 2 -> 2 which doesn't exist. Instead, split the trajectory at negative state indices to get + # [0, 0, 2], [2, 1, 0] + negative_state_indices = np.where(fragment < 0)[0] + if len(negative_state_indices) > 0: + fragments[k].extend(_split_at_negative_state_indices(fragment, negative_state_indices)) + else: + fragments[k].append(fragment) return fragments def _find_trajectory_fragment_mapping(self): diff --git a/tests/markov/msm/test_tram_datatset.py b/tests/markov/msm/test_tram_datatset.py index d2f1c0d43..f966f7c0e 100644 --- a/tests/markov/msm/test_tram_datatset.py +++ b/tests/markov/msm/test_tram_datatset.py @@ -227,10 +227,10 @@ def test_get_trajectory_fragments(dtrajs, ttrajs): bias_matrices = make_matching_bias_matrix(dtrajs) dataset = TRAMDataset(dtrajs=dtrajs, ttrajs=ttrajs, bias_matrices=bias_matrices) - # dtraj should be split into fragments [[[1, 3], [5, 6], [8, 9]], [[10, 11], [12, 13]]] due to replica exchanges + # dtraj should be split into fragments [[[1], [3], [5, 6], [8, 9]], [[10, 11], [12, 13]]] due to replica exchanges # found in ttrajs. This should lead having only 5 transitions in transition counts: np.testing.assert_equal(dataset.state_counts.sum(), 10) - np.testing.assert_equal(dataset.transition_counts.sum(), 5) + np.testing.assert_equal(dataset.transition_counts.sum(), 4) def test_unknown_connectivity():