Skip to content

Commit e5a0a89

Browse files
committed
Enable NUMPY_DISLIKE_SCALARS=1
I don't believe in this (I mean as an option, I suppose who cares, but that isn't actually useful if you can never transition, IMO). But, it should be at least tried a bit, so...
1 parent 4f1fcef commit e5a0a89

File tree

10 files changed

+74
-22
lines changed

10 files changed

+74
-22
lines changed

numpy/_core/arrayprint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ def recurser(index, hanging_indent, curr_width):
851851
axes_left = a.ndim - axis
852852

853853
if axes_left == 0:
854-
return format_function(a[index])
854+
return format_function(a[index + (...,)].to_scalar())
855855

856856
# when recursing, add a space to align with the [ added, and reduce the
857857
# length of the line by 1
@@ -1709,7 +1709,7 @@ def _array_str_implementation(
17091709
# obtain a scalar and call str on it, avoiding problems for subclasses
17101710
# for which indexing with () returns a 0d instead of a scalar by using
17111711
# ndarray's getindex. Also guard against recursive 0d object arrays.
1712-
return _guarded_repr_or_str(np.ndarray.__getitem__(a, ()))
1712+
return _guarded_repr_or_str(np.ndarray.to_scalar(a))
17131713

17141714
return array2string(a, max_line_width, precision, suppress_small, ' ', "")
17151715

numpy/_core/src/multiarray/arraywrap.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "arraywrap.h"
1616
#include "npy_static_data.h"
17+
#include "multiarraymodule.h"
1718

1819
/*
1920
* Find the array wrap or array prepare method that applies to the inputs.
@@ -139,6 +140,8 @@ npy_apply_wrap(
139140
PyArrayObject *arr = NULL;
140141
PyObject *err_type, *err_value, *traceback;
141142

143+
return_scalar = (return_scalar && !npy_thread_unsafe_state.dislike_scalars);
144+
142145
/* If provided, we prefer the actual out objects wrap: */
143146
if (original_out != NULL && original_out != Py_None) {
144147
/*

numpy/_core/src/multiarray/mapping.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
/* TODO: Only for `NpyIter_GetTransferFlags` until it is public */
3030
#define NPY_ITERATOR_IMPLEMENTATION_CODE
3131
#include "nditer_impl.h"
32+
#include "multiarraymodule.h"
3233

3334
#include "umathmodule.h"
3435

@@ -1302,7 +1303,7 @@ array_item_asarray(PyArrayObject *self, npy_intp i)
13021303
NPY_NO_EXPORT PyObject *
13031304
array_item(PyArrayObject *self, Py_ssize_t i)
13041305
{
1305-
if (PyArray_NDIM(self) == 1) {
1306+
if (PyArray_NDIM(self) == 1 && !npy_thread_unsafe_state.dislike_scalars) {
13061307
char *item;
13071308
npy_index_info index;
13081309

@@ -1485,7 +1486,7 @@ array_subscript(PyArrayObject *self, PyObject *op)
14851486
}
14861487

14871488
/* Full integer index */
1488-
else if (index_type == HAS_INTEGER) {
1489+
else if (index_type == HAS_INTEGER && !npy_thread_unsafe_state.dislike_scalars) {
14891490
char *item;
14901491
if (get_item_pointer(self, &item, indices, index_num) < 0) {
14911492
goto finish;

numpy/_core/src/multiarray/methods.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2797,6 +2797,27 @@ array_class_getitem(PyObject *cls, PyObject *args)
27972797
return Py_GenericAlias(cls, args);
27982798
}
27992799

2800+
static PyObject *
2801+
array_to_scalar(PyArrayObject *mp, PyObject *NPY_UNUSED(args))
2802+
{
2803+
/* TODO, just a silly copy of PyArray_Result, as I disabled that! */
2804+
Py_INCREF(mp);
2805+
if (!PyArray_Check(mp)) {
2806+
return (PyObject *)mp;
2807+
}
2808+
if (PyArray_NDIM(mp) == 0) {
2809+
PyObject *ret;
2810+
ret = PyArray_ToScalar(PyArray_DATA(mp), mp);
2811+
Py_DECREF(mp);
2812+
return ret;
2813+
}
2814+
else {
2815+
return (PyObject *)mp;
2816+
}
2817+
}
2818+
2819+
2820+
28002821
NPY_NO_EXPORT PyMethodDef array_methods[] = {
28012822

28022823
/* for subtypes */
@@ -3025,6 +3046,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
30253046
{"to_device",
30263047
(PyCFunction)array_to_device,
30273048
METH_VARARGS | METH_KEYWORDS, NULL},
3049+
{"to_scalar",
3050+
(PyCFunction)array_to_scalar,
3051+
METH_NOARGS, NULL},
30283052

30293053
{NULL, NULL, 0, NULL} /* sentinel */
30303054
};

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4770,6 +4770,14 @@ initialize_thread_unsafe_state(void) {
47704770
npy_thread_unsafe_state.warn_if_no_mem_policy = 0;
47714771
}
47724772

4773+
env = getenv("NUMPY_DISLIKE_SCALARS");
4774+
if ((env != NULL) && (strncmp(env, "1", 1) == 0)) {
4775+
npy_thread_unsafe_state.dislike_scalars = 1;
4776+
}
4777+
else {
4778+
npy_thread_unsafe_state.dislike_scalars = 0;
4779+
}
4780+
47734781
return 0;
47744782
}
47754783

numpy/_core/src/multiarray/multiarraymodule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ typedef struct npy_thread_unsafe_state_struct {
7676
* if there is no memory policy set
7777
*/
7878
int warn_if_no_mem_policy;
79+
int dislike_scalars;
7980

8081
} npy_thread_unsafe_state_struct;
8182

numpy/_core/src/multiarray/scalarapi.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "ctors.h"
1919
#include "descriptor.h"
2020
#include "dtypemeta.h"
21+
#include "multiarraymodule.h"
2122
#include "scalartypes.h"
2223

2324
#include "common.h"
@@ -631,7 +632,9 @@ PyArray_Scalar(void *data, PyArray_Descr *descr, PyObject *base)
631632
NPY_NO_EXPORT PyObject *
632633
PyArray_Return(PyArrayObject *mp)
633634
{
634-
635+
if (npy_thread_unsafe_state.dislike_scalars) {
636+
return (PyObject *)mp;
637+
}
635638
if (mp == NULL) {
636639
return NULL;
637640
}

numpy/_core/src/multiarray/scalartypes.c.src

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2591,6 +2591,12 @@ integer_is_integer(PyObject *self, PyObject *NPY_UNUSED(args)) {
25912591
Py_RETURN_TRUE;
25922592
}
25932593

2594+
static PyObject *
2595+
gentype_to_scalar(PyObject *self, PyObject *NPY_UNUSED(args)) {
2596+
Py_INCREF(self);
2597+
return self;
2598+
}
2599+
25942600
/*
25952601
* need to fill in doc-strings for these methods on import -- copy from
25962602
* array docstrings
@@ -2789,6 +2795,9 @@ static PyMethodDef gentype_methods[] = {
27892795
{"to_device",
27902796
(PyCFunction)array_to_device,
27912797
METH_VARARGS | METH_KEYWORDS, NULL},
2798+
{"to_scalar",
2799+
(PyCFunction)gentype_to_scalar,
2800+
METH_NOARGS, NULL},
27922801

27932802
{NULL, NULL, 0, NULL} /* sentinel */
27942803
};

numpy/_core/tests/test_array_coercion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def __init__(self, a):
8484
def scalar_instances(times=True, extended_precision=True, user_dtype=True):
8585
# Hard-coded list of scalar instances.
8686
# Floats:
87+
if type(np.array(1)[()]) is np.ndarray:
88+
return # whooops doesn't work at all
89+
8790
yield param(np.sqrt(np.float16(5)), id="float16")
8891
yield param(np.sqrt(np.float32(5)), id="float32")
8992
yield param(np.sqrt(np.float64(5)), id="float64")

numpy/_core/tests/test_nditer.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,7 @@ def test_iter_object_arrays_basic():
10591059
assert_equal(sys.getrefcount(obj), rc)
10601060

10611061
i = nditer(a, ['refs_ok'], ['readonly'])
1062-
vals = [x_[()] for x_ in i]
1062+
vals = [x_.to_scalar() for x_ in i]
10631063
assert_equal(np.array(vals, dtype='O'), a)
10641064
vals, i, x = [None] * 3
10651065
if HAS_REFCOUNT:
@@ -1068,7 +1068,7 @@ def test_iter_object_arrays_basic():
10681068
i = nditer(a.reshape(2, 2).T, ['refs_ok', 'buffered'],
10691069
['readonly'], order='C')
10701070
assert_(i.iterationneedsapi)
1071-
vals = [x_[()] for x_ in i]
1071+
vals = [x_.to_scalar() for x_ in i]
10721072
assert_equal(np.array(vals, dtype='O'), a.reshape(2, 2).ravel(order='F'))
10731073
vals, i, x = [None] * 3
10741074
if HAS_REFCOUNT:
@@ -1120,7 +1120,7 @@ def test_iter_object_arrays_conversions():
11201120
i = nditer(a, ['refs_ok', 'buffered'], ['readwrite'],
11211121
casting='unsafe', op_dtypes='O')
11221122
with i:
1123-
ob = i[0][()]
1123+
ob = i[0, ...].to_scalar()
11241124
if HAS_REFCOUNT:
11251125
rc = sys.getrefcount(ob)
11261126
for x in i:
@@ -1356,42 +1356,42 @@ def test_iter_copy():
13561356
# Simple iterator
13571357
i = nditer(a)
13581358
j = i.copy()
1359-
assert_equal([x[()] for x in i], [x[()] for x in j])
1359+
assert_equal([x.to_scalar() for x in i], [x.to_scalar() for x in j])
13601360

13611361
i.iterindex = 3
13621362
j = i.copy()
1363-
assert_equal([x[()] for x in i], [x[()] for x in j])
1363+
assert_equal([x.to_scalar() for x in i], [x.to_scalar() for x in j])
13641364

13651365
# Buffered iterator
13661366
i = nditer(a, ['buffered', 'ranged'], order='F', buffersize=3)
13671367
j = i.copy()
1368-
assert_equal([x[()] for x in i], [x[()] for x in j])
1368+
assert_equal([x.to_scalar() for x in i], [x.to_scalar() for x in j])
13691369

13701370
i.iterindex = 3
13711371
j = i.copy()
1372-
assert_equal([x[()] for x in i], [x[()] for x in j])
1372+
assert_equal([x.to_scalar() for x in i], [x.to_scalar() for x in j])
13731373

13741374
i.iterrange = (3, 9)
13751375
j = i.copy()
1376-
assert_equal([x[()] for x in i], [x[()] for x in j])
1376+
assert_equal([x.to_scalar() for x in i], [x.to_scalar() for x in j])
13771377

13781378
i.iterrange = (2, 18)
13791379
next(i)
13801380
next(i)
13811381
j = i.copy()
1382-
assert_equal([x[()] for x in i], [x[()] for x in j])
1382+
assert_equal([x.to_scalar() for x in i], [x.to_scalar() for x in j])
13831383

13841384
# Casting iterator
13851385
with nditer(a, ['buffered'], order='F', casting='unsafe',
13861386
op_dtypes='f8', buffersize=5) as i:
13871387
j = i.copy()
1388-
assert_equal([x[()] for x in j], a.ravel(order='F'))
1388+
assert_equal([x.to_scalar() for x in j], a.ravel(order='F'))
13891389

13901390
a = arange(24, dtype='<i4').reshape(2, 3, 4)
13911391
with nditer(a, ['buffered'], order='F', casting='unsafe',
13921392
op_dtypes='>f8', buffersize=5) as i:
13931393
j = i.copy()
1394-
assert_equal([x[()] for x in j], a.ravel(order='F'))
1394+
assert_equal([x.to_scalar() for x in j], a.ravel(order='F'))
13951395

13961396

13971397
@pytest.mark.parametrize("dtype", np.typecodes["All"])
@@ -1757,20 +1757,20 @@ def test_iter_iterrange():
17571757
i = nditer(a, ['ranged'], ['readonly'], order='F',
17581758
buffersize=buffersize)
17591759
assert_equal(i.iterrange, (0, 24))
1760-
assert_equal([x[()] for x in i], a_fort)
1760+
assert_equal([x.to_scalar() for x in i], a_fort)
17611761
for r in [(0, 24), (1, 2), (3, 24), (5, 5), (0, 20), (23, 24)]:
17621762
i.iterrange = r
17631763
assert_equal(i.iterrange, r)
1764-
assert_equal([x[()] for x in i], a_fort[r[0]:r[1]])
1764+
assert_equal([x.to_scalar() for x in i], a_fort[r[0]:r[1]])
17651765

17661766
i = nditer(a, ['ranged', 'buffered'], ['readonly'], order='F',
17671767
op_dtypes='f8', buffersize=buffersize)
17681768
assert_equal(i.iterrange, (0, 24))
1769-
assert_equal([x[()] for x in i], a_fort)
1769+
assert_equal([x.to_scalar() for x in i], a_fort)
17701770
for r in [(0, 24), (1, 2), (3, 24), (5, 5), (0, 20), (23, 24)]:
17711771
i.iterrange = r
17721772
assert_equal(i.iterrange, r)
1773-
assert_equal([x[()] for x in i], a_fort[r[0]:r[1]])
1773+
assert_equal([x.to_scalar() for x in i], a_fort[r[0]:r[1]])
17741774

17751775
def get_array(i):
17761776
val = np.array([], dtype='f8')
@@ -1862,7 +1862,7 @@ def assign_iter(i):
18621862
assert_equal(i[0], 0)
18631863
i[1] = 1
18641864
assert_equal(i[0:2], [0, 1])
1865-
assert_equal([[x[0][()], x[1][()]] for x in i], list(zip(range(6), [1] * 6)))
1865+
assert_equal([[x[0, ...].to_scalar(), x[1, ...].to_scalar()] for x in i], list(zip(range(6), [1] * 6)))
18661866

18671867
def test_iter_buffered_cast_simple():
18681868
# Test that buffering can handle a simple cast
@@ -2015,7 +2015,7 @@ def test_iter_buffered_cast_structured_type():
20152015
i = nditer(a, ['buffered', 'refs_ok'], ['readonly'],
20162016
casting='unsafe',
20172017
op_dtypes='i4')
2018-
assert_equal([x_[()] for x_ in i], [5, 8])
2018+
assert_equal([x_.to_scalar() for x_ in i], [5, 8])
20192019

20202020
# make sure multi-field struct type -> simple doesn't work
20212021
sdt = [('a', 'f4'), ('b', 'i8'), ('d', 'O')]

0 commit comments

Comments
 (0)