Skip to content

Commit 4a47062

Browse files
committed
Fix None in slice for numba boxing
1 parent 004281a commit 4a47062

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
import sys
23
import warnings
34
from contextlib import contextmanager
45
from functools import singledispatch
@@ -10,7 +11,7 @@
1011
import numpy as np
1112
import scipy
1213
import scipy.special
13-
from llvmlite.ir import Type as llvm_Type
14+
from llvmlite import ir
1415
from numba import types
1516
from numba.core.errors import TypingError
1617
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
@@ -131,7 +132,7 @@ def create_numba_signature(
131132

132133

133134
def slice_new(self, start, stop, step):
134-
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
135+
fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
135136
fn = self._get_function(fnty, name="PySlice_New")
136137
return self.builder.call(fn, [start, stop, step])
137138

@@ -150,11 +151,33 @@ def box_slice(typ, val, c):
150151
This makes it possible to return an Numba's internal representation of a
151152
``slice`` object as a proper ``slice`` to Python.
152153
"""
154+
start = c.builder.extract_value(val, 0)
155+
stop = c.builder.extract_value(val, 1)
156+
157+
none_val = ir.Constant(ir.IntType(64), sys.maxsize)
158+
159+
start_is_none = c.builder.icmp_signed("==", start, none_val)
160+
start = c.builder.select(
161+
start_is_none,
162+
c.pyapi.get_null_object(),
163+
c.box(types.int64, start),
164+
)
165+
166+
stop_is_none = c.builder.icmp_signed("==", stop, none_val)
167+
stop = c.builder.select(
168+
stop_is_none,
169+
c.pyapi.get_null_object(),
170+
c.box(types.int64, stop),
171+
)
153172

154-
start = c.box(types.int64, c.builder.extract_value(val, 0))
155-
stop = c.box(types.int64, c.builder.extract_value(val, 1))
156173
if typ.has_step:
157-
step = c.box(types.int64, c.builder.extract_value(val, 2))
174+
step = c.builder.extract_value(val, 2)
175+
step_is_none = c.builder.icmp_signed("==", step, none_val)
176+
step = c.builder.select(
177+
step_is_none,
178+
c.pyapi.get_null_object(),
179+
c.box(types.int64, step),
180+
)
158181
else:
159182
step = c.pyapi.get_null_object()
160183

0 commit comments

Comments
 (0)