diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py new file mode 100644 index 00000000..ead30be1 --- /dev/null +++ b/tests/test_evaluation.py @@ -0,0 +1,147 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# Copyright (c) 2020-21 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test tools for mapping between node sets of different tree sequences +""" +from collections import defaultdict +from itertools import combinations + +import msprime +import numpy as np +import pytest +import scipy.sparse +import tsinfer + +from tsdate import evaluation + +# --- simulate test case --- +demo = msprime.Demography.isolated_model([1e4]) +for t in np.linspace(500, 10000, 20): + demo.add_census(time=t) +true_unary = msprime.sim_ancestry( + samples=10, + sequence_length=1e6, + demography=demo, + recombination_rate=1e-8, + random_seed=1024, +) +true_unary = msprime.sim_mutations(true_unary, rate=2e-8, random_seed=1024) +assert true_unary.num_trees > 1 +true_simpl = true_unary.simplify(filter_sites=False) +sample_dat = tsinfer.SampleData.from_tree_sequence(true_simpl) +infr_unary = tsinfer.infer(sample_dat) +infr_simpl = infr_unary.simplify(filter_sites=False) + + +def naive_shared_node_spans(ts, other): + """ + Inefficient but transparent function to get span where nodes from two tree + sequences subtend the same sample set + """ + + def _clade_dict(tree): + clade_to_node = defaultdict(set) + for node in tree.nodes(): + clade = frozenset(tree.samples(node)) + clade_to_node[clade].add(node) + return clade_to_node + + assert ts.sequence_length == other.sequence_length + assert ts.num_samples == other.num_samples + out = np.zeros((ts.num_nodes, other.num_nodes)) + for (interval, query_tree, target_tree) in ts.coiterate(other): + query = _clade_dict(query_tree) + target = _clade_dict(target_tree) + span = interval.right - interval.left + for clade, nodes in query.items(): + if clade in target: + for i in nodes: + for j in target[clade]: + out[i, j] += span + return scipy.sparse.csr_matrix(out) + + +@pytest.mark.parametrize("ts", [true_unary, infr_unary, true_simpl, infr_simpl]) +class TestCladeMap: + def test_map(self, ts): + """ + test that clade map has correct nodes, clades + """ + clade_map = evaluation.CladeMap(ts) + for tree in ts.trees(): + for node in tree.nodes(): + clade = frozenset(tree.samples(node)) + assert node in clade_map._nodes[clade] + assert clade_map._clades[node] == clade + clade_map.next() + + def test_diff(self, ts): + """ + test difference in clades between adjacent trees + """ + clade_map = evaluation.CladeMap(ts) + tree_1 = ts.first() + tree_2 = ts.first() + while True: + tree_2.next() + diff = clade_map.next() + diff_test = {} + for n in set(tree_1.nodes()) | set(tree_2.nodes()): + prev = frozenset(tree_1.samples(n)) + curr = frozenset(tree_2.samples(n)) + if prev != curr: + diff_test[n] = (prev, curr) + for node in diff_test.keys() | diff.keys(): + assert diff_test[node][0] == diff[node][0] + assert diff_test[node][1] == diff[node][1] + if tree_2.index == ts.num_trees - 1: + break + tree_1.next() + + +class TestNodeMatching: + @pytest.mark.parametrize( + "pair", combinations([infr_simpl, true_simpl, infr_unary, true_unary], 2) + ) + def test_shared_spans(self, pair): + """ + Check that efficient implementation returns same answer as naive + implementation + """ + check = naive_shared_node_spans(pair[0], pair[1]) + test = evaluation.shared_node_spans(pair[0], pair[1]) + assert check.shape == test.shape + assert check.nnz == test.nnz + assert np.allclose(check.data, test.data) + + @pytest.mark.parametrize("ts", [infr_simpl, true_simpl]) + def test_match_self(self, ts): + """ + Check that matching against self returns node ids + + TODO: this'll only work reliably when there's not unary nodes. + """ + time, _, hit = evaluation.match_node_ages(ts, ts) + assert np.allclose(time, ts.nodes_time) + assert np.array_equal(hit, np.arange(ts.num_nodes)) diff --git a/tsdate/evaluation.py b/tsdate/evaluation.py new file mode 100644 index 00000000..896f1708 --- /dev/null +++ b/tsdate/evaluation.py @@ -0,0 +1,254 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Tools for comparing node times between tree sequences with different node sets +""" +import copy +from collections import defaultdict +from itertools import product + +import numpy as np +import scipy.sparse +import tskit + + +class CladeMap: + """ + An iterator across trees that maintains a mapping from a clade (a `frozenset` of + sample IDs) to a `set` of nodes. When there are unary nodes, there may be multiple + nodes associated with each clade. + """ + + def __init__(self, ts): + self._nil = frozenset() + self._nodes = defaultdict(set) # nodes[clade] = {node ids} + self._clades = defaultdict(frozenset) # clades[node] = {sample ids} + self.tree_sequence = ts + self.tree = ts.first(sample_lists=True) + for node in self.tree.nodes(): + clade = frozenset(self.tree.samples(node)) + self._nodes[clade].add(node) + self._clades[node] = clade + self._prev = copy.deepcopy(self._clades) + self._diff = ts.edge_diffs() + next(self._diff) + + def _propagate(self, edge, downdate=False): + """ + Traverse path from `edge.parent` to root, either adding or removing the + state (clade) associated with `edge.child` from the state of each + visited node. Return a set with the node ids encountered during + traversal. + """ + nodes = set() + node = edge.parent + clade = self._clades[edge.child] + while node != tskit.NULL: + last = self._clades[node] + self._clades[node] = last - clade if downdate else last | clade + if len(last): + self._nodes[last].remove(node) + if len(self._nodes[last]) == 0: + del self._nodes[last] + self._nodes[self._clades[node]].add(node) + nodes.add(node) + node = self.tree.parent(node) + return nodes + + def next(self): # noqa: A003 + """ + Advance to the next tree, returning the difference between trees as a + dictionary of the form `node : (last_clade, next_clade)` + """ + nodes = set() # nodes with potentially altered clades + diff = {} # diff[node] = (prev_clade, curr_clade) + + if self.tree.index + 1 == self.tree_sequence.num_trees: + return None + + # Subtract clades subtended by outgoing edges + edge_diff = next(self._diff) + for eo in edge_diff.edges_out: + nodes |= self._propagate(eo, downdate=True) + + # Prune nodes that are no longer in tree + for node in self._nodes[self._nil]: + diff[node] = (self._prev[node], self._nil) + del self._clades[node] + nodes -= self._nodes[self._nil] + self._nodes[self._nil].clear() + + # Add clades subtended by incoming edges + self.tree.next() + for ei in edge_diff.edges_in: + nodes |= self._propagate(ei, downdate=False) + + # Find difference in clades between adjacent trees + for node in nodes: + diff[node] = (self._prev[node], self._clades[node]) + if self._prev[node] == self._clades[node]: + del diff[node] + + # Sync previous and current states + for node, (_, curr) in diff.items(): + if curr == self._nil: + del self._prev[node] + else: + self._prev[node] = curr + + return diff + + @property + def interval(self): + """ + Return interval spanned by tree + """ + return self.tree.interval + + def clades(self): + """ + Return set of clades in tree + """ + return self._nodes.keys() - self._nil + + def __getitem__(self, clade): + """ + Return set of nodes associated with a given clade. + """ + return frozenset(self._nodes[clade]) if frozenset(clade) in self else self._nil + + def __contains__(self, clade): + """ + Check if a clade is present in the tree + """ + return clade in self._nodes + + +def shared_node_spans(ts, other): + """ + Calculate the spans over which pairs of nodes in two tree sequences are + ancestral to indentical sets of samples. + + Returns a sparse matrix where rows correspond to nodes in `ts` and columns + correspond to nodes in `other`. + """ + + if ts.sequence_length != other.sequence_length: + raise ValueError("Tree sequences must be of equal sequence length.") + + if ts.num_samples != other.num_samples: + raise ValueError("Tree sequences must have the same numbers of samples.") + + nil = frozenset() + + # Initialize clade iterators + query = CladeMap(ts) + target = CladeMap(other) + + # Initialize buffer[clade] = (query_nodes, target_nodes, left_coord) + modified = query.clades() | target.clades() + buffer = {} + + # Build sparse matrix of matches in triplet format + query_node = [] + target_node = [] + shared_span = [] + right = 0 + while True: + left = right + right = min(query.interval[1], target.interval[1]) + + # Flush pairs of nodes that no longer have matching clades + for clade in modified: # flush: + if clade in buffer: + n_i, n_j, start = buffer.pop(clade) + span = left - start + for i, j in product(n_i, n_j): + query_node.append(i) + target_node.append(j) + shared_span.append(span) + + # Add new pairs of nodes with matching clades + for clade in modified: + assert clade not in buffer + if clade in query and clade in target: + n_i, n_j = query[clade], target[clade] + buffer[clade] = (n_i, n_j, left) + + if right == ts.sequence_length: + break + + # Find difference in clades with advance to next tree + modified.clear() + for clade_map in (query, target): + if clade_map.interval[1] == right: + clade_diff = clade_map.next() + for (prev, curr) in clade_diff.values(): + if prev != nil: + modified.add(prev) + if curr != nil: + modified.add(curr) + + # Flush final tree + for clade in buffer: + n_i, n_j, start = buffer[clade] + span = right - start + for i, j in product(n_i, n_j): + query_node.append(i) + target_node.append(j) + shared_span.append(span) + + numer = scipy.sparse.coo_matrix( + (shared_span, (query_node, target_node)), + shape=(ts.num_nodes, other.num_nodes), + ).tocsr() + + return numer + + +def match_node_ages(ts, other): + """ + For each node in `ts`, return the age of a matched node from `other`. Node + matching is accomplished by calculating the intervals over which pairs of + nodes (one from `ts`, one from `other`) subtend the same set of samples. + + Returns three vectors of length `ts.num_nodes`: the age of the best + matching node in `other` (e.g. with the longest shared span); the + proportion of the node span in `ts` that is covered by the best match; and + the id of the best match in `other`. + + If either tree sequence contains unary nodes, then there may be multiple + matches with the same span for a single node. In this case, the returned + match is the node with the smallest integer id. + """ + + shared_spans = shared_node_spans(ts, other) + matched_span = shared_spans.max(axis=1).todense().A1 + best_match = shared_spans.argmax(axis=1).A1 + # NB: if there are multiple nodes with the largest span in a row, + # argmax returns the node with the smallest integer id + matched_time = other.nodes_time[best_match] + + best_match[matched_span == 0] = tskit.NULL + matched_time[matched_span == 0] = np.nan + + return matched_time, matched_span, best_match