diff --git a/neurom/apps/morph_stats.py b/neurom/apps/morph_stats.py index 4f887903f..bc892a392 100644 --- a/neurom/apps/morph_stats.py +++ b/neurom/apps/morph_stats.py @@ -27,10 +27,14 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Core code for morph_stats application.""" +import os import logging from collections import defaultdict from itertools import product from pathlib import Path +import multiprocessing +from functools import partial +import warnings import numpy as np import pandas as pd @@ -82,11 +86,18 @@ def _stat_name(feat_name, stat_mode): return '%s_%s' % (stat_mode, feat_name) -def extract_dataframe(neurons, config): +def _run_extract_stats(nrn, config): + """The function to be called by multiprocessing.Pool.imap_unordered.""" + if not isinstance(nrn, FstNeuron): + nrn = nm.load_neuron(nrn) + return nrn.name, extract_stats(nrn, config) + + +def extract_dataframe(neurons, config, n_workers=1): """Extract stats grouped by neurite type from neurons. Arguments: - neurons: a neuron, population or neurite tree + neurons: a neuron, population, neurite tree or list of neuron paths config (dict): configuration dict. The keys are: - neurite_type: a list of neurite types for which features are extracted If not provided, all neurite_type will be used @@ -94,6 +105,7 @@ def extract_dataframe(neurons, config): - neurite_feature is a string from NEURITEFEATURES - mode is an aggregation operation provided as a string such as: ['min', 'max', 'median', 'mean', 'std', 'raw', 'total'] + n_workers (int): number of workers for multiprocessing (on collection of neurons) Returns: The extracted statistics @@ -112,7 +124,15 @@ def extract_dataframe(neurons, config): if 'neuron' in config: del config['neuron'] - stats = {nrn.name: extract_stats(nrn, config) for nrn in neurons} + func = partial(_run_extract_stats, config=config) + if n_workers == 1: + stats = dict(map(func, neurons)) + else: + if n_workers > os.cpu_count(): + warnings.warn(f'n_workers ({n_workers}) > os.cpu_count() ({os.cpu_count()}))') + with multiprocessing.Pool(n_workers) as pool: + stats = dict(pool.imap_unordered(func, neurons)) + columns = list(next(iter(next(iter(stats.values())).values())).keys()) rows = [[name, neurite_type] + list(features.values()) @@ -126,7 +146,7 @@ def extract_stats(neurons, config): """Extract stats from neurons. Arguments: - neurons: a neuron, population or neurite tree + neurons: a neuron, population, neurite tree or list of neuron paths/str config (dict): configuration dict. The keys are: - neurite_type: a list of neurite types for which features are extracted If not provided, all neurite_type will be used @@ -162,15 +182,17 @@ def _fill_stats_dict(data, stat_name, stat): config.get('neurite_type', _NEURITE_MAP.keys())): neurite_type = _NEURITE_MAP[neurite_type] + feature = nm.get(feature_name, neurons, neurite_type=neurite_type) for mode in modes: stat_name = _stat_name(feature_name, mode) - stat = eval_stats(nm.get(feature_name, neurons, neurite_type=neurite_type), mode) + stat = eval_stats(feature, mode) _fill_stats_dict(stats[neurite_type.name], stat_name, stat) for feature_name, modes in config.get('neuron', {}).items(): + feature = nm.get(feature_name, neurons) for mode in modes: stat_name = _stat_name(feature_name, mode) - stat = eval_stats(nm.get(feature_name, neurons), mode) + stat = eval_stats(feature, mode) _fill_stats_dict(stats, stat_name, stat) return dict(stats) diff --git a/neurom/apps/tests/test_morph_stats.py b/neurom/apps/tests/test_morph_stats.py index 70c003278..7f1ec6708 100644 --- a/neurom/apps/tests/test_morph_stats.py +++ b/neurom/apps/tests/test_morph_stats.py @@ -26,7 +26,9 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os from pathlib import Path +import warnings import numpy as np from nose.tools import (assert_almost_equal, assert_equal, @@ -176,6 +178,11 @@ def test_extract_dataframe(): actual = ms.extract_dataframe(nrns, config) assert_frame_equal(actual, expected) + # Test with a List[Path] argument + nrns = [Path(SWC_PATH, name) for name in ['Neuron.swc', 'simple.swc']] + actual = ms.extract_dataframe(nrns, config) + assert_frame_equal(actual, expected) + # Test without any neurite_type keys, it should pick the defaults config = {'neurite': {'total_length_per_neurite': ['total']}} actual = ms.extract_dataframe(nrns, config) @@ -192,6 +199,24 @@ def test_extract_dataframe(): assert_frame_equal(actual, expected) +def test_extract_dataframe_multiproc(): + nrns = nm.load_neurons([Path(SWC_PATH, name) + for name in ['Neuron.swc', 'simple.swc']]) + with warnings.catch_warnings(record=True) as w: + actual = ms.extract_dataframe(nrns, REF_CONFIG, n_workers=2) + expected = pd.read_csv(Path(DATA_PATH, 'extracted-stats.csv'), index_col=0) + + # Compare sorted DataFrame since Pool.imap_unordered disrupted the order + assert_frame_equal(actual.sort_values(by=['name']).reset_index(drop=True), + expected.sort_values(by=['name']).reset_index(drop=True)) + + with warnings.catch_warnings(record=True) as w: + actual = ms.extract_dataframe(nrns, REF_CONFIG, n_workers=os.cpu_count() + 1) + assert_equal(len(w), 1, "Warning not emitted") + assert_frame_equal(actual.sort_values(by=['name']).reset_index(drop=True), + expected.sort_values(by=['name']).reset_index(drop=True)) + + def test_get_header(): fake_results = {'fake_name0': REF_OUT, diff --git a/neurom/features/tests/test_get_features.py b/neurom/features/tests/test_get_features.py index 15cfdb2cc..d408a3330 100644 --- a/neurom/features/tests/test_get_features.py +++ b/neurom/features/tests/test_get_features.py @@ -436,7 +436,7 @@ def test_neurite_features_accept_single_tree(): for f in features: ret = get_feature(f, NRN.neurites[0]) nt.ok_(ret.dtype.kind in ('i', 'f')) - nt.ok_(len(ret) or len(ret) == 0) # make sure that len() resolves + nt.ok_(len(ret) > 0) def test_register_neurite_feature_nrns():