Skip to content

Commit 5961b23

Browse files
committed
Use numpy function to normalize axis argument
1 parent 670b5ca commit 5961b23

File tree

1 file changed

+7
-19
lines changed

1 file changed

+7
-19
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numba
88
import numpy as np
9+
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
910

1011
from pytensor import config
1112
from pytensor.graph.basic import Apply
@@ -164,18 +165,6 @@ def create_vectorize_func(
164165
return elemwise_fn
165166

166167

167-
def normalize_axis(axis, ndim):
168-
if axis is None:
169-
return axis
170-
171-
if axis < 0:
172-
axis = ndim + axis
173-
174-
if axis < 0 or axis >= ndim:
175-
raise np.AxisError(ndim=ndim, axis=axis)
176-
return axis
177-
178-
179168
def create_axis_reducer(
180169
scalar_op: Op,
181170
identity: Union[np.ndarray, Number],
@@ -230,7 +219,7 @@ def careduce_axis(x):
230219
231220
"""
232221

233-
axis = normalize_axis(axis, ndim)
222+
axis = normalize_axis_index(axis, ndim)
234223

235224
reduce_elemwise_fn_name = "careduce_axis"
236225

@@ -354,7 +343,7 @@ def careduce_maximum(input):
354343
if len(axes) == 1:
355344
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
356345

357-
axes = [normalize_axis(axis, ndim) for axis in axes]
346+
axes = normalize_axis_tuple(axes, ndim)
358347

359348
careduce_fn_name = f"careduce_{scalar_op}"
360349
global_env = {}
@@ -425,7 +414,7 @@ def jit_compile_reducer(node, fn, **kwds):
425414

426415

427416
def create_axis_apply_fn(fn, axis, ndim, dtype):
428-
axis = normalize_axis(axis, ndim)
417+
axis = normalize_axis_index(axis, ndim)
429418

430419
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
431420

@@ -627,9 +616,8 @@ def numba_funcify_Softmax(op, node, **kwargs):
627616
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
628617
axis = op.axis
629618

630-
axis = normalize_axis(axis, x_at.ndim)
631-
632619
if axis is not None:
620+
axis = normalize_axis_index(axis, x_at.ndim)
633621
reduce_max_py = create_axis_reducer(
634622
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
635623
)
@@ -666,8 +654,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
666654
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
667655

668656
axis = op.axis
669-
axis = normalize_axis(axis, sm_at.ndim)
670657
if axis is not None:
658+
axis = normalize_axis_index(axis, sm_at.ndim)
671659
reduce_sum_py = create_axis_reducer(
672660
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
673661
)
@@ -697,9 +685,9 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
697685
x_dtype = x_at.type.numpy_dtype
698686
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
699687
axis = op.axis
700-
axis = normalize_axis(axis, x_at.ndim)
701688

702689
if axis is not None:
690+
axis = normalize_axis_index(axis, x_at.ndim)
703691
reduce_max_py = create_axis_reducer(
704692
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
705693
)

0 commit comments

Comments
 (0)