Skip to content

Commit

Permalink
API: Fix structured dtype cast-safety, promotion, and comparison
Browse files Browse the repository at this point in the history
This PR replaces the old numpygh-15509 implementing proper type promotion
for structured voids.  It further fixes the casting safety to consider
casts with equivalent field number and matching order as "safe"
and if the names, titles, and offsets match as "equiv".

The change perculates into the void comparison, and since it fixes
the order, it removes the current FutureWarning there as well.

This addresses liberfa/pyerfa#77
and replaces numpygh-15509 (the implementation has changed too much).

Fixes numpygh-15494  (and probably a few more)

Co-authored-by: Allan Haldane <[email protected]>
  • Loading branch information
seberg and ahaldane committed Feb 24, 2022
1 parent a33a10a commit 93e303c
Show file tree
Hide file tree
Showing 12 changed files with 327 additions and 159 deletions.
16 changes: 16 additions & 0 deletions numpy/core/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ def _struct_dict_str(dtype, includealignedflag):
return ret


def _aligned_offset(offset, alignment):
# round up offset:
return - (-offset // alignment) * alignment


def _is_packed(dtype):
"""
Checks whether the structured data type in 'dtype'
Expand All @@ -249,12 +254,23 @@ def _is_packed(dtype):
Duplicates the C `is_dtype_struct_simple_unaligned_layout` function.
"""
align = dtype.isalignedstruct
max_alignment = 1
total_offset = 0
for name in dtype.names:
fld_dtype, fld_offset, title = _unpack_field(*dtype.fields[name])

if align:
total_offset = _aligned_offset(total_offset, fld_dtype.alignment)
max_alignment = max(max_alignment, fld_dtype.alignment)

if fld_offset != total_offset:
return False
total_offset += fld_dtype.itemsize

if align:
total_offset = _aligned_offset(total_offset, max_alignment)

if total_offset != dtype.itemsize:
return False
return True
Expand Down
42 changes: 41 additions & 1 deletion numpy/core/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import platform
import warnings

from .multiarray import dtype, array, ndarray
from .multiarray import dtype, array, ndarray, promote_types
try:
import ctypes
except ImportError:
Expand Down Expand Up @@ -433,6 +433,46 @@ def _copy_fields(ary):
'formats': [dt.fields[name][0] for name in dt.names]}
return array(ary, dtype=copy_dtype, copy=True)

def _promote_fields(dt1, dt2):
""" Perform type promotion for two structured dtypes.
Parameters
----------
dt1 : structured dtype
First dtype.
dt2 : structured dtype
Second dtype.
Returns
-------
out : dtype
The promoted dtype
Notes
-----
If one of the inputs is aligned, the result will be. The titles of
both descriptors must match (point to the same field).
"""
# Both must be structured and have the same names in the same order
if (dt1.names is None or dt2.names is None) or dt1.names != dt2.names:
raise TypeError("invalid type promotion")

new_fields = []
for name in dt1.names:
field1 = dt1.fields[name]
field2 = dt2.fields[name]
new_descr = promote_types(field1[0], field2[0])
# Check that the titles match (if given):
if field1[2:] != field2[2:]:
raise TypeError("invalid type promotion")
if len(field1) == 2:
new_fields.append((name, new_descr))
else:
new_fields.append(((field1[2], name), new_descr))

return dtype(new_fields, align=dt1.isalignedstruct or dt2.isalignedstruct)


def _getfield_is_safe(oldtype, newtype, offset):
""" Checks safety of getfield for object arrays.
Expand Down
143 changes: 85 additions & 58 deletions numpy/core/src/multiarray/arrayobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -1033,31 +1033,83 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
"Void-arrays can only be compared for equality.");
return NULL;
}
if (PyArray_HASFIELDS(self)) {
PyObject *res = NULL, *temp, *a, *b;
PyObject *key, *value, *temp2;
PyObject *op;
Py_ssize_t pos = 0;
if (PyArray_TYPE(other) != NPY_VOID) {
PyErr_SetString(PyExc_ValueError,
"Cannot compare structured or void to non-void arrays. "
"(This may return array of False in the future.)");
return NULL;
}
if (PyArray_HASFIELDS(self) && PyArray_HASFIELDS(other)) {
PyArray_Descr *self_descr = PyArray_DESCR(self);
PyArray_Descr *other_descr = PyArray_DESCR(other);

/* Use promotion to decide whether the comparison is valid */
PyArray_Descr *promoted = PyArray_PromoteTypes(self_descr, other_descr);
if (promoted == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot compare structured arrays unless they have a "
"common dtype. I.e. `np.result_type(arr1, arr2)` must "
"be defined.\n"
"(This may return array of False in the future.)");
return NULL;
}
Py_DECREF(promoted);

npy_intp result_ndim = PyArray_NDIM(self) > PyArray_NDIM(other) ?
PyArray_NDIM(self) : PyArray_NDIM(other);

op = (cmp_op == Py_EQ ? n_ops.logical_and : n_ops.logical_or);
while (PyDict_Next(PyArray_DESCR(self)->fields, &pos, &key, &value)) {
if (NPY_TITLE_KEY(key, value)) {
continue;
}
a = array_subscript_asarray(self, key);
int field_count = PyTuple_GET_SIZE(self_descr->names);
if (field_count != PyTuple_GET_SIZE(other_descr->names)) {
PyErr_SetString(PyExc_TypeError,
"Cannot compare structured dtypes with different number of "
"fields. (unreachable error please report to NumPy devs)");
return NULL;
}

PyObject *op = (cmp_op == Py_EQ ? n_ops.logical_and : n_ops.logical_or);
PyObject *res = NULL;
for (int i = 0; i < field_count; ++i) {
PyObject *fieldname, *temp, *temp2;

fieldname = PyTuple_GET_ITEM(self_descr->names, i);
PyArrayObject *a = (PyArrayObject *)array_subscript_asarray(
self, fieldname);
if (a == NULL) {
Py_XDECREF(res);
return NULL;
}
b = array_subscript_asarray(other, key);
fieldname = PyTuple_GET_ITEM(other_descr->names, i);
PyArrayObject *b = (PyArrayObject *)array_subscript_asarray(
other, fieldname);
if (b == NULL) {
Py_XDECREF(res);
Py_DECREF(a);
return NULL;
}
temp = array_richcompare((PyArrayObject *)a,b,cmp_op);
/*
* If the fields were subarrays, the dimensions may have changed.
* In that case, the new shape (subarray part) must match exactly.
* (If this is 0, there is no subarray.)
*/
int field_dims_a = PyArray_NDIM(a) - PyArray_NDIM(self);
int field_dims_b = PyArray_NDIM(b) - PyArray_NDIM(other);
if (field_dims_a != field_dims_b || (
field_dims_a != 0 && /* neither is subarray */
/* Compare only the added (subarray) dimensions: */
!PyArray_CompareLists(
PyArray_DIMS(a) + PyArray_NDIM(self),
PyArray_DIMS(b) + PyArray_NDIM(other),
field_dims_a))) {
PyErr_SetString(PyExc_TypeError,
"Cannot compare subarrays with different shapes. "
"(unreachable error, please report to NumPy devs.)");
Py_DECREF(a);
Py_DECREF(b);
Py_XDECREF(res);
return NULL;
}

temp = array_richcompare(a, (PyObject *)b, cmp_op);
Py_DECREF(a);
Py_DECREF(b);
if (temp == NULL) {
Expand Down Expand Up @@ -1142,7 +1194,24 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
}
return res;
}
else if (PyArray_HASFIELDS(self) || PyArray_HASFIELDS(other)) {
PyErr_SetString(PyExc_TypeError,
"Cannot compare structured with unstructured void. "
"(This may return array of False in the future.)");
return NULL;
}
else {
/*
* Since arrays absorb subarray descriptors, this path can only be
* reached when both arrays have unstructured voids "V<len>" dtypes.
*/
if (PyArray_ITEMSIZE(self) != PyArray_ITEMSIZE(other)) {
PyErr_SetString(PyExc_TypeError,
"cannot compare unstructured voids of different length. "
"Use bytes to compare. "
"(This may return array of False in the future.)");
return NULL;
}
/* compare as a string. Assumes self and other have same descr->type */
return _strings_richcompare(self, other, cmp_op, 0);
}
Expand Down Expand Up @@ -1345,28 +1414,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return Py_NotImplemented;
}

_res = PyArray_CheckCastSafety(
NPY_EQUIV_CASTING,
PyArray_DESCR(self), PyArray_DESCR(array_other), NULL);
if (_res < 0) {
PyErr_Clear();
_res = 0;
}
if (_res == 0) {
/* 2015-05-07, 1.10 */
Py_DECREF(array_other);
if (DEPRECATE_FUTUREWARNING(
"elementwise == comparison failed and returning scalar "
"instead; this will raise an error or perform "
"elementwise comparison in the future.") < 0) {
return NULL;
}
Py_INCREF(Py_False);
return Py_False;
}
else {
result = _void_compare(self, array_other, cmp_op);
}
result = _void_compare(self, array_other, cmp_op);
Py_DECREF(array_other);
return result;
}
Expand Down Expand Up @@ -1400,29 +1448,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return Py_NotImplemented;
}

_res = PyArray_CheckCastSafety(
NPY_EQUIV_CASTING,
PyArray_DESCR(self), PyArray_DESCR(array_other), NULL);
if (_res < 0) {
PyErr_Clear();
_res = 0;
}
if (_res == 0) {
/* 2015-05-07, 1.10 */
Py_DECREF(array_other);
if (DEPRECATE_FUTUREWARNING(
"elementwise != comparison failed and returning scalar "
"instead; this will raise an error or perform "
"elementwise comparison in the future.") < 0) {
return NULL;
}
Py_INCREF(Py_True);
return Py_True;
}
else {
result = _void_compare(self, array_other, cmp_op);
Py_DECREF(array_other);
}
result = _void_compare(self, array_other, cmp_op);
Py_DECREF(array_other);
return result;
}

Expand Down
77 changes: 43 additions & 34 deletions numpy/core/src/multiarray/convert_datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ PyArray_FindConcatenationDescriptor(
npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype)
{
if (requested_dtype == NULL) {
return PyArray_LegacyResultType(n, arrays, 0, NULL);
return PyArray_ResultType(n, arrays, 0, NULL);
}

PyArray_DTypeMeta *common_dtype;
Expand Down Expand Up @@ -3281,8 +3281,7 @@ can_cast_fields_safety(
{
Py_ssize_t field_count = PyTuple_Size(from->names);
if (field_count != PyTuple_Size(to->names)) {
/* TODO: This should be rejected! */
return NPY_UNSAFE_CASTING;
return -1;
}

NPY_CASTING casting = NPY_NO_CASTING;
Expand All @@ -3294,18 +3293,41 @@ can_cast_fields_safety(
if (from_tup == NULL) {
return give_bad_field_error(from_key);
}
PyArray_Descr *from_base = (PyArray_Descr*)PyTuple_GET_ITEM(from_tup, 0);
PyArray_Descr *from_base = (PyArray_Descr *) PyTuple_GET_ITEM(from_tup, 0);

/*
* TODO: This should use to_key (order), compare gh-15509 by
* by Allan Haldane. And raise an error on failure.
* (Fixing that may also requires fixing/changing promotion.)
*/
PyObject *to_tup = PyDict_GetItem(to->fields, from_key);
/* Check whether the field names match */
PyObject *to_key = PyTuple_GET_ITEM(to->names, i);
PyObject *to_tup = PyDict_GetItem(to->fields, to_key);
if (to_tup == NULL) {
return NPY_UNSAFE_CASTING;
return give_bad_field_error(from_key);
}
PyArray_Descr *to_base = (PyArray_Descr *) PyTuple_GET_ITEM(to_tup, 0);

int cmp = PyUnicode_Compare(from_key, to_key);
if (error_converting(cmp)) {
return -1;
}
if (cmp != 0) {
/* Field name mismatch, consider this at most SAFE. */
casting = PyArray_MinCastSafety(casting, NPY_SAFE_CASTING);
}

/* Also check the title (denote mismatch as SAFE only) */
PyObject *from_title = from_key;
PyObject *to_title = to_key;
if (PyTuple_GET_SIZE(from_tup) > 2) {
from_title = PyTuple_GET_ITEM(from_tup, 2);
}
if (PyTuple_GET_SIZE(to_tup) > 2) {
to_title = PyTuple_GET_ITEM(to_tup, 2);
}
cmp = PyObject_RichCompareBool(from_title, to_title, Py_EQ);
if (error_converting(cmp)) {
return -1;
}
if (!cmp) {
casting = PyArray_MinCastSafety(casting, NPY_SAFE_CASTING);
}
PyArray_Descr *to_base = (PyArray_Descr*)PyTuple_GET_ITEM(to_tup, 0);

NPY_CASTING field_casting = PyArray_GetCastInfo(
from_base, to_base, NULL, &field_view_off);
Expand Down Expand Up @@ -3338,39 +3360,26 @@ can_cast_fields_safety(
*view_offset = NPY_MIN_INTP;
}
}
if (*view_offset != 0) {
/* If the calculated `view_offset` is not 0, it can only be "equiv" */
return PyArray_MinCastSafety(casting, NPY_EQUIV_CASTING);
}

/*
* If the itemsize (includes padding at the end), fields, or names
* do not match, this cannot be a view and also not a "no" cast
* (identical dtypes).
* It may be possible that this can be relaxed in some cases.
* If the itemsize (includes padding at the end), does not match,
* this is not a "no" cast (identical dtypes) and may not be viewable.
*/
if (from->elsize != to->elsize) {
/*
* The itemsize may mismatch even if all fields and formats match
* (due to additional padding).
*/
return PyArray_MinCastSafety(casting, NPY_EQUIV_CASTING);
}

int cmp = PyObject_RichCompareBool(from->fields, to->fields, Py_EQ);
if (cmp != 1) {
if (cmp == -1) {
PyErr_Clear();
casting = PyArray_MinCastSafety(casting, NPY_EQUIV_CASTING);
if (from->elsize < to->elsize) {
*view_offset = NPY_MIN_INTP;
}
return PyArray_MinCastSafety(casting, NPY_EQUIV_CASTING);
}
cmp = PyObject_RichCompareBool(from->names, to->names, Py_EQ);
if (cmp != 1) {
if (cmp == -1) {
PyErr_Clear();
}
return PyArray_MinCastSafety(casting, NPY_EQUIV_CASTING);
else if (*view_offset != 0) {
/* If the calculated `view_offset` is not 0, it can only be "equiv" */
casting = PyArray_MinCastSafety(casting, NPY_EQUIV_CASTING);
}

return casting;
}

Expand Down
Loading

0 comments on commit 93e303c

Please sign in to comment.