Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert merged lineage #2300

Merged
merged 3 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ clang-format:
clang-format -i lib/tests/* lib/*.[c,h]

tags:
ctags -f TAGS msprime/*.c lib/*.[c,h] msprime/*.py tests/*.py
ctags -f TAGS msprime/*.c lib/*.[c,h] msprime/*.py tests/*.py algorithms.py


clean:
Expand Down
152 changes: 84 additions & 68 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class Segment:
next: Segment = None # noqa: A003
lineage: Lineage = None

def __repr__(self):
def __str__(self):
return repr((self.left, self.right, self.node))

@staticmethod
Expand Down Expand Up @@ -1115,6 +1115,11 @@ def alloc_segment(

def alloc_lineage(self, head, population, *, label=0, tail=None):
lineage = Lineage(head, population=population, label=label, tail=tail)
assert tail is None
# If we're allocating a new lineage for a given head segment, then we
# have no choice but to iterate over the rest of the chain to update
# the lineage reference, and determine the tail. If head is None,
# this doesn't do anything.
lineage.reset_segments()
return lineage

Expand Down Expand Up @@ -1690,6 +1695,7 @@ def dtwf_climb_pedigree(self):
for ploid in range(ind.ploidy):
self.process_pedigree_common_ancestors(ind, ploid)

# TODO change to accept a lineage
def store_arg_edges(self, segment, u=-1):
if u == -1:
u = len(self.tables.nodes) - 1
Expand Down Expand Up @@ -1997,6 +2003,7 @@ def wiuf_gene_conversion_within_event(self, label):
elif head is not None:
new_individual_head = head
if new_individual_head is not None:
# FIXME when doing the smc_k update
lineage.reset_segments()
new_lineage = self.alloc_lineage(new_individual_head, pop)
if self.model == "smc_k":
Expand Down Expand Up @@ -2254,13 +2261,12 @@ def store_additional_nodes_edges(self, flag, new_node_id, z):
return new_node_id

def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
pop = self.P[pop_id]
defrag_required = False
coalescence = False
pass_through = len(H) == 1
alpha = None
z = None
new_lineage = None
new_lineage = self.alloc_lineage(None, pop_id, label=label)

while len(H) > 0:
alpha = None
left = H[0][0]
Expand Down Expand Up @@ -2321,10 +2327,15 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):

# loop tail; update alpha and integrate it into the state.
if alpha is not None:
if z is None:
new_lineage = self.alloc_lineage(alpha, pop_id)
pop.add(new_lineage, label)
alpha.lineage = new_lineage
alpha.prev = new_lineage.tail
self.set_segment_mass(alpha)
if new_lineage.head is None:
new_lineage.head = alpha
assert new_lineage.tail is None
else:
new_lineage.tail.next = alpha
z = new_lineage.tail
if (coalescence and not self.coalescing_segments_only) or (
self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0
):
Expand All @@ -2333,38 +2344,18 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
defrag_required |= (
z.right == alpha.left and z.node == alpha.node
)
z.next = alpha
alpha.prev = z
alpha.lineage = new_lineage
self.set_segment_mass(alpha)
z = alpha
if coalescence:
if not self.coalescing_segments_only:
self.store_arg_edges(z, new_node_id)
else:
if not pass_through:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
new_node_id = self.store_additional_nodes_edges(
msprime.NODE_IS_CA_EVENT, new_node_id, z
)
else:
if self.additional_nodes.value & msprime.NODE_IS_PASS_THROUGH > 0:
assert new_node_id != -1
assert self.model == "fixed_pedigree"
new_node_id = self.store_additional_nodes_edges(
msprime.NODE_IS_PASS_THROUGH, new_node_id, z
)

if defrag_required:
self.defrag_segment_chain(z)
if coalescence:
self.defrag_breakpoints()
if new_lineage is not None:
# FIXME do this more efficiently!
new_lineage.reset_segments()
return new_lineage
new_lineage.tail = alpha

return self.insert_merged_lineage(
new_lineage,
new_node_id,
coalescence=coalescence,
pass_through=pass_through,
defrag_required=defrag_required,
)

def defrag_segment_chain(self, z):
def defrag_segment_chain(self, lineage):
z = lineage.tail
y = z
while y.prev is not None:
x = y.prev
Expand All @@ -2374,6 +2365,9 @@ def defrag_segment_chain(self, z):
if y.next is not None:
y.next.prev = x
self.set_segment_mass(x)
if y == lineage.tail:
lineage.tail = x
assert y != lineage.head
self.free_segment(y)
y = x

Expand Down Expand Up @@ -2442,12 +2436,11 @@ def common_ancestor_event(self, population_index, label):
self.merge_two_ancestors(population_index, label, x, y)

def merge_two_ancestors(self, population_index, label, x, y, u=-1):
pop = self.P[population_index]
self.num_ca_events += 1
z = None
new_lineage = None
new_lineage = self.alloc_lineage(None, population_index, label=label)
coalescence = False
defrag_required = False

while x is not None or y is not None:
alpha = None
if x is None or y is None:
Expand Down Expand Up @@ -2476,8 +2469,7 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
if not coalescence:
coalescence = True
if u == -1:
self.store_node(population_index)
u = len(self.tables.nodes) - 1
u = self.store_node(population_index)
# Put in breakpoints for the outer edges of the coalesced
# segment
left = x.left
Expand All @@ -2501,7 +2493,6 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
left=left,
right=right,
node=u,
population=population_index,
)
if x.node != u: # required for dtwf and fixed_pedigree
self.store_edge(left, right, u, x.node)
Expand All @@ -2521,11 +2512,15 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):

# loop tail; update alpha and integrate it into the state.
if alpha is not None:
if z is None:
new_lineage = self.alloc_lineage(
alpha, population_index, label=label
)
alpha.lineage = new_lineage
alpha.prev = new_lineage.tail
self.set_segment_mass(alpha)
if new_lineage.head is None:
new_lineage.head = alpha
assert new_lineage.tail is None
else:
new_lineage.tail.next = alpha
z = new_lineage.tail
if (coalescence and not self.coalescing_segments_only) or (
self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0
):
Expand All @@ -2534,42 +2529,62 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
defrag_required |= (
z.right == alpha.left and z.node == alpha.node
)
z.next = alpha
alpha.prev = z
alpha.lineage = new_lineage
self.set_segment_mass(alpha)
z = alpha
new_lineage.tail = alpha

return self.insert_merged_lineage(
new_lineage, u, coalescence=coalescence, defrag_required=defrag_required
)

def insert_merged_lineage(
self, new_lineage, u, *, coalescence, defrag_required, pass_through=False
):
z = new_lineage.tail

if coalescence:
if not self.coalescing_segments_only:
self.store_arg_edges(z, u)
else:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
self.store_additional_nodes_edges(msprime.NODE_IS_CA_EVENT, u, z)
if not pass_through:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
u = self.store_additional_nodes_edges(
msprime.NODE_IS_CA_EVENT, u, z
)
else:
if self.additional_nodes.value & msprime.NODE_IS_PASS_THROUGH > 0:
assert u != -1
assert self.model == "fixed_pedigree"
u = self.store_additional_nodes_edges(
msprime.NODE_IS_PASS_THROUGH, u, z
)

if defrag_required:
self.defrag_segment_chain(z)
self.defrag_segment_chain(new_lineage)
if coalescence:
self.defrag_breakpoints()

if new_lineage is not None:
x = new_lineage.head
# TODO do this more efficiently
if new_lineage.head is not None:
# Use up any uncoalesced segments at the end of the chain
x = new_lineage.tail.next
while x is not None:
x.lineage = new_lineage
new_lineage.tail = x
x = x.next
# tail = new_lineage.tail
# new_lineage.reset_segments()
# assert tail == new_lineage.tail
self.add_lineage(new_lineage)

if new_lineage is not None and self.model == "smc_k":
merged_head = new_lineage.head
assert merged_head.prev is None
hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage)
while merged_head is not None:
right = merged_head.right
merged_head = merged_head.next
hull.right = min(right + self.hull_offset, self.L)
pop.add_hull(label, hull)
if self.model == "smc_k":
merged_head = new_lineage.head
assert merged_head.prev is None
hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage)
while merged_head is not None:
right = merged_head.right
merged_head = merged_head.next
hull.right = min(right + self.hull_offset, self.L)
pop = self.P[new_lineage.population]
pop.add_hull(new_lineage.label, hull)
return new_lineage

def print_state(self, verify=False):
print("State @ time ", self.t)
Expand Down Expand Up @@ -2640,6 +2655,7 @@ def verify_segments(self):
for pop_index, pop in enumerate(self.P):
for label in range(self.num_labels):
for lineage in pop.iter_label(label):
# print("LIN", lineage)
assert isinstance(lineage, Lineage)
assert lineage.label == label
assert lineage.population == pop_index
Expand Down
Loading
Loading