|
6 | 6 |
|
7 | 7 | import numba
|
8 | 8 | import numpy as np
|
| 9 | +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
9 | 10 |
|
10 | 11 | from pytensor import config
|
11 | 12 | from pytensor.graph.basic import Apply
|
@@ -164,18 +165,6 @@ def create_vectorize_func(
|
164 | 165 | return elemwise_fn
|
165 | 166 |
|
166 | 167 |
|
167 |
| -def normalize_axis(axis, ndim): |
168 |
| - if axis is None: |
169 |
| - return axis |
170 |
| - |
171 |
| - if axis < 0: |
172 |
| - axis = ndim + axis |
173 |
| - |
174 |
| - if axis < 0 or axis >= ndim: |
175 |
| - raise np.AxisError(ndim=ndim, axis=axis) |
176 |
| - return axis |
177 |
| - |
178 |
| - |
179 | 168 | def create_axis_reducer(
|
180 | 169 | scalar_op: Op,
|
181 | 170 | identity: Union[np.ndarray, Number],
|
@@ -230,7 +219,7 @@ def careduce_axis(x):
|
230 | 219 |
|
231 | 220 | """
|
232 | 221 |
|
233 |
| - axis = normalize_axis(axis, ndim) |
| 222 | + axis = normalize_axis_index(axis, ndim) |
234 | 223 |
|
235 | 224 | reduce_elemwise_fn_name = "careduce_axis"
|
236 | 225 |
|
@@ -354,7 +343,7 @@ def careduce_maximum(input):
|
354 | 343 | if len(axes) == 1:
|
355 | 344 | return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
|
356 | 345 |
|
357 |
| - axes = [normalize_axis(axis, ndim) for axis in axes] |
| 346 | + axes = normalize_axis_tuple(axes, ndim) |
358 | 347 |
|
359 | 348 | careduce_fn_name = f"careduce_{scalar_op}"
|
360 | 349 | global_env = {}
|
@@ -425,7 +414,7 @@ def jit_compile_reducer(node, fn, **kwds):
|
425 | 414 |
|
426 | 415 |
|
427 | 416 | def create_axis_apply_fn(fn, axis, ndim, dtype):
|
428 |
| - axis = normalize_axis(axis, ndim) |
| 417 | + axis = normalize_axis_index(axis, ndim) |
429 | 418 |
|
430 | 419 | reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
|
431 | 420 |
|
@@ -627,9 +616,8 @@ def numba_funcify_Softmax(op, node, **kwargs):
|
627 | 616 | x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
|
628 | 617 | axis = op.axis
|
629 | 618 |
|
630 |
| - axis = normalize_axis(axis, x_at.ndim) |
631 |
| - |
632 | 619 | if axis is not None:
|
| 620 | + axis = normalize_axis_index(axis, x_at.ndim) |
633 | 621 | reduce_max_py = create_axis_reducer(
|
634 | 622 | scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
|
635 | 623 | )
|
@@ -666,8 +654,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
|
666 | 654 | sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
|
667 | 655 |
|
668 | 656 | axis = op.axis
|
669 |
| - axis = normalize_axis(axis, sm_at.ndim) |
670 | 657 | if axis is not None:
|
| 658 | + axis = normalize_axis_index(axis, sm_at.ndim) |
671 | 659 | reduce_sum_py = create_axis_reducer(
|
672 | 660 | add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
|
673 | 661 | )
|
@@ -697,9 +685,9 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
|
697 | 685 | x_dtype = x_at.type.numpy_dtype
|
698 | 686 | x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
|
699 | 687 | axis = op.axis
|
700 |
| - axis = normalize_axis(axis, x_at.ndim) |
701 | 688 |
|
702 | 689 | if axis is not None:
|
| 690 | + axis = normalize_axis_index(axis, x_at.ndim) |
703 | 691 | reduce_max_py = create_axis_reducer(
|
704 | 692 | scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
|
705 | 693 | )
|
|
0 commit comments