Skip to content

Commit

Permalink
Switch to file based caching
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed May 15, 2024
1 parent 7b38085 commit 8729437
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 210 deletions.
47 changes: 21 additions & 26 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import json
import logging
import os.path
import pickle
import random
import re
import string
Expand All @@ -33,7 +34,6 @@
import unittest
import unittest.mock as mock

import lmdb
import msprime
import numpy as np
import pytest
Expand Down Expand Up @@ -1358,43 +1358,38 @@ def test_equivalance(self):
assert ts1.equals(ts2, ignore_provenance=True)


@pytest.mark.skipif(IS_WINDOWS, reason="Not enough disk space as no sparse files")
class TestResume:
def count_keys(self, lmdb_file):
with lmdb.open(
lmdb_file, subdir=False, map_size=100 * 1024 * 1024 * 1024
) as lmdb_file:
with lmdb_file.begin() as txn:
# Count the number of keys
n_keys = 0
for _ in txn.cursor():
n_keys += 1
return n_keys
def count_paths(self, match_data_dir):
path_count = 0
for filename in os.listdir(match_data_dir):
with open(os.path.join(match_data_dir, filename), "rb") as f:
stored_data = pickle.load(f)
path_count += len(stored_data.results)
return path_count

def test_equivalance(self, tmpdir):
lmdb_file = str(tmpdir / "LMDB")
ts = msprime.simulate(5, mutation_rate=2, recombination_rate=2, random_seed=2)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
ancestor_data = tsinfer.generate_ancestors(sample_data)
ancestor_ts1 = tsinfer.match_ancestors(
sample_data, ancestor_data, resume_lmdb_file=lmdb_file
sample_data, ancestor_data, match_data_dir=tmpdir
)
assert self.count_keys(lmdb_file) == 4
assert self.count_paths(tmpdir) == 4
ancestor_ts2 = tsinfer.match_ancestors(
sample_data, ancestor_data, resume_lmdb_file=lmdb_file
sample_data, ancestor_data, match_data_dir=tmpdir
)
ancestor_ts1.tables.assert_equals(ancestor_ts2.tables, ignore_provenance=True)
final_ts1 = tsinfer.match_samples(
sample_data, ancestor_ts1, resume_lmdb_file=lmdb_file
sample_data, ancestor_ts1, match_data_dir=tmpdir
)
assert self.count_keys(lmdb_file) == 5
assert self.count_paths(tmpdir) == 9
final_ts2 = tsinfer.match_samples(
sample_data, ancestor_ts1, resume_lmdb_file=lmdb_file
sample_data, ancestor_ts1, match_data_dir=tmpdir
)
final_ts1.tables.assert_equals(final_ts2.tables, ignore_provenance=True)

def test_cache_used_by_timing(self, tmpdir):
lmdb_file = str(tmpdir / "LMDB")

ts = msprime.sim_ancestry(
100, recombination_rate=1, sequence_length=1000, random_seed=42
)
Expand All @@ -1405,27 +1400,27 @@ def test_cache_used_by_timing(self, tmpdir):
ancestor_data = tsinfer.generate_ancestors(sample_data)
t = time.time()
ancestor_ts1 = tsinfer.match_ancestors(
sample_data, ancestor_data, resume_lmdb_file=lmdb_file
sample_data, ancestor_data, match_data_dir=tmpdir
)
time1 = time.time() - t
assert self.count_keys(lmdb_file) >= 103
assert self.count_paths(tmpdir) == 1001
t = time.time()
ancestor_ts2 = tsinfer.match_ancestors(
sample_data, ancestor_data, resume_lmdb_file=lmdb_file
sample_data, ancestor_data, match_data_dir=tmpdir
)
ancestor_ts1.tables.assert_equals(ancestor_ts2.tables, ignore_provenance=True)
time2 = time.time() - t
assert time2 < time1 / 2

t = time.time()
final_ts1 = tsinfer.match_samples(
sample_data, ancestor_ts1, resume_lmdb_file=lmdb_file
sample_data, ancestor_ts1, match_data_dir=tmpdir
)
time1 = time.time() - t
assert self.count_keys(lmdb_file) == 104
assert self.count_paths(tmpdir) == 1201
t = time.time()
final_ts2 = tsinfer.match_samples(
sample_data, ancestor_ts1, resume_lmdb_file=lmdb_file
sample_data, ancestor_ts1, match_data_dir=tmpdir
)
time2 = time.time() - t
assert time2 < time1 / 1.25
Expand Down
44 changes: 24 additions & 20 deletions tests/test_sgkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,14 @@ def test_match_samples_to_disk_write(self, slice, tmp_path, tmpdir):
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
tsinfer.match_samples_slice_to_disk(
samples, anc_ts, slice, tmpdir / "test.path"
samples, anc_ts, slice, tmpdir / "samples.pkl"
)
file_slice, matches = pickle.load(open(tmpdir / "test.path", "rb"))
assert slice == file_slice
assert len(matches) == slice[1] - slice[0]
for m in matches:
stored = pickle.load(open(tmpdir / "samples.pkl", "rb"))
assert stored.group_id == "samples"
assert stored.num_sites == 86 # Num inferred sites
assert len(stored.results) == slice[1] - slice[0]
for i, (s, m) in enumerate(stored.results.items()):
assert s == slice[0] + i
assert isinstance(m, tsinfer.inference.MatchResult)

def test_match_samples_to_disk_slice_error(self, tmp_path, tmpdir):
Expand All @@ -622,6 +624,8 @@ def test_match_samples_to_disk_slice_error(self, tmp_path, tmpdir):
)

def test_match_samples_to_disk_full(self, tmp_path, tmpdir):
match_data_dir = tmpdir / "match_data"
os.mkdir(match_data_dir)
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(samples)
Expand All @@ -634,31 +638,31 @@ def test_match_samples_to_disk_full(self, tmp_path, tmpdir):
samples,
anc_ts,
(start_index, end_index),
tmpdir / f"test-{start_index}.path",
match_data_dir / f"test-{start_index}.pkl",
)
start_index = end_index
batch_ts = tsinfer.match_samples(
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
samples, anc_ts, match_data_dir=str(match_data_dir)
)
ts.tables.assert_equals(batch_ts.tables, ignore_provenance=True)

tmpdir.join("test-6.path").copy(tmpdir.join("test-6-copy.path"))
(match_data_dir / "test-6.pkl").copy(match_data_dir / "test-6-copy.pkl")
with pytest.raises(ValueError, match="Duplicate sample index 6"):
tsinfer.match_samples(
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
)
tsinfer.match_samples(samples, anc_ts, match_data_dir=str(match_data_dir))

os.remove(tmpdir / "test-6.path")
os.remove(tmpdir / "test-6-copy.path")
os.remove(match_data_dir / "test-6.pkl")
os.remove(match_data_dir / "test-6-copy.pkl")
with pytest.raises(ValueError, match="index 6 not found"):
tsinfer.match_samples(
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
)
tsinfer.match_samples(samples, anc_ts, match_data_dir=str(match_data_dir))

def test_match_samples_to_disk_with_mask(self, tmp_path, tmpdir):
mat_sd, mask_sd, _, _ = tsutil.make_materialized_and_masked_sampledata(
tmp_path, tmpdir
)
mat_data_dir = tmpdir / "mat_data"
os.mkdir(mat_data_dir)
mask_data_dir = tmpdir / "mask_data"
os.mkdir(mask_data_dir)
mat_ancestors = tsinfer.generate_ancestors(mat_sd)
mask_ancestors = tsinfer.generate_ancestors(mask_sd)
mat_anc_ts = tsinfer.match_ancestors(mat_sd, mat_ancestors)
Expand All @@ -670,12 +674,12 @@ def test_match_samples_to_disk_with_mask(self, tmp_path, tmpdir):
mat_sd,
mat_anc_ts,
(start_index, end_index),
tmpdir / f"test-mat-{start_index}.path",
mat_data_dir / f"test-mat-{start_index}.path",
)
start_index = end_index

mat_ts_disk = tsinfer.match_samples(
mat_sd, mat_anc_ts, match_file_pattern=str(tmpdir / "test-mat-*.path")
mat_sd, mat_anc_ts, match_data_dir=str(mat_data_dir)
)

start_index = 0
Expand All @@ -685,11 +689,11 @@ def test_match_samples_to_disk_with_mask(self, tmp_path, tmpdir):
mask_sd,
mask_anc_ts,
(start_index, end_index),
tmpdir / f"test-mask-{start_index}.path",
mask_data_dir / f"test-mask-{start_index}.path",
)
start_index = end_index
mask_ts_disk = tsinfer.match_samples(
mask_sd, mask_anc_ts, match_file_pattern=str(tmpdir / "test-mask-*.path")
mask_sd, mask_anc_ts, match_data_dir=str(mask_data_dir)
)

mask_ts = tsinfer.match_samples(mask_sd, mask_anc_ts)
Expand Down
Loading

0 comments on commit 8729437

Please sign in to comment.