Skip to content

Commit

Permalink
adding array_tools
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Mar 4, 2024
1 parent 49e84e4 commit d4022b1
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 12 deletions.
13 changes: 11 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
#!/usr/bin/env python
from setuptools import setup
import numpy
from setuptools import Extension, setup

setup()
setup(
ext_modules=[
Extension(
"qseek.ext.array_tools",
sources=["src/qseek/ext/array_tools.c"],
include_dirs=[numpy.get_include()],
)
]
)
154 changes: 154 additions & 0 deletions src/qseek/ext/array_tools.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#define PY_SSIZE_T_CLEAN /* Make "s#" use Py_ssize_t rather than int. */
#include "numpy/arrayobject.h"
#include <Python.h>
#include <omp.h>

static PyObject *fill_zero_bytes(PyObject *module, PyObject *args,
PyObject *kwds) {
PyObject *array;
static char *kwlist[] = {"array", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &array))
return NULL;

if (!PyArray_Check(array)) {
PyErr_SetString(PyExc_ValueError, "weights is not a NumPy array");
return NULL;
}
Py_BEGIN_ALLOW_THREADS;
memset(PyArray_DATA((PyArrayObject *)array), 0,
PyArray_NBYTES((PyArrayObject *)array));
Py_END_ALLOW_THREADS;
Py_RETURN_NONE;
}

static PyObject *apply_cache(PyObject *module, PyObject *args, PyObject *kwds) {
PyObject *array, *cache, *mask;
int nthreads = 4;
npy_intp *array_shape;
npy_intp n_nodes, n_samples;

npy_int *cumsum_mask, mask_value;
npy_int idx_sum = 0;
npy_bool *mask_data;
PyArrayObject *cached_array;

static char *kwlist[] = {"array", "cache", "mask", "nthreads", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|i", kwlist, &array, &cache,
&mask, &nthreads))
return NULL;

if (!PyArray_Check(array)) {
PyErr_SetString(PyExc_ValueError, "array is not a NumPy array");
return NULL;
}
if (PyArray_NDIM(array) != 2) {
PyErr_SetString(PyExc_ValueError, "array is not a 2D NumPy array");
return NULL;
}
if (!PyArray_IS_C_CONTIGUOUS(array)) {
PyErr_SetString(PyExc_ValueError, "array is not C contiguous");
return NULL;
}
if (PyArray_TYPE(array) != NPY_FLOAT) {
fprintf(stderr, "array type: %d %d\n", PyArray_TYPE(array), NPY_FLOAT);
PyErr_SetString(PyExc_ValueError, "array is not of type np.float32");
return NULL;
}
array_shape = PyArray_SHAPE((PyArrayObject *)array);
n_nodes = array_shape[0];
n_samples = array_shape[1];

if (!PyArray_Check(mask)) {
PyErr_SetString(PyExc_ValueError, "mask is not a NumPy array");
return NULL;
}
if (PyArray_NDIM(mask) != 1) {
PyErr_SetString(PyExc_ValueError, "mask is not a 2D NumPy array");
return NULL;
}
if (PyArray_SIZE(mask) != n_nodes) {
PyErr_SetString(PyExc_ValueError, "mask size does not match array");
return NULL;
}

if (!PyList_Check(cache)) {
PyErr_SetString(PyExc_ValueError, "cache is not a list");
return NULL;
}

for (int i_node = 0; i_node < PyList_Size(cache); i_node++) {
PyObject *item = PyList_GetItem(cache, i_node);
if (!PyArray_Check(item)) {
PyErr_SetString(PyExc_ValueError, "cache item is not a NumPy array");
return NULL;
}
if (PyArray_TYPE(item) != NPY_FLOAT) {
PyErr_SetString(PyExc_ValueError, "cache item is not of type np.float32");
return NULL;
}
if (PyArray_NDIM(item) != 1) {
PyErr_SetString(PyExc_ValueError, "cache item is not a 1D NumPy array");
return NULL;
}
if (!PyArray_IS_C_CONTIGUOUS(item)) {
PyErr_SetString(PyExc_ValueError, "cache item is not C contiguous");
return NULL;
}
if (PyArray_SIZE(item) != n_samples) {
PyErr_SetString(PyExc_ValueError, "cache item size does not match array");
return NULL;
}
}

// cumsum mask

cumsum_mask = (npy_int *)malloc(n_nodes * sizeof(npy_int));
mask_data = PyArray_DATA((PyArrayObject *)mask);
for (int i_node = 0; i_node < n_nodes; i_node++) {
mask_value = mask_data[i_node];
if (!mask_value) {
cumsum_mask[i_node] = -1;
} else {
cumsum_mask[i_node] = idx_sum;
idx_sum += 1;
}
}
Py_BEGIN_ALLOW_THREADS;
for (int i_node = 0; i_node < n_nodes; i_node++) {
if (cumsum_mask[i_node] == -1) {
continue;
}
cached_array = PyList_GET_ITEM(cache, (Py_ssize_t)cumsum_mask[i_node]);
memcpy(
PyArray_GETPTR2((PyArrayObject *)array, (npy_intp)i_node, (npy_intp)0),
PyArray_DATA((PyArrayObject *)cached_array),
n_samples * sizeof(npy_float32));
}
Py_END_ALLOW_THREADS;

Py_RETURN_NONE;
}

static PyMethodDef methods[] = {
/* The cast of the function is necessary since PyCFunction values
* only take two PyObject* parameters, and fill_zero_bytes() takes
* three.
*/
{"fill_zero_bytes", (PyCFunction)(void (*)(void))fill_zero_bytes,
METH_VARARGS | METH_KEYWORDS, "Fill a numpy array with zero bytes."},
{"apply_cache", (PyCFunction)(void (*)(void))apply_cache,
METH_VARARGS | METH_KEYWORDS,
"Apply a cache to a 2D numpy array of type float32."},
{NULL, NULL, 0, NULL} /* sentinel */
};

static struct PyModuleDef module = {
PyModuleDef_HEAD_INIT, "array_tools", NULL, -1, methods,
};

PyMODINIT_FUNC PyInit_array_tools(void) {
import_array();
return PyModule_Create(&module);
}
13 changes: 13 additions & 0 deletions src/qseek/ext/array_tools.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np

def fill_zero_bytes(array: np.ndarray) -> None:
"""Fill the zero bytes of the array with zeros."""

def apply_cache(data: np.ndarray, cache: list[np.ndarray], mask: np.ndarray) -> None:
"""Apply the cache to the data array.
Args:
data: The data array, ndim=2 with NxM shape.
cache: List of arrays with ndim=1 and M shape.
mask: The mask array, ndim=1 with N shape of np.bool type.
"""
10 changes: 6 additions & 4 deletions src/qseek/models/semblance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rich.table import Table
from scipy import signal, stats

from qseek.ext.array_tools import apply_cache, fill_zero_bytes
from qseek.stats import Stats
from qseek.utils import datetime_now, get_cpu_count, human_readable_bytes

Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(
n_samples,
):
logger.debug("recycling semblance memory")
self._cached_semblance.fill(0.0)
fill_zero_bytes(self._cached_semblance)
self.semblance_unpadded = self._cached_semblance
else:
logger.debug("re-allocating semblance memory")
Expand Down Expand Up @@ -213,9 +214,10 @@ def apply_cache(self, cache: dict[bytes, np.ndarray]) -> None:

# This is a faster then
# self.semblance_unpadded[mask, :] = data
for idx, copy in enumerate(mask):
if copy:
memoryview(self.semblance_unpadded[idx])[:] = memoryview(data.pop(0))
# for idx, copy in enumerate(mask):
# if copy:
# memoryview(self.semblance_unpadded[idx])[:] = memoryview(data.pop(0))
apply_cache(self.semblance_unpadded, data, mask)

def maximum_node_semblance(self) -> np.ndarray:
semblance = self.semblance.max(axis=1)
Expand Down
11 changes: 5 additions & 6 deletions src/qseek/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import asyncio
import logging
from collections import deque

# from cProfile import Profile
from cProfile import Profile
from datetime import datetime, timedelta, timezone
from itertools import chain
from pathlib import Path
Expand Down Expand Up @@ -57,7 +56,7 @@
logger = logging.getLogger(__name__)

SamplingRate = Literal[10, 20, 25, 50, 100, 200]
# p = Profile()
p = Profile()


class SearchStats(Stats):
Expand Down Expand Up @@ -721,7 +720,7 @@ async def search(
tuple[list[EventDetection], Trace]: The event detections and the
semblance traces used for the search.
"""
# p.enable()
p.enable()
parent = self.parent
sampling_rate = parent.semblance_sampling_rate

Expand Down Expand Up @@ -886,6 +885,6 @@ async def search(

detections.append(detection)

# p.disable()
# p.dump_stats("search.prof")
p.disable()
p.dump_stats("search.prof")
return detections, semblance.get_trace()

0 comments on commit d4022b1

Please sign in to comment.