Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch job ancestor matching. #921

Open
benjeffery opened this issue May 22, 2024 · 3 comments
Open

Batch job ancestor matching. #921

benjeffery opened this issue May 22, 2024 · 3 comments

Comments

@benjeffery
Copy link
Member

After some hairy snakemake deliberation here is the strawman pipeline for ancestor matching.

It requires the following tsinfer methods:
match_ancestors_batch_init - creates a folder with metadata
match_ancestors_batch_group - matches a group locally, writes ts to folder
match_ancestors_batch_group_init - creates a folder and writes metadata on partitions for a group
match_ancestors_batch_group_partition - matches a partition of ancestors for a group
match_ancestors_batch_group_finalise - uses the partitions to write a ts for the group
match_ancestors_batch_finalise - writes final ts.

@jeromekelleher

from pathlib import Path
import json

# The number of ancestors needed in a group to trigger partitioning
BATCH_THRESHOLD = 100
# The number of groups to process in one job when doing local matching
NUM_GROUPS_ONE_JOB = 10

rule all:
    input: 'ancestors.ts'

checkpoint match_ancestors_init:
    input: 'ancestors.zarr'
    output: 'match_wip/metadata.json'
    run:
        #tsinfer.match_ancestors_batch_init(input[0], working_dir="match_wip"))
        
        #Dump some dummy groupings
        groupings = []
        for i, ancestors in enumerate(range(0,1000,10)):
            groupings.append(list(range(ancestors, ancestors+10)))
        for i, ancestors in enumerate(range(1000,5000,1000)):
            groupings.append(list(range(ancestors, ancestors+1000)))
        md = {'ancestor_groups':groupings}
        with open(output[0], 'w') as f:
            json.dump(md, f)


# Load ancestor groupings from metadata for snakemake
def ancestor_groupings(wildcards):
    checkpoint_output = checkpoints.match_ancestors_init.get(**wildcards)
    with open('match_wip/metadata.json') as f:
        md = json.load(f)
    return md['ancestor_groups']

# Load the number of partitions for a group
def num_partitions(wildcards):
    checkpoint_output = checkpoints.match_ancestors_large_group_init.get(**wildcards)
    with open(f'match_wip/batch_{wildcards.group}_wip/metadata.json') as f:
        metadata = json.load(f)
    return metadata["num_partitions"]

# This function decides if a group should be processed in a single job or partitioned
def match_ancestor_group_input(wildcards):
    groupings = ancestor_groupings(wildcards)
    group_index = int(wildcards.group)
    # If the group is a large one then the inputs will be the partitions
    if len(groupings[group_index]) > BATCH_THRESHOLD:
        return expand(
            'match_wip/batch_{group}_wip/partition-{partition}.json',
             partition=range(num_partitions(wildcards)), allow_missing=True
            )
    # This group is small enough to do locally
    # search back until we find a group that is large enough to require partitioning, or we reach the start, or we have enough groups
    for i in range(group_index, max(group_index-NUM_GROUPS_ONE_JOB, 0), -1):
        if len(groupings[i]) > BATCH_THRESHOLD:
            return 'match_wip/ancestors_{i}.ts'
    if group_index-NUM_GROUPS_ONE_JOB > 0:
        return f'match_wip/ancestors_{group_index-NUM_GROUPS_ONE_JOB}.ts'
    else:
        return 'match_wip/metadata.json'

rule match_ancestors_group:
    input:
        match_ancestor_group_input
    output:
        'match_wip/ancestors_{group}.ts'
    run:
        # Use the input to determine if we are processing a set of groups or finalising
        if "partition" in input[0]:
            #tsinfer match_ancestors_batch_group_finalise("match_wip", group=group)
            print(f"Finalise group {wildcards.group}")
        else:
            output_group = int(wildcards.group)
            print(input[0])
            if "metadata" in input[0]:
                input_group = -1
            else:
                input_group = int(re.match(r'match_wip/ancestors_(\d+).ts', input[0]).group(1))
            for group in range(input_group+1, output_group+1):
                #tsinfer.match_ancestors_batch_group("match_wip", group=group)
                print(f"Local Match group {group}")
        Path(output[0]).touch()


checkpoint match_ancestors_large_group_init:
    input: 
        lambda wildcards: f'match_wip/ancestors_{int(wildcards.group)-1}.ts'
    output:
        'match_wip/batch_{group}_wip/metadata.json'
    run:
        #tsinfer.match_ancestors_batch_group_init("match_wip", group=wildcards.group)
        print(f"Init large group {wildcards.group}")
        # Write some dummy data
        with open(f'match_wip/batch_{wildcards.group}_wip/metadata.json', 'w') as f:
            json.dump({"num_partitions": 10}, f)

rule match_ancestors_large_group_partition:
    input: 
        'match_wip/batch_{group}_wip/metadata.json'
    output:
        'match_wip/batch_{group}_wip/partition-{partition}.json'
    run:
        #tsinfer.match_ancestors_batch_group_partition("match_wip", group=wildcards.group, partition=wildcards.partition)
        print(f"Match group {wildcards.group} partition {wildcards.partition}")
        Path(output[0]).touch()


def last_ancestor_group(wildcards):
    groupings = ancestor_groupings(wildcards)
    return len(groupings)-1
    
rule match_ancestors_final:
    input:
        lambda wildcards: f'match_wip/ancestors_{last_ancestor_group(wildcards)}.ts'
    output:
        'ancestors.ts'
    run:
        #tsinfer.match_ancestors_batch_finalise("match_wip")
        print("Finalise")
        Path(output[0]).touch()
@jeromekelleher
Copy link
Member

Can't say I follow all the snakemaking here, but it looks quite logical and the tsinfer operations should be simple enough.

I guess it's worth stress testing this with something large to make sure snakemake can handle the type of task graph we'll be creating?

@benjeffery
Copy link
Member Author

Have been playing around with this - with a large amount of partitioned groups snakemake takes a long time rebuilding the DAG between groups. I'm copying the code over to GeL to do some testing on a real ancestor grouping and to see if the DAG rebuilding is problematic with the file system.

@benjeffery
Copy link
Member Author

Ah, looks like I can refactor to avoid the checkpoints on init, but it means moving all the decision making about the number of partitions to the initial match_ancestors_batch_init and removing the match_ancestors_batch_group_init call completely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants