Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mask names #900

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 76 additions & 14 deletions tests/test_sgkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,25 +435,25 @@ class TestSgkitMask:
def test_sgkit_variant_mask(self, tmp_path, sites):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.zeros_like(ds["variant_position"], dtype=bool)
sites_mask = np.ones_like(ds["variant_position"], dtype=bool)
for i in sites:
sites_mask[i] = True
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
samples = tsinfer.SgkitSampleData(zarr_path)
sites_mask[i] = False
add_array_to_dataset("variant_mask_42", sites_mask, zarr_path)
samples = tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_42")
assert samples.num_sites == len(sites)
assert np.array_equal(samples.sites_mask, sites_mask)
assert np.array_equal(samples.sites_mask, ~sites_mask)
assert np.array_equal(
samples.sites_position, ts.tables.sites.position[sites_mask]
samples.sites_position, ts.tables.sites.position[~sites_mask]
)
inf_ts = tsinfer.infer(samples)
assert np.array_equal(
ts.genotype_matrix()[sites_mask], inf_ts.genotype_matrix()
ts.genotype_matrix()[~sites_mask], inf_ts.genotype_matrix()
)
assert np.array_equal(
ts.tables.sites.position[sites_mask], inf_ts.tables.sites.position
ts.tables.sites.position[~sites_mask], inf_ts.tables.sites.position
)
assert np.array_equal(
ts.tables.sites.ancestral_state[sites_mask],
ts.tables.sites.ancestral_state[~sites_mask],
inf_ts.tables.sites.ancestral_state,
)
# TODO - site metadata needs merging not replacing
Expand All @@ -464,26 +464,88 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
def test_sgkit_variant_bad_mask_length(self, tmp_path):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.ones(ds.sizes["variants"] + 1, dtype=int)
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int)
add_array_to_dataset("variant_mask_foobar", sites_mask, zarr_path)
tsinfer.SgkitSampleData(zarr_path)
with pytest.raises(
ValueError,
match="Mask must be the same length as the number of unmasked sites",
):
tsinfer.SgkitSampleData(zarr_path)
tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_foobar")

def test_bad_mask_length_at_iterator(self, tmp_path):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.ones(ds.sizes["variants"] + 1, dtype=int)
sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int)
from tsinfer.formats import chunk_iterator

with pytest.raises(
ValueError, match="Mask must be the same length as the array"
):
for _ in chunk_iterator(ds.variant_position, mask=sites_mask):
for _ in chunk_iterator(ds.call_genotype, mask=sites_mask):
pass

@pytest.mark.parametrize("sample_list", [[1, 2, 3, 5, 9, 27], [0], []])
def test_sgkit_sample_mask(self, tmp_path, sample_list):
ts, zarr_path = make_ts_and_zarr(tmp_path, add_optional=True)
ds = sgkit.load_dataset(zarr_path)
samples_mask = np.ones_like(ds["sample_id"], dtype=bool)
for i in sample_list:
samples_mask[i] = False
add_array_to_dataset("samples_mask_69", samples_mask, zarr_path)
samples = tsinfer.SgkitSampleData(
zarr_path, sgkit_samples_mask_name="samples_mask_69"
)
assert samples.ploidy == 3
assert samples.num_individuals == len(sample_list)
assert samples.num_samples == len(sample_list) * samples.ploidy
assert np.array_equal(samples.individuals_mask, ~samples_mask)
assert np.array_equal(samples.samples_mask, np.repeat(~samples_mask, 3))
assert np.array_equal(
samples.individuals_time, ds.individuals_time.values[~samples_mask]
)
assert np.array_equal(
samples.individuals_location, ds.individuals_location.values[~samples_mask]
)
assert np.array_equal(
samples.individuals_population,
ds.individuals_population.values[~samples_mask],
)
assert np.array_equal(
samples.individuals_flags, ds.individuals_flags.values[~samples_mask]
)
assert np.array_equal(
samples.samples_individual, np.repeat(np.arange(len(sample_list)), 3)
)
expected_gt = ds.call_genotype.values[:, ~samples_mask, :].reshape(
samples.num_sites, len(sample_list) * 3
)
assert np.array_equal(samples.sites_genotypes, expected_gt)
for v, gt in zip(samples.variants(), expected_gt):
assert np.array_equal(v.genotypes, gt)

for i, (id, haplo) in enumerate(samples.haplotypes()):
assert id == i
assert np.array_equal(haplo, expected_gt[:, i])

def test_sgkit_missing_masks(self, tmp_path):
ts, zarr_path = make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples.individuals_mask
samples.sites_mask
with pytest.raises(
ValueError, match="The sites mask foobar was not found in the dataset."
):
tsinfer.SgkitSampleData(zarr_path, sites_mask_name="foobar")
with pytest.raises(
ValueError,
match="The sgkit samples mask foobar2 was not found in the dataset.",
):
samples = tsinfer.SgkitSampleData(
zarr_path, sgkit_samples_mask_name="foobar2"
)
samples.individuals_mask


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
def test_sgkit_ancestral_allele_same_ancestors(tmp_path):
Expand Down
107 changes: 73 additions & 34 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def zarr_summary(array):
return ret


def chunk_iterator(array, indexes=None, mask=None, dimension=0):
def chunk_iterator(array, indexes=None, mask=None, orthogonal_mask=None, dimension=0):
"""
Utility to iterate over closely spaced rows in the specified array efficiently
by accessing one chunk at a time (normally used as an iterator over each row)
Expand All @@ -314,6 +314,8 @@ def chunk_iterator(array, indexes=None, mask=None, dimension=0):
assert dimension < 2
if mask is None:
mask = np.ones(array.shape[dimension], dtype=bool)
if orthogonal_mask is None:
orthogonal_mask = np.ones(array.shape[int(not dimension)], dtype=bool)
if len(mask) != array.shape[dimension]:
raise ValueError("Mask must be the same length as the array")

Expand All @@ -339,14 +341,14 @@ def chunk_iterator(array, indexes=None, mask=None, dimension=0):
if chunk_id != prev_chunk_id:
chunk = array[chunk_id * chunk_size : (chunk_id + 1) * chunk_size][:]
prev_chunk_id = chunk_id
yield chunk[j % chunk_size]
yield chunk[j % chunk_size, orthogonal_mask]
elif dimension == 1:
for j in indexes:
chunk_id = j // chunk_size
if chunk_id != prev_chunk_id:
chunk = array[:, chunk_id * chunk_size : (chunk_id + 1) * chunk_size][:]
prev_chunk_id = chunk_id
yield chunk[:, j % chunk_size]
yield chunk[orthogonal_mask, j % chunk_size]


def merge_variants(sd1, sd2):
Expand Down Expand Up @@ -2293,13 +2295,15 @@ class SgkitSampleData(SampleData):
FORMAT_NAME = "tsinfer-sgkit-sample-data"
FORMAT_VERSION = (0, 1)

def __init__(self, path):
def __init__(self, path, sgkit_samples_mask_name=None, sites_mask_name=None):
self.path = path
self.data = zarr.open(path, mode="r")
self._sgkit_samples_mask_name = sgkit_samples_mask_name
self._sites_mask_name = sites_mask_name
genotypes_arr = self.data["call_genotype"]
_, self._num_individuals, self.ploidy = genotypes_arr.shape
_, self._num_unmasked_individuals, self.ploidy = genotypes_arr.shape
self._num_sites = np.sum(self.sites_mask)
self._num_samples = self._num_individuals * self.ploidy
self._num_unmasked_samples = self._num_unmasked_individuals * self.ploidy

assert self.ploidy == self.data["call_genotype"].chunks[2]
if self.ploidy > 1:
Expand Down Expand Up @@ -2333,6 +2337,26 @@ def sequence_length(self):
def num_sites(self):
return self._num_sites

@functools.cached_property
def individuals_mask(self):
if self._sgkit_samples_mask_name is None:
return np.full(self._num_unmasked_individuals, True, dtype=bool)
else:
try:
# We negate the mask as it is much easier in numpy to have True=keep
return ~(self.data[self._sgkit_samples_mask_name][:].astype(bool))
except KeyError:
raise ValueError(
f"The sgkit samples mask {self._sgkit_samples_mask_name} was not"
f" found in the dataset."
)

@functools.cached_property
def samples_mask(self):
# Samples in sgkit are individuals in tskit, so we need to expand
# the mask to cover all the samples for each individual.
return np.repeat(self.individuals_mask, self.ploidy)

@functools.cached_property
def sites_metadata_schema(self):
try:
Expand Down Expand Up @@ -2370,19 +2394,25 @@ def sites_alleles(self):

@functools.cached_property
def sites_mask(self):
try:
if (
self.data["variant_mask"].shape[0]
!= self.data["variant_position"].shape[0]
):
if self._sites_mask_name is None:
return np.full(self.data["variant_position"].shape, True, dtype=bool)
else:
try:
if (
self.data[self._sites_mask_name].shape[0]
!= self.data["variant_position"].shape[0]
):
raise ValueError(
"Mask must be the same length as the number of unmasked sites"
)
# We negate the mask as it is much easier in numpy to have True=keep
return ~(self.data[self._sites_mask_name].astype(bool)[:])
except KeyError:
raise ValueError(
"Mask must be the same length as the number of unmasked sites"
f"The sites mask {self._sites_mask_name} was not found"
f" in the dataset."
)

return self.data["variant_mask"].astype(bool)
except KeyError:
return np.full(self.data["variant_position"].shape, True, dtype=bool)

@functools.cached_property
def sites_ancestral_allele(self):
unknown_alleles = collections.Counter()
Expand Down Expand Up @@ -2427,9 +2457,9 @@ def sites_genotypes(self):
gt = self.data["call_genotype"]
# This method is only used for test/debug so we retrieve and
# reshape the entire array.
return gt[...][self.sites_mask, :, :].reshape(
gt.shape[0], gt.shape[1] * gt.shape[2]
)
ret = gt[...][self.sites_mask, :, :]
ret = ret[:, self.individuals_mask, :]
return ret.reshape(ret.shape[0], ret.shape[1] * ret.shape[2])

@functools.cached_property
def provenances_timestamp(self):
Expand All @@ -2445,9 +2475,9 @@ def provenances_record(self):
except KeyError:
return np.array([], dtype=object)

@property
@functools.cached_property
def num_samples(self):
return self._num_samples
return np.sum(self.samples_mask)

@functools.cached_property
def samples_individual(self):
Expand Down Expand Up @@ -2500,12 +2530,12 @@ def populations_metadata_schema(self):

@property
def num_individuals(self):
return self._num_individuals
return np.sum(self.individuals_mask)

@functools.cached_property
def individuals_time(self):
try:
return self.data["individuals_time"]
return self.data["individuals_time"][:][self.individuals_mask]
except KeyError:
return np.full(self.num_individuals, tskit.UNKNOWN_TIME)

Expand All @@ -2524,11 +2554,14 @@ def individuals_metadata(self):
# We set the sample_id in the individual metadata as this is often useful,
# however we silently don't overwrite if the key exists
if "individuals_metadata" in self.data:
assert len(self.data["individuals_metadata"]) == self.num_individuals
assert self.num_individuals == len(self.data["sample_id"])
assert (
len(self.data["individuals_metadata"]) == self._num_unmasked_individuals
)
assert self._num_unmasked_individuals == len(self.data["sample_id"])
md_list = []
for sample_id, r in zip(
self.data["sample_id"], self.data["individuals_metadata"][:]
self.data["sample_id"][:][self.individuals_mask],
self.data["individuals_metadata"][:][self.individuals_mask],
):
md = schema.decode_row(r)
if "sgkit_sample_id" not in md:
Expand All @@ -2537,27 +2570,28 @@ def individuals_metadata(self):
return md_list
else:
return [
{"sgkit_sample_id": sample_id} for sample_id in self.data["sample_id"]
{"sgkit_sample_id": sample_id}
for sample_id in self.data["sample_id"][:][self.individuals_mask]
]

@functools.cached_property
def individuals_location(self):
try:
return self.data["individuals_location"]
return self.data["individuals_location"][:][self.individuals_mask]
except KeyError:
return np.array([[]] * self.num_individuals, dtype=float)

@functools.cached_property
def individuals_population(self):
try:
return self.data["individuals_population"]
return self.data["individuals_population"][:][self.individuals_mask]
except KeyError:
return np.full((self.num_individuals), tskit.NULL, dtype=np.int32)

@functools.cached_property
def individuals_flags(self):
try:
return self.data["individuals_flags"]
return self.data["individuals_flags"][:][self.individuals_mask]
except KeyError:
return np.full((self.num_individuals), 0, dtype=np.int32)

Expand Down Expand Up @@ -2585,7 +2619,10 @@ def variants(self, sites=None, recode_ancestral=None):
if recode_ancestral is None:
recode_ancestral = False
all_genotypes = chunk_iterator(
self.data["call_genotype"], indexes=sites, mask=self.sites_mask
self.data["call_genotype"],
indexes=sites,
mask=self.sites_mask,
orthogonal_mask=self.individuals_mask,
)
assert MISSING_DATA < 0 # required for geno_map to remap MISSING_DATA
for genos, site in zip(all_genotypes, self.sites(ids=sites)):
Expand Down Expand Up @@ -2627,9 +2664,11 @@ def _all_haplotypes(self, sites=None, recode_ancestral=None):
aa_index[aa_index == MISSING_DATA] = 0
gt = self.data["call_genotype"]
chunk_size = gt.chunks[1]
for j in range(self.num_individuals):
if j % chunk_size == 0:
chunk = gt[:, j : j + chunk_size, :]
current_chunk = None
for j in np.where(self.individuals_mask)[0]:
if j // chunk_size != current_chunk:
current_chunk = j // chunk_size
chunk = gt[:, j // chunk_size : (j // chunk_size) + chunk_size, :]
# Zarr doesn't support fancy indexing, so we have to do this after
chunk = chunk[self.sites_mask]
indiv_gt = chunk[:, j % chunk_size, :]
Expand Down
1 change: 1 addition & 0 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,7 @@ def group_by_linesweep(self):
epoch_end = np.hstack([breaks + 1, [self.num_ancestors]])
time_slices = np.vstack([epoch_start, epoch_end]).T
epoch_sizes = time_slices[:, 1] - time_slices[:, 0]

median_size = np.median(epoch_sizes)
cutoff = 500 * median_size
# Zero out the first half so that an initial large epoch doesn't
Expand Down
Loading