Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewire tsdate to allow nonfixed sample nodes #1

Open
awohns opened this issue Aug 18, 2022 · 10 comments
Open

Rewire tsdate to allow nonfixed sample nodes #1

awohns opened this issue Aug 18, 2022 · 10 comments

Comments

@awohns
Copy link
Collaborator

awohns commented Aug 18, 2022

Currently tsdate only allows sample nodes which have a known date. We want to rewire tsdate so sample nodes can have an unknown date, allowing for "molecular sampling"

@hyanwong
Copy link
Owner

We can pass the sample node to date as a "datable node" here:

https://github.com/tskit-dev/tsdate/blob/53add5684443388f06e743d39edf05a5e3f33149/tsdate/prior.py#L979

but we will need to specify a mean and variance for the distribution somehow, or modify the contents of the returned NodeGridValues object.

I think the best thing would be to set any nodes as "datable" if they have a non-zero variance (rather than if they are not samples). Then we simple need to figure out how to give samples a non-zero variance (we can take the mean for the prior as the time of the node in the tree sequence).

It looks like the base_priors.get_mixture_prior_params returns an array of alpha and beta params which are used in

https://github.com/tskit-dev/tsdate/blob/53add5684443388f06e743d39edf05a5e3f33149/tsdate/prior.py#L956

And which contains nan for fixed nodes. It think we should modify this to provide alpha and beta values which specify a mean of the sample time and a variance of 0 for fixed nodes. I think this is only possible for the lognormal distribution, however: there does not exist a gamma distribution that has an arbitrary mean but zero variance (although we could approximate it with a minuscule variance, e.g. 1e-20).

@hyanwong
Copy link
Owner

hyanwong commented Aug 19, 2022

I have made some progress with tskit-dev/tsdate@786cf12. Here is some code to test:

import tsinfer
import msprime
import tskit
import tsdate

import numpy as np
import matplotlib.pyplot as plt

Ne = 10000
samples = [
        msprime.SampleSet(2),
        msprime.SampleSet(1, time=100),
    ]
mutated_ts = msprime.sim_ancestry(
    samples=samples,
    population_size=Ne,
    sequence_length=2e4,
    recombination_rate=0, # For testing, just have a single tree
    random_seed=1,
)
mutated_ts = msprime.mutate(mutated_ts, rate=1e-8, random_seed=1)

def create_sampledata_with_individual_times(ts):
    """
    The tsinfer.SampleData.from_tree_sequence function doesn't allow different time
    units for sites and individuals. This function adds individual times by hand
    """
    # sampledata file with times-as-frequencies
    sd = tsinfer.SampleData.from_tree_sequence(ts)
    # Set individual times separately - warning: this mixes time units
    # so that sites have TIME_UNCALIBRATED but individuals have meaningful times
    individual_time = np.full(sd.num_individuals, -1)
    for sample, node_id in zip(sd.samples(), ts.samples()):
        if individual_time[sample.individual] >= 0:
            assert individual_time[sample.individual] == ts.node(node_id).time
        individual_time[sample.individual] = ts.node(node_id).time
    assert np.all(individual_time >= 0)
    sd = sd.copy()
    sd.individuals_time[:] = individual_time
    sd.finalise()
    return sd

def set_times_for_historical_samples(ts):
    """
    Use the times stored in the individuals metadata of an inferred tree sequence
    to constrain the times.
    """
    tables = ts.dump_tables()
    tables.individuals.metadata_schema = tskit.MetadataSchema.permissive_json()
    ts = tables.tree_sequence()
    times = np.zeros(ts.num_nodes)
    # set sample node times of historic samples
    for node_id in ts.samples():
        individual_id = ts.node(node_id).individual
        if individual_id != tskit.NULL:
            times[node_id] = ts.individual(individual_id).metadata.get("sample_data_time", 0)
    constrained_times = tsdate.core.constrain_ages_topo(ts, times, eps=1e-1)
    tables.nodes.time = constrained_times
    tables.mutations.time = np.full(ts.num_mutations, tskit.UNKNOWN_TIME)
    tables.sort()
    return tables.tree_sequence()

sampledata = create_sampledata_with_individual_times(mutated_ts)
inferred_ts = tsinfer.infer(sampledata)
inferred_ts_w_times = set_times_for_historical_samples(inferred_ts).simplify()

print(inferred_ts_w_times.node(5))

prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=True, node_var_override={5:1000})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8)

This fails when truncating priors, however:

/usr/local/lib/python3.9/site-packages/tsdate/prior.py in _truncate_priors(ts, priors, progress)
   1062     ):
   1063         if index + 1 != len(truncate_nodes):
-> 1064             children_index = np.arange(parent_indices[index], parent_indices[index + 1])
   1065         else:
   1066             children_index = np.arange(parent_indices[index], ts.num_edges)

IndexError: index 3 is out of bounds for axis 0 with size 3

I can't quite figure out the logic in that function. Perhaps @awohns can talk me through it and we can see what is not working. It should be perfectly possible to truncate on the basis of a few fixed sample nodes.

@hyanwong
Copy link
Owner

hyanwong commented Aug 19, 2022

We can test the pathway without truncation using the code above via

prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:10})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8)

With tskit-dev/tsdate@786cf12 this not complains about dangling nodes on the inside pass, which is correct, as the node corresponding to the sample-to-date will appear as if it is dangling.

tsdate/core.py in inside_pass(self, normalize, cache_inside, progress)
    682                         # Child appears fixed, or we have not visited it. Either our
    683                         # edge order is wrong (bug) or we have hit a dangling node
--> 684                         raise ValueError(
    685                             "The input tree sequence includes "
    686                             "dangling nodes: please simplify it"

ValueError: The input tree sequence includes dangling nodes: please simplify it

The inside[edge.child] array is full of np.nan in the case of an undated sample node. I presume that we simply need to fill it with the appropriate values from the prior?

The key line is here, where we fill the inside either with np.nan for a node of unknown date, or with the identity value (i.e. prob=1) for the fixed nodes:

        inside = self.priors.clone_with_new_data(  # store inside matrix values
            grid_data=np.nan, fixed_data=self.lik.identity_constant
        )

@hyanwong
Copy link
Owner

hyanwong commented Aug 19, 2022

Note that my changes simply create a lognormal distribution (with a user-specified variance) for the prior on an undated sample node. If a more complicated prior is needed, I guess it can be created by hand. We can show an example of this in the docs.

@hyanwong
Copy link
Owner

Wow, with tskit-dev/tsdate@979f55c it's almost working with the outside_maximization method. The only issue now is setting the times so that they are topologically constrained:

prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:1000})
dated_ts = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, method="maximization")
tsdate/core.py in constrain_ages_topo(ts, post_mn, eps, nodes_to_date, progress)
    943     ):
    944         if index + 1 != len(nodes_to_date):
--> 945             children_index = np.arange(parent_indices[index], parent_indices[index + 1])
    946         else:
    947             children_index = np.arange(parent_indices[index], ts.num_edges)

IndexError: index 3 is out of bounds for axis 0 with size 3

@hyanwong
Copy link
Owner

The only issue now is setting the times so that they are topologically constrained:

Fixed with tskit-dev/tsdate@da58644 and tskit-dev/tsdate@02a9b67

@hyanwong
Copy link
Owner

hyanwong commented Aug 20, 2022

The current PR tskit-dev/tsdate#214 works, but only with the outside maximisation method, which won't return posteriors.

Here's what we get when trying the inside-outside:

import tsinfer
import msprime
import tskit
import tsdate

import numpy as np

Ne = 10000
samples = [
        msprime.SampleSet(2),
        msprime.SampleSet(1, time=100),
    ]
mutated_ts = msprime.sim_ancestry(
    samples=samples,
    population_size=Ne,
    sequence_length=2e4,
    recombination_rate=0, # For testing, just have a single tree
    random_seed=1,
)
mutated_ts = msprime.mutate(mutated_ts, rate=1e-8, random_seed=1)

def create_sampledata_with_individual_times(ts):
    """
    The tsinfer.SampleData.from_tree_sequence function doesn't allow different time
    units for sites and individuals. This function adds individual times by hand
    """
    # sampledata file with times-as-frequencies
    sd = tsinfer.SampleData.from_tree_sequence(ts)
    # Set individual times separately - warning: this mixes time units
    # so that sites have TIME_UNCALIBRATED but individuals have meaningful times
    individual_time = np.full(sd.num_individuals, -1)
    for sample, node_id in zip(sd.samples(), ts.samples()):
        if individual_time[sample.individual] >= 0:
            assert individual_time[sample.individual] == ts.node(node_id).time
        individual_time[sample.individual] = ts.node(node_id).time
    assert np.all(individual_time >= 0)
    sd = sd.copy()
    sd.individuals_time[:] = individual_time
    sd.finalise()
    return sd

def set_times_for_historical_samples(ts):
    """
    Use the times stored in the individuals metadata of an inferred tree sequence
    to constrain the times.
    """
    tables = ts.dump_tables()
    tables.individuals.metadata_schema = tskit.MetadataSchema.permissive_json()
    ts = tables.tree_sequence()
    times = np.zeros(ts.num_nodes)
    # set sample node times of historic samples
    for node_id in ts.samples():
        individual_id = ts.node(node_id).individual
        if individual_id != tskit.NULL:
            times[node_id] = ts.individual(individual_id).metadata.get("sample_data_time", 0)
    # Just need to make the ts consistent
    constrained_times = tsdate.core.constrain_ages_topo(ts, times, eps=1e-1)
    tables.nodes.time = constrained_times
    tables.mutations.time = np.full(ts.num_mutations, tskit.UNKNOWN_TIME)
    tables.sort()
    return tables.tree_sequence()

sampledata = create_sampledata_with_individual_times(mutated_ts)
inferred_ts = tsinfer.infer(sampledata)
inferred_ts_w_times = set_times_for_historical_samples(inferred_ts).simplify()

prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:1000})
dated_ts, posteriors = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, method="maximization", return_posteriors=True)  # WORKS!
dated_ts, posteriors = tsdate.date(inferred_ts_w_times, priors=prior, mutation_rate=1e-8, return_posteriors=True)  # FAILS
/usr/local/lib/python3.9/site-packages/tsdate/core.py in outside_pass(self, normalize, ignore_oldest_root, progress, probability_space_returned)
    792 
    793             # vv[0] = 0  # Seems a hack: internal nodes should be allowed at time 0
--> 794             assert self.norm[edge.child] > self.lik.null_constant
    795             outside[child] = self.lik.reduce(val, self.norm[child])
    796             if normalize:

AssertionError: 

It's failing because self.norm[edge.child] is nan in this case. If we can fix this, I think we should have a working computational molecular dating method. Any ideas how to get the outside pass working @awohns ? Can we simply set the normalisation constant to 1 here?

@hyanwong
Copy link
Owner

Can we simply set the normalisation constant to 1 here?

tskit-dev/tsdate@974038d sets the normalization constant to unity for non fixed leaf nodes.

However, I'm having second thoughts about the sum_to_unity function. Since we have different width time bins, I suspect that we want the cumulative sum to be one, right? We can't simply sum up all the probabilities for the grid slices.

@hyanwong
Copy link
Owner

It's technically working but there's a bug, I think. I reckon the following should give a relatively flat prior for node 5:

import tsdate
import matplotlib.pyplot as plt
variance = 1e8  # a big number
prior = tsdate.build_prior_grid(inferred_ts_w_times, Ne=10000, allow_historical_samples=True, truncate_priors=False, node_var_override={5:variance})
prior.force_probability_space("linear")
print(prior[5])
plt.stairs(prior[5][:-1], prior.timepoints)

It doesn't for me. The variance logic must be wrong, I think.

@hyanwong
Copy link
Owner

I reckon the following should give a relatively flat prior for node 5:

Here's some code to discuss:

import scipy.stats
import numpy as np

def lognorm_approx(mean, var):
    """
    alpha is mean of underlying normal distribution
    beta is variance of underlying normal distribution
    """
    beta = np.log(var / (mean ** 2) + 1)
    alpha = np.log(mean) - 0.5 * beta
    return alpha, beta

def shape_scale_from_mean_var(mean, var):
            a, b = lognorm_approx(mean, var)
            return np.sqrt(b), np.exp(a)

timepoints = np.array(
    [    0.        ,   422.2655105 ,   596.44205246,   752.38907357,
         904.41565448,  1058.56279063,  1218.65683564,  1387.84069925,
        1569.19515032,  1766.1152649 ,  1982.64596477,  2223.88220297,
        2496.53641723,  2809.82966009,  3177.00084615,  3618.05796915,
        4165.24283193,  4875.1600594 ,  5860.22741522,  6725.01537605,
        7392.4550719 ,  8583.20477957, 10441.70306644, 11706.4971667 ,
       13503.29402783, 14535.1094196 , 15645.51778133, 17737.96380629,
       20558.5674799 , 23651.27728082, 27051.19988133, 29023.72411745,
       31155.88653047, 33497.40292709, 36116.46074264, 39112.19711119,
       42638.72817761, 46956.92775847, 52564.36877884, 60607.17771601,
       74895.98339996])
#timepoints = np.arange(16) * 5000
print(timepoints)

shape, scale = shape_scale_from_mean_var(10000, 1e8)
cdf_func = scipy.stats.lognorm.cdf
prior_node = cdf_func(timepoints, shape, scale=scale)
print("cdf", prior_node)
#prior_node = np.divide(prior_node, np.max(prior_node))
p = np.concatenate([np.array([0]), np.diff(prior_node)])
print("pdf (prior)", p)

import matplotlib.pyplot as plt
plt.stairs(p[:-1]/np.diff(timepoints), timepoints)
plt.xscale("log")

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants