From e529ac01ed52ec61e92dbfa235ddb668875a7e9f Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 26 Jan 2024 12:24:56 +0000 Subject: [PATCH 1/3] Add sample mask --- tests/test_sgkit.py | 43 ++++++++++++++++++++++++++- tsinfer/formats.py | 70 +++++++++++++++++++++++++++++--------------- tsinfer/inference.py | 1 + 3 files changed, 90 insertions(+), 24 deletions(-) diff --git a/tests/test_sgkit.py b/tests/test_sgkit.py index b930263a..9a0bced4 100644 --- a/tests/test_sgkit.py +++ b/tests/test_sgkit.py @@ -481,9 +481,50 @@ def test_bad_mask_length_at_iterator(self, tmp_path): with pytest.raises( ValueError, match="Mask must be the same length as the array" ): - for _ in chunk_iterator(ds.variant_position, mask=sites_mask): + for _ in chunk_iterator(ds.call_genotype, mask=sites_mask): pass + @pytest.mark.parametrize("sample_list", [[1, 2, 3, 5, 9, 27], [0], []]) + def test_sgkit_sample_mask(self, tmp_path, sample_list): + ts, zarr_path = make_ts_and_zarr(tmp_path, add_optional=True) + ds = sgkit.load_dataset(zarr_path) + samples_mask = np.zeros_like(ds["sample_id"], dtype=bool) + for i in sample_list: + samples_mask[i] = True + add_array_to_dataset("samples_mask", samples_mask, zarr_path) + samples = tsinfer.SgkitSampleData(zarr_path) + assert samples.ploidy == 3 + assert samples.num_individuals == len(sample_list) + assert samples.num_samples == len(sample_list) * samples.ploidy + assert np.array_equal(samples.individuals_mask, samples_mask) + assert np.array_equal(samples.samples_mask, np.repeat(samples_mask, 3)) + assert np.array_equal( + samples.individuals_time, ds.individuals_time.values[samples_mask] + ) + assert np.array_equal( + samples.individuals_location, ds.individuals_location.values[samples_mask] + ) + assert np.array_equal( + samples.individuals_population, + ds.individuals_population.values[samples_mask], + ) + assert np.array_equal( + samples.individuals_flags, ds.individuals_flags.values[samples_mask] + ) + assert np.array_equal( + samples.samples_individual, np.repeat(np.arange(len(sample_list)), 3) + ) + expected_gt = ds.call_genotype.values[:, samples_mask, :].reshape( + samples.num_sites, len(sample_list) * 3 + ) + assert np.array_equal(samples.sites_genotypes, expected_gt) + for v, gt in zip(samples.variants(), expected_gt): + assert np.array_equal(v.genotypes, gt) + + for i, (id, haplo) in enumerate(samples.haplotypes()): + assert id == i + assert np.array_equal(haplo, expected_gt[:, i]) + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows") def test_sgkit_ancestral_allele_same_ancestors(tmp_path): diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 5ae575b7..72f5d3c4 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -305,7 +305,7 @@ def zarr_summary(array): return ret -def chunk_iterator(array, indexes=None, mask=None, dimension=0): +def chunk_iterator(array, indexes=None, mask=None, orthogonal_mask=None, dimension=0): """ Utility to iterate over closely spaced rows in the specified array efficiently by accessing one chunk at a time (normally used as an iterator over each row) @@ -314,6 +314,8 @@ def chunk_iterator(array, indexes=None, mask=None, dimension=0): assert dimension < 2 if mask is None: mask = np.ones(array.shape[dimension], dtype=bool) + if orthogonal_mask is None: + orthogonal_mask = np.ones(array.shape[int(not dimension)], dtype=bool) if len(mask) != array.shape[dimension]: raise ValueError("Mask must be the same length as the array") @@ -339,14 +341,14 @@ def chunk_iterator(array, indexes=None, mask=None, dimension=0): if chunk_id != prev_chunk_id: chunk = array[chunk_id * chunk_size : (chunk_id + 1) * chunk_size][:] prev_chunk_id = chunk_id - yield chunk[j % chunk_size] + yield chunk[j % chunk_size, orthogonal_mask] elif dimension == 1: for j in indexes: chunk_id = j // chunk_size if chunk_id != prev_chunk_id: chunk = array[:, chunk_id * chunk_size : (chunk_id + 1) * chunk_size][:] prev_chunk_id = chunk_id - yield chunk[:, j % chunk_size] + yield chunk[orthogonal_mask, j % chunk_size] def merge_variants(sd1, sd2): @@ -2297,9 +2299,9 @@ def __init__(self, path): self.path = path self.data = zarr.open(path, mode="r") genotypes_arr = self.data["call_genotype"] - _, self._num_individuals, self.ploidy = genotypes_arr.shape + _, self._num_unmasked_individuals, self.ploidy = genotypes_arr.shape self._num_sites = np.sum(self.sites_mask) - self._num_samples = self._num_individuals * self.ploidy + self._num_unmasked_samples = self._num_unmasked_individuals * self.ploidy assert self.ploidy == self.data["call_genotype"].chunks[2] if self.ploidy > 1: @@ -2333,6 +2335,19 @@ def sequence_length(self): def num_sites(self): return self._num_sites + @functools.cached_property + def individuals_mask(self): + try: + return self.data["samples_mask"][:].astype(bool) + except KeyError: + return np.full(self._num_unmasked_individuals, True, dtype=bool) + + @functools.cached_property + def samples_mask(self): + # Samples in sgkit are individuals in tskit, so we need to expand + # the mask to cover all the samples for each individual. + return np.repeat(self.individuals_mask, self.ploidy) + @functools.cached_property def sites_metadata_schema(self): try: @@ -2427,9 +2442,9 @@ def sites_genotypes(self): gt = self.data["call_genotype"] # This method is only used for test/debug so we retrieve and # reshape the entire array. - return gt[...][self.sites_mask, :, :].reshape( - gt.shape[0], gt.shape[1] * gt.shape[2] - ) + ret = gt[...][self.sites_mask, :, :] + ret = ret[:, self.individuals_mask, :] + return ret.reshape(ret.shape[0], ret.shape[1] * ret.shape[2]) @functools.cached_property def provenances_timestamp(self): @@ -2445,9 +2460,9 @@ def provenances_record(self): except KeyError: return np.array([], dtype=object) - @property + @functools.cached_property def num_samples(self): - return self._num_samples + return np.sum(self.samples_mask) @functools.cached_property def samples_individual(self): @@ -2500,12 +2515,12 @@ def populations_metadata_schema(self): @property def num_individuals(self): - return self._num_individuals + return np.sum(self.individuals_mask) @functools.cached_property def individuals_time(self): try: - return self.data["individuals_time"] + return self.data["individuals_time"][:][self.individuals_mask] except KeyError: return np.full(self.num_individuals, tskit.UNKNOWN_TIME) @@ -2524,11 +2539,14 @@ def individuals_metadata(self): # We set the sample_id in the individual metadata as this is often useful, # however we silently don't overwrite if the key exists if "individuals_metadata" in self.data: - assert len(self.data["individuals_metadata"]) == self.num_individuals - assert self.num_individuals == len(self.data["sample_id"]) + assert ( + len(self.data["individuals_metadata"]) == self._num_unmasked_individuals + ) + assert self._num_unmasked_individuals == len(self.data["sample_id"]) md_list = [] for sample_id, r in zip( - self.data["sample_id"], self.data["individuals_metadata"][:] + self.data["sample_id"][:][self.individuals_mask], + self.data["individuals_metadata"][:][self.individuals_mask], ): md = schema.decode_row(r) if "sgkit_sample_id" not in md: @@ -2537,27 +2555,28 @@ def individuals_metadata(self): return md_list else: return [ - {"sgkit_sample_id": sample_id} for sample_id in self.data["sample_id"] + {"sgkit_sample_id": sample_id} + for sample_id in self.data["sample_id"][:][self.individuals_mask] ] @functools.cached_property def individuals_location(self): try: - return self.data["individuals_location"] + return self.data["individuals_location"][:][self.individuals_mask] except KeyError: return np.array([[]] * self.num_individuals, dtype=float) @functools.cached_property def individuals_population(self): try: - return self.data["individuals_population"] + return self.data["individuals_population"][:][self.individuals_mask] except KeyError: return np.full((self.num_individuals), tskit.NULL, dtype=np.int32) @functools.cached_property def individuals_flags(self): try: - return self.data["individuals_flags"] + return self.data["individuals_flags"][:][self.individuals_mask] except KeyError: return np.full((self.num_individuals), 0, dtype=np.int32) @@ -2585,7 +2604,10 @@ def variants(self, sites=None, recode_ancestral=None): if recode_ancestral is None: recode_ancestral = False all_genotypes = chunk_iterator( - self.data["call_genotype"], indexes=sites, mask=self.sites_mask + self.data["call_genotype"], + indexes=sites, + mask=self.sites_mask, + orthogonal_mask=self.individuals_mask, ) assert MISSING_DATA < 0 # required for geno_map to remap MISSING_DATA for genos, site in zip(all_genotypes, self.sites(ids=sites)): @@ -2627,9 +2649,11 @@ def _all_haplotypes(self, sites=None, recode_ancestral=None): aa_index[aa_index == MISSING_DATA] = 0 gt = self.data["call_genotype"] chunk_size = gt.chunks[1] - for j in range(self.num_individuals): - if j % chunk_size == 0: - chunk = gt[:, j : j + chunk_size, :] + current_chunk = None + for j in np.where(self.individuals_mask)[0]: + if j // chunk_size != current_chunk: + current_chunk = j // chunk_size + chunk = gt[:, j // chunk_size : (j // chunk_size) + chunk_size, :] # Zarr doesn't support fancy indexing, so we have to do this after chunk = chunk[self.sites_mask] indiv_gt = chunk[:, j % chunk_size, :] diff --git a/tsinfer/inference.py b/tsinfer/inference.py index d053053c..f4b365af 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -1644,6 +1644,7 @@ def group_by_linesweep(self): epoch_end = np.hstack([breaks + 1, [self.num_ancestors]]) time_slices = np.vstack([epoch_start, epoch_end]).T epoch_sizes = time_slices[:, 1] - time_slices[:, 0] + median_size = np.median(epoch_sizes) cutoff = 500 * median_size # Zero out the first half so that an initial large epoch doesn't From 444e4fa6f39a7ee8bebd18d9d80a05efa9703ce2 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 6 Feb 2024 10:43:47 +0000 Subject: [PATCH 2/3] Flip sgkit mask polarity --- tests/test_sgkit.py | 36 ++++++++++++++++++------------------ tsinfer/formats.py | 7 ++++--- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/test_sgkit.py b/tests/test_sgkit.py index 9a0bced4..07e5757e 100644 --- a/tests/test_sgkit.py +++ b/tests/test_sgkit.py @@ -435,25 +435,25 @@ class TestSgkitMask: def test_sgkit_variant_mask(self, tmp_path, sites): ts, zarr_path = make_ts_and_zarr(tmp_path) ds = sgkit.load_dataset(zarr_path) - sites_mask = np.zeros_like(ds["variant_position"], dtype=bool) + sites_mask = np.ones_like(ds["variant_position"], dtype=bool) for i in sites: - sites_mask[i] = True + sites_mask[i] = False add_array_to_dataset("variant_mask", sites_mask, zarr_path) samples = tsinfer.SgkitSampleData(zarr_path) assert samples.num_sites == len(sites) - assert np.array_equal(samples.sites_mask, sites_mask) + assert np.array_equal(samples.sites_mask, ~sites_mask) assert np.array_equal( - samples.sites_position, ts.tables.sites.position[sites_mask] + samples.sites_position, ts.tables.sites.position[~sites_mask] ) inf_ts = tsinfer.infer(samples) assert np.array_equal( - ts.genotype_matrix()[sites_mask], inf_ts.genotype_matrix() + ts.genotype_matrix()[~sites_mask], inf_ts.genotype_matrix() ) assert np.array_equal( - ts.tables.sites.position[sites_mask], inf_ts.tables.sites.position + ts.tables.sites.position[~sites_mask], inf_ts.tables.sites.position ) assert np.array_equal( - ts.tables.sites.ancestral_state[sites_mask], + ts.tables.sites.ancestral_state[~sites_mask], inf_ts.tables.sites.ancestral_state, ) # TODO - site metadata needs merging not replacing @@ -464,7 +464,7 @@ def test_sgkit_variant_mask(self, tmp_path, sites): def test_sgkit_variant_bad_mask_length(self, tmp_path): ts, zarr_path = make_ts_and_zarr(tmp_path) ds = sgkit.load_dataset(zarr_path) - sites_mask = np.ones(ds.sizes["variants"] + 1, dtype=int) + sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int) add_array_to_dataset("variant_mask", sites_mask, zarr_path) with pytest.raises( ValueError, @@ -475,7 +475,7 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path): def test_bad_mask_length_at_iterator(self, tmp_path): ts, zarr_path = make_ts_and_zarr(tmp_path) ds = sgkit.load_dataset(zarr_path) - sites_mask = np.ones(ds.sizes["variants"] + 1, dtype=int) + sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int) from tsinfer.formats import chunk_iterator with pytest.raises( @@ -488,33 +488,33 @@ def test_bad_mask_length_at_iterator(self, tmp_path): def test_sgkit_sample_mask(self, tmp_path, sample_list): ts, zarr_path = make_ts_and_zarr(tmp_path, add_optional=True) ds = sgkit.load_dataset(zarr_path) - samples_mask = np.zeros_like(ds["sample_id"], dtype=bool) + samples_mask = np.ones_like(ds["sample_id"], dtype=bool) for i in sample_list: - samples_mask[i] = True + samples_mask[i] = False add_array_to_dataset("samples_mask", samples_mask, zarr_path) samples = tsinfer.SgkitSampleData(zarr_path) assert samples.ploidy == 3 assert samples.num_individuals == len(sample_list) assert samples.num_samples == len(sample_list) * samples.ploidy - assert np.array_equal(samples.individuals_mask, samples_mask) - assert np.array_equal(samples.samples_mask, np.repeat(samples_mask, 3)) + assert np.array_equal(samples.individuals_mask, ~samples_mask) + assert np.array_equal(samples.samples_mask, np.repeat(~samples_mask, 3)) assert np.array_equal( - samples.individuals_time, ds.individuals_time.values[samples_mask] + samples.individuals_time, ds.individuals_time.values[~samples_mask] ) assert np.array_equal( - samples.individuals_location, ds.individuals_location.values[samples_mask] + samples.individuals_location, ds.individuals_location.values[~samples_mask] ) assert np.array_equal( samples.individuals_population, - ds.individuals_population.values[samples_mask], + ds.individuals_population.values[~samples_mask], ) assert np.array_equal( - samples.individuals_flags, ds.individuals_flags.values[samples_mask] + samples.individuals_flags, ds.individuals_flags.values[~samples_mask] ) assert np.array_equal( samples.samples_individual, np.repeat(np.arange(len(sample_list)), 3) ) - expected_gt = ds.call_genotype.values[:, samples_mask, :].reshape( + expected_gt = ds.call_genotype.values[:, ~samples_mask, :].reshape( samples.num_sites, len(sample_list) * 3 ) assert np.array_equal(samples.sites_genotypes, expected_gt) diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 72f5d3c4..7bd8a7c2 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2338,7 +2338,8 @@ def num_sites(self): @functools.cached_property def individuals_mask(self): try: - return self.data["samples_mask"][:].astype(bool) + # We negate the mask as it is much easier in numpy to have True=keep + return ~(self.data["samples_mask"][:].astype(bool)) except KeyError: return np.full(self._num_unmasked_individuals, True, dtype=bool) @@ -2393,8 +2394,8 @@ def sites_mask(self): raise ValueError( "Mask must be the same length as the number of unmasked sites" ) - - return self.data["variant_mask"].astype(bool) + # We negate the mask as it is much easier in numpy to have True=keep + return ~(self.data["variant_mask"].astype(bool)[:]) except KeyError: return np.full(self.data["variant_position"].shape, True, dtype=bool) From 15267b7d83f7c80e302108c7bf08e6cd5f3b735e Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 7 Feb 2024 14:46:02 +0000 Subject: [PATCH 3/3] Add mask names --- tests/test_sgkit.py | 33 +++++++++++++++++++++++++++------ tsinfer/formats.py | 44 +++++++++++++++++++++++++++++--------------- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/tests/test_sgkit.py b/tests/test_sgkit.py index 07e5757e..d3c842d3 100644 --- a/tests/test_sgkit.py +++ b/tests/test_sgkit.py @@ -438,8 +438,8 @@ def test_sgkit_variant_mask(self, tmp_path, sites): sites_mask = np.ones_like(ds["variant_position"], dtype=bool) for i in sites: sites_mask[i] = False - add_array_to_dataset("variant_mask", sites_mask, zarr_path) - samples = tsinfer.SgkitSampleData(zarr_path) + add_array_to_dataset("variant_mask_42", sites_mask, zarr_path) + samples = tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_42") assert samples.num_sites == len(sites) assert np.array_equal(samples.sites_mask, ~sites_mask) assert np.array_equal( @@ -465,12 +465,13 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path): ts, zarr_path = make_ts_and_zarr(tmp_path) ds = sgkit.load_dataset(zarr_path) sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int) - add_array_to_dataset("variant_mask", sites_mask, zarr_path) + add_array_to_dataset("variant_mask_foobar", sites_mask, zarr_path) + tsinfer.SgkitSampleData(zarr_path) with pytest.raises( ValueError, match="Mask must be the same length as the number of unmasked sites", ): - tsinfer.SgkitSampleData(zarr_path) + tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_foobar") def test_bad_mask_length_at_iterator(self, tmp_path): ts, zarr_path = make_ts_and_zarr(tmp_path) @@ -491,8 +492,10 @@ def test_sgkit_sample_mask(self, tmp_path, sample_list): samples_mask = np.ones_like(ds["sample_id"], dtype=bool) for i in sample_list: samples_mask[i] = False - add_array_to_dataset("samples_mask", samples_mask, zarr_path) - samples = tsinfer.SgkitSampleData(zarr_path) + add_array_to_dataset("samples_mask_69", samples_mask, zarr_path) + samples = tsinfer.SgkitSampleData( + zarr_path, sgkit_samples_mask_name="samples_mask_69" + ) assert samples.ploidy == 3 assert samples.num_individuals == len(sample_list) assert samples.num_samples == len(sample_list) * samples.ploidy @@ -525,6 +528,24 @@ def test_sgkit_sample_mask(self, tmp_path, sample_list): assert id == i assert np.array_equal(haplo, expected_gt[:, i]) + def test_sgkit_missing_masks(self, tmp_path): + ts, zarr_path = make_ts_and_zarr(tmp_path) + samples = tsinfer.SgkitSampleData(zarr_path) + samples.individuals_mask + samples.sites_mask + with pytest.raises( + ValueError, match="The sites mask foobar was not found in the dataset." + ): + tsinfer.SgkitSampleData(zarr_path, sites_mask_name="foobar") + with pytest.raises( + ValueError, + match="The sgkit samples mask foobar2 was not found in the dataset.", + ): + samples = tsinfer.SgkitSampleData( + zarr_path, sgkit_samples_mask_name="foobar2" + ) + samples.individuals_mask + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows") def test_sgkit_ancestral_allele_same_ancestors(tmp_path): diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 7bd8a7c2..54250b94 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2295,9 +2295,11 @@ class SgkitSampleData(SampleData): FORMAT_NAME = "tsinfer-sgkit-sample-data" FORMAT_VERSION = (0, 1) - def __init__(self, path): + def __init__(self, path, sgkit_samples_mask_name=None, sites_mask_name=None): self.path = path self.data = zarr.open(path, mode="r") + self._sgkit_samples_mask_name = sgkit_samples_mask_name + self._sites_mask_name = sites_mask_name genotypes_arr = self.data["call_genotype"] _, self._num_unmasked_individuals, self.ploidy = genotypes_arr.shape self._num_sites = np.sum(self.sites_mask) @@ -2337,11 +2339,17 @@ def num_sites(self): @functools.cached_property def individuals_mask(self): - try: - # We negate the mask as it is much easier in numpy to have True=keep - return ~(self.data["samples_mask"][:].astype(bool)) - except KeyError: + if self._sgkit_samples_mask_name is None: return np.full(self._num_unmasked_individuals, True, dtype=bool) + else: + try: + # We negate the mask as it is much easier in numpy to have True=keep + return ~(self.data[self._sgkit_samples_mask_name][:].astype(bool)) + except KeyError: + raise ValueError( + f"The sgkit samples mask {self._sgkit_samples_mask_name} was not" + f" found in the dataset." + ) @functools.cached_property def samples_mask(self): @@ -2386,18 +2394,24 @@ def sites_alleles(self): @functools.cached_property def sites_mask(self): - try: - if ( - self.data["variant_mask"].shape[0] - != self.data["variant_position"].shape[0] - ): + if self._sites_mask_name is None: + return np.full(self.data["variant_position"].shape, True, dtype=bool) + else: + try: + if ( + self.data[self._sites_mask_name].shape[0] + != self.data["variant_position"].shape[0] + ): + raise ValueError( + "Mask must be the same length as the number of unmasked sites" + ) + # We negate the mask as it is much easier in numpy to have True=keep + return ~(self.data[self._sites_mask_name].astype(bool)[:]) + except KeyError: raise ValueError( - "Mask must be the same length as the number of unmasked sites" + f"The sites mask {self._sites_mask_name} was not found" + f" in the dataset." ) - # We negate the mask as it is much easier in numpy to have True=keep - return ~(self.data["variant_mask"].astype(bool)[:]) - except KeyError: - return np.full(self.data["variant_position"].shape, True, dtype=bool) @functools.cached_property def sites_ancestral_allele(self):