Skip to content

Commit

Permalink
Test out example, improve process failure detection
Browse files Browse the repository at this point in the history
  • Loading branch information
wasade committed May 29, 2024
1 parent e933955 commit b18bfc8
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 102 deletions.
2 changes: 1 addition & 1 deletion examples/mx-bt2-dx.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ mxdx mux \
bowtie2 \
-p ${bt2_cores} \
-x ${DBIDX} \
-f - \
-q - \
--no-head \
--no-unal | \
mxdx demux \
Expand Down
6 changes: 3 additions & 3 deletions examples/submit.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash

files=../usage-test-files.tsv
dbidx=/path/to/bt2-db
dbidx=/path/to/db
batchsize=15
batchmax=$(mxdx get-number-of-batches --file-map ${files})
batchmax=$(mxdx get-max-batch-number --file-map ${files} --batch-size ${batchsize})
output_base=$(pwd)
ext=fna.gz
ext=sam.xz

j=$(sbatch \
--parsable \
Expand Down
4 changes: 3 additions & 1 deletion mxdx/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
SEQUENTIAL = 'sequential'
MERGE = 'merge'
SEPARATE = 'separate'
THREAD_COMPLETE = 'thread-complete'
READ_COMPLETE = 'read-complete'
R1 = 'r1'
R2 = 'r2'
PARTIAL = 'dx-partial'
DATA = 'data'
PATH = 'path'
ERROR = 'error'
COMPLETE = 'complete'
66 changes: 18 additions & 48 deletions mxdx/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import lzma
import bz2
import mimetypes
from itertools import chain
from math import ceil

import polars as pl

Expand Down Expand Up @@ -116,6 +116,10 @@ def is_paired(self):
def cumsum(self):
return list(self._df[self._record_cumsum])

@property
def number_of_batches(self):
return ceil(self._df[self._record_count].sum() / self._batch_size)

def _init(self):
df = self._df.with_row_index(name=self._row_index, offset=1)
df = df.with_columns(pl.col(self._record_count)
Expand Down Expand Up @@ -272,48 +276,6 @@ class ParseError(Exception):
pass


class ChainIO:
def __init__(self, a, b):
self._stream_a = a
self._stream_b = b
self._current = self._stream_a

def read(self, n=None):
buf = self._current.read(n)

if n is None:
if self._current is self._stream_a:
self._current = self._stream_b
buf2 = self._current.read()
return buf + buf2
else:
return buf
elif len(buf) < n:
if self._current is self._stream_a:
self._current = self._stream_b
remainder = n - len(buf)
buf2 = self._current.read(remainder)
return buf + buf2
else:
return buf
else:
return buf

def readline(self):
buf = self._current.readline()
if not buf and self._current is self._stream_a:
self._current = self._stream_b
buf = self._current.readline()
return buf

def __iter__(self):
while True:
line = self.readline()
if not line:
break
yield line


class IO:
@classmethod
def valid_interleave(cls):
Expand All @@ -322,13 +284,21 @@ def valid_interleave(cls):
@staticmethod
def io_from_stream(stream, n_lines=4):
"""Sniff a datastream which cannot be seeked."""
lines = ''.join([stream.readline() for i in range(n_lines)])
buf = io.StringIO(lines)
if hasattr(stream, 'buffer'):
peek = stream.buffer.peek(1024)
elif hasattr(stream, 'peek'):
peek = stream.peek(1024)
else:
peek = stream.read(1024)
stream.seek(0)

if isinstance(peek, bytes):
peek = peek.decode('utf-8')

buf = io.StringIO(peek)
read_f, write_f = IO.sniff(buf)
buf.seek(0)

reset_stream = ChainIO(buf, stream)
return reset_stream, read_f, write_f
return stream, read_f, write_f

@staticmethod
def io_from_mx(mxfile):
Expand Down
94 changes: 82 additions & 12 deletions mxdx/_mxdx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import io
import glob
import re
import time

from ._io import IO, MuxFile
from ._constants import (INTERLEAVE, R1ONLY, R2ONLY, SEQUENTIAL,
THREAD_COMPLETE, R1, R2, MERGE, SEQUENTIAL,
SEPARATE, PARTIAL, PATH, DATA)
READ_COMPLETE, R1, R2, MERGE, SEQUENTIAL,
SEPARATE, PARTIAL, PATH, DATA, ERROR, COMPLETE)


class Multiplex:
Expand Down Expand Up @@ -87,7 +88,7 @@ def read(self):
self.buffered_queue.put(rec.tag(tag))

# signal that we are done reading
self.buffered_queue.put(THREAD_COMPLETE)
self._read_complete()

def write(self):
"""Write records from a queue to an output."""
Expand All @@ -107,27 +108,62 @@ def write(self):
recs = self.buffered_queue.get()

# if we are complete then terminate gracefully
if recs == THREAD_COMPLETE:
if recs == READ_COMPLETE:
break

# otherwise, write each record
for rec in recs:
output.write(rec.write())
try:
output.write(rec.write())
except BrokenPipeError:
# something bad happened downstream
self._terminate()
break

if self._output != '-':
output.close()

self._write_complete()

def _terminate(self):
self.msg_queue.put(ERROR)

def _write_complete(self):
self.msg_queue.put(COMPLETE)

def _read_complete(self):
self.buffered_queue.put(READ_COMPLETE)

def start(self):
"""Start the Multiplexing."""
ctx = mp.get_context('spawn')
self.buffered_queue = BufferedQueue(ctx)
self.msg_queue = ctx.Queue()

reader = ctx.Process(target=self.read)
writer = ctx.Process(target=self.write)

reader.start()
writer.start()

# monitor reader and writer externally and kill if requested
while True:
time.sleep(0.1)
if self.msg_queue.empty():
continue

msg = self.msg_queue.get()

if msg == ERROR:
print(f"Error received in {self.__class__}; terminating",
file=sys.stderr, flush=True)
reader.terminate()
writer.terminate()
break
elif msg == COMPLETE:
break
time.sleep(0.1)

reader.join()
writer.join()

Expand Down Expand Up @@ -160,9 +196,9 @@ def _place_buf(self):
def put(self, item):
# if we receive a thread complete signal, make sure we drain our
# buffer before passing on the completion message
if item == THREAD_COMPLETE:
if item == READ_COMPLETE:
self._place_buf()
self._queue.put(THREAD_COMPLETE)
self._queue.put(READ_COMPLETE)
else:
self._buf.append(item)

Expand Down Expand Up @@ -206,13 +242,30 @@ def start(self):
"""Start the Demultiplexing."""
ctx = mp.get_context('spawn')
self.buffered_queue = BufferedQueue(ctx)
self.msg_queue = ctx.Queue()

reader = ctx.Process(target=self.read)
writer = ctx.Process(target=self.write)

reader.start()
writer.start()

while True:
time.sleep(0.1)
if self.msg_queue.empty():
continue

msg = self.msg_queue.get()

if msg == ERROR:
print(f"Error received in {self.__class__}; terminating",
file=sys.stderr, flush=True)
reader.terminate()
writer.terminate()
break
elif msg == COMPLETE:
break

reader.join()
writer.join()

Expand All @@ -228,13 +281,28 @@ def read(self):
else:
mux_input = open(self._mux_input, 'rt')

sniffed, read_f, _ = IO.io_from_stream(mux_input)
try:
sniffed, read_f, _ = IO.io_from_stream(mux_input)
except StopIteration:
# stream is empty, upstream program likely bailed
self._terminate()
return

mux_input = sniffed

for rec in read_f(mux_input):
self.buffered_queue.put(rec)

self.buffered_queue.put(THREAD_COMPLETE)
self._read_complete()

def _terminate(self):
self.msg_queue.put(ERROR)

def _read_complete(self):
self.buffered_queue.put(READ_COMPLETE)

def _write_complete(self):
self.msg_queue.put(COMPLETE)

@lru_cache()
def _get_output_path(self, mx, orientation):
Expand Down Expand Up @@ -278,7 +346,7 @@ def write(self):
recs = self.buffered_queue.get()

# if we are complete then terminate gracefully
if recs == THREAD_COMPLETE:
if recs == READ_COMPLETE:
break

# otherwise, write each record
Expand All @@ -287,6 +355,8 @@ def write(self):
mx = self._tag_lookup.get(tag, default)
self._write_rec(mx, rec)

self._write_complete()


class Consolidate:
BUFSIZE = 1024 * 1024 # 1MB
Expand Down Expand Up @@ -338,7 +408,7 @@ def read(self):
for block in self._bulk_read(self._open_f(fp, 'rb')):
self.queue.put((DATA, block))

self.queue.put(THREAD_COMPLETE)
self.queue.put(READ_COMPLETE)

def write(self):
current_handle = None
Expand All @@ -347,7 +417,7 @@ def write(self):
msg = self.queue.get()

# if we are complete then terminate gracefully
if msg == THREAD_COMPLETE:
if msg == READ_COMPLETE:
break
else:
dtype, data = msg
Expand Down
16 changes: 16 additions & 0 deletions mxdx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,21 @@ def consolidate_partials(output_base, extension):
cx.start()


@cli.command()
@click.option('--file-map', type=click.Path(exists=True), required=True,
help="Files with record counts for processing")
@click.option('--batch-size', type=int, required=True,
help="Number of records per batch")
@click.option('--is-one-based', is_flag=True, default=False,
help="Whether indexing is zero or one based")
def get_max_batch_number(file_map, batch_size, is_one_based):
file_map = FileMap.from_tsv(file_map, batch_size)
num_batches = file_map.number_of_batches
if is_one_based:
click.echo(num_batches)
else:
click.echo(num_batches - 1)


if __name__ == '__main__':
cli()
Loading

0 comments on commit b18bfc8

Please sign in to comment.