Skip to content

Commit

Permalink
Restore scalar pysplinetable_searchcenters (#40)
Browse files Browse the repository at this point in the history
* ENH: accept os.PathLike in ctor

Enabling more laziness

* Restore scalar searchcenters

Fixes #39.

* Bump version to 2.2.1

* Work with python < 3.10

in which Py_XDECREF is a macro

* Clip stray import

* tests: add local photospline to pythonpath

* tests: set PYTHONPATH for ctest

* Add a test for evaluate

and adjust annotations to cover the vector case
  • Loading branch information
jvansanten authored Oct 19, 2023
1 parent 99e7021 commit 5b4ff89
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 25 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required (VERSION 3.1.0 FATAL_ERROR)
cmake_policy(VERSION 3.1.0)

project (photospline VERSION 2.2.0 LANGUAGES C CXX)
project (photospline VERSION 2.2.1 LANGUAGES C CXX)

SET(CMAKE_CXX_STANDARD 11)
SET(CMAKE_C_STANDARD 99)
Expand Down Expand Up @@ -369,7 +369,11 @@ LIST (APPEND ALL_TESTS

if(PYTHON_FOUND AND NUMPY_FOUND)
ADD_TEST(photospline-test-pystack ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/test/test_stack.py)
set_property(TEST photospline-test-pystack PROPERTY ENVIRONMENT PYTHONPATH=${PROJECT_BINARY_DIR})
LIST (APPEND ALL_TESTS photospline-test-pystack)
ADD_TEST(photospline-test-pyeval ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/test/test_eval.py)
set_property(TEST photospline-test-pyeval PROPERTY ENVIRONMENT PYTHONPATH=${PROJECT_BINARY_DIR})
LIST (APPEND ALL_TESTS photospline-test-pyeval)
endif()

if(BUILD_SPGLAM)
Expand Down
39 changes: 29 additions & 10 deletions src/python/photosplinemodule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ pysplinetable_dealloc(pysplinetable* self){
delete(self->table);
}

typedef std::unique_ptr<PyObject, void(*)(PyObject*)> handle;

static inline handle new_reference(PyObject *ptr)
{
auto deleter = [](PyObject* ptr){ Py_XDECREF(ptr); };
return handle(ptr, deleter);
}

static PyObject*
pysplinetable_new(PyTypeObject* type, PyObject* args, PyObject* kwds){
pysplinetable* self;
Expand All @@ -334,13 +342,14 @@ pysplinetable_new(PyTypeObject* type, PyObject* args, PyObject* kwds){
static int
pysplinetable_init(pysplinetable* self, PyObject* args, PyObject* kwds){
static const char* kwlist[] = {"path", NULL};
char* path=NULL;
PyObject *path=NULL;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", (char**)kwlist, &path))
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&", (char**)kwlist, &PyUnicode_FSConverter, &path))
return -1;
auto handle = new_reference(path);

try{
self->table=new photospline::splinetable<>(path);
self->table=new photospline::splinetable<>(PyBytes_AsString(handle.get()));
}catch(std::exception& ex){
PyErr_SetString(PyExc_Exception,
(std::string("Unable to allocate spline table: ")+ex.what()).c_str());
Expand Down Expand Up @@ -567,19 +576,28 @@ pysplinetable_searchcenters(pysplinetable* self, PyObject* args, PyObject* kwds)
return(NULL);
}

bool scalar = true;
for(unsigned int i=0; i!=ndim; i++){
if (!(scalar = scalar && PyFloat_Check(new_reference(PySequence_GetItem(pyx,i)).get())))
break;
}
#ifndef HAVE_NUMPY
if (!scalar) {
PyErr_SetString(PyExc_ValueError, "x must be a sequence of floats");
return(NULL);
}
#else
if (scalar)
#endif
//unpack x
//assume these are arbitrary sequences, not numpy arrays
#ifndef HAVE_NUMPY
{
//a small amount of evil
double x[ndim];
int centers[ndim];

for(unsigned int i=0; i!=ndim; i++){
PyObject* xi=PySequence_GetItem(pyx,i);
x[i]=PyFloat_AsDouble(xi);
Py_DECREF(xi); //done with this
//printf("x[%u]=%lf\n",i,x[i]);
x[i]=PyFloat_AsDouble(new_reference(PySequence_GetItem(pyx,i)).get());
}

if(!self->table->searchcenters(x,centers)){
Expand All @@ -592,7 +610,8 @@ pysplinetable_searchcenters(pysplinetable* self, PyObject* args, PyObject* kwds)
PyTuple_SetItem(result,i,Py_BuildValue("i",centers[i]));
return(result);
}
#else
#if defined(HAVE_NUMPY)
else
//optimized case for numpy arrays (or things that can be converted to them)
{
PyArrayObject* arrays[2*ndim];
Expand Down Expand Up @@ -622,7 +641,7 @@ pysplinetable_searchcenters(pysplinetable* self, PyObject* args, PyObject* kwds)

for(unsigned int i=0; i!=ndim; i++){
// get a pointer to the data of the row
double* row_data = (double*)PyArray_GETPTR1(array_out, i);
void* row_data = PyArray_GETPTR1(array_out, i);

// create a new 1D array that shares data with the row
npy_intp dims[1] = {PyArray_DIM(array_out, 1)}; // length of the row
Expand Down
37 changes: 37 additions & 0 deletions test/test_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python

import unittest
import photospline

import numpy as np

from pathlib import Path


class TestEvaluation(unittest.TestCase):
def setUp(self):
self.testdata = Path(__file__).parent / "test_data"
self.spline = photospline.SplineTable(self.testdata / "test_spline_4d.fits")
extents = np.array(self.spline.extents)
loc = extents[:, :1]
scale = np.diff(extents, axis=1)
self.x = (np.random.uniform(0, 1, size=(self.spline.ndim, 10)) + loc) * scale

def test_vector(self):
centers = self.spline.search_centers(self.x)
self.assertIsInstance(centers, np.ndarray)
self.assertEqual(centers.shape, (self.spline.ndim, self.x.shape[1]))
v = self.spline.evaluate(self.x, centers=centers)
self.assertIsInstance(v, np.ndarray)
self.assertEqual(v.shape, (self.x.shape[1],))

def test_scalar(self):
centers = self.spline.search_centers([x[0] for x in self.x])
self.assertIsInstance(centers, tuple)
self.assertEqual(len(centers), self.spline.ndim)
v = self.spline.evaluate([x[0] for x in self.x], centers=centers)
self.assertIsInstance(v, float)


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions test/test_stack.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
#!/usr/bin/env python

import sys, os

sys.path.append(os.getcwd())
import numpy
import photospline

Expand Down
45 changes: 35 additions & 10 deletions typings/photospline-stubs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence
from typing import Sequence, Union, Any, overload
from os import PathLike
import numpy as np
import numpy.typing as npt

Expand All @@ -9,8 +10,7 @@ class SplineTable:
ndim: int
order: tuple[int, ...]
def aux_value(self, key: str) -> str: ...

def __init__(self, path: str): ...
def __init__(self, path: PathLike): ...
@classmethod
def stack(
cls,
Expand All @@ -20,18 +20,43 @@ class SplineTable:
) -> SplineTable:
""" """
...

def convolve(self, dim: int, knots: Sequence[float]) -> None: ...
def permute_dimensions(self, permutation: Sequence[int]) -> None: ...
def write(self, path: str) -> None: ...

def search_centers(self, x: Sequence[float]) -> Sequence[int]: ...
def deriv(self, x: Sequence[float], centers: Sequence[int], derivatives: Sequence[int]) -> float: ...
def evaluate(self, x: Sequence[float], centers: Sequence[int], derivatives: Sequence[int]) -> float: ...
def evaluate_gradient(self, x: Sequence[float], centers: Sequence[int]) -> float: ...
@overload
def search_centers(self, x: Sequence[float]) -> tuple[int]: ...
@overload
def search_centers(
self, x: npt.NDArray[np.floating[Any]]
) -> npt.NDArray[np.longlong]: ...
def search_centers(
self, x: Union[npt.NDArray[np.floating[Any]], Sequence[float]]
) -> Union[npt.NDArray[np.longlong], tuple[int]]: ...
def deriv(
self, x: Sequence[float], centers: Sequence[int], derivatives: Sequence[int]
) -> float: ...
@overload
def evaluate(
self, x: Sequence[float], centers: Sequence[int], derivatives: int = 0
) -> float: ...
@overload
def evaluate(
self,
x: npt.NDArray[np.floating[Any]],
centers: npt.NDArray[np.longlong],
derivatives: int = 0,
) -> npt.NDArray[np.floating[Any]]: ...
def evaluate(
self,
x: Union[npt.NDArray[np.floating[Any]], Sequence[float]],
centers: Union[npt.NDArray[np.longlong], Sequence[int]],
derivatives: int = 0,
) -> Union[npt.NDArray[np.floating[Any]], float]: ...
def evaluate_gradient(
self, x: Sequence[float], centers: Sequence[int]
) -> float: ...
def evaluate_simple(self, x: Sequence[float]) -> float: ...
def __call__(self, x: Sequence[float]) -> float: ...

def grideval(self, coords: Sequence[npt.ArrayLike]) -> npt.NDArray[np.float64]: ...

class ndsparse:
Expand Down

0 comments on commit 5b4ff89

Please sign in to comment.