Skip to content

Commit

Permalink
Use newer sgkit/zarr calling conventions
Browse files Browse the repository at this point in the history
This removes a couple of deprecation warnings
  • Loading branch information
hyanwong authored and mergify[bot] committed Sep 5, 2024
1 parent 89b4ca0 commit f9de549
Showing 1 changed file with 9 additions and 25 deletions.
34 changes: 9 additions & 25 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def test_sgkit_dataset_roundtrip(tmp_path):
inf_ts = tsinfer.infer(samples)
ds = sgkit.load_dataset(zarr_path)

assert ts.num_individuals == inf_ts.num_individuals == ds.dims["samples"]
assert ts.num_individuals == inf_ts.num_individuals == ds.sizes["samples"]
for ts_ind, sample_id in zip(inf_ts.individuals(), ds["sample_id"].values):
assert ts_ind.metadata["variant_data_sample_id"] == sample_id

assert (
ts.num_samples == inf_ts.num_samples == ds.dims["samples"] * ds.dims["ploidy"]
ts.num_samples == inf_ts.num_samples == ds.sizes["samples"] * ds.sizes["ploidy"]
)
assert ts.num_sites == inf_ts.num_sites == ds.dims["variants"]
assert ts.num_sites == inf_ts.num_sites == ds.sizes["variants"]
assert ts.sequence_length == inf_ts.sequence_length == ds.attrs["contig_lengths"][0]
for (
v,
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path):
inf_ts = tsinfer.infer(samples)
ds = sgkit.load_dataset(zarr_path)

assert ts.num_individuals == inf_ts.num_individuals == ds.dims["samples"]
assert ts.num_individuals == inf_ts.num_individuals == ds.sizes["samples"]
for i, (ts_ind, sample_id) in enumerate(
zip(inf_ts.individuals(), ds["sample_id"].values)
):
Expand Down Expand Up @@ -694,23 +694,15 @@ def test_phased(self, tmp_path):
ds["call_genotype"].dims,
np.ones(ds["call_genotype"].shape, dtype=bool),
)
ds["variant_ancestral_allele"] = (
ds["variant_position"].dims,
np.array(["A", "C", "G"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
tsinfer.VariantData(path, "variant_ancestral_allele")
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))

def test_ploidy1_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
# Ploidy==1 is always ok
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_ancestral_allele"] = (
ds["variant_position"].dims,
np.array(["A", "C", "G"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
tsinfer.VariantData(path, "variant_ancestral_allele")
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))

def test_ploidy1_unphased(self, tmp_path):
path = tmp_path / "data.zarr"
Expand All @@ -719,12 +711,8 @@ def test_ploidy1_unphased(self, tmp_path):
ds["call_genotype"].dims,
np.zeros(ds["call_genotype"].shape, dtype=bool),
)
ds["variant_ancestral_allele"] = (
ds["variant_position"].dims,
np.array(["A", "C", "G"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
tsinfer.VariantData(path, "variant_ancestral_allele")
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))

def test_duplicate_positions(self, tmp_path):
path = tmp_path / "data.zarr"
Expand All @@ -749,14 +737,10 @@ def test_empty_alleles_not_at_end(self, tmp_path):
ds["variant_allele"].dims,
np.array([["", "A", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
)
ds["variant_ancestral_allele"] = (
["variants"],
np.array(["C", "A", "A"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
samples = tsinfer.VariantData(path, "variant_ancestral_allele")
vdata = tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
tsinfer.infer(samples)
tsinfer.infer(vdata)

def test_unimplemented_from_tree_sequence(self):
# NB we should reimplement something like this functionality.
Expand Down

0 comments on commit f9de549

Please sign in to comment.