Skip to content

Commit

Permalink
TRAM fix artificial transition counts by negative state indices (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaaikeG authored Jan 24, 2022
1 parent 0b76177 commit 5a785b2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
33 changes: 28 additions & 5 deletions deeptime/markov/msm/tram/_tram_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ def _determine_n_therm_states(dtrajs, ttrajs):
return _determine_n_states(ttrajs)


def _split_at_negative_state_indices(trajectory_fragment, negative_state_indices):
split_fragments = np.split(trajectory_fragment, negative_state_indices)
sub_fragments = []
# now get rid of the negative state indices.
for frag in split_fragments:
frag = frag[frag >= 0]
# Only add to the list if there are any samples left in the fragments
if len(frag) > 0:
sub_fragments.append(frag)
return sub_fragments


def transition_counts_from_count_models(n_therm_states, n_markov_states, count_models):
transition_counts = np.zeros((n_therm_states, n_markov_states, n_markov_states), dtype=np.int32)

Expand Down Expand Up @@ -454,13 +466,24 @@ def _find_trajectory_fragments(self):
# get a mapping from trajectory segments to thermodynamic states
fragment_indices = self._find_trajectory_fragment_mapping()

fragments = []
fragments = [[] for _ in range(self.n_therm_states)]
# for each them. state k, gather all trajectory fragments that were sampled at that state.
for k in range(self.n_therm_states):
# take the fragments based on the list of indices. Exclude all values that are less than zero. They don't
# belong in the connected set.
fragments.append([self.dtrajs[traj_idx][start:stop][self.dtrajs[traj_idx][start:stop] >= 0]
for (traj_idx, start, stop) in fragment_indices[k]])
# Select the fragments using the list of indices.
for (traj_idx, start, stop) in fragment_indices[k]:
fragment = self.dtrajs[traj_idx][start:stop]

# Whenever state values are negative, those samples do not belong in the connected set and need to be
# excluded. We split trajectories where negative state indices occur.
# Example: [0, 0, 2, -1, 2, 1, 0], we want to exclude the sample with state index -1.
# Simply filtering out negative state indices would lead to [0, 0, 2, 2, 1, 0] which gives a transition
# 2 -> 2 which doesn't exist. Instead, split the trajectory at negative state indices to get
# [0, 0, 2], [2, 1, 0]
negative_state_indices = np.where(fragment < 0)[0]
if len(negative_state_indices) > 0:
fragments[k].extend(_split_at_negative_state_indices(fragment, negative_state_indices))
else:
fragments[k].append(fragment)
return fragments

def _find_trajectory_fragment_mapping(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/markov/msm/test_tram_datatset.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ def test_get_trajectory_fragments(dtrajs, ttrajs):
bias_matrices = make_matching_bias_matrix(dtrajs)
dataset = TRAMDataset(dtrajs=dtrajs, ttrajs=ttrajs, bias_matrices=bias_matrices)

# dtraj should be split into fragments [[[1, 3], [5, 6], [8, 9]], [[10, 11], [12, 13]]] due to replica exchanges
# dtraj should be split into fragments [[[1], [3], [5, 6], [8, 9]], [[10, 11], [12, 13]]] due to replica exchanges
# found in ttrajs. This should lead having only 5 transitions in transition counts:
np.testing.assert_equal(dataset.state_counts.sum(), 10)
np.testing.assert_equal(dataset.transition_counts.sum(), 5)
np.testing.assert_equal(dataset.transition_counts.sum(), 4)


def test_unknown_connectivity():
Expand Down

0 comments on commit 5a785b2

Please sign in to comment.