From 865cc3cd3f5b1fdda7aaeb0c729b9d7b7cb825ba Mon Sep 17 00:00:00 2001 From: Don Freed Date: Wed, 20 Nov 2024 23:11:18 -0800 Subject: [PATCH] Add more flexible job scheduling --- sentieon_cli/command_strings.py | 19 +++ sentieon_cli/dag.py | 75 +++++++++ sentieon_cli/dnascope.py | 261 ++++++++++++++++++++++---------- sentieon_cli/exceptions.py | 5 + sentieon_cli/executor.py | 124 +++++++++++++++ sentieon_cli/job.py | 58 +++++++ sentieon_cli/scheduler.py | 39 +++++ 7 files changed, 504 insertions(+), 77 deletions(-) create mode 100644 sentieon_cli/dag.py create mode 100644 sentieon_cli/exceptions.py create mode 100644 sentieon_cli/executor.py create mode 100644 sentieon_cli/job.py create mode 100644 sentieon_cli/scheduler.py diff --git a/sentieon_cli/command_strings.py b/sentieon_cli/command_strings.py index 8c9e1cc..bf725b5 100644 --- a/sentieon_cli/command_strings.py +++ b/sentieon_cli/command_strings.py @@ -250,6 +250,25 @@ def get_rg_lines( return rg_lines +def rehead_wgsmetrics( + orig_metrics: pathlib.Path, tmp_metrics: pathlib.Path +) -> str: + """Rehead Sentieon WGS metrics so the file is recognized by MultiQC""" + cmd1 = shlex.join(["mv", str(orig_metrics), str(tmp_metrics)]) + cmd2 = ( + shlex.join(["echo", "'## METRICS CLASS WgsMetrics'"]) + + ">" + + shlex.quote(str(orig_metrics)) + ) + cmd3 = ( + shlex.join(["tail", "-n", "+2", str(tmp_metrics)]) + + ">>" + + shlex.quote(str(orig_metrics)) + ) + cmd4 = shlex.join(["rm", str(tmp_metrics)]) + return "; ".join((cmd1, cmd2, cmd3, cmd4)) + + def cmd_samtools_fastq_minimap2( out_aln: pathlib.Path, input_aln: pathlib.Path, diff --git a/sentieon_cli/dag.py b/sentieon_cli/dag.py new file mode 100644 index 0000000..37f1328 --- /dev/null +++ b/sentieon_cli/dag.py @@ -0,0 +1,75 @@ +""" +A directed acyclic graph of jobs to execute +""" + +from __future__ import annotations + +from typing import Dict, Generator, List, Optional, Set + +from .exceptions import DagExecutionError +from .job import Job +from .logging import get_logger + +logger = get_logger(__name__) + + +class DAG: + """A directed acyclic graph of jobs""" + + def __init__(self) -> None: + # waiting_jobs = {job: {dependencies}} + self.waiting_jobs: Dict[Job, Set[Job]] = {} + # map dependencies to waiting jobs + # planned_jobs = {dependency: [downstream_jobs]} + self.planned_jobs: Dict[Job, List[Job]] = {} + self.ready_jobs: Set[Job] = set() + self.finished_jobs: List[Job] = [] + + def add_job(self, job: Job, dependencies: Optional[Set[Job]] = None): + """Add a job to the DAG""" + if dependencies: + for dependency in dependencies: + assert ( + dependency in self.waiting_jobs + or dependency in self.ready_jobs + ) + + if isinstance(dependencies, set) and len(dependencies) > 0: + self.waiting_jobs[job] = dependencies + for dependency in dependencies: + downstream = self.planned_jobs.setdefault(dependency, []) + downstream.append(job) + else: + self.ready_jobs.add(job) + + def update_dag( + self, + ) -> Generator[Set[Job], Optional[Job], None]: + """Update the DAG with a finished job""" + finished_job = yield self.ready_jobs.copy() + + while True: + logger.debug("Waiting jobs: %s", self.waiting_jobs) + logger.debug("Planned jobs: %s", self.planned_jobs) + logger.debug("Ready jobs: %s", self.ready_jobs) + logger.debug("Finished jobs: %s", self.finished_jobs) + logger.debug("Newly finished: %s", finished_job) + + new_ready_jobs: set[Job] = set() + if isinstance(finished_job, Job): + if finished_job not in self.ready_jobs: + raise DagExecutionError( + f"Finished job '{finished_job}' was not ready for " + "execution" + ) + + self.ready_jobs.remove(finished_job) + self.finished_jobs.append(finished_job) + for dependency in self.planned_jobs.get(finished_job, []): + upstream = self.waiting_jobs[dependency] + upstream.remove(finished_job) + if len(upstream) < 1: + self.ready_jobs.add(dependency) + new_ready_jobs.add(dependency) + del self.waiting_jobs[dependency] + finished_job = yield new_ready_jobs diff --git a/sentieon_cli/dnascope.py b/sentieon_cli/dnascope.py index 325c2aa..21635a0 100644 --- a/sentieon_cli/dnascope.py +++ b/sentieon_cli/dnascope.py @@ -9,13 +9,14 @@ import shlex import shutil import sys -from typing import Any, Callable, List, Optional +from typing import Any, List, Optional, Set, Tuple import packaging.version -from argh import arg +from argh import arg, CommandError from . import command_strings as cmds +from .dag import DAG from .driver import ( AlignmentStat, BaseDistributionByCycle, @@ -34,7 +35,10 @@ SVSolver, WgsMetricsAlgo, ) +from .executor import DryRunExecutor, LocalExecutor +from .job import Job from .logging import get_logger +from .scheduler import ThreadScheduler from .util import ( __version__, check_version, @@ -65,7 +69,6 @@ def align_inputs( - run: Callable[[str], None], tmp_dir: pathlib.Path, output_vcf: pathlib.Path, reference: pathlib.Path, @@ -80,7 +83,7 @@ def align_inputs( util_sort_args: str = "--cram_write_options version=3.0,compressor=rans", input_ref: Optional[pathlib.Path] = None, **_kwargs: Any, -) -> List[pathlib.Path]: +) -> Tuple[List[pathlib.Path], Set[Job], Job]: """Align input BAM/CRAM/uBAM/uCRAM files with bwa""" if not skip_version_check: for cmd, min_version in ALN_MIN_VERSIONS.items(): @@ -89,6 +92,8 @@ def align_inputs( res: List[pathlib.Path] = [] suffix = "bam" if bam_format else "cram" + jobs = set() + align_outputs: List[pathlib.Path] = [] for i, input_aln in enumerate(sample_input): out_aln = pathlib.Path( str(output_vcf).replace(".vcf.gz", f"_bwa_sorted_{i}.{suffix}") @@ -98,8 +103,7 @@ def align_inputs( with open(rg_header, "w", encoding="utf-8") as rg_fh: for line in rg_lines: print(line, file=rg_fh) - - run( + job = Job( cmds.cmd_samtools_fastq_bwa( out_aln, input_aln, @@ -111,14 +115,29 @@ def align_inputs( collate=collate_align, bwa_args=bwa_args, util_sort_args=util_sort_args, - ) + ), + f"bam-align-{i}", + cores, ) res.append(out_aln) - return res + jobs.add(job) + align_outputs.append(out_aln) + align_outputs.append(pathlib.Path(str(out_aln) + ".bai")) + if not bam_format: + align_outputs.append(pathlib.Path(str(out_aln) + ".crai")) + + # Create an unscheduled job to remove the aligned inputs + rm_job = Job( + shlex.join(["rm"] + [str(x) for x in align_outputs]), + "rm-bam-aln", + 0, + True, + ) + + return (res, jobs, rm_job) def align_fastq( - run: Callable[[str], None], output_vcf: pathlib.Path, reference: pathlib.Path, model_bundle: pathlib.Path, @@ -131,11 +150,12 @@ def align_fastq( bwa_args: str = "-K 100000000", util_sort_args: str = "--cram_write_options version=3.0,compressor=rans", **_kwargs: Any, -) -> List[pathlib.Path]: +) -> Tuple[List[pathlib.Path], Set[Job], Optional[Job]]: """Align fastq files to the reference genome using bwa""" res: List[pathlib.Path] = [] + jobs: Set[Job] = set() if r1_fastq is None and readgroups is None: - return res + return (res, jobs, None) if (not r1_fastq or not readgroups) or (len(r1_fastq) != len(readgroups)): logger.error( "The number of readgroups does not equal the number of fastq files" @@ -157,13 +177,14 @@ def align_fastq( suffix = "bam" if bam_format else "cram" r2_fastq = [] if r2_fastq is None else r2_fastq + align_outputs: List[pathlib.Path] = [] for i, (r1, r2, rg) in enumerate( itertools.zip_longest(r1_fastq, r2_fastq, readgroups) ): out_aln = pathlib.Path( str(output_vcf).replace(".vcf.gz", f"_bwa_sorted_fq_{i}.{suffix}") ) - run( + job = Job( cmds.cmd_fastq_bwa( out_aln, r1, @@ -175,14 +196,29 @@ def align_fastq( unzip, bwa_args, util_sort_args, - ) + ), + f"bam-align-{i}", + cores, ) res.append(out_aln) - return res + jobs.add(job) + align_outputs.append(out_aln) + align_outputs.append(pathlib.Path(str(out_aln) + ".bai")) + if not bam_format: + align_outputs.append(pathlib.Path(str(out_aln) + ".crai")) + + # Create an unscheduled job to remove the aligned inputs + rm_job = Job( + shlex.join(["rm"] + [str(x) for x in align_outputs]), + "rm-fq-aln", + 0, + True, + ) + + return (res, jobs, rm_job) def dedup_and_metrics( - run: Callable[[str], None], output_vcf: pathlib.Path, reference: pathlib.Path, sample_input: List[pathlib.Path], @@ -195,7 +231,9 @@ def dedup_and_metrics( bam_format: bool = False, cram_write_options: str = "version=3.0,compressor=rans", **_kwargs: Any, -) -> List[pathlib.Path]: +) -> Tuple[ + List[pathlib.Path], Job, Optional[Job], Optional[Job], Optional[Job] +]: """Perform dedup and metrics collection""" suffix = "bam" if bam_format else "cram" @@ -244,7 +282,7 @@ def dedup_and_metrics( ) # Prefer to run InsertSizeMetricAlgo after duplicate marking - if assay == "WES" and not bed: + if (assay == "WES" and not bed) or duplicate_marking == "none": driver.add_algo(InsertSizeMetricAlgo(is_metrics)) driver.add_algo(MeanQualityByCycle(mqbc_metrics)) @@ -254,10 +292,10 @@ def dedup_and_metrics( if assay == "WGS": driver.add_algo(GCBias(gc_metrics, summary=gc_summary)) - run(shlex.join(driver.build_cmd())) + lc_job = Job(shlex.join(driver.build_cmd()), "locuscollector", cores) if duplicate_marking == "none": - return sample_input + return (sample_input, lc_job, None, None, None) # Dedup deduped = pathlib.Path( @@ -278,9 +316,11 @@ def dedup_and_metrics( rmdup=(duplicate_marking == "rmdup"), ) ) - run(shlex.join(driver.build_cmd())) + dedup_job = Job(shlex.join(driver.build_cmd()), "dedup", cores) # Run HsMetricAlgo after duplicate marking to account for duplicate reads + metrics_job = None + rehead_job = None driver = Driver( reference=reference, thread_count=cores, @@ -290,35 +330,33 @@ def dedup_and_metrics( if assay == "WES" and bed: driver.add_algo(HsMetricAlgo(hs_metrics, bed, bed)) driver.add_algo(InsertSizeMetricAlgo(is_metrics)) - run(shlex.join(driver.build_cmd())) + metrics_job = Job( + shlex.join(driver.build_cmd()), "metrics", 0 + ) # Run metrics in the background # Run WgsMetricsAlgo after duplicate marking to account for duplicate reads if assay == "WGS": driver.add_algo(InsertSizeMetricAlgo(is_metrics)) driver.add_algo(WgsMetricsAlgo(wgs_metrics, include_unpaired="true")) driver.add_algo(CoverageMetrics(coverage_metrics)) - run(shlex.join(driver.build_cmd())) + metrics_job = Job( + shlex.join(driver.build_cmd()), "metrics", 0 + ) # Run metrics in the background # Rehead WGS metrics so they are recognized by MultiQC - wgs_metrics.rename(wgs_metrics_tmp) - with open(wgs_metrics, "w", encoding="utf-8") as fho, open( - wgs_metrics_tmp, encoding="utf-8" - ) as fhi: - print("## METRICS CLASS WgsMetrics", file=fho) - _ = fhi.readline() # remove the Sentieon header - for line in fhi: - line = line.rstrip() - print(line, file=fho) - wgs_metrics_tmp.unlink() - return [deduped] + rehead_job = Job( + cmds.rehead_wgsmetrics(wgs_metrics, wgs_metrics_tmp), + "Rehead metrics", + 0, + ) + return ([deduped], lc_job, dedup_job, metrics_job, rehead_job) def multiqc( - run: Callable[[str], None], output_vcf: pathlib.Path, skip_version_check: bool = False, **_kwargs: Any, -) -> int: +) -> Optional[Job]: """Run MultiQC on the metrics files""" if not skip_version_check: @@ -332,21 +370,22 @@ def multiqc( "Skipping MultiQC. MultiQC version %s or later not found", MULTIQC_MIN_VERSION["multiqc"], ) - return 1 + return None metrics_dir = pathlib.Path(str(output_vcf).replace(".vcf.gz", "_metrics")) - run( + multiqc_job = Job( cmds.cmd_multiqc( metrics_dir, metrics_dir, f"Generated by the Sentieon-CLI version {__version__}", - ) + ), + "multiqc", + 0, ) - return 0 + return multiqc_job def call_variants( - run: Callable[[str], None], output_vcf: pathlib.Path, reference: pathlib.Path, deduped: List[pathlib.Path], @@ -359,9 +398,8 @@ def call_variants( gvcf: bool = False, skip_svs: bool = False, skip_version_check: bool = False, - dry_run: bool = False, **_kwargs: Any, -) -> int: +) -> Tuple[Job, Job, Job, Optional[Job], Optional[Job], Optional[Job]]: """Call SNVs, indels, and SVs using DNAscope""" if not skip_version_check: for cmd, min_version in VARIANTS_MIN_VERSIONS.items(): @@ -411,7 +449,7 @@ def call_variants( var_type="BND", ) ) - run(shlex.join(driver.build_cmd())) + call_job = Job(shlex.join(driver.build_cmd()), "variant-calling", cores) # Genotyping and filtering with DNAModelApply driver = Driver( @@ -426,15 +464,14 @@ def call_variants( ds_out, ) ) - run(shlex.join(driver.build_cmd())) + apply_job = Job(shlex.join(driver.build_cmd()), "model-apply", cores) # Remove the tmp_vcf - tmp_vcf_idx = pathlib.Path(str(tmp_vcf) + ".tbi") - if not dry_run: - tmp_vcf_idx.unlink(missing_ok=True) - tmp_vcf.unlink() + rm_cmd = ["rm", str(tmp_vcf), str(tmp_vcf) + ".tbi"] + rm_job = Job(shlex.join(rm_cmd), "rm-tmp-vcf", 0, True) # Genotype gVCFs + gvcftyper_job = None if gvcf: driver = Driver( reference=reference, @@ -447,9 +484,11 @@ def call_variants( vcf=out_gvcf, ) ) - run(shlex.join(driver.build_cmd())) + gvcftyper_job = Job(shlex.join(driver.build_cmd()), "gvcftyper", cores) # Call SVs + svsolver_job = None + sv_rm_job = None if not skip_svs: driver = Driver( reference=reference, @@ -462,13 +501,22 @@ def call_variants( vcf=out_svs_tmp, ) ) - run(shlex.join(driver.build_cmd())) - out_svs_tmp_idx = pathlib.Path(str(out_svs_tmp) + ".tbi") - if not dry_run: - out_svs_tmp.unlink() - out_svs_tmp_idx.unlink(missing_ok=True) + svsolver_job = Job(shlex.join(driver.build_cmd()), "svsolver") + sv_rm_job = Job( + shlex.join(["rm", str(out_svs_tmp), str(out_svs_tmp) + ".tbi"]), + "rm-tmp-sv", + 0, + True, + ) - return 0 + return ( + call_job, + apply_job, + rm_job, + gvcftyper_job, + svsolver_job, + sv_rm_job, + ) @arg( @@ -628,7 +676,7 @@ def dnascope( dbsnp: Optional[pathlib.Path] = None, # pylint: disable=W0613 bed: Optional[pathlib.Path] = None, interval_padding: int = 0, # pylint: disable=W0613 - cores: int = mp.cpu_count(), # pylint: disable=W0613 + cores: int = mp.cpu_count(), pcr_free: bool = False, # pylint: disable=W0613 gvcf: bool = False, # pylint: disable=W0613 duplicate_marking: str = "markdup", @@ -681,33 +729,92 @@ def dnascope( tmp_dir_str = tmp() tmp_dir = pathlib.Path(tmp_dir_str) # type: ignore # pylint: disable=W0641 # noqa: E501 - if dry_run: - run = print # type: ignore # pylint: disable=W0641 - else: - from .runner import run # type: ignore[assignment] # noqa: F401 + logger.info("Building the DAG") + dag = DAG() + align_jobs: Set[Job] = set() sample_input = sample_input if sample_input else [] + bam_rm_job = None if align or collate_align: - sample_input = align_inputs(**locals()) - sample_input.extend(align_fastq(**locals())) - - deduped = dedup_and_metrics(**locals()) # pylint: disable=W0641 - - # Remove the bwa output before duplicate marking - if duplicate_marking != "none" and not dry_run: - for aln in sample_input: - for idx_suffix in (".bai", ".crai"): - idx = pathlib.Path(str(aln) + idx_suffix) - idx.unlink(missing_ok=True) - aln.unlink() + sample_input, align_jobs, bam_rm_job = align_inputs(**locals()) + for job in align_jobs: + dag.add_job(job) + aligned_fastq, align_fastq_jobs, fq_rm_job = align_fastq(**locals()) + for job in align_fastq_jobs: + dag.add_job(job) + sample_input.extend(aligned_fastq) + + deduped, lc_job, dedup_job, metrics_job, rehead_job = dedup_and_metrics( + **locals() + ) # pylint: disable=W0641 + dag.add_job(lc_job, align_jobs.union(align_fastq_jobs)) + if dedup_job: + dag.add_job(dedup_job, {lc_job}) + if metrics_job: + dag.add_job(metrics_job, {dedup_job}) + if rehead_job: + dag.add_job(rehead_job, {metrics_job}) + if bam_rm_job: + dag.add_job(bam_rm_job, {dedup_job}) + if fq_rm_job: + dag.add_job(fq_rm_job, {dedup_job}) if not skip_small_variants: - res = call_variants(**locals()) - if res != 0: - logger.error("Variant calling failed") - return + ( + call_job, + apply_job, + rm_job, + gvcftyper_job, + svsolver_job, + sv_rm_job, + ) = call_variants(**locals()) + call_dependencies: Set[Job] = set() + if dedup_job: + call_dependencies.add(dedup_job) + else: + call_dependencies.update(align_jobs) + call_dependencies.update(align_fastq_jobs) + dag.add_job(call_job, call_dependencies) + dag.add_job(apply_job, {call_job}) + dag.add_job(rm_job, {apply_job}) + if gvcftyper_job: + dag.add_job(gvcftyper_job, {apply_job}) + if svsolver_job and sv_rm_job: + dag.add_job(svsolver_job, {call_job}) + dag.add_job(sv_rm_job, {svsolver_job}) if not skip_multiqc: - _res = multiqc(**locals()) + multiqc_job = multiqc(**locals()) + multiqc_dependencies: Set[Job] = set() + multiqc_dependencies.add(lc_job) + if metrics_job: + multiqc_dependencies.add(metrics_job) + if rehead_job: + multiqc_dependencies.add(rehead_job) + + if multiqc_job: + dag.add_job(multiqc_job, multiqc_dependencies) + + logger.debug("Creating the scheduler") + scheduler = ThreadScheduler( + dag, + cores, + ) + + logger.debug("Creating the executor") + Executor = DryRunExecutor if dry_run else LocalExecutor + executor = Executor(scheduler) + logger.info("Starting execution") + executor.execute() shutil.rmtree(tmp_dir_str) + + if executor.jobs_with_errors: + raise CommandError("Execution failed") + + if len(dag.waiting_jobs) > 0 or len(dag.ready_jobs) > 0: + raise CommandError( + "The DAG has some unexecuted jobs\n" + f"Waiting jobs: {dag.waiting_jobs}\n" + f"Ready jobs: {dag.ready_jobs}\n" + ) diff --git a/sentieon_cli/exceptions.py b/sentieon_cli/exceptions.py new file mode 100644 index 0000000..ebf0976 --- /dev/null +++ b/sentieon_cli/exceptions.py @@ -0,0 +1,5 @@ +"""Exceptions raised by the program""" + + +class DagExecutionError(Exception): + """An error class during DAG execution""" diff --git a/sentieon_cli/executor.py b/sentieon_cli/executor.py new file mode 100644 index 0000000..160e831 --- /dev/null +++ b/sentieon_cli/executor.py @@ -0,0 +1,124 @@ +"""Execute jobs""" + +import asyncio +import asyncio.subprocess +import sys +from typing import Any, List, Tuple + +from .job import Job +from .logging import get_logger +from .scheduler import ThreadScheduler + +logger = get_logger(__name__) + + +class BaseExecutor: + """Execute jobs""" + + def __init__(self, scheduler: ThreadScheduler): + self.scheduler = scheduler + self.jobs_with_errors: List[Job] = [] + + def execute(self) -> None: + """Execute jobs from the DAG""" + raise NotImplementedError + + +class DryRunExecutor(BaseExecutor): + """Dry-run execution""" + + def run_job(self, job: Job) -> None: + """Dry-run a job""" + print(job.shell) + + def execute(self) -> None: + scheduler_gen = self.scheduler.schedule() + ready_jobs = scheduler_gen.send(None) + for job in ready_jobs: + self.run_job(job) + + while ready_jobs: + finished_jobs = ready_jobs.copy() + ready_jobs = { + new_job + for completed_job in finished_jobs + for new_job in scheduler_gen.send(completed_job) + } + for job in ready_jobs: + self.run_job(job) + + +class LocalExecutor(BaseExecutor): + """Run jobs locally""" + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.running: List[ + Tuple[ + Job, + asyncio.subprocess.Process, + asyncio.Task[int], + ] + ] = [] + + async def run_job(self, job: Job) -> None: + """Run a job""" + cmd = job.shell + logger.info("Running: %s", cmd) + proc = await asyncio.create_subprocess_shell( + cmd, + stdout=sys.stdout, + stderr=sys.stderr, + executable="/bin/bash", + ) + self.running.append( + ( + job, + proc, + asyncio.create_task(proc.wait()), + ) + ) + + def execute(self) -> None: + """Execute jobs from the DAG""" + asyncio.run(self._execute()) + + async def _execute(self) -> None: + """Execute jobs from the DAG""" + self.jobs_with_errors: List[Job] = [] + scheduler_gen = self.scheduler.schedule() + ready_jobs = scheduler_gen.send(None) + for job in ready_jobs: + await self.run_job(job) + + while self.running: + done, _running = await asyncio.wait( + [job[2] for job in self.running], + return_when=asyncio.FIRST_COMPLETED, + ) + + finished_jobs = [ + self.running.pop(i) + for i in reversed(range(len(self.running))) + if self.running[i][2] in done + ] + + # Check job execution + for job, proc, _task in finished_jobs: + if proc.returncode != 0 and not job.fail_ok: + logger.error("Error running command, '%s'", job.shell) + self.jobs_with_errors.append(job) + + if self.jobs_with_errors: + # Don't start new jobs + continue + + ready_jobs = { + new_job + for completed_job in finished_jobs + for new_job in scheduler_gen.send(completed_job[0]) + } + + # Run the ready jobs + for job in ready_jobs: + await self.run_job(job) diff --git a/sentieon_cli/job.py b/sentieon_cli/job.py new file mode 100644 index 0000000..d7e2e8e --- /dev/null +++ b/sentieon_cli/job.py @@ -0,0 +1,58 @@ +""" +Job objects +""" + +import subprocess as sp +import sys +import time + +from .logging import get_logger + +logger = get_logger(__name__) + + +class Job: + """A job for execution""" + + def __init__( + self, shell: str, name: str, threads: int = 1, fail_ok: bool = False + ): + self.shell = shell + self.name = name + self.threads = threads + self.fail_ok = fail_ok + + def __hash__(self): + return hash(self.shell) + + def __eq__(self, other: object): + if isinstance(other, Job): + return self.shell == other.shell + return False + + def __ne__(self, other: object): + return not self == other + + def __repr__(self): + return f"Job({self.name})" + + def __str__(self): + return f"Job({self.name})" + + def run(self, dry_run: bool = False): + """Run a command""" + if dry_run: + print(self.shell) + return + + logger.info("running command: %s", self.shell) + t0 = time.time() + sp.run( + self.shell, + shell=True, + check=False, + stdout=sys.stdout, + stderr=sys.stderr, + executable="/bin/bash", + ) + logger.info("finished in: %s seconds", f"{time.time() - t0:.1f}") diff --git a/sentieon_cli/scheduler.py b/sentieon_cli/scheduler.py new file mode 100644 index 0000000..e8f377f --- /dev/null +++ b/sentieon_cli/scheduler.py @@ -0,0 +1,39 @@ +"""Schedule jobs""" + +from typing import Generator, Optional, Set + +from .dag import DAG +from .job import Job +from .logging import get_logger + +logger = get_logger(__name__) + + +class ThreadScheduler: + """Schedule jobs as threads are available""" + + def __init__(self, dag: DAG, threads: int = 1): + self.dag = dag + self.threads = threads + self.available_threads = threads + + def schedule(self) -> Generator[Set[Job], Optional[Job], None]: + """Schedule a job for execution""" + dag_gen = self.dag.update_dag() + ready_jobs = dag_gen.send(None) + + while True: + logger.debug("Ready jobs: %s", ready_jobs) + + scheduled_jobs: Set[Job] = set() + for ready_job in ready_jobs: + if self.available_threads - ready_job.threads >= 0: + self.available_threads -= ready_job.threads + scheduled_jobs.add(ready_job) + + ready_jobs -= scheduled_jobs + + finished_job = yield scheduled_jobs + if isinstance(finished_job, Job): + self.available_threads += finished_job.threads + ready_jobs.update(dag_gen.send(finished_job))