diff --git a/emtools/__init__.py b/emtools/__init__.py index aed35e0..d4692f3 100644 --- a/emtools/__init__.py +++ b/emtools/__init__.py @@ -24,5 +24,5 @@ # * # ************************************************************************** -__version__ = '0.1.0' +__version__ = '0.1.1' diff --git a/emtools/jobs/__init__.py b/emtools/jobs/__init__.py index fff9a6d..d09e3e8 100644 --- a/emtools/jobs/__init__.py +++ b/emtools/jobs/__init__.py @@ -15,5 +15,6 @@ # ************************************************************************** from .pipeline import Pipeline +from .batch_manager import BatchManager -__all__ = ["Pipeline"] \ No newline at end of file +__all__ = ["Pipeline", "BatchManager"] \ No newline at end of file diff --git a/emtools/jobs/__main__.py b/emtools/jobs/__main__.py deleted file mode 100755 index e944a47..0000000 --- a/emtools/jobs/__main__.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python -# ************************************************************************** -# * -# * Authors: J.M. de la Rosa Trevin (delarosatrevin@gmail.com) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 3 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# ************************************************************************** - -from .motioncor import Main as mc_Main -import argparse - - -parser = argparse.ArgumentParser(prog='emtools.processing') -subparsers = parser.add_subparsers( - help="Utils' command (motioncor)", - dest='command') - -mc_Main.add_arguments(subparsers.add_parser("motioncor")) -parser.add_argument('--verbose', '-v', action='count') -args = parser.parse_args() -cmd = args.command - -if cmd == "motioncor": - mc_Main.run(args) -elif cmd: - raise Exception(f"Unknown option '{cmd}'") -else: - parser.print_help() diff --git a/emtools/jobs/batch_manager.py b/emtools/jobs/batch_manager.py new file mode 100644 index 0000000..706777d --- /dev/null +++ b/emtools/jobs/batch_manager.py @@ -0,0 +1,87 @@ +# ************************************************************************** +# * +# * Authors: J.M. de la Rosa Trevin (delarosatrevin@gmail.com) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# ************************************************************************** + +import os +from uuid import uuid4 +from datetime import datetime + +from emtools.utils import Process + + +class BatchManager: + """ Class used to generate and handle the creation of batches + from an input stream of items. + + This is used for streaming/parallel processing. Batches will have a folder + and a filename is extracted from each item and linked into the batch + folder. + """ + def __init__(self, batchSize, inputItemsIterator, workingPath, + itemFileNameFunc=lambda item: item.getFileName()): + """ + Args: + batchSize: Number of items that will be grouped into one batch + inputItemsIterator: input items iterator + workingPath: path where the batches folder will be created + itemFileNameFunc: function to extract a filename from each item + (by default: lambda item: item.getFileName()) + """ + self._items = inputItemsIterator + self._batchSize = batchSize + self._batchCount = 0 + self._workingPath = workingPath + self._itemFileNameFunc = itemFileNameFunc + + def _createBatchId(self): + # We will use batchCount, before the batch is created + nowPrefix = datetime.now().strftime('%y%m%d-%H%M%S') + countStr = '%02d' % (self._batchCount + 1) + uuidSuffix = str(uuid4()).split('-')[0] + return f"{nowPrefix}_{countStr}_{uuidSuffix}" + + def _createBatch(self, items): + batch_id = self._createBatchId() + batch_path = os.path.join(self._workingPath, batch_id) + print(f"Creating batch: {batch_path}") + Process.system(f"rm -rf '{batch_path}'") + Process.system(f"mkdir '{batch_path}'") + + for item in items: + fn = self._itemFileNameFunc(item) + baseName = os.path.basename(fn) + os.symlink(os.path.abspath(fn), + os.path.join(batch_path, baseName)) + self._batchCount += 1 + return { + 'items': items, + 'id': batch_id, + 'path': batch_path, + 'index': self._batchCount + } + + def generate(self): + """ Generate batches based on the input items. """ + items = [] + + for item in self._items: + items.append(item) + + if len(items) == self._batchSize: + yield self._createBatch(items) + items = [] + + if items: + yield self._createBatch(items) diff --git a/emtools/jobs/motioncor.py b/emtools/jobs/motioncor.py deleted file mode 100755 index 938cc53..0000000 --- a/emtools/jobs/motioncor.py +++ /dev/null @@ -1,191 +0,0 @@ -#!/usr/bin/env python - -import os -import argparse -import threading -from glob import glob -from uuid import uuid4 - -from emtools.utils import Timer, Pipeline -from emtools.metadata import StarFile - - -class McPipeline(Pipeline): - """ Pipeline specific to Motioncor processing. """ - def __init__(self, generateInputFunc, gpuList, outputDir, threads=1, **kwargs): - Pipeline.__init__(self, **kwargs) - self.mc = { - 'program': os.environ['MC_PROGRAM'], - 'args': os.environ['MC_ARGS'], - 'gain': os.environ['MC_GAIN'] - } - self.run_id = f'R-{uuid4()}' - self.outputDir = outputDir - self.scratchDir = kwargs.get('scratchDir', self.outputDir) - self.workingDir = os.path.join(self.scratchDir, self.run_id) - - g = self.addGenerator(generateInputFunc) - f1 = self.addProcessor(g.outputQueue, self.convert) - if threads > 1: - for i in range(1, threads): - self.addProcessor(g.outputQueue, self.convert, - outputQueue=f1.outputQueue) - moveQueue = f1.outputQueue - if gpuList: - mc1 = self.addProcessor(f1.outputQueue, - self.get_motioncor_proc(gpuList[0])) - for gpu in gpuList[1:]: - self.addProcessor(f1.outputQueue, self.get_motioncor_proc(gpu), - outputQueue=mc1.outputQueue) - moveQueue = mc1.outputQueue - - self.addProcessor(moveQueue, self.moveout) - - def convert(self, batch): - images = batch['images'] - thread_id = threading.get_ident() - batch_id = f'{str(uuid4()).split("-")[0]}' - batch_dir = os.path.join(self.workingDir, batch_id) - - print(f'T-{thread_id}: converting {len(images)} images...batch dir: {batch_dir}') - - batch = { - 'images': images, - 'batch_id': batch_id, - 'batch_dir': batch_dir - } - - script = f"rm -rf {batch_dir} && mkdir -p {batch_dir}\n" - - # FIXME: Just create links if input is already in .mrc format - cmd = 'tif2mrc' - - for fn in images: - base = os.path.basename(fn) - if cmd == 'tif2mrc': - base = base.replace('.tiff', '.mrc') - outFn = os.path.join(batch_dir, base) - script += f'{cmd} {fn} {outFn}\n' - - prefix = f'{thread_id}_{batch_id}_tif2mrc' - scriptFn = f'{prefix}_script.sh' - logFn = f'{thread_id}_tif2mrc_log.txt' - - with open(scriptFn, 'w') as f: - f.write(script) - - os.system(f'bash -x {scriptFn} &>> {logFn} && rm {scriptFn}') - - return batch - - def motioncor(self, gpu, batch): - mc = self.mc - batch_id = batch['batch_id'] - batch_dir = batch['batch_dir'] - script = f""" - cd {batch_dir} - {mc['program']} -InMrc ./FoilHole_ -InSuffix fractions.mrc -OutMrc aligned_ -Serial 1 -Gpu {gpu} -Gain {mc['gain']} -LogDir ./ {mc['args']} - """ - prefix = f'motioncor-gpu{gpu}' - print(f'{prefix}: running Motioncor for batch {batch_id} on GPU {gpu}') - - scriptFn = f'{prefix}_{batch_id}_script.sh' - logFn = f'{prefix}_motioncor_log.txt' - with open(scriptFn, 'w') as f: - f.write(script) - os.system(f'bash -x {scriptFn} &>> {logFn} && rm {scriptFn}') - return batch - - def get_motioncor_proc(self, gpu): - def _motioncor(batch): - return self.motioncor(gpu, batch) - - return _motioncor - - def moveout(self, batch): - batch_dir = batch['batch_dir'] - thread_id = threading.get_ident() - print(f'T-{thread_id}: moving output from batch dir: {batch_dir}') - # FIXME: Check what we want to move to output - os.system(f'mv {batch_dir}/* {self.outputDir}/ && rm -rf {batch_dir}') - return batch - - -class Main: - @staticmethod - def add_arguments(parser): - parser.add_argument('input_images', - help='Input images, can be a star file, a txt file or ' - 'a pattern.') - - parser.add_argument('--convert', choices=['default', 'tif2mrc', 'cp']) - - parser.add_argument('--nimages', '-n', nargs='?', type=int, - default=argparse.SUPPRESS) - parser.add_argument('--output', '-o', default='output') - parser.add_argument('-j', type=int, default=1, - help='Number of parallel threads') - parser.add_argument('--batch', '-b', type=int, default=0, - help='Batch size') - parser.add_argument('--gpu', default='', nargs='?', - help='Gpu list, separated by comma.' - 'E.g --gpu 0,1') - parser.add_argument('--scratch', default='', - help='Scratch directory to do intermediate I/O') - - @staticmethod - def run(args): - n = args.nimages - output = args.output - - if args.input_images.endswith('.star'): - input_star = args.input_images - - with StarFile(input_star) as sf: - # Read table in a different order as they appear in file - # also before the getTableNames() call that create the offsets - tableMovies = sf.getTable('movies') - - all_images = [row.rlnMicrographMovieName for row in tableMovies] - elif '*' in args.input_images: - all_images = glob(args.input_images) - else: - raise Exception('Please provide input as star file or files pattern (with * in it).') - - input_images = all_images[:n] - - run_id = f'R-{uuid4()}' - - print(f' run_id: {run_id}') - print(f' images: {len(input_images)}') - print(f' output: {output}') - print(f'threads: {args.j}') - print(f' gpus: {args.gpu}') - print(f' batch: {args.batch or len(input_images)}') - - wd = args.scratch if args.scratch else output - intermediate = os.path.join(wd, run_id) - - def generate(): - b = args.batch - if b: - n = len(input_images) // b - for i in range(n): - yield {'images': input_images[i*b:(i+1)*b]} - else: - yield {'images': input_images} - - os.system(f'rm -rf {output} && mkdir {output}') - os.system(f'rm -rf {intermediate} && mkdir {intermediate}') - - t = Timer() - - gpuList = args.gpu.split(',') - mc = McPipeline(generate, gpuList, output, threads=args.j, debug=False) - mc.run() - - os.system(f'rm -rf {intermediate}') - - t.toc() - - diff --git a/emtools/metadata/__init__.py b/emtools/metadata/__init__.py index e0f934f..0445682 100644 --- a/emtools/metadata/__init__.py +++ b/emtools/metadata/__init__.py @@ -15,10 +15,11 @@ # ************************************************************************** from .table import Column, ColumnList, Table -from .starfile import StarFile +from .starfile import StarFile, StarMonitor from .epu import EPU from .misc import Bins, TsBins, DataFiles, MovieFiles from .sqlite import SqliteFile -__all__ = ["Column", "ColumnList", "Table", "StarFile", "EPU", + +__all__ = ["Column", "ColumnList", "Table", "StarFile", "StarMonitor", "EPU", "Bins", "TsBins", "SqliteFile", "DataFiles", "MovieFiles"] diff --git a/emtools/metadata/misc.py b/emtools/metadata/misc.py index 69d857f..e570fc1 100644 --- a/emtools/metadata/misc.py +++ b/emtools/metadata/misc.py @@ -15,9 +15,10 @@ # ************************************************************************** import os + from datetime import datetime, timedelta -from emtools.utils import Path, Pretty +from emtools.utils import Path, Pretty, Process class Bins: diff --git a/emtools/metadata/starfile.py b/emtools/metadata/starfile.py index b5dd2af..baeb930 100644 --- a/emtools/metadata/starfile.py +++ b/emtools/metadata/starfile.py @@ -21,9 +21,13 @@ __author__ = 'Jose Miguel de la Rosa Trevin, Grigory Sharov' +import os import sys +import time import re from contextlib import AbstractContextManager +from collections import OrderedDict +from datetime import datetime, timedelta from .table import ColumnList, Table @@ -43,10 +47,10 @@ class StarFile(AbstractContextManager): @staticmethod def printTable(table, tableName=''): - w = StarFile(sys.stdout) + w = StarFile(sys.stdout, closeFile=False) w.writeTable(tableName, table, singleRow=len(table) <= 1) - def __init__(self, inputFile, mode='r'): + def __init__(self, inputFile, mode='r', **kwargs): """ Args: inputFile: can be a str with the file path or a file object. @@ -54,6 +58,7 @@ def __init__(self, inputFile, mode='r'): the mode will be ignored. """ self._file = self.__loadFile(inputFile, mode) + self._closeFile = kwargs.get('closeFile', True) # While parsing the file, store the offsets for data_ blocks # for quick access when need to load data rows @@ -301,7 +306,8 @@ def _iterRowLines(self): def close(self): if getattr(self, '_file', None): - self._file.close() + if self._closeFile: + self._file.close() self._file = None # ---------------------- Writer functions -------------------------------- @@ -332,7 +338,7 @@ def writeHeader(self, tableName, table): for col in self._columns: self._file.write("_%s \n" % col.getName()) - def _writeRowValues(self, values): + def writeRowValues(self, values): """ Write to file a line for these row values. Order should be ensured that is the same of the expected columns. """ @@ -346,7 +352,7 @@ def writeRow(self, row): """ Write to file the line for this row. Row should be an instance of the expected Row class. """ - self._writeRowValues(row._asdict().values()) + self.writeRowValues(row._asdict().values()) def _writeNewline(self): self._file.write('\n') @@ -391,6 +397,68 @@ def writeTable(self, tableName, table, singleRow=False): self._writeTableName(tableName) +class StarMonitor: + """ + Monitor a STAR file for changes and return new items in a given table. + + This class will subclass OrderedDict to hold a clone of each new element. + It will also keep internally the last access timestamp to prevent loading + the STAR file if it has not been modified since the last check. + """ + def __init__(self, fileName, tableName, rowKeyFunc, **kwargs): + self._seenItems = set() + self.fileName = fileName + self._tableName = tableName + self._rowKeyFunc = rowKeyFunc + self._wait = kwargs.get('wait', 10) + self._timeout = timedelta(seconds=kwargs.get('timeout', 300)) + self.lastCheck = None # Last timestamp when input was checked + self.lastUpdate = None # Last timestamp when new items were found + self.inputCount = 0 # Count all input elements + + # Black list some items to not be monitored again + # We are not interested in the items but just skip them from + # the processing + blacklist = kwargs.get('blacklist', None) + if blacklist: + for row in blacklist: + self._seenItems.add(self._rowKeyFunc(row)) + + def update(self): + newRows = [] + now = datetime.now() + mTime = datetime.fromtimestamp(os.path.getmtime(self.fileName)) + + if self.lastCheck is None or mTime > self.lastCheck: + with StarFile(self.fileName) as sf: + for row in sf.iterTable(self._tableName): + rowKey = self._rowKeyFunc(row) + if rowKey not in self._seenItems: + self.inputCount += 1 + self._seenItems.add(rowKey) + newRows.append(row) + + self.lastCheck = now + if newRows: + self.lastUpdate = now + return newRows + + def timedOut(self): + """ Return True when there has been timeout seconds + since last new items were found. """ + if self.lastCheck is None or self.lastUpdate is None: + return False + else: + return self.lastCheck - self.lastUpdate > self._timeout + + def newItems(self, sleep=10): + """ Yield new items since last update until the stream is closed. """ + while not self.timedOut(): + for row in self.update(): + yield row + time.sleep(self._wait) + + # --------- Helper functions ------------------------ def _formatValue(v): return '%0.6f' % v if isinstance(v, float) else str(v) diff --git a/emtools/pwx/__init__.py b/emtools/pwx/__init__.py index 68105c9..a49fe6b 100644 --- a/emtools/pwx/__init__.py +++ b/emtools/pwx/__init__.py @@ -16,8 +16,11 @@ # This emtools submodule need Scipion environment -from .monitors import ProtocolMonitor, SetMonitor, BatchManager +from .monitors import ProtocolMonitor, SetMonitor from .workflow import Workflow +# This is imported here for backward compatibility +from emtools.jobs import BatchManager + __all__ = ["ProtocolMonitor", "SetMonitor", "Workflow", "BatchManager"] diff --git a/emtools/pwx/monitors.py b/emtools/pwx/monitors.py index 28f1b43..4c6f97b 100644 --- a/emtools/pwx/monitors.py +++ b/emtools/pwx/monitors.py @@ -18,9 +18,7 @@ import time from datetime import datetime from collections import OrderedDict -from uuid import uuid4 -from emtools.utils import Pretty, Process import pyworkflow.protocol as pwprot @@ -177,49 +175,3 @@ def iterProtocolInput(self, prot, label, waitSecs=60): prot.info(f"No more {label}, stream closed. Total: {len(self)}") - -class BatchManager: - """ Class used to generate and handle creation of item batch - for streaming/parallel processing. - """ - def __init__(self, batchSize, inputItemsIterator, workingPath): - self._items = inputItemsIterator - self._batchSize = batchSize - self._batchCount = 0 - self._workingPath = workingPath - - def generate(self): - """ Generate batches based on the input items. """ - def _createBatch(items): - batch_id = str(uuid4()) - batch_path = os.path.join(self._workingPath, batch_id) - ts = Pretty.now() - - print(f"{ts}: Creating batch: {batch_path}") - Process.system(f"rm -rf '{batch_path}'") - Process.system(f"mkdir '{batch_path}'") - - for item in items: - fn = item.getFileName() - baseName = os.path.basename(fn) - os.symlink(os.path.abspath(fn), - os.path.join(batch_path, baseName)) - self._batchCount += 1 - return { - 'items': items, - 'id': batch_id, - 'path': batch_path, - 'index': self._batchCount - } - - items = [] - - for item in self._items: - items.append(item) - - if len(items) == self._batchSize: - yield _createBatch(items) - items = [] - - if items: - yield _createBatch(items) diff --git a/emtools/scripts/emt-scipion-otf.py b/emtools/scripts/emt-scipion-otf.py index 26cb784..6155cdf 100755 --- a/emtools/scripts/emt-scipion-otf.py +++ b/emtools/scripts/emt-scipion-otf.py @@ -267,11 +267,27 @@ def create_project(workingDir): def _path(*p): return os.path.join(workingDir, *p) + """ + {"acquisition": {"voltage": 200, "magnification": 79000, "pixel_size": 1.044, "dose": 1.063, "cs": 2.7}} + """ + scipionOptsFn = _path('scipion_otf_options.json') relionOptsFn = _path('relion_it_options.py') - with open(scipionOptsFn) as f: - opts = json.load(f) + if os.path.exists(scipionOptsFn): + with open(scipionOptsFn) as f: + opts = json.load(f) + + elif os.path.exists(relionOptsFn): + with open(_path('relion_it_options.py')) as f: + relionOpts = OrderedDict(ast.literal_eval(f.read())) + opts = {'acquisition': { + 'voltage': relionOpts['prep__importmovies__kV'], + 'pixel_size': relionOpts['prep__importmovies__angpix'], + 'cs': relionOpts['prep__importmovies__Cs'], + 'magnification': 130000, + 'dose': relionOpts['prep__motioncorr__dose_per_frame'] + }} acq = opts['acquisition'] picking = opts.get('picking', {}) @@ -377,11 +393,6 @@ def _path(*p): wf.launchProtocol(protCryolo, wait={OUT_COORD: 100}) - skip_2d = not opts.get('2d', True) - - if skip_2d: - return - calculateBoxSize(protCryolo) protRelionExtract = wf.createProtocol( @@ -465,7 +476,7 @@ def print_protocol(workingDir, protId): if protId == 'all': for prot in project.getRuns(iterate=True): clsName = prot.getClassName() - print(f"- {prot.getObjId():>6} {prot.getStatus():<10} {clsName:<30} - {prot.getRunName()}") + print(f"- {prot.getObjId():>8} {prot.getStatus():<10} {clsName}") else: prot = project.getProtocol(int(protId)) if prot is None: @@ -560,9 +571,6 @@ def write_coordinates(micStarFn, prot): for coord in coords.iterItems(orderBy='_micId', direction='ASC'): micId = coord.getMicId() - if micId not in micDict: - continue - if micId not in micIds: micIds.add(micId) micFn = micDict[micId] @@ -590,9 +598,7 @@ def print_prot(prot, label='Protocol'): def write_stars(workingDir, ids=None): - """ Write star files for Relion. Generates micrographs_ctf.star, - coordinates.star and Coordinates folder. - """ + """ Restart one or more protocols. """ print("ids", ids) def _get_keys(tokens): @@ -610,12 +616,8 @@ def _get_keys(tokens): idsDict = {k: v for k, v in _get_keys(ids)} if 'ctfs' in idsDict: protCtf = project.getProtocol(idsDict['ctfs']) - if protCtf is None: - raise Exception(f"There is no CTF protocol with id {idsDict['ctfs']}") if 'picking' in idsDict: protPicking = project.getProtocol(idsDict['picking']) - if protPicking is None: - raise Exception(f"There is no CTF protocol with id {idsDict['picking']}") else: # Default option when running OTF that we export STAR files # from CTFFind and Cryolo runs @@ -702,7 +704,7 @@ def main(): p = argparse.ArgumentParser(prog='scipion-otf') g = p.add_mutually_exclusive_group() - g.add_argument('--create', metavar='Scipion project path', + g.add_argument('--create', action='store_true', help="Create a new Scipion project in the working " "directory. This will overwrite any existing " "'scipion' folder there.") @@ -719,7 +721,6 @@ def main(): g.add_argument('--clean', action="store_true", help="Clean Scipion project files/folders.") g.add_argument('--continue_2d', action="store_true") - g.add_argument('--write_stars', default=argparse.SUPPRESS, nargs='*', help="Generate STAR micrographs and particles STAR files." "By default, it will get the first CTFfind protocol for ctfs" @@ -736,7 +737,7 @@ def main(): args = p.parse_args() if args.create: - create_project(args.create) + create_project(cwd) elif args.restart: restart(cwd, args.restart) elif args.restart_rankers: diff --git a/emtools/tests/star_pipeline_tester.py b/emtools/tests/star_pipeline_tester.py new file mode 100644 index 0000000..3331396 --- /dev/null +++ b/emtools/tests/star_pipeline_tester.py @@ -0,0 +1,91 @@ +# ************************************************************************** +# * +# * Authors: J.M. de la Rosa Trevin (delarosatrevin@gmail.com) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# ************************************************************************** + +import os +import time + +from emtools.jobs import Pipeline, BatchManager +from emtools.metadata import StarFile, StarMonitor +from emtools.utils import Pretty + + +class StarPipelineTester(Pipeline): + """ Helper class to test Pipeline behaviour base on an input + STAR file generated in streaming. + """ + def __init__(self, inputStar, outputDir, **kwargs): + Pipeline.__init__(self) + + self.threads = kwargs.get('threads', 4) + self.batchSize = kwargs.get('batchSize', 128) + self.outputDir = outputDir + self.inputStar = inputStar + self.outputStar = os.path.join(outputDir, 'output.star') + self._sf = None + self._file = None + self.totalItems = 0 + + print(f">>> {Pretty.now()}: ----------------- " + f"Starting STAR Pipeline Tester ----------- ") + + monitor = StarMonitor(inputStar, 'particles', + lambda row: row.rlnImageId, + timeout=30) + + batchMgr = BatchManager(self.batchSize, monitor.newItems(), outputDir, + itemFileNameFunc=self._filename) + + g = self.addGenerator(batchMgr.generate) + outputQueue = None + print(f"Creating {self.threads} processing threads.") + for _ in range(self.threads): + p = self.addProcessor(g.outputQueue, self._process, + outputQueue=outputQueue) + outputQueue = p.outputQueue + + self.addProcessor(outputQueue, self._output) + + def _filename(self, row): + """ Helper to get unique name from a particle row. """ + pts, stack = row.rlnImageName.split('@') + return stack.replace('.mrcs', f'_p{pts}.mrcs') + + def _process(self, batch): + """ Dummy function to process an input batch. """ + time.sleep(5) + return batch + + def _output(self, batch): + """ Compile a batch that has been 'processed'. """ + if self._sf is None: + self._file = open(self.outputStar, 'w') + self._sf = StarFile(self._file) + with StarFile(self.inputStar) as sf: + self._sf.writeTable('optics', sf.getTable('optics')) + self._sf.writeHeader('particles', sf.getTableInfo('particles')) + + for row in batch['items']: + self._sf.writeRow(row) + self._file.flush() + + self.totalItems += len(batch['items']) + + def run(self): + Pipeline.run(self) + if self._sf is not None: + self._sf.close() + + diff --git a/emtools/tests/test_metadata.py b/emtools/tests/test_metadata.py index 84b41e4..cdf9ab4 100644 --- a/emtools/tests/test_metadata.py +++ b/emtools/tests/test_metadata.py @@ -16,12 +16,20 @@ import os import unittest import tempfile +import random +import time +import threading +import tempfile from pprint import pprint +from datetime import datetime -from emtools.utils import Timer, Color -from emtools.metadata import StarFile, SqliteFile, EPU +from emtools.utils import Timer, Color, Pretty +from emtools.metadata import StarFile, SqliteFile, EPU, StarMonitor +from emtools.jobs import BatchManager from emtools.tests import testpath +from .star_pipeline_tester import StarPipelineTester + # Try to load starfile library to launch some comparisons try: import starfile @@ -240,6 +248,131 @@ def _checkValues(t): os.unlink(ftmp.name) + def __test_star_streaming(self, monitorFunc, inputStreaming=True): + partStar = testpath('metadata', 'particles_1k.star') + if partStar is None: + return + + N = 1000 + + with StarFile(partStar) as sf: + ptable = sf.getTable('particles') + self.assertEqual(len(ptable), N) + otable = sf.getTable('optics') + self.assertEqual(len(otable), 1) + + ftmp = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.star') + print(f">>>> Using temporary file: {ftmp.name}") + + def _write_star_parts(): + with StarFile(ftmp) as sfOut: + sfOut.writeTable('optics', otable) + if inputStreaming: + sfOut.writeHeader('particles', ptable) + u = int(random.uniform(5, 10)) + s = u * 10 + w = 0 + for i, row in enumerate(ptable): + if i == s: + print(f"{w} rows written.") + ftmp.flush() + time.sleep(3) + u = int(random.uniform(5, 10)) + s = i + u * 10 + w = 0 + sfOut.writeRow(row) + w += 1 + else: + sfOut.writeTable('particles', ptable) + w = len(ptable) + + print(f"{w} rows written.") + + th = threading.Thread(target=_write_star_parts) + print(">>> Starting thread...") + th.start() + + monitor = StarMonitor(ftmp.name, 'particles', + lambda row: row.rlnImageId, + timeout=30) + + totalCount = monitorFunc(monitor) + self.assertEqual(totalCount, N) + + print("<<< Waiting for thread") + th.join() + + ftmp.close() + + # Check output is what we expect + with StarFile(ftmp.name) as sf: + ptable = sf.getTable('particles') + self.assertEqual(len(ptable), N) + otable = sf.getTable('optics') + self.assertEqual(len(otable), 1) + + os.unlink(ftmp.name) + + def test_star_monitor(self): + """ Basic test checking that we are able to monitor a streaming + generated star file. The final count of rows should be the + same as the input one. + """ + def _monitor(monitor): + totalRows = 0 + while not monitor.timedOut(): + newRows = monitor.update() + n = len(newRows) + totalRows += n + print(f"New rows: {n}") + print(f"Last update: {Pretty.datetime(monitor.lastUpdate)} " + f"Last check: {Pretty.datetime(monitor.lastCheck)} " + f"No activity: {Pretty.delta(monitor.lastCheck - monitor.lastUpdate)}") + time.sleep(5) + return totalRows + + self.__test_star_streaming(_monitor) + + def test_star_batchmanager(self): + """ Testing the creating of batches from an input star monitor + using different batch sizes. + """ + + def _filename(row): + """ Helper to get unique name from a particle row. """ + pts, stack = row.rlnImageName.split('@') + return stack.replace('.mrcs', f'_p{pts}.mrcs') + + def _batchmanager(monitor, batchSize): + totalFiles = 0 + + with tempfile.TemporaryDirectory() as tmp: + print(f"Using dir: {tmp}") + + batchMgr = BatchManager(batchSize, monitor.newItems(), tmp, + itemFileNameFunc=_filename) + + for batch in batchMgr.generate(): + files = len(os.listdir(batch['path'])) + print(f"Batch {batch['id']} -> {batch['path']}, files: {files}") + totalFiles += files + + return totalFiles + + self.__test_star_streaming(lambda m: _batchmanager(m, 128)) + self.__test_star_streaming(lambda m: _batchmanager(m, 200)) + + def test_star_pipeline(self): + def _pipeline(monitor): + with tempfile.TemporaryDirectory() as tmp: + print(f"Using dir: {tmp}") + p = StarPipelineTester(monitor.fileName, tmp) + p.run() + return p.totalItems + + #self.__test_star_streaming(_pipeline, inputStreaming=True) + self.__test_star_streaming(_pipeline, inputStreaming=False) + class TestEPU(unittest.TestCase): """ Tests for EPU class. """ @@ -266,6 +399,7 @@ def test_read_session_info(self): pprint(session) + class TestSqliteFile(unittest.TestCase): """ Tests for StarFile class. diff --git a/emtools/utils/path.py b/emtools/utils/path.py index b1049be..e256170 100644 --- a/emtools/utils/path.py +++ b/emtools/utils/path.py @@ -166,8 +166,8 @@ def _mkdir(d): @staticmethod def replaceExt(filename, newExt): """ Replace the current path extension(from last .) - with a new one. The new one should not contain the .""" - return Path.removeExt(filename) + '.' + newExt + with a new one. The new one should contain the .""" + return Path.removeExt(filename) + newExt @staticmethod def replaceBaseExt(filename, newExt): @@ -186,6 +186,11 @@ def removeExt(filename): """ Remove extension from basename """ return os.path.splitext(filename)[0] + @staticmethod + def getExt(filename): + """ Get filename extension """ + return os.path.splitext(filename)[1] + @staticmethod def exists(path): """ Just avoid empty or None path to raise exception