Skip to content

Commit 2736b28

Browse files
Use math instead of numpy in some places (#6765)
* Use math when possible. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * switch to math.prod Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4c8dd10 commit 2736b28

File tree

8 files changed

+19
-15
lines changed

8 files changed

+19
-15
lines changed

xarray/coding/cftimeindex.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4141
from __future__ import annotations
4242

43+
import math
4344
import re
4445
import warnings
4546
from datetime import timedelta
@@ -249,7 +250,7 @@ def format_times(
249250
):
250251
"""Format values of cftimeindex as pd.Index."""
251252
n_per_row = max(max_width // (CFTIME_REPR_LENGTH + len(separator)), 1)
252-
n_rows = int(np.ceil(len(index) / n_per_row))
253+
n_rows = math.ceil(len(index) / n_per_row)
253254

254255
representation = ""
255256
for row in range(n_rows):

xarray/core/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import datetime
55
import inspect
66
import itertools
7+
import math
78
import sys
89
import warnings
910
from collections import defaultdict
@@ -5187,7 +5188,7 @@ def dropna(
51875188
if dim in array.dims:
51885189
dims = [d for d in array.dims if d != dim]
51895190
count += np.asarray(array.count(dims)) # type: ignore[attr-defined]
5190-
size += np.prod([self.dims[d] for d in dims])
5191+
size += math.prod([self.dims[d] for d in dims])
51915192

51925193
if thresh is not None:
51935194
mask = count >= thresh
@@ -5945,7 +5946,7 @@ def _set_numpy_data_from_dataframe(
59455946
# We already verified that the MultiIndex has all unique values, so
59465947
# there are missing values if and only if the size of output arrays is
59475948
# larger that the index.
5948-
missing_values = np.prod(shape) > idx.shape[0]
5949+
missing_values = math.prod(shape) > idx.shape[0]
59495950

59505951
for name, values in arrays:
59515952
# NumPy indexing is much faster than using DataFrame.reindex() to

xarray/core/formatting.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import contextlib
66
import functools
7+
import math
78
from collections import defaultdict
89
from datetime import datetime, timedelta
910
from itertools import chain, zip_longest
@@ -44,10 +45,10 @@ def wrap_indent(text, start="", length=None):
4445

4546

4647
def _get_indexer_at_least_n_items(shape, n_desired, from_end):
47-
assert 0 < n_desired <= np.prod(shape)
48+
assert 0 < n_desired <= math.prod(shape)
4849
cum_items = np.cumprod(shape[::-1])
4950
n_steps = np.argmax(cum_items >= n_desired)
50-
stop = int(np.ceil(float(n_desired) / np.r_[1, cum_items][n_steps]))
51+
stop = math.ceil(float(n_desired) / np.r_[1, cum_items][n_steps])
5152
indexer = (
5253
((-1 if from_end else 0),) * (len(shape) - 1 - n_steps)
5354
+ ((slice(-stop, None) if from_end else slice(stop)),)
@@ -185,9 +186,7 @@ def format_array_flat(array, max_width: int):
185186
"""
186187
# every item will take up at least two characters, but we always want to
187188
# print at least first and last items
188-
max_possibly_relevant = min(
189-
max(array.size, 1), max(int(np.ceil(max_width / 2.0)), 2)
190-
)
189+
max_possibly_relevant = min(max(array.size, 1), max(math.ceil(max_width / 2.0), 2))
191190
relevant_front_items = format_items(
192191
first_n_items(array, (max_possibly_relevant + 1) // 2)
193192
)

xarray/core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import functools
66
import io
77
import itertools
8+
import math
89
import os
910
import re
1011
import sys
@@ -555,8 +556,7 @@ def ndim(self: Any) -> int:
555556

556557
@property
557558
def size(self: Any) -> int:
558-
# cast to int so that shape = () gives size = 1
559-
return int(np.prod(self.shape))
559+
return math.prod(self.shape)
560560

561561
def __len__(self: Any) -> int:
562562
try:

xarray/core/variable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
import itertools
5+
import math
56
import numbers
67
import warnings
78
from datetime import timedelta
@@ -1644,7 +1645,7 @@ def _unstack_once_full(
16441645
"name as an existing dimension"
16451646
)
16461647

1647-
if np.prod(new_dim_sizes) != self.sizes[old_dim]:
1648+
if math.prod(new_dim_sizes) != self.sizes[old_dim]:
16481649
raise ValueError(
16491650
"the product of the new dimension sizes must "
16501651
"equal the size of the old dimension"
@@ -1684,7 +1685,7 @@ def _unstack_once(
16841685
new_dims = reordered.dims[: len(other_dims)] + new_dim_names
16851686

16861687
if fill_value is dtypes.NA:
1687-
is_missing_values = np.prod(new_shape) > np.prod(self.shape)
1688+
is_missing_values = math.prod(new_shape) > math.prod(self.shape)
16881689
if is_missing_values:
16891690
dtype, fill_value = dtypes.maybe_promote(self.dtype)
16901691
else:

xarray/testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _format_message(x, y, err_msg, verbose):
179179
abs_diff = max(abs(diff))
180180
rel_diff = "not implemented"
181181

182-
n_diff = int(np.count_nonzero(diff))
182+
n_diff = np.count_nonzero(diff)
183183
n_total = diff.size
184184

185185
fraction = f"{n_diff} / {n_total}"

xarray/tests/test_plot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import inspect
5+
import math
56
from copy import copy
67
from datetime import datetime
78
from typing import Any
@@ -117,7 +118,7 @@ def easy_array(shape, start=0, stop=1):
117118
118119
shape is a tuple like (2, 3)
119120
"""
120-
a = np.linspace(start, stop, num=np.prod(shape))
121+
a = np.linspace(start, stop, num=math.prod(shape))
121122
return a.reshape(shape)
122123

123124

xarray/tests/test_sparse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
import pickle
45
from textwrap import dedent
56

@@ -28,7 +29,7 @@ def assert_sparse_equal(a, b):
2829

2930

3031
def make_ndarray(shape):
31-
return np.arange(np.prod(shape)).reshape(shape)
32+
return np.arange(math.prod(shape)).reshape(shape)
3233

3334

3435
def make_sparray(shape):

0 commit comments

Comments
 (0)