diff --git a/tests/test_sgkit.py b/tests/test_sgkit.py index 96eb7ddd..0940e8b6 100644 --- a/tests/test_sgkit.py +++ b/tests/test_sgkit.py @@ -20,6 +20,8 @@ Tests for the data files. """ import json +import os +import pickle import sys import tempfile @@ -615,3 +617,58 @@ def test_empty_alleles_not_at_end(self, tmp_path): samples = tsinfer.SgkitSampleData(path) with pytest.raises(ValueError, match="Empty alleles must be at the end"): tsinfer.infer(samples) + + +class TestSgkitMatchSamplesToDisk: + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") + @pytest.mark.parametrize("slice", [(0, 5), (0, 0), (0, 1), (10, 15)]) + def test_match_samples_to_disk_write( + self, slice, small_sd_fixture, tmp_path, tmpdir + ): + ts, zarr_path = make_ts_and_zarr(tmp_path) + samples = tsinfer.SgkitSampleData(zarr_path) + ancestors = tsinfer.generate_ancestors(samples) + anc_ts = tsinfer.match_ancestors(samples, ancestors) + tsinfer.match_samples_slice_to_disk( + samples, anc_ts, slice, tmpdir / "test.path" + ) + file_slice, matches = pickle.load(open(tmpdir / "test.path", "rb")) + assert slice == file_slice + assert len(matches) == slice[1] - slice[0] + for m in matches: + assert isinstance(m, tsinfer.inference.MatchResult) + + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") + def test_match_samples_to_disk_full(self, small_sd_fixture, tmp_path, tmpdir): + ts, zarr_path = make_ts_and_zarr(tmp_path) + samples = tsinfer.SgkitSampleData(zarr_path) + ancestors = tsinfer.generate_ancestors(samples) + anc_ts = tsinfer.match_ancestors(samples, ancestors) + ts = tsinfer.match_samples(samples, anc_ts) + start_index = 0 + while start_index < ts.num_samples: + end_index = min(start_index + 5, ts.num_samples) + tsinfer.match_samples_slice_to_disk( + samples, + anc_ts, + (start_index, end_index), + tmpdir / f"test-{start_index}.path", + ) + start_index = end_index + batch_ts = tsinfer.match_samples( + samples, anc_ts, match_file_pattern=str(tmpdir / "*.path") + ) + ts.tables.assert_equals(batch_ts.tables, ignore_provenance=True) + + tmpdir.join("test-5.path").copy(tmpdir.join("test-5-copy.path")) + with pytest.raises(ValueError, match="Duplicate sample index 5"): + tsinfer.match_samples( + samples, anc_ts, match_file_pattern=str(tmpdir / "*.path") + ) + + os.remove(tmpdir / "test-5.path") + os.remove(tmpdir / "test-5-copy.path") + with pytest.raises(ValueError, match="index 5 not found"): + tsinfer.match_samples( + samples, anc_ts, match_file_pattern=str(tmpdir / "*.path") + ) diff --git a/tsinfer/inference.py b/tsinfer/inference.py index db35dcfb..df8842dc 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -24,6 +24,7 @@ import collections import copy import dataclasses +import glob import heapq import json import logging @@ -697,6 +698,7 @@ def match_samples( resume_lmdb_file=None, use_dask=False, map_additional_sites=None, + match_file_pattern=None, ): """ match_samples(sample_data, ancestors_ts, *, recombination_rate=None,\ @@ -785,6 +787,7 @@ def match_samples( progress_monitor=progress_monitor, resume_lmdb_file=resume_lmdb_file, use_dask=use_dask, + match_file_pattern=match_file_pattern, ) sample_indexes = check_sample_indexes(sample_data, indexes) sample_times = np.zeros( @@ -819,6 +822,53 @@ def match_samples( return ts +def match_samples_slice_to_disk( + sample_data, + ancestors_ts, + samples_slice, + output_path, + *, + recombination_rate=None, + mismatch_ratio=None, + path_compression=True, + indexes=None, + # Deliberately undocumented parameters below + recombination=None, # See :class:`Matcher` + mismatch=None, # See :class:`Matcher` + precision=None, + extended_checks=False, + engine=constants.C_ENGINE, +): + sample_data._check_finalised() + + manager = SampleMatcher( + sample_data, + ancestors_ts, + recombination_rate=recombination_rate, + mismatch_ratio=mismatch_ratio, + recombination=recombination, + mismatch=mismatch, + path_compression=path_compression, + num_threads=0, + precision=precision, + extended_checks=extended_checks, + engine=engine, + progress_monitor=None, + resume_lmdb_file=None, + use_dask=False, + ) + sample_indexes = check_sample_indexes(sample_data, indexes) + sample_times = np.zeros( + len(sample_indexes), dtype=sample_data.individuals_time.dtype + ) + if sample_times is None: + sample_times = np.zeros(len(sample_indexes)) + builder = manager.tree_sequence_builder + for j, t in zip(sample_indexes, sample_times): + manager.sample_id_map[j] = builder.add_node(t) + manager.find_path_to_disk(samples_slice=samples_slice, output_path=output_path) + + def insert_missing_sites( sample_data, tree_sequence, *, sample_id_map=None, progress_monitor=None ): @@ -1248,6 +1298,7 @@ def __init__( allow_multiallele=False, resume_lmdb_file=None, use_dask=False, + match_file_pattern=None, ): self.sample_data = sample_data self.num_threads = num_threads @@ -1261,6 +1312,7 @@ def __init__( self.match_progress = None # Allocated by subclass self.extended_checks = extended_checks self.use_dask = use_dask + self.match_file_pattern = match_file_pattern all_sites = self.sample_data.sites_position[:] index = np.searchsorted(all_sites, inference_site_position) @@ -1910,41 +1962,79 @@ def dask_find_path( site_indexes=None, sample_id_map=None, ): - t = time.time() + result = SampleMatcher.inner_find_path( + samples_slice, + tree_sequence_builder_wrapper.tsb, + data_path, + engine, + recombination, + mismatch, + precision, + extended_checks, + site_indexes, + sample_id_map, + ) + # Pickle here rather than let dask deal with it so we can log sizes + return pickle.dumps(result) + + def find_path_to_disk(self, samples_slice, output_path): + result = self.inner_find_path( + samples_slice=samples_slice, + tsb=self.tree_sequence_builder, + sampledata_or_path=self.sample_data, + engine=self.engine, + recombination=self.recombination, + mismatch=self.mismatch, + precision=self.precision, + extended_checks=self.extended_checks, + site_indexes=self.inference_site_id, + sample_id_map=self.sample_id_map, + ) + result = pickle.dumps((samples_slice, result)) + with open(output_path, "wb") as f: + f.write(result) + + @staticmethod + def inner_find_path( + samples_slice, + tsb, + sampledata_or_path, + engine, + recombination, + mismatch, + precision, + extended_checks, + site_indexes, + sample_id_map, + ): ancestor_matcher_class = ( _tsinfer.AncestorMatcher if engine == constants.C_ENGINE else algorithm.AncestorMatcher ) matcher = ancestor_matcher_class( - tree_sequence_builder_wrapper.tsb, + tsb, recombination=recombination, mismatch=mismatch, precision=precision, extended_checks=extended_checks, ) - logging.info( - f"Loading sample data from {data_path} at {time.time() - t:.2f} seconds" - ) - sample_data = formats.SgkitSampleData(data_path) - logging.info(f"Loaded sample data at {time.time() - t:.2f} seconds") + if isinstance(sampledata_or_path, formats.SampleData): + sample_data = sampledata_or_path + else: + sample_data = formats.SgkitSampleData(sampledata_or_path) + haplotypes = sample_data._slice_haplotypes( sites=site_indexes, recode_ancestral=True, samples_slice=samples_slice ) - logging.info(f"Init haplotypes at {time.time() - t:.2f} seconds") - # Pickle here rather than let dask deal with it so we can log sizes results = [] for sample_id, haplotype in haplotypes: - logging.info( - f"Finding path for {sample_id} at {time.time() - t:.2f} seconds" - ) results.append( AncestorMatcher.find_path( matcher, sample_id_map[sample_id], haplotype, 0, len(site_indexes) ) ) - logging.info(f"Found path for {sample_id} at {time.time() - t:.2f} seconds") - return pickle.dumps(results) + return results def restore_tree_sequence_builder(self): tables = self.ancestors_ts_tables @@ -2023,6 +2113,7 @@ def thread_worker_function(j_haplotype): start=0, end=self.num_sites, ) + self.match_progress.update() logger.info( f"{time.time()}Thread {threading.get_ident()} finished haplotype {j}" ) @@ -2084,30 +2175,27 @@ def match_with_dask( flat_results.extend(result) return flat_results - def _match_samples(self, sample_indexes): - num_samples = len(sample_indexes) - builder = self.tree_sequence_builder - _, times = builder.dump_nodes() - logger.info(f"Started matching for {num_samples} samples") - if self.num_sites == 0: - return - self.match_progress = self.progress_monitor.get("ms_match", num_samples) - if self.use_dask: - logger.info("Using dask for sample matching") - delayed_recombination = dask.delayed(self.recombination) - delayed_mismatch = dask.delayed(self.mismatch) - delayed_sites = dask.delayed(self.inference_site_id) - delayed_sample_id_map = dask.delayed(self.sample_id_map) - tree_sequence_builder_wrapper = TSBWrapper( - self.engine, - self.tree_sequence_builder, - self.num_alleles, - self.max_nodes, - self.max_edges, - ) - tree_sequence_builder_wrapper_delayed = dask.delayed( - tree_sequence_builder_wrapper, pure=True - ).persist() + def _dask_args(self): + logger.info("Using dask for sample matching") + return { + "tree_sequence_builder_wrapper_delayed": dask.delayed( + TSBWrapper( + self.engine, + self.tree_sequence_builder, + self.num_alleles, + self.max_nodes, + self.max_edges, + ), + pure=True, + ).persist(), + "delayed_recombination": dask.delayed(self.recombination), + "delayed_mismatch": dask.delayed(self.mismatch), + "delayed_sites": dask.delayed(self.inference_site_id), + "delayed_sample_id_map": dask.delayed(self.sample_id_map), + } + + def _process_samples_in_batches(self, sample_indexes): + dask_args = self._dask_args() if self.use_dask else {} with LMDBCache(self.resume_lmdb) as cache: start_index = 0 results = [] @@ -2120,13 +2208,10 @@ def _match_samples(self, sample_indexes): if batch_results is None: if self.use_dask: batch_results = self.match_with_dask( - (start_index, end_index), - tree_sequence_builder_wrapper_delayed, - delayed_recombination, - delayed_mismatch, - delayed_sites, - delayed_sample_id_map, + (start_index, end_index), **dask_args ) + for _ in batch_results: + self.match_progress.update() else: batch_results = self.match_locally( sample_indexes[start_index:end_index] @@ -2134,7 +2219,6 @@ def _match_samples(self, sample_indexes): batch_results_list = [] for result in batch_results: batch_results_list.append(result) - self.match_progress.update() batch_results = batch_results_list cache.put(key, batch_results) logger.info( @@ -2142,42 +2226,84 @@ def _match_samples(self, sample_indexes): f"{time.time() - t:.2f} seconds" ) else: + self.match_progress.update(end_index - start_index) logger.info( f"Found cached results for samples " f"{start_index}-{end_index} in {time.time() - t:.2f} seconds" ) results.extend(batch_results) start_index = end_index - self.match_progress.close() - logger.info( - "Inserting sample paths: {} edges in total".format( - sum(len(r.path.left) for r in results) + return results + + def results_from_disk(self, sample_indexes): + # Read in all the files that match the pattern + files = glob.glob(self.match_file_pattern) + results = {} + for file in files: + with open(file, "rb") as f: + (start, end), batch_results = pickle.load(f) + for j, result in zip(range(start, end), batch_results): + if j in results: + raise ValueError(f"Duplicate sample index {j} found in {file}") + results[j] = result + self.match_progress.update() + sorted_results = [] + for i, sample_index in enumerate(sample_indexes): + try: + result = results[i] + except KeyError: + raise ValueError(f"Sample index {i} not found in results files") + node_id = int(self.sample_id_map[sample_index]) + if node_id != result.node: + raise ValueError( + f"Sample {i} in results file has node {result.node} but was" + f" expecting {node_id}" ) + sorted_results.append(result) + return sorted_results + + def _match_samples(self, sample_indexes): + num_samples = len(sample_indexes) + builder = self.tree_sequence_builder + _, times = builder.dump_nodes() + logger.info(f"Started matching for {num_samples} samples") + if self.num_sites == 0: + return + self.match_progress = self.progress_monitor.get("ms_match", num_samples) + if self.match_file_pattern: + results = self.results_from_disk(sample_indexes) + else: + results = self._process_samples_in_batches(sample_indexes) + self.match_progress.close() + logger.info( + "Inserting sample paths: {} edges in total".format( + sum(len(r.path.left) for r in results) ) - progress_monitor = self.progress_monitor.get("ms_paths", num_samples) - for j, result in zip(sample_indexes, results): - node_id = int(self.sample_id_map[j]) - assert node_id == result.node - if np.any(times[node_id] > times[result.path.parent]): - p = result.path.parent[np.argmin(times[result.path.parent])] - raise ValueError( - f"Failed to put sample {j} (node {node_id}) at time " - f"{times[node_id]} as it has a younger parent (node {p})." - ) - builder.add_path( - result.node, - result.path.left, - result.path.right, - result.path.parent, - compress=self.path_compression, - ) - builder.add_mutations( - result.node, - result.mutations_site, - result.mutations_derived_state, + ) + progress_monitor = self.progress_monitor.get("ms_paths", num_samples) + for j, result in zip(sample_indexes, results): + node_id = int(self.sample_id_map[j]) + assert node_id == result.node + if np.any(times[node_id] > times[result.path.parent]): + p = result.path.parent[np.argmin(times[result.path.parent])] + raise ValueError( + f"Failed to put sample {j} (node {node_id}) at time " + f"{times[node_id]} as it has a younger parent (node {p})." ) - progress_monitor.update() - progress_monitor.close() + builder.add_path( + result.node, + result.path.left, + result.path.right, + result.path.parent, + compress=self.path_compression, + ) + builder.add_mutations( + result.node, + result.mutations_site, + result.mutations_derived_state, + ) + progress_monitor.update() + progress_monitor.close() def match_samples(self, sample_indexes, sample_times=None): if sample_times is None: diff --git a/tsinfer/progress.py b/tsinfer/progress.py index e28bddbe..b5f52b6c 100644 --- a/tsinfer/progress.py +++ b/tsinfer/progress.py @@ -101,7 +101,7 @@ class DummyProgress: Class that mimics the subset of the tqdm API that we use in this module. """ - def update(self): + def update(self, n=None): pass def close(self):