Skip to content

Commit

Permalink
Merge pull request numpy#24681 from seberg/nep50-bad-promotion
Browse files Browse the repository at this point in the history
BUG: Fix weak promotion with some mixed float/int dtypes
  • Loading branch information
mattip authored Sep 28, 2023
2 parents 5d4d05a + 3cd5d35 commit eabefa4
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 59 deletions.
30 changes: 0 additions & 30 deletions numpy/core/src/multiarray/abstractdtypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,6 @@ int_common_dtype(PyArray_DTypeMeta *NPY_UNUSED(cls), PyArray_DTypeMeta *other)
/* Use the default integer for bools: */
return PyArray_DTypeFromTypeNum(NPY_LONG);
}
else if (PyTypeNum_ISNUMBER(other->type_num) ||
other->type_num == NPY_TIMEDELTA) {
/* All other numeric types (ant timedelta) are preserved: */
Py_INCREF(other);
return other;
}
}
else if (NPY_DT_is_legacy(other)) {
/* This is a back-compat fallback to usually do the right thing... */
Expand Down Expand Up @@ -211,11 +205,6 @@ float_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
/* Use the default integer for bools and ints: */
return PyArray_DTypeFromTypeNum(NPY_DOUBLE);
}
else if (PyTypeNum_ISNUMBER(other->type_num)) {
/* All other numeric types (float+complex) are preserved: */
Py_INCREF(other);
return other;
}
}
else if (other == &PyArray_PyIntAbstractDType) {
Py_INCREF(cls);
Expand Down Expand Up @@ -255,25 +244,6 @@ complex_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
/* Use the default integer for bools and ints: */
return PyArray_DTypeFromTypeNum(NPY_CDOUBLE);
}
else if (PyTypeNum_ISFLOAT(other->type_num)) {
/*
* For floats we choose the equivalent precision complex, although
* there is no CHALF, so half also goes to CFLOAT.
*/
if (other->type_num == NPY_HALF || other->type_num == NPY_FLOAT) {
return PyArray_DTypeFromTypeNum(NPY_CFLOAT);
}
if (other->type_num == NPY_DOUBLE) {
return PyArray_DTypeFromTypeNum(NPY_CDOUBLE);
}
assert(other->type_num == NPY_LONGDOUBLE);
return PyArray_DTypeFromTypeNum(NPY_CLONGDOUBLE);
}
else if (PyTypeNum_ISCOMPLEX(other->type_num)) {
/* All other numeric types are preserved: */
Py_INCREF(other);
return other;
}
}
else if (NPY_DT_is_legacy(other)) {
/* This is a back-compat fallback to usually do the right thing... */
Expand Down
42 changes: 13 additions & 29 deletions numpy/core/src/multiarray/common_dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,26 +90,18 @@ PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2)
* NotImplemented (so `c` knows more). You may notice that the result
* `res = a.__common_dtype__(b)` is not important. We could try to use it
* to remove the whole branch if `res is c` or by checking if
* `c.__common_dtype(res) is c`.
* `c.__common_dtype__(res) is c`.
* Right now, we only clear initial elements in the most simple case where
* `a.__common_dtype(b) is a` (and thus `b` cannot alter the end-result).
* `a.__common_dtype__(b) is a` (and thus `b` cannot alter the end-result).
* Clearing means, we do not have to worry about them later.
*
* There is one further subtlety. If we have an abstract DType and a
* non-abstract one, we "prioritize" the non-abstract DType here.
* In this sense "prioritizing" means that we use:
* abstract.__common_dtype__(other)
* If both return NotImplemented (which is acceptable and even expected in
* this case, see later) then `other` will be considered to know more.
*
* The reason why this may be acceptable for abstract DTypes, is that
* the value-dependent abstract DTypes may provide default fall-backs.
* The priority inversion effectively means that abstract DTypes are ordered
* just below their concrete counterparts.
* (This fall-back is convenient but not perfect, it can lead to
* non-minimal promotions: e.g. `np.uint24 + 2**20 -> int32`. And such
* cases may also be possible in some mixed type scenarios; they can be
* avoided by defining the promotion explicitly in the user DType.)
* Abstract dtypes are not handled specially here. In a first
* version they were but this version also tried to be able to do value-based
* behavior.
* There may be some advantage to special casing the abstract ones (e.g.
* so that the concrete ones do not have to deal with it), but this would
* require more complex handling later on. See the logic in
* default_builtin_common_dtype
*
* @param length Number of DTypes
* @param dtypes
Expand All @@ -126,33 +118,25 @@ reduce_dtypes_to_most_knowledgeable(
for (npy_intp low = 0; low < half; low++) {
npy_intp high = length - 1 - low;
if (dtypes[high] == dtypes[low]) {
/* Fast path for identical dtypes: do not call common_dtype */
Py_INCREF(dtypes[low]);
Py_XSETREF(res, dtypes[low]);
}
else {
if (NPY_DT_is_abstract(dtypes[high])) {
/*
* Priority inversion, start with abstract, because if it
* returns `other`, we can let other pass instead.
*/
PyArray_DTypeMeta *tmp = dtypes[low];
dtypes[low] = dtypes[high];
dtypes[high] = tmp;
}

Py_XSETREF(res, NPY_DT_CALL_common_dtype(dtypes[low], dtypes[high]));
if (res == NULL) {
return NULL;
}
}

if (res == (PyArray_DTypeMeta *)Py_NotImplemented) {
/* guess at other being more "knowledgable" */
PyArray_DTypeMeta *tmp = dtypes[low];
dtypes[low] = dtypes[high];
dtypes[high] = tmp;
}
if (res == dtypes[low]) {
/* `dtypes[high]` cannot influence the final result, so clear: */
else if (res == dtypes[low]) {
/* `dtypes[high]` cannot influence result: clear */
dtypes[high] = NULL;
}
}
Expand Down
38 changes: 38 additions & 0 deletions numpy/core/src/multiarray/dtypemeta.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "npy_pycompat.h"
#include "npy_import.h"

#include "abstractdtypes.h"
#include "arraytypes.h"
#include "common.h"
#include "dtypemeta.h"
Expand Down Expand Up @@ -623,6 +624,43 @@ static PyArray_DTypeMeta *
default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
{
assert(cls->type_num < NPY_NTYPES);
if (NPY_UNLIKELY(NPY_DT_is_abstract(other))) {
/*
* The abstract complex has a lower priority than the concrete inexact
* types to ensure the correct promotion with integers.
*/
if (other == &PyArray_PyComplexAbstractDType) {
if (PyTypeNum_ISCOMPLEX(cls->type_num)) {
Py_INCREF(cls);
return cls;
}
else if (cls->type_num == NPY_HALF || cls->type_num == NPY_FLOAT) {
return PyArray_DTypeFromTypeNum(NPY_CFLOAT);
}
else if (cls->type_num == NPY_DOUBLE) {
return PyArray_DTypeFromTypeNum(NPY_CDOUBLE);
}
else if (cls->type_num == NPY_LONGDOUBLE) {
return PyArray_DTypeFromTypeNum(NPY_CLONGDOUBLE);
}
}
else if (other == &PyArray_PyFloatAbstractDType) {
if (PyTypeNum_ISCOMPLEX(cls->type_num)
|| PyTypeNum_ISFLOAT(cls->type_num)) {
Py_INCREF(cls);
return cls;
}
}
else if (other == &PyArray_PyIntAbstractDType) {
if (PyTypeNum_ISCOMPLEX(cls->type_num)
|| PyTypeNum_ISFLOAT(cls->type_num)
|| PyTypeNum_ISINTEGER(cls->type_num)
|| cls->type_num == NPY_TIMEDELTA) {
Py_INCREF(cls);
return cls;
}
}
}
if (!NPY_DT_is_legacy(other) || other->type_num > cls->type_num) {
/*
* Let the more generic (larger type number) DType handle this
Expand Down
27 changes: 27 additions & 0 deletions numpy/core/tests/test_nep50_promotions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import numpy as np

import pytest
import hypothesis
from hypothesis import strategies

from numpy.testing import IS_WASM


Expand Down Expand Up @@ -244,3 +247,27 @@ def test_nep50_in_concat_and_choose():
with pytest.warns(UserWarning, match="result dtype changed"):
res = np.choose(1, [np.float32(1), 1.])
assert res.dtype == "float32"


@pytest.mark.parametrize("expected,dtypes,optional_dtypes", [
(np.float32, [np.float32],
[np.float16, 0.0, np.uint16, np.int16, np.int8, 0]),
(np.complex64, [np.float32, 0j],
[np.float16, 0.0, np.uint16, np.int16, np.int8, 0]),
(np.float32, [np.int16, np.uint16, np.float16],
[np.int8, np.uint8, np.float32, 0., 0]),
(np.int32, [np.int16, np.uint16],
[np.int8, np.uint8, 0, np.bool_]),
])
@hypothesis.given(data=strategies.data())
def test_expected_promotion(expected, dtypes, optional_dtypes, data):
np._set_promotion_state("weak")

# Sample randomly while ensuring "dtypes" is always present:
optional = data.draw(strategies.lists(
strategies.sampled_from(dtypes + optional_dtypes)))
all_dtypes = dtypes + optional
dtypes_sample = data.draw(strategies.permutations(all_dtypes))

res = np.result_type(*dtypes_sample)
assert res == expected

0 comments on commit eabefa4

Please sign in to comment.