Skip to content

Commit

Permalink
Properly treat blank ancestral allele, and set "N" as the default "un…
Browse files Browse the repository at this point in the history
…known" state

Also document the class
  • Loading branch information
hyanwong committed Sep 9, 2024
1 parent 1d04fb8 commit 15c500a
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 79 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## [0.4.0a3] - ****-**-**

**Fixes**

- Properly account for "N" as an unknown ancestral state, and ban "" from being
set as an ancestral state ({pr}`963`, {user}`hyanwong`))

## [0.4.0a2] - 2024-09-06

2nd Alpha release of tsinfer 0.4.0
Expand Down
24 changes: 12 additions & 12 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ for sample in range(ds['call_genotype'].shape[1]):

We wish to infer a genealogy that could have given rise to this data set. To run _tsinfer_
we wrap the .vcz file in a `tsinfer.VariantData` object. This requires an
*ancestral allele* to be specified for each site; there are
*ancestral state* to be specified for each site; there are
many methods for calculating these: details are outside the scope of this manual, but we
have started a [discussion topic](https://github.com/tskit-dev/tsinfer/discussions/523)
on this issue to provide some recommendations.
Expand All @@ -76,11 +76,11 @@ and not used for inference (with a warning given).
import tsinfer
# For this example take the REF allele (index 0) as ancestral
ancestral_allele = ds['variant_allele'][:,0].astype(str)
ancestral_state = ds['variant_allele'][:,0].astype(str)
# This is just a numpy array, set the last site to an unknown value, for demo purposes
ancestral_allele[-1] = "."
ancestral_state[-1] = "."
vdata = tsinfer.VariantData("_static/example_data.vcz", ancestral_allele)
vdata = tsinfer.VariantData("_static/example_data.vcz", ancestral_state)
```

The `VariantData` object is a lightweight wrapper around the .vcz file.
Expand Down Expand Up @@ -127,7 +127,7 @@ site_mask[ds.variant_position[:] >= 6] = True
smaller_vdata = tsinfer.VariantData(
"_static/example_data.vcz",
ancestral_allele=ancestral_allele[site_mask == False],
ancestral_state=ancestral_state[site_mask == False],
site_mask=site_mask,
)
print(f"The `smaller_vdata` object returns data for only {smaller_vdata.num_sites} sites")
Expand Down Expand Up @@ -351,8 +351,8 @@ Once we have our `.vcz` file created, running the inference is straightforward.

```{code-cell} ipython3
# Infer & save a ts from the notebook simulation.
ancestral_alleles = np.load(f"{name}-AA.npy")
vdata = tsinfer.VariantData(f"{name}.vcz", ancestral_alleles)
ancestral_states = np.load(f"{name}-AA.npy")
vdata = tsinfer.VariantData(f"{name}.vcz", ancestral_states)
tsinfer.infer(vdata, progress_monitor=True, num_threads=4).dump(name + ".trees")
```

Expand Down Expand Up @@ -477,12 +477,12 @@ vcf_location = "_static/P_dom_chr24_phased.vcf.gz"
```

This creates the `sparrows.vcz` datastore, which we open using `tsinfer.VariantData`.
The original VCF had ancestral alleles specified in the `AA` INFO field, so we can
simply provide the string `"variant_AA"` as the ancestral_allele parameter.
The original VCF had the ancestral allelic state specified in the `AA` INFO field,
so we can simply provide the string `"variant_AA"` as the ancestral_state parameter.

```{code-cell} ipython3
# Do the inference: this VCF has ancestral alleles in the AA field
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_allele="variant_AA")
# Do the inference: this VCF has ancestral states in the AA field
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_state="variant_AA")
ts = tsinfer.infer(vdata)
print(
"Inferred tree sequence: {} trees over {} Mb ({} edges)".format(
Expand Down Expand Up @@ -534,7 +534,7 @@ Now when we carry out the inference, we get a tree sequence in which the nodes a
correctly assigned to named populations

```{code-cell} ipython3
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_allele="variant_AA")
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_state="variant_AA")
sparrow_ts = tsinfer.infer(vdata)
for sample_node_id in sparrow_ts.samples():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
mat_wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working_mat",
sample_data_path=mat_sd.path,
ancestral_allele="variant_ancestral_allele",
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "mat_anc.trees",
min_work_per_job=1,
max_num_partitions=10,
Expand All @@ -1547,7 +1547,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
mask_wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working_mask",
sample_data_path=mask_sd.path,
ancestral_allele="variant_ancestral_allele",
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "mask_anc.trees",
min_work_per_job=1,
max_num_partitions=10,
Expand Down
117 changes: 100 additions & 17 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
Tests for the data files.
"""
import json
import logging
import sys
import tempfile
import warnings

import msprime
import numcodecs
Expand Down Expand Up @@ -627,14 +629,12 @@ def test_missing_ancestral_allele(tmp_path):


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
def test_ancestral_missingness(tmp_path):
def test_deliberate_ancestral_missingness(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
ancestral_allele = ds.variant_ancestral_allele.values
ancestral_allele[0] = "N"
ancestral_allele[11] = "-"
ancestral_allele[12] = "💩"
ancestral_allele[15] = "💩"
ancestral_allele[1] = "n"
ds = ds.drop_vars(["variant_ancestral_allele"])
sgkit.save_dataset(ds, str(zarr_path) + ".tmp")
tsutil.add_array_to_dataset(
Expand All @@ -644,15 +644,56 @@ def test_ancestral_missingness(tmp_path):
["variants"],
)
ds = sgkit.load_dataset(str(zarr_path) + ".tmp")
with warnings.catch_warnings():
warnings.simplefilter("error") # No warning raised if AA deliberately missing
sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")
inf_ts = tsinfer.infer(sd)
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
if i in [0, 1]:
assert inf_var.site.metadata == {"inference_type": "parsimony"}
else:
assert inf_var.site.ancestral_state == var.site.ancestral_state


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
def test_ancestral_missing_warning(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
anc_state = ds.variant_ancestral_allele.values
anc_state[0] = "N"
anc_state[11] = "-"
anc_state[12] = "💩"
anc_state[15] = "💩"
with pytest.warns(
UserWarning,
match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2",
):
sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")
inf_ts = tsinfer.infer(sd)
vdata = tsinfer.VariantData(zarr_path, anc_state)
inf_ts = tsinfer.infer(vdata)
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
if i in [0, 11, 12, 15]:
assert inf_var.site.metadata == {"inference_type": "parsimony"}
assert inf_var.site.ancestral_state in var.site.alleles
else:
assert inf_var.site.ancestral_state == var.site.ancestral_state


def test_ancestral_missing_info(tmp_path, caplog):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
anc_state = ds.variant_ancestral_allele.values
anc_state[0] = "N"
anc_state[11] = "N"
anc_state[12] = "n"
anc_state[15] = "n"
with caplog.at_level(logging.INFO):
vdata = tsinfer.VariantData(zarr_path, anc_state)
assert f"4 sites ({4/ts.num_sites * 100 :.2f}%) were deliberately " in caplog.text
inf_ts = tsinfer.infer(vdata)
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
if i in [0, 11, 12, 15]:
assert inf_var.site.metadata == {"inference_type": "parsimony"}
assert inf_var.site.ancestral_state in var.site.alleles
else:
assert inf_var.site.ancestral_state == var.site.ancestral_state

Expand All @@ -670,6 +711,25 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):


class TestVariantDataErrors:
@staticmethod
def simulate_genotype_call_dataset(*args, **kwargs):
# roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
# duplicate alleles are created. Doesn't need to be efficient: just for testing
if "seed" not in kwargs:
kwargs["seed"] = 123
ds = sgkit.simulate_genotype_call_dataset(*args, **kwargs)
variant_alleles = ds["variant_allele"].values
allowed_alleles = np.array(
["A", "T", "C", "G", "N"], dtype=variant_alleles.dtype
)
for row in range(len(variant_alleles)):
alleles = variant_alleles[row]
if len(set(alleles)) != len(alleles):
# Just use a set that we know is unique
variant_alleles[row] = allowed_alleles[0 : len(alleles)]
ds["variant_allele"] = ds["variant_allele"].dims, variant_alleles
return ds

def test_bad_zarr_spec(self):
ds = zarr.group()
ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
Expand All @@ -680,7 +740,7 @@ def test_bad_zarr_spec(self):

def test_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
sgkit.save_dataset(ds, path)
with pytest.raises(
ValueError, match="The call_genotype_phased array is missing"
Expand All @@ -689,7 +749,7 @@ def test_missing_phase(self, tmp_path):

def test_phased(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
ds["call_genotype_phased"] = (
ds["call_genotype"].dims,
np.ones(ds["call_genotype"].shape, dtype=bool),
Expand All @@ -700,13 +760,13 @@ def test_phased(self, tmp_path):
def test_ploidy1_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
# Ploidy==1 is always ok
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
sgkit.save_dataset(ds, path)
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))

def test_ploidy1_unphased(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["call_genotype_phased"] = (
ds["call_genotype"].dims,
np.zeros(ds["call_genotype"].shape, dtype=bool),
Expand All @@ -716,31 +776,54 @@ def test_ploidy1_unphased(self, tmp_path):

def test_duplicate_positions(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["variant_position"][2] = ds["variant_position"][1]
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_bad_order_positions(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["variant_position"][0] = ds["variant_position"][2] - 0.5
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_bad_ancestral_state(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
ancestral_state[1] = ""
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="cannot contain empty strings"):
tsinfer.VariantData(path, ancestral_state)

def test_empty_alleles_not_at_end(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_allele"] = (
ds["variant_allele"].dims,
np.array([["A", "", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
)
sgkit.save_dataset(ds, path)
with pytest.raises(
ValueError, match='Bad alleles: fill value "" in middle of list'
):
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))

def test_unique_alleles(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_allele"] = (
ds["variant_allele"].dims,
np.array([["", "A", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
np.array([["A", "C", "T"], ["A", "C", ""], ["A", "A", ""]], dtype="S1"),
)
sgkit.save_dataset(ds, path)
vdata = tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
tsinfer.infer(vdata)
with pytest.raises(
ValueError, match="Duplicate allele values provided at site 2"
):
tsinfer.VariantData(path, np.array(["A", "A", "A"], dtype="S1"))

def test_unimplemented_from_tree_sequence(self):
# NB we should reimplement something like this functionality.
Expand Down
Loading

0 comments on commit 15c500a

Please sign in to comment.