Skip to content

Commit

Permalink
Use new algo for truncating priors
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Aug 20, 2022
1 parent da58644 commit 02a9b67
Showing 1 changed file with 24 additions and 28 deletions.
52 changes: 24 additions & 28 deletions tsdate/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,14 +1027,10 @@ def shape_scale_from_mean_var(mean, var):

def _truncate_priors(ts, priors, progress=False):
"""
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
sequence
Truncate priors for all nonfixed nodes
so they conform to the age of fixed nodes in the tree sequence
"""
tables = ts.tables
truncate_nodes = priors.nonfixed_node_ids()
# ensure truncate_nodes is ordered by node time
truncate_nodes = truncate_nodes[np.argsort(tables.nodes.time[truncate_nodes])]

fixed_nodes = priors.fixed_node_ids()
fixed_times = tables.nodes.time[fixed_nodes]
Expand All @@ -1050,29 +1046,29 @@ def _truncate_priors(ts, priors, progress=False):
constrained_min_times = np.zeros_like(tables.nodes.time)
# Set the min times of fixed nodes to those in the tree sequence
constrained_min_times[fixed_nodes] = fixed_times
constrained_max_times = np.full_like(constrained_min_times, np.inf)

parents = tables.edges.parent
nd_children = tables.edges.child[np.argsort(parents)]
parents = sorted(parents)
parents_unique = np.unique(parents, return_index=True)
parent_indices = parents_unique[1][np.isin(parents_unique[0], truncate_nodes)]
for index, nd in tqdm(
enumerate(truncate_nodes), desc="Constrain Ages", disable=not progress

# Traverse through the ARG, ensuring children come before parents.
# This can be done by iterating over groups of edges with the same parent
new_parent_edge_idx = np.concatenate(
(
[0],
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
[tables.edges.num_rows],
)
)
for edges_start, edges_end in zip(
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
):
if index + 1 != len(truncate_nodes):
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
else:
children_index = np.arange(parent_indices[index], ts.num_edges)
children = nd_children[children_index]
time = np.max(constrained_min_times[children])
# The constrained time of the node should be the age of the oldest child
if constrained_min_times[nd] <= time:
constrained_min_times[nd] = time
nearest_time = np.argmin(np.abs(timepoints - time))
lookup_index = priors.row_lookup[int(nd)]
grid_data[lookup_index][:nearest_time] = zero_value
assert np.all(constrained_min_times < constrained_max_times)
parent = tables.edges.parent[edges_start]
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
oldest_child_time = np.max(constrained_min_times[child_ids])
if oldest_child_time > constrained_min_times[parent]:
constrained_min_times[parent] = oldest_child_time
if constrained_min_times[parent] > 0:
# What if the parent here is a fixed node?
nearest_time = np.argmin(np.abs(timepoints - constrained_min_times[parent]))
lookup_index = priors.row_lookup[parent]
grid_data[lookup_index][:nearest_time] = zero_value

rowmax = grid_data[:, 1:].max(axis=1)
if priors.probability_space == "linear":
Expand Down

0 comments on commit 02a9b67

Please sign in to comment.