Skip to content

Commit

Permalink
Experiment with making float/int types inherit from np.floating/np.in…
Browse files Browse the repository at this point in the history
…teger.

PiperOrigin-RevId: 719426592
  • Loading branch information
Jake VanderPlas authored and The ml_dtypes Authors committed Jan 24, 2025
1 parent f1439a9 commit 98c9e29
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ bool RegisterFloatDtype(PyObject* numpy) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
// the base type directly when dropping Python 3.9 support.
Safe_PyObjectPtr bases(
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyGenericArrType_Type)));
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyFloatingArrType_Type)));
PyObject* type =
PyType_FromSpecWithBases(&CustomFloatType<T>::type_spec, bases.get());
if (!type) {
Expand Down
8 changes: 4 additions & 4 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,10 @@ bool Initialize() {
return false;
}

if (!RegisterIntNDtype<int2>(numpy.get()) ||
!RegisterIntNDtype<uint2>(numpy.get()) ||
!RegisterIntNDtype<int4>(numpy.get()) ||
!RegisterIntNDtype<uint4>(numpy.get())) {
if (!RegisterIntNDtype<int2>(numpy.get(), /* is_signed= */ true) ||
!RegisterIntNDtype<uint2>(numpy.get(), /* is_signed= */ false) ||
!RegisterIntNDtype<int4>(numpy.get(), /* is_signed= */ true) ||
!RegisterIntNDtype<uint4>(numpy.get(), /* is_signed= */ false)) {
return false;
}

Expand Down
8 changes: 5 additions & 3 deletions ml_dtypes/_src/intn_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -765,11 +765,13 @@ bool RegisterIntNUFuncs(PyObject* numpy) {
}

template <typename T>
bool RegisterIntNDtype(PyObject* numpy) {
bool RegisterIntNDtype(PyObject* numpy, bool is_signed) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
// the base type directly when dropping Python 3.9 support.
Safe_PyObjectPtr bases(
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyGenericArrType_Type)));
Safe_PyObjectPtr bases(PyTuple_Pack(
1, is_signed
? (reinterpret_cast<PyObject*>(&PySignedIntegerArrType_Type))
: (reinterpret_cast<PyObject*>(&PyUnsignedIntegerArrType_Type))));
PyObject* type =
PyType_FromSpecWithBases(&IntNTypeDescriptor<T>::type_spec, bases.get());
if (!type) {
Expand Down
6 changes: 5 additions & 1 deletion ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,10 @@ def testArgminOnPositiveInfinity(self, float_type):
def testDtypeFromString(self, float_type):
assert np.dtype(float_type.__name__) == np.dtype(float_type)

def testIssubdtype(self, float_type):
self.assertTrue(np.issubdtype(float_type, np.floating))
self.assertTrue(np.issubdtype(np.dtype(float_type), np.floating))


BinaryOp = collections.namedtuple("BinaryOp", ["op"])

Expand Down Expand Up @@ -713,7 +717,7 @@ def testDeepCopyDoesNotAlterHash(self, float_type):
def testArray(self, float_type):
x = np.array([[1, 2, 4]], dtype=float_type)
self.assertEqual(float_type, x.dtype)
self.assertEqual("[[1 2 4]]", str(x))
self.assertEqual("[[1. 2. 4.]]", str(x))
np.testing.assert_equal(x, x)
numpy_assert_allclose(x, x, float_type=float_type)
self.assertTrue((x == x).all())
Expand Down
7 changes: 6 additions & 1 deletion ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def testCanCast(self, a, b):
((a, b) in allowed_casts), np.can_cast(a, b, casting="safe")
)

@parameterized.product(scalar_type=INTN_TYPES)
def testIssubdtype(self, scalar_type):
self.assertTrue(np.issubdtype(scalar_type, np.integer))
self.assertTrue(np.issubdtype(np.dtype(scalar_type), np.integer))


# Tests for the Python scalar type
@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
Expand Down Expand Up @@ -274,7 +279,7 @@ def testDeepCopyDoesNotAlterHash(self, scalar_type):
def testArray(self, scalar_type):
if scalar_type == int2:
x = np.array([[-2, 1, 0, 1]], dtype=scalar_type)
self.assertEqual("[[-2 1 0 1]]", str(x))
self.assertEqual("[[-2 1 0 1]]", str(x))
else:
x = np.array([[1, 2, 3]], dtype=scalar_type)
self.assertEqual("[[1 2 3]]", str(x))
Expand Down

0 comments on commit 98c9e29

Please sign in to comment.