From 0ac2d9bffffe90eb7be8bb62f529327a533033d7 Mon Sep 17 00:00:00 2001 From: miili Date: Wed, 6 Mar 2024 09:57:09 +0100 Subject: [PATCH] upd --- setup.py | 2 + src/qseek/ext/array_tools.c | 71 ++++++++++++++++++++++------------- src/qseek/ext/array_tools.pyi | 8 +++- src/qseek/models/semblance.py | 8 +++- 4 files changed, 60 insertions(+), 29 deletions(-) diff --git a/setup.py b/setup.py index 215aeb5c..91a968c6 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,8 @@ "qseek.ext.array_tools", sources=["src/qseek/ext/array_tools.c"], include_dirs=[numpy.get_include()], + extra_compile_args=["-fopenmp"], + extra_link_args=["-lgomp"], ) ] ) diff --git a/src/qseek/ext/array_tools.c b/src/qseek/ext/array_tools.c index f823c84c..a832557a 100644 --- a/src/qseek/ext/array_tools.c +++ b/src/qseek/ext/array_tools.c @@ -1,3 +1,4 @@ +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #define PY_SSIZE_T_CLEAN /* Make "s#" use Py_ssize_t rather than int. */ #include "numpy/arrayobject.h" #include @@ -23,25 +24,28 @@ static PyObject *fill_zero_bytes(PyObject *module, PyObject *args, } static PyObject *apply_cache(PyObject *module, PyObject *args, PyObject *kwds) { - PyObject *array, *cache, *mask; + PyObject *obj, *cache, *mask; + PyArrayObject *array, *mask_array, *cached_row; npy_intp *array_shape; npy_intp n_nodes, n_samples; + int n_threads = 1; + uint sum_mask = 0; npy_int *cumsum_mask, mask_value; npy_int idx_sum = 0; npy_bool *mask_data; - PyArrayObject *cached_array; - static char *kwlist[] = {"array", "cache", "mask", NULL}; + static char *kwlist[] = {"array", "cache", "mask", "nthreads", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO", kwlist, &array, &cache, - &mask)) + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|i", kwlist, &obj, &cache, + &mask, &n_threads)) return NULL; - if (!PyArray_Check(array)) { + if (!PyArray_Check(obj)) { PyErr_SetString(PyExc_ValueError, "array is not a NumPy array"); return NULL; } + array = (PyArrayObject *)obj; if (PyArray_NDIM(array) != 2) { PyErr_SetString(PyExc_ValueError, "array is not a 2D NumPy array"); return NULL; @@ -63,11 +67,12 @@ static PyObject *apply_cache(PyObject *module, PyObject *args, PyObject *kwds) { PyErr_SetString(PyExc_ValueError, "mask is not a NumPy array"); return NULL; } - if (PyArray_NDIM(mask) != 1) { + mask_array = (PyArrayObject *)mask; + if (PyArray_NDIM(mask_array) != 1) { PyErr_SetString(PyExc_ValueError, "mask is not a 2D NumPy array"); return NULL; } - if (PyArray_SIZE(mask) != n_nodes) { + if (PyArray_SIZE(mask_array) != n_nodes) { PyErr_SetString(PyExc_ValueError, "mask size does not match array"); return NULL; } @@ -77,53 +82,65 @@ static PyObject *apply_cache(PyObject *module, PyObject *args, PyObject *kwds) { return NULL; } + cumsum_mask = (npy_int *)malloc(n_nodes * sizeof(npy_int)); + mask_data = PyArray_DATA(mask_array); + for (int i_node = 0; i_node < n_nodes; i_node++) { + mask_value = mask_data[i_node]; + if (!mask_value) { + cumsum_mask[i_node] = -1; + } else { + cumsum_mask[i_node] = idx_sum; + idx_sum += 1; + sum_mask += 1; + } + } + + if (PyList_Size(cache) != sum_mask) { + PyErr_SetString(PyExc_ValueError, "cache elements does not match mask"); + return NULL; + } + for (int i_node = 0; i_node < PyList_Size(cache); i_node++) { PyObject *item = PyList_GetItem(cache, i_node); if (!PyArray_Check(item)) { PyErr_SetString(PyExc_ValueError, "cache item is not a NumPy array"); return NULL; } - if (PyArray_TYPE(item) != NPY_FLOAT) { + cached_row = (PyArrayObject *)item; + if (PyArray_TYPE(cached_row) != NPY_FLOAT) { PyErr_SetString(PyExc_ValueError, "cache item is not of type np.float32"); return NULL; } - if (PyArray_NDIM(item) != 1) { + if (PyArray_NDIM(cached_row) != 1) { PyErr_SetString(PyExc_ValueError, "cache item is not a 1D NumPy array"); return NULL; } - if (!PyArray_IS_C_CONTIGUOUS(item)) { + if (!PyArray_IS_C_CONTIGUOUS(cached_row)) { PyErr_SetString(PyExc_ValueError, "cache item is not C contiguous"); return NULL; } - if (PyArray_SIZE(item) != n_samples) { - PyErr_SetString(PyExc_ValueError, "cache item size does not match array"); + if (PyArray_SIZE(cached_row) != n_samples) { + PyErr_SetString(PyExc_ValueError, + "cache item size does not match array nsamples"); return NULL; } } // cumsum mask - cumsum_mask = (npy_int *)malloc(n_nodes * sizeof(npy_int)); - mask_data = PyArray_DATA((PyArrayObject *)mask); - for (int i_node = 0; i_node < n_nodes; i_node++) { - mask_value = mask_data[i_node]; - if (!mask_value) { - cumsum_mask[i_node] = -1; - } else { - cumsum_mask[i_node] = idx_sum; - idx_sum += 1; - } - } Py_BEGIN_ALLOW_THREADS; +#pragma omp parallel for num_threads(n_threads) \ + schedule(dynamic) private(cached_row) for (int i_node = 0; i_node < n_nodes; i_node++) { if (cumsum_mask[i_node] == -1) { continue; } - cached_array = PyList_GET_ITEM(cache, (Py_ssize_t)cumsum_mask[i_node]); + cached_row = (PyArrayObject *)PyList_GET_ITEM( + cache, (Py_ssize_t)cumsum_mask[i_node]); memcpy( PyArray_GETPTR2((PyArrayObject *)array, (npy_intp)i_node, (npy_intp)0), - PyArray_DATA((PyArrayObject *)cached_array), - n_samples * sizeof(npy_float32)); + PyArray_DATA((PyArrayObject *)cached_row), + (size_t)n_samples * sizeof(npy_float32)); } Py_END_ALLOW_THREADS; diff --git a/src/qseek/ext/array_tools.pyi b/src/qseek/ext/array_tools.pyi index 102e60b6..57f3d671 100644 --- a/src/qseek/ext/array_tools.pyi +++ b/src/qseek/ext/array_tools.pyi @@ -3,11 +3,17 @@ import numpy as np def fill_zero_bytes(array: np.ndarray) -> None: """Fill the zero bytes of the array with zeros.""" -def apply_cache(data: np.ndarray, cache: list[np.ndarray], mask: np.ndarray) -> None: +def apply_cache( + data: np.ndarray, + cache: list[np.ndarray], + mask: np.ndarray, + nthreads: int = 1, +) -> None: """Apply the cache to the data array. Args: data: The data array, ndim=2 with NxM shape. cache: List of arrays with ndim=1 and M shape. mask: The mask array, ndim=1 with N shape of np.bool type. + nthreads: The number of threads to use. """ diff --git a/src/qseek/models/semblance.py b/src/qseek/models/semblance.py index 0be1ce82..9b6aabbb 100644 --- a/src/qseek/models/semblance.py +++ b/src/qseek/models/semblance.py @@ -204,7 +204,13 @@ async def apply_cache(self, cache: dict[bytes, np.ndarray]) -> None: # for idx, copy in enumerate(mask): # if copy: # memoryview(self.semblance_unpadded[idx])[:] = memoryview(data.pop(0)) - await asyncio.to_thread(apply_cache, self.semblance_unpadded, data, mask) + await asyncio.to_thread( + apply_cache, + self.semblance_unpadded, + data, + mask, + nthreads=4, + ) def maximum_node_semblance(self) -> np.ndarray: semblance = self.semblance.max(axis=1)