Skip to content

Commit

Permalink
Incorporate some more changes for numpy 2
Browse files Browse the repository at this point in the history
Co-authored-by: Peter Hawkins <[email protected]>
  • Loading branch information
moble and hawkinsp committed Aug 17, 2024
1 parent 8286759 commit 5aec0ee
Showing 1 changed file with 25 additions and 31 deletions.
56 changes: 25 additions & 31 deletions src/numpy_quaternion.c
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// Copyright (c) 2024, Michael Boyle
// See LICENSE file for details: <https://github.com/moble/quaternion/blob/main/LICENSE>

#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#define PY_ARRAY_UNIQUE_SYMBOL NumpyQuaternion
// #define NPY_NO_DEPRECATED_API NPY_API_VERSION
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION

#include <Python.h>
#include <numpy/arrayobject.h>
#include <numpy/npy_math.h>
#include <numpy/ndarrayobject.h>
#include <numpy/ufuncobject.h>
#include "structmember.h"

#include "quaternion.h"

// Provide compatibility with numpy 1 and 2
Expand Down Expand Up @@ -62,7 +62,8 @@ static PyTypeObject PyQuaternion_Type;
// built-in numpy data type. We will describe its features below.
PyArray_DescrProto* quaternion_descr;


PyArray_DescrProto quaternion_proto = {PyObject_HEAD_INIT(NULL)};

static NPY_INLINE int
PyQuaternion_Check(PyObject* object) {
return PyObject_IsInstance(object,(PyObject*)&PyQuaternion_Type);
Expand Down Expand Up @@ -263,7 +264,6 @@ pyquaternion_##fake_name##_array_operator(PyObject* a, PyObject* b) { \
} \
iternext = NpyIter_GetIterNext(iter, NULL); \
innerstride = NpyIter_GetInnerStrideArray(iter)[0]; \
/*itemsize = NpyIter_GetDescrArray(iter)[1]->elsize;*/ \
itemsize = PyDataType_ELSIZE(NpyIter_GetDescrArray(iter)[1]); \
innersizeptr = NpyIter_GetInnerLoopSizePtr(iter); \
dataptrarray = NpyIter_GetDataPtrArray(iter); \
Expand Down Expand Up @@ -1445,13 +1445,8 @@ PyMODINIT_FUNC initnumpy_quaternion(void) {
}

// Initialize numpy
import_array();
if (PyErr_Occurred()) {
INITERROR;
}
import_umath();
if (PyErr_Occurred()) {
INITERROR;
if (PyArray_ImportNumPyAPI() < 0) {
return NULL;
}
numpy = PyImport_ImportModule("numpy");
if (!numpy) {
Expand Down Expand Up @@ -1485,30 +1480,29 @@ PyMODINIT_FUNC initnumpy_quaternion(void) {
_PyQuaternion_ArrFuncs.fillwithscalar = (PyArray_FillWithScalarFunc*)QUATERNION_fillwithscalar;

// The quaternion array descr
quaternion_descr = PyObject_New(PyArray_DescrProto, &PyArrayDescr_Type);
quaternion_descr->typeobj = &PyQuaternion_Type;
quaternion_descr->kind = 'V';
quaternion_descr->type = 'q';
quaternion_descr->byteorder = '=';
quaternion_descr->flags = NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM;
quaternion_descr->type_num = 0; // assigned at registration
// quaternion_descr->elsize = quaternion_elsize;
PyDataType_SET_ELSIZE(quaternion_descr, quaternion_elsize);
quaternion_descr->alignment = quaternion_alignment;
quaternion_descr->subarray = NULL;
quaternion_descr->fields = NULL;
quaternion_descr->names = NULL;
// quaternion_descr->f = &_PyQuaternion_ArrFuncs;
quaternion_descr->f = &_PyQuaternion_ArrFuncs;
quaternion_descr->metadata = NULL;
quaternion_descr->c_metadata = NULL;
Py_SET_TYPE(&quaternion_proto, &PyArrayDescr_Type);
quaternion_proto.typeobj = &PyQuaternion_Type;
quaternion_proto.kind = 'V';
quaternion_proto.type = 'q';
quaternion_proto.byteorder = '=';
quaternion_proto.flags = NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM;
quaternion_proto.type_num = 0; // assigned at registration
quaternion_proto.elsize = quaternion_elsize;
quaternion_proto.alignment = quaternion_alignment;
quaternion_proto.subarray = NULL;
quaternion_proto.fields = NULL;
quaternion_proto.names = NULL;
quaternion_proto.f = &_PyQuaternion_ArrFuncs;
quaternion_proto.metadata = NULL;
quaternion_proto.c_metadata = NULL;

Py_INCREF(&PyQuaternion_Type);
quaternionNum = PyArray_RegisterDataType(quaternion_descr);
quaternionNum = PyArray_RegisterDataType(&quaternion_proto);

if (quaternionNum < 0) {
INITERROR;
}
quaternion_descr = PyArray_DescrFromType(quaternionNum);

register_cast_function(NPY_BOOL, quaternionNum, (PyArray_VectorUnaryFunc*)BOOL_to_quaternion);
register_cast_function(NPY_BYTE, quaternionNum, (PyArray_VectorUnaryFunc*)BYTE_to_quaternion);
Expand Down

0 comments on commit 5aec0ee

Please sign in to comment.