Skip to content

Commit

Permalink
Use blocks to interate haplotypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Jeffery authored and benjeffery committed Apr 30, 2024
1 parent 2b37081 commit 5938d9a
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 34 deletions.
5 changes: 4 additions & 1 deletion tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def make_ts_and_zarr(path, add_optional=False, shuffle_alleles=True):
sgkit.io.vcf.vcf_to_zarr(
path / "data.vcf",
path / "data.zarr",
chunk_length=10,
chunk_width=12,
ploidy=3,
max_alt_alleles=4, # tests tsinfer's ability to handle empty string alleles
)
Expand Down Expand Up @@ -442,7 +444,8 @@ def make_materialized_and_masked_sampledata(tmp_path, tmpdir):

# Create a new sgkit dataset with the subset baked in
mat_ds = ds.isel(variants=~variant_mask, samples=~samples_mask)
sgkit.save_dataset(mat_ds, tmpdir / "subset.zarr")
mat_ds = mat_ds.unify_chunks()
sgkit.save_dataset(mat_ds, tmpdir / "subset.zarr", auto_rechunk=True)

mat_sd = tsinfer.SgkitSampleData(tmpdir / "subset.zarr")
mask_sd = tsinfer.SgkitSampleData(
Expand Down
130 changes: 97 additions & 33 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,59 @@ def __init__(self, path, sgkit_samples_mask_name=None, sites_mask_name=None):
" sgkit dataset, indicating that all the genotypes are"
" unphased"
)
# Create zarr arrays for convenience when iterating over chunks
self.z_sites_mask = zarr.array(
self.sites_mask, chunks=self.data["call_genotype"].chunks[0], dtype=bool
)
self.z_individuals_mask = zarr.array(
self.individuals_mask,
chunks=self.data["call_genotype"].chunks[1],
dtype=bool,
)

# Find the first chunk from the left and right that contains an unmasked site
self.sites_first_chunk = None
self.sites_last_chunk = 0
for sites_chunk in range(self.z_sites_mask.cdata_shape[0]):
if np.sum(self.z_sites_mask.blocks[sites_chunk]) > 0:
if self.sites_first_chunk is None:
self.sites_first_chunk = sites_chunk
self.sites_last_chunk = sites_chunk + 1
self.sites_first_chunk = (
0 if self.sites_first_chunk is None else self.sites_first_chunk
)

# Same for individuals
self.individuals_first_chunk = None
self.individuals_last_chunk = 0
for individuals_chunk in range(self.z_individuals_mask.cdata_shape[0]):
if np.sum(self.z_individuals_mask.blocks[individuals_chunk]) > 0:
if self.individuals_first_chunk is None:
self.individuals_first_chunk = individuals_chunk
self.individuals_last_chunk = individuals_chunk + 1

logging.info(f"Number of sites after applying mask: {self.num_sites}")
logging.info(
f"Sites chunk range: {self.sites_first_chunk} - {self.sites_last_chunk}"
f"of {self.z_sites_mask.cdata_shape[0]}"
)
logging.info(
f"Number of individuals after applying mask: {self.num_individuals}"
)
logging.info(
f"Individuals chunk range: {self.individuals_first_chunk} - "
f"{self.individuals_last_chunk} of {self.z_individuals_mask.chunks[0]}"
)
haplo_mem_use = (
(self.sites_last_chunk - self.sites_first_chunk)
* self.z_sites_mask.chunks[0]
* self.z_individuals_mask.chunks[0]
* self.ploidy
)
logging.info(
f"Memory use for loading haplotypes (match_samples): "
f"{humanize.naturalsize(haplo_mem_use)}bytes"
)

@functools.cached_property
def format_name(self):
Expand Down Expand Up @@ -2661,39 +2714,50 @@ def _all_haplotypes(self, sites=None, recode_ancestral=None, samples_slice=None)
samples_slice = (0, self.num_samples)
if samples_slice[0] % self.ploidy != 0 or samples_slice[1] % self.ploidy != 0:
raise ValueError("Samples slice must be a multiple of ploidy")
start, stop = samples_slice
indiv_start = start // self.ploidy
indiv_stop = (stop + self.ploidy - 1) // self.ploidy # round up the division
if recode_ancestral is None:
recode_ancestral = False
aa_index = self.sites_ancestral_allele[:]
# If ancestral allele is missing, keep the order unchanged (aa_index of zero)
aa_index[aa_index == MISSING_DATA] = 0
gt = self.data["call_genotype"]
chunk_size = gt.chunks[1]
current_chunk = None
for i, j in zip(
range(indiv_start, indiv_stop),
np.where(self.individuals_mask)[0][indiv_start:indiv_stop],
):
if j // chunk_size != current_chunk:
current_chunk = j // chunk_size
chunk = gt[:, j // chunk_size : (j // chunk_size) + chunk_size, :]
# Zarr doesn't support fancy indexing, so we have to do this after
chunk = chunk[self.sites_mask]
indiv_gt = chunk[:, j % chunk_size, :]
for k in range(self.ploidy):
a = indiv_gt[:, k].T
if recode_ancestral:
# Remap the genotypes at all sites, depending on the aa_index
a = np.where(
a == aa_index,
0,
np.where(
np.logical_and(a != MISSING_DATA, a < aa_index), a + 1, a
),
)
yield (i * self.ploidy) + k, a if sites is None else a[sites]
# Make an individuals mask that respects the samples slice
ind_indexes = np.where(self.individuals_mask)[0]
ind_mask = zarr.zeros(
self.individuals_mask.shape,
chunks=self.data["call_genotype"].chunks[1],
dtype=bool,
)
ind_mask[
ind_indexes[
samples_slice[0] // self.ploidy : samples_slice[1] // self.ploidy
]
] = True

sample_index = samples_slice[0]
for ind_chunk in range(0, self.data["call_genotype"].cdata_shape[1]):
ind_mask_chunk = ind_mask.blocks[ind_chunk]
if np.sum(ind_mask_chunk) == 0:
continue
gt_chunk = self.data["call_genotype"].blocks[
self.sites_first_chunk : self.sites_last_chunk, ind_chunk, :
]
gt_chunk = gt_chunk[
self.z_sites_mask.blocks[self.sites_first_chunk : self.sites_last_chunk]
]
gt_chunk = gt_chunk[:, ind_mask_chunk]
for s in range(gt_chunk.shape[1]):
for p in range(self.ploidy):
a = gt_chunk[:, s, p]
if recode_ancestral:
# Remap the genotypes at all sites, depending on the aa_index
a = np.where(
a == self.sites_ancestral_allele,
0,
np.where(
np.logical_and(
a != MISSING_DATA, a < self.sites_ancestral_allele
),
a + 1,
a,
),
)
yield sample_index, a if sites is None else a[sites]
sample_index += 1
assert sample_index == samples_slice[1]


@attr.s(order=False, eq=False)
Expand Down

0 comments on commit 5938d9a

Please sign in to comment.