Skip to content

Commit

Permalink
Add match_samples_to_disk function
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery authored and mergify[bot] committed Nov 20, 2023
1 parent 93e386e commit f15abee
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 73 deletions.
57 changes: 57 additions & 0 deletions tests/test_sgkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
Tests for the data files.
"""
import json
import os
import pickle
import sys
import tempfile

Expand Down Expand Up @@ -615,3 +617,58 @@ def test_empty_alleles_not_at_end(self, tmp_path):
samples = tsinfer.SgkitSampleData(path)
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
tsinfer.infer(samples)


class TestSgkitMatchSamplesToDisk:
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
@pytest.mark.parametrize("slice", [(0, 5), (0, 0), (0, 1), (10, 15)])
def test_match_samples_to_disk_write(
self, slice, small_sd_fixture, tmp_path, tmpdir
):
ts, zarr_path = make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
tsinfer.match_samples_slice_to_disk(
samples, anc_ts, slice, tmpdir / "test.path"
)
file_slice, matches = pickle.load(open(tmpdir / "test.path", "rb"))
assert slice == file_slice
assert len(matches) == slice[1] - slice[0]
for m in matches:
assert isinstance(m, tsinfer.inference.MatchResult)

@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_match_samples_to_disk_full(self, small_sd_fixture, tmp_path, tmpdir):
ts, zarr_path = make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
ts = tsinfer.match_samples(samples, anc_ts)
start_index = 0
while start_index < ts.num_samples:
end_index = min(start_index + 5, ts.num_samples)
tsinfer.match_samples_slice_to_disk(
samples,
anc_ts,
(start_index, end_index),
tmpdir / f"test-{start_index}.path",
)
start_index = end_index
batch_ts = tsinfer.match_samples(
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
)
ts.tables.assert_equals(batch_ts.tables, ignore_provenance=True)

tmpdir.join("test-5.path").copy(tmpdir.join("test-5-copy.path"))
with pytest.raises(ValueError, match="Duplicate sample index 5"):
tsinfer.match_samples(
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
)

os.remove(tmpdir / "test-5.path")
os.remove(tmpdir / "test-5-copy.path")
with pytest.raises(ValueError, match="index 5 not found"):
tsinfer.match_samples(
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
)
Loading

0 comments on commit f15abee

Please sign in to comment.