Skip to content

Commit 04c4dc8

Browse files
hawkinspThe ml_dtypes Authors
authored andcommitted
Fix test cases that test for equality of NaNs of ml_dtypes floats under NumPy 2.4.3.
NumPy 2.4.3 changed its criterion for when numpy.testing.assert_array_equal will use "equal NaN" semantics. Currently ml_dtypes floats don't appear to be numeric types to NumPy, so NaNs no longer compare as equal. This is actually a helpful change in the future since once ml_dtypes migrates to NumPy 2's user dtype APIs we will be able to declare our types as such. For now, just cast types to float32 before testing for equality in tests. PiperOrigin-RevId: 882541951
1 parent c45a3dd commit 04c4dc8

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

ml_dtypes/tests/custom_float_test.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -900,20 +900,28 @@ def testConformNumpyComplex(self, float_type):
900900
numpy_assert_allclose(z_np, z_tf, atol=2e-2, float_type=float_type)
901901

902902
def testArange(self, float_type):
903+
# TODO(phawkins): remove the casts to float32 once NumPy considers our float
904+
# types to be numeric types.
903905
np.testing.assert_equal(
904-
np.arange(1, 100, dtype=np.float32).astype(float_type),
905-
np.arange(1, 100, dtype=float_type),
906+
np.arange(1, 100, dtype=np.float32)
907+
.astype(float_type)
908+
.astype(np.float32),
909+
np.arange(1, 100, dtype=float_type).astype(np.float32),
906910
)
907911
if float_type == float8_e8m0fnu:
908912
raise self.skipTest("Skip negative ranges for E8M0.")
909913

910914
np.testing.assert_equal(
911-
np.arange(-6, 6, 2, dtype=np.float32).astype(float_type),
912-
np.arange(-6, 6, 2, dtype=float_type),
915+
np.arange(-6, 6, 2, dtype=np.float32)
916+
.astype(float_type)
917+
.astype(np.float32),
918+
np.arange(-6, 6, 2, dtype=float_type).astype(np.float32),
913919
)
914920
np.testing.assert_equal(
915-
np.arange(-0.0, -2.0, -0.5, dtype=np.float32).astype(float_type),
916-
np.arange(-0.0, -2.0, -0.5, dtype=float_type),
921+
np.arange(-0.0, -2.0, -0.5, dtype=np.float32)
922+
.astype(float_type)
923+
.astype(np.float32),
924+
np.arange(-0.0, -2.0, -0.5, dtype=float_type).astype(np.float32),
917925
)
918926

919927
@ignore_warning(category=RuntimeWarning, message="invalid value encountered")

0 commit comments

Comments
 (0)