diff --git a/numpy/core/src/multiarray/abstractdtypes.c b/numpy/core/src/multiarray/abstractdtypes.c index da44fd01f11f..80099c360127 100644 --- a/numpy/core/src/multiarray/abstractdtypes.c +++ b/numpy/core/src/multiarray/abstractdtypes.c @@ -155,12 +155,6 @@ int_common_dtype(PyArray_DTypeMeta *NPY_UNUSED(cls), PyArray_DTypeMeta *other) /* Use the default integer for bools: */ return PyArray_DTypeFromTypeNum(NPY_LONG); } - else if (PyTypeNum_ISNUMBER(other->type_num) || - other->type_num == NPY_TIMEDELTA) { - /* All other numeric types (ant timedelta) are preserved: */ - Py_INCREF(other); - return other; - } } else if (NPY_DT_is_legacy(other)) { /* This is a back-compat fallback to usually do the right thing... */ @@ -211,11 +205,6 @@ float_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) /* Use the default integer for bools and ints: */ return PyArray_DTypeFromTypeNum(NPY_DOUBLE); } - else if (PyTypeNum_ISNUMBER(other->type_num)) { - /* All other numeric types (float+complex) are preserved: */ - Py_INCREF(other); - return other; - } } else if (other == &PyArray_PyIntAbstractDType) { Py_INCREF(cls); @@ -255,25 +244,6 @@ complex_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) /* Use the default integer for bools and ints: */ return PyArray_DTypeFromTypeNum(NPY_CDOUBLE); } - else if (PyTypeNum_ISFLOAT(other->type_num)) { - /* - * For floats we choose the equivalent precision complex, although - * there is no CHALF, so half also goes to CFLOAT. - */ - if (other->type_num == NPY_HALF || other->type_num == NPY_FLOAT) { - return PyArray_DTypeFromTypeNum(NPY_CFLOAT); - } - if (other->type_num == NPY_DOUBLE) { - return PyArray_DTypeFromTypeNum(NPY_CDOUBLE); - } - assert(other->type_num == NPY_LONGDOUBLE); - return PyArray_DTypeFromTypeNum(NPY_CLONGDOUBLE); - } - else if (PyTypeNum_ISCOMPLEX(other->type_num)) { - /* All other numeric types are preserved: */ - Py_INCREF(other); - return other; - } } else if (NPY_DT_is_legacy(other)) { /* This is a back-compat fallback to usually do the right thing... */ diff --git a/numpy/core/src/multiarray/common_dtype.c b/numpy/core/src/multiarray/common_dtype.c index 38a130221403..c9045c8d96e4 100644 --- a/numpy/core/src/multiarray/common_dtype.c +++ b/numpy/core/src/multiarray/common_dtype.c @@ -90,26 +90,18 @@ PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2) * NotImplemented (so `c` knows more). You may notice that the result * `res = a.__common_dtype__(b)` is not important. We could try to use it * to remove the whole branch if `res is c` or by checking if - * `c.__common_dtype(res) is c`. + * `c.__common_dtype__(res) is c`. * Right now, we only clear initial elements in the most simple case where - * `a.__common_dtype(b) is a` (and thus `b` cannot alter the end-result). + * `a.__common_dtype__(b) is a` (and thus `b` cannot alter the end-result). * Clearing means, we do not have to worry about them later. * - * There is one further subtlety. If we have an abstract DType and a - * non-abstract one, we "prioritize" the non-abstract DType here. - * In this sense "prioritizing" means that we use: - * abstract.__common_dtype__(other) - * If both return NotImplemented (which is acceptable and even expected in - * this case, see later) then `other` will be considered to know more. - * - * The reason why this may be acceptable for abstract DTypes, is that - * the value-dependent abstract DTypes may provide default fall-backs. - * The priority inversion effectively means that abstract DTypes are ordered - * just below their concrete counterparts. - * (This fall-back is convenient but not perfect, it can lead to - * non-minimal promotions: e.g. `np.uint24 + 2**20 -> int32`. And such - * cases may also be possible in some mixed type scenarios; they can be - * avoided by defining the promotion explicitly in the user DType.) + * Abstract dtypes are not handled specially here. In a first + * version they were but this version also tried to be able to do value-based + * behavior. + * There may be some advantage to special casing the abstract ones (e.g. + * so that the concrete ones do not have to deal with it), but this would + * require more complex handling later on. See the logic in + * default_builtin_common_dtype * * @param length Number of DTypes * @param dtypes @@ -126,20 +118,11 @@ reduce_dtypes_to_most_knowledgeable( for (npy_intp low = 0; low < half; low++) { npy_intp high = length - 1 - low; if (dtypes[high] == dtypes[low]) { + /* Fast path for identical dtypes: do not call common_dtype */ Py_INCREF(dtypes[low]); Py_XSETREF(res, dtypes[low]); } else { - if (NPY_DT_is_abstract(dtypes[high])) { - /* - * Priority inversion, start with abstract, because if it - * returns `other`, we can let other pass instead. - */ - PyArray_DTypeMeta *tmp = dtypes[low]; - dtypes[low] = dtypes[high]; - dtypes[high] = tmp; - } - Py_XSETREF(res, NPY_DT_CALL_common_dtype(dtypes[low], dtypes[high])); if (res == NULL) { return NULL; @@ -147,12 +130,13 @@ reduce_dtypes_to_most_knowledgeable( } if (res == (PyArray_DTypeMeta *)Py_NotImplemented) { + /* guess at other being more "knowledgable" */ PyArray_DTypeMeta *tmp = dtypes[low]; dtypes[low] = dtypes[high]; dtypes[high] = tmp; } - if (res == dtypes[low]) { - /* `dtypes[high]` cannot influence the final result, so clear: */ + else if (res == dtypes[low]) { + /* `dtypes[high]` cannot influence result: clear */ dtypes[high] = NULL; } } diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index eef8d2f996e6..26b673aded16 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -12,6 +12,7 @@ #include "npy_pycompat.h" #include "npy_import.h" +#include "abstractdtypes.h" #include "arraytypes.h" #include "common.h" #include "dtypemeta.h" @@ -623,6 +624,43 @@ static PyArray_DTypeMeta * default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) { assert(cls->type_num < NPY_NTYPES); + if (NPY_UNLIKELY(NPY_DT_is_abstract(other))) { + /* + * The abstract complex has a lower priority than the concrete inexact + * types to ensure the correct promotion with integers. + */ + if (other == &PyArray_PyComplexAbstractDType) { + if (PyTypeNum_ISCOMPLEX(cls->type_num)) { + Py_INCREF(cls); + return cls; + } + else if (cls->type_num == NPY_HALF || cls->type_num == NPY_FLOAT) { + return PyArray_DTypeFromTypeNum(NPY_CFLOAT); + } + else if (cls->type_num == NPY_DOUBLE) { + return PyArray_DTypeFromTypeNum(NPY_CDOUBLE); + } + else if (cls->type_num == NPY_LONGDOUBLE) { + return PyArray_DTypeFromTypeNum(NPY_CLONGDOUBLE); + } + } + else if (other == &PyArray_PyFloatAbstractDType) { + if (PyTypeNum_ISCOMPLEX(cls->type_num) + || PyTypeNum_ISFLOAT(cls->type_num)) { + Py_INCREF(cls); + return cls; + } + } + else if (other == &PyArray_PyIntAbstractDType) { + if (PyTypeNum_ISCOMPLEX(cls->type_num) + || PyTypeNum_ISFLOAT(cls->type_num) + || PyTypeNum_ISINTEGER(cls->type_num) + || cls->type_num == NPY_TIMEDELTA) { + Py_INCREF(cls); + return cls; + } + } + } if (!NPY_DT_is_legacy(other) || other->type_num > cls->type_num) { /* * Let the more generic (larger type number) DType handle this diff --git a/numpy/core/tests/test_nep50_promotions.py b/numpy/core/tests/test_nep50_promotions.py index 74a18a8dda48..5e2068762eeb 100644 --- a/numpy/core/tests/test_nep50_promotions.py +++ b/numpy/core/tests/test_nep50_promotions.py @@ -9,6 +9,9 @@ import numpy as np import pytest +import hypothesis +from hypothesis import strategies + from numpy.testing import IS_WASM @@ -244,3 +247,27 @@ def test_nep50_in_concat_and_choose(): with pytest.warns(UserWarning, match="result dtype changed"): res = np.choose(1, [np.float32(1), 1.]) assert res.dtype == "float32" + + +@pytest.mark.parametrize("expected,dtypes,optional_dtypes", [ + (np.float32, [np.float32], + [np.float16, 0.0, np.uint16, np.int16, np.int8, 0]), + (np.complex64, [np.float32, 0j], + [np.float16, 0.0, np.uint16, np.int16, np.int8, 0]), + (np.float32, [np.int16, np.uint16, np.float16], + [np.int8, np.uint8, np.float32, 0., 0]), + (np.int32, [np.int16, np.uint16], + [np.int8, np.uint8, 0, np.bool_]), + ]) +@hypothesis.given(data=strategies.data()) +def test_expected_promotion(expected, dtypes, optional_dtypes, data): + np._set_promotion_state("weak") + + # Sample randomly while ensuring "dtypes" is always present: + optional = data.draw(strategies.lists( + strategies.sampled_from(dtypes + optional_dtypes))) + all_dtypes = dtypes + optional + dtypes_sample = data.draw(strategies.permutations(all_dtypes)) + + res = np.result_type(*dtypes_sample) + assert res == expected