diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 02936ccdb1..fe66207d73 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -204,7 +204,7 @@ def in_seq_empty_tuple(x, y): def to_scalar(x): - raise NotImplementedError() + return np.asarray(x).item() @numba.extending.overload(to_scalar) @@ -543,7 +543,7 @@ def {fn_name}({", ".join(input_names)}): {index_prologue} {indices_creation_src} {index_body} - return z + return np.asarray(z) """ return subtensor_def_src @@ -665,7 +665,7 @@ def numba_funcify_Shape_i(op, **kwargs): @numba_njit def shape_i(x): - return np.shape(x)[i] + return np.asarray(np.shape(x)[i]) return shape_i @@ -698,7 +698,7 @@ def numba_funcify_Reshape(op, **kwargs): @numba_njit def reshape(x, shape): - return x.item() + return np.asarray(x.item()) else: diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 0595191da0..aad741c67d 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,4 +1,5 @@ -import inspect +import base64 +import pickle from functools import singledispatch from numbers import Number from textwrap import indent @@ -6,23 +7,25 @@ import numba import numpy as np +from numba import TypingError, types +from numba.core import cgutils +from numba.core.extending import overload +from numba.np import arrayobj from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor import config from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch import elemwise_codegen from pytensor.link.numba.dispatch.basic import ( create_numba_signature, create_tuple_creator, numba_funcify, + numba_njit, use_optimized_cheap_pass, ) -from pytensor.link.utils import ( - compile_function_src, - get_name_for_object, - unique_name_generator, -) +from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.scalar.basic import ( AND, OR, @@ -40,7 +43,7 @@ from pytensor.scalar.basic import add as add_as from pytensor.scalar.basic import scalar_maximum from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros +from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.type import scalar @@ -172,6 +175,7 @@ def create_axis_reducer( ndim: int, dtype: numba.types.Type, keepdims: bool = False, + return_scalar=False, ) -> numba.core.dispatcher.Dispatcher: r"""Create Python function that performs a NumPy-like reduction on a given axis. @@ -282,6 +286,8 @@ def {reduce_elemwise_fn_name}(x): inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2) return_expr = "res" if keepdims else "res.item()" + if not return_scalar: + return_expr = f"np.asarray({return_expr})" reduce_elemwise_def_src = f""" def {reduce_elemwise_fn_name}(x): @@ -303,7 +309,13 @@ def {reduce_elemwise_fn_name}(x): def create_multiaxis_reducer( - scalar_op, identity, axes, ndim, dtype, input_name="input" + scalar_op, + identity, + axes, + ndim, + dtype, + input_name="input", + return_scalar=False, ): r"""Construct a function that reduces multiple axes. @@ -334,6 +346,8 @@ def careduce_maximum(input): The number of dimensions of the result. dtype: The data type of the result. + return_scalar: + If True, return a scalar, otherwise an array. Returns ======= @@ -368,10 +382,17 @@ def careduce_maximum(input): ) careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) + if not return_scalar: + pre_result = "np.asarray" + post_result = "" + else: + pre_result = "np.asarray" + post_result = ".item()" + careduce_def_src = f""" def {careduce_fn_name}({input_name}): {careduce_assign_lines} - return {var_name} + return {pre_result}({var_name}){post_result} """ careduce_fn = compile_function_src( @@ -381,7 +402,7 @@ def {careduce_fn_name}({input_name}): return careduce_fn -def jit_compile_reducer(node, fn, **kwds): +def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds): """Compile Python source for reduction loops using additional optimizations. Parameters @@ -398,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds): A :func:`numba.njit`-compiled function. """ - signature = create_numba_signature(node, reduce_to_scalar=True) + signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar) # Eagerly compile the function using increased optimizations. This should # help improve nested loop reductions. @@ -431,6 +452,162 @@ def axis_apply_fn(x): return axis_apply_fn +_jit_options = { + "fastmath": { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # TODO Do we want this one? + } +} + + +@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) +def _vectorized( + typingctx, + scalar_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + inputs, +): + arg_types = [ + scalar_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + inputs, + ] + + if not isinstance(input_bc_patterns, types.Literal): + raise TypingError("input_bc_patterns must be literal.") + input_bc_patterns = input_bc_patterns.literal_value + input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode())) + + if not isinstance(output_bc_patterns, types.Literal): + raise TypeError("output_bc_patterns must be literal.") + output_bc_patterns = output_bc_patterns.literal_value + output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode())) + + if not isinstance(output_dtypes, types.Literal): + raise TypeError("output_dtypes must be literal.") + output_dtypes = output_dtypes.literal_value + output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode())) + + if not isinstance(inplace_pattern, types.Literal): + raise TypeError("inplace_pattern must be literal.") + inplace_pattern = inplace_pattern.literal_value + inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) + + n_outputs = len(output_bc_patterns) + + if not len(inputs) > 0: + raise TypingError("Empty argument list to elemwise op.") + + if not n_outputs > 0: + raise TypingError("Empty list of outputs for elemwise op.") + + if not all(isinstance(input, types.Array) for input in inputs): + raise TypingError("Inputs to elemwise must be arrays.") + ndim = inputs[0].ndim + + if not all(input.ndim == ndim for input in inputs): + raise TypingError("Inputs to elemwise must have the same rank.") + + if not all(len(pattern) == ndim for pattern in output_bc_patterns): + raise TypingError("Invalid output broadcasting pattern.") + + scalar_signature = typingctx.resolve_function_type( + scalar_func, [in_type.dtype for in_type in inputs], {} + ) + + # So we can access the constant values in codegen... + input_bc_patterns_val = input_bc_patterns + output_bc_patterns_val = output_bc_patterns + output_dtypes_val = output_dtypes + inplace_pattern_val = inplace_pattern + input_types = inputs + + def codegen( + ctx, + builder, + sig, + args, + ): + + [_, _, _, _, _, inputs] = args + inputs = cgutils.unpack_tuple(builder, inputs) + inputs = [ + arrayobj.make_array(ty)(ctx, builder, val) + for ty, val in zip(input_types, inputs) + ] + in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] + + iter_shape = elemwise_codegen.compute_itershape( + ctx, + builder, + in_shapes, + input_bc_patterns_val, + ) + + outputs, output_types = elemwise_codegen.make_outputs( + ctx, + builder, + iter_shape, + output_bc_patterns_val, + output_dtypes_val, + inplace_pattern_val, + inputs, + input_types, + ) + + elemwise_codegen.make_loop_call( + typingctx, + ctx, + builder, + scalar_func, + scalar_signature, + iter_shape, + inputs, + outputs, + input_bc_patterns_val, + output_bc_patterns_val, + input_types, + output_types, + ) + + if len(outputs) == 1: + if inplace_pattern: + assert inplace_pattern[0][0] == 0 + ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) + return outputs[0]._getvalue() + + for inplace_idx in dict(inplace_pattern): + ctx.nrt.incref( + builder, + sig.return_type.types[inplace_idx], + outputs[inplace_idx]._get_value(), + ) + return ctx.make_tuple( + builder, sig.return_type, [out._getvalue() for out in outputs] + ) + + ret_type = types.Tuple( + [ + types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") + for dtype in output_dtypes + ] + ) + if len(output_dtypes) == 1: + ret_type = ret_type.types[0] + sig = ret_type(*arg_types) + + return sig, codegen + + @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): # Creating a new scalar node is more involved and unnecessary @@ -441,55 +618,114 @@ def numba_funcify_Elemwise(op, node, **kwargs): scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs] scalar_node = op.scalar_op.make_node(*scalar_inputs) + flags = { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # TODO Do we want this one? + } + scalar_op_fn = numba_funcify( - op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs + op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs ) - elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) - elemwise_fn_name = elemwise_fn.__name__ - if op.inplace_pattern: - input_idx = op.inplace_pattern[0] - sign_obj = inspect.signature(elemwise_fn.py_scalar_func) - input_names = list(sign_obj.parameters.keys()) + ndim = node.outputs[0].ndim + output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs]) + input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs]) + output_dtypes = tuple(variable.dtype for variable in node.outputs) + inplace_pattern = tuple(op.inplace_pattern.items()) + + # numba doesn't support nested literals right now... + input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode() + output_bc_patterns_enc = base64.encodebytes( + pickle.dumps(output_bc_patterns) + ).decode() + output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode() + inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode() + + def elemwise_wrapper(*inputs): + return _vectorized( + scalar_op_fn, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, + inputs, + ) - unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_") - input_names = [unique_names(i, force_unique=True) for i in input_names] + # Pure python implementation, that will be used in tests + def elemwise(*inputs): + inputs = [np.asarray(input) for input in inputs] + inputs_bc = np.broadcast_arrays(*inputs) + shape = inputs[0].shape + for input, bc in zip(inputs, input_bc_patterns): + for length, allow_bc, iter_length in zip(input.shape, bc, shape): + if length == 1 and shape and iter_length != 1 and not allow_bc: + raise ValueError("Broadcast not allowed.") + + outputs = [] + for dtype in output_dtypes: + outputs.append(np.empty(shape, dtype=dtype)) + + for idx in np.ndindex(shape): + vals = [input[idx] for input in inputs_bc] + outs = scalar_op_fn(*vals) + if not isinstance(outs, tuple): + outs = (outs,) + for out, out_val in zip(outputs, outs): + out[idx] = out_val + + outputs_summed = [] + for output, bc in zip(outputs, output_bc_patterns): + axes = tuple(np.nonzero(bc)[0]) + outputs_summed.append(output.sum(axes, keepdims=True)) + if len(outputs_summed) != 1: + return tuple(outputs_summed) + return outputs_summed[0] + + @overload(elemwise) + def ov_elemwise(*inputs): + return elemwise_wrapper + + return elemwise + + +@numba_funcify.register(Sum) +def numba_funcify_Sum(op, node, **kwargs): + axes = op.axis + if axes is None: + axes = list(range(node.inputs[0].ndim)) - updated_input_name = input_names[input_idx] + axes = tuple(axes) - inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np} + ndim_input = node.inputs[0].ndim - inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace" + if hasattr(op, "acc_dtype") and op.acc_dtype is not None: + acc_dtype = op.acc_dtype + else: + acc_dtype = node.outputs[0].type.dtype - input_signature_str = ", ".join(input_names) + np_acc_dtype = np.dtype(acc_dtype) - if node.inputs[input_idx].ndim > 0: - inplace_elemwise_src = f""" -def {inplace_elemwise_fn_name}({input_signature_str}): - return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}) - """ - else: - # We can't perform in-place updates on Numba scalars, so we need to - # convert them to NumPy scalars. - # TODO: We should really prevent the rewrites from creating - # in-place updates on scalars when the Numba mode is selected (or - # in general?). - inplace_elemwise_src = f""" -def {inplace_elemwise_fn_name}({input_signature_str}): - {updated_input_name}_scalar = np.asarray({updated_input_name}) - return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item() - """ - - inplace_elemwise_fn = compile_function_src( - inplace_elemwise_src, - inplace_elemwise_fn_name, - {**globals(), **inplace_global_env}, - ) - return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)( - inplace_elemwise_fn - ) + out_dtype = np.dtype(node.outputs[0].dtype) - return elemwise_fn + if ndim_input == len(axes): + + @numba_njit(fastmath=True) + def impl_sum(array): + return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) + + elif len(axes) == 0: + + @numba_njit(fastmath=True) + def impl_sum(array): + return np.asarray(array, dtype=out_dtype) + + else: + impl_sum = numba_funcify_CAReduce(op, node, **kwargs) + + return impl_sum @numba_funcify.register(CAReduce) @@ -526,7 +762,7 @@ def numba_funcify_CAReduce(op, node, **kwargs): input_name=input_name, ) - careduce_fn = jit_compile_reducer(node, careduce_py_fn) + careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) return careduce_fn @@ -709,7 +945,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) reduce_max_py = create_axis_reducer( - scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True + scalar_maximum, + -np.inf, + axis, + x_at.ndim, + x_dtype, + keepdims=True, ) reduce_sum_py = create_axis_reducer( add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True @@ -756,10 +997,17 @@ def maxandargmax(x): keep_axes = tuple(i for i in range(x_ndim) if i not in axes) reduce_max_py_fn = create_multiaxis_reducer( - scalar_maximum, -np.inf, axes, x_ndim, x_dtype + scalar_maximum, + -np.inf, + axes, + x_ndim, + x_dtype, + return_scalar=False, ) reduce_max = jit_compile_reducer( - Apply(node.op, node.inputs, [node.outputs[0].clone()]), reduce_max_py_fn + Apply(node.op, node.inputs, [node.outputs[0].clone()]), + reduce_max_py_fn, + reduce_to_scalar=False, ) reduced_x_ndim = x_ndim - len(axes) + 1 diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py new file mode 100644 index 0000000000..d3d1ff1df1 --- /dev/null +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -0,0 +1,240 @@ +from typing import Any, List, Optional, Tuple + +import numba +import numpy as np +from llvmlite import ir +from numba import types +from numba.core import cgutils +from numba.core.base import BaseContext +from numba.np import arrayobj + + +def compute_itershape( + ctx: BaseContext, + builder: ir.IRBuilder, + in_shapes: Tuple[ir.Instruction, ...], + broadcast_pattern: Tuple[Tuple[bool, ...], ...], +): + one = ir.IntType(64)(1) + ndim = len(in_shapes[0]) + shape = [None] * ndim + for i in range(ndim): + for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + length = in_shape[i] + if bc[i]: + with builder.if_then( + builder.icmp_unsigned("!=", length, one), likely=False + ): + msg = ( + f"Input {j} to elemwise is expected to have shape 1 in axis {i}" + ) + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + elif shape[i] is not None: + with builder.if_then( + builder.icmp_unsigned("!=", length, shape[i]), likely=False + ): + with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( + then, + otherwise, + ): + with then: + msg = ( + f"Incompatible shapes for input {j} and axis {i} of " + f"elemwise. Input {j} has shape 1, but is not statically " + "known to have shape 1, and thus not broadcastable." + ) + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + with otherwise: + msg = ( + f"Input {j} to elemwise has an incompatible " + f"shape in axis {i}." + ) + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + else: + shape[i] = length + for i in range(ndim): + if shape[i] is None: + shape[i] = one + return shape + + +def make_outputs( + ctx: numba.core.base.BaseContext, + builder: ir.IRBuilder, + iter_shape: Tuple[ir.Instruction, ...], + out_bc: Tuple[Tuple[bool, ...], ...], + dtypes: Tuple[Any, ...], + inplace: Tuple[Tuple[int, int], ...], + inputs: Tuple[Any, ...], + input_types: Tuple[Any, ...], +): + arrays = [] + ar_types: list[types.Array] = [] + one = ir.IntType(64)(1) + inplace_dict = dict(inplace) + for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)): + if i in inplace_dict: + arrays.append(inputs[inplace_dict[i]]) + ar_types.append(input_types[inplace_dict[i]]) + # We need to incref once we return the inplace objects + continue + dtype = numba.from_dtype(np.dtype(dtype)) + arrtype = types.Array(dtype, len(iter_shape), "C") + ar_types.append(arrtype) + # This is actually an interal numba function, I guess we could + # call `numba.nd.unsafe.ndarray` instead? + shape = [ + length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc) + ] + array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) + arrays.append(array) + + # If there is no inplace operation, we know that all output arrays + # don't alias. Informing llvm can make it easier to vectorize. + if not inplace: + # The first argument is the output pointer + arg = builder.function.args[0] + arg.add_attribute("noalias") + return arrays, ar_types + + +def make_loop_call( + typingctx, + context: numba.core.base.BaseContext, + builder: ir.IRBuilder, + scalar_func: Any, + scalar_signature: types.FunctionType, + iter_shape: Tuple[ir.Instruction, ...], + inputs: Tuple[ir.Instruction, ...], + outputs: Tuple[ir.Instruction, ...], + input_bc: Tuple[Tuple[bool, ...], ...], + output_bc: Tuple[Tuple[bool, ...], ...], + input_types: Tuple[Any, ...], + output_types: Tuple[Any, ...], +): + safe = (False, False) + n_outputs = len(outputs) + + # context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) + + # Extract shape and stride information from the array. + # For later use in the loop body to do the indexing + def extract_array(aryty, obj): + shape = cgutils.unpack_tuple(builder, obj.shape) + strides = cgutils.unpack_tuple(builder, obj.strides) + data = obj.data + layout = aryty.layout + return (data, shape, strides, layout) + + # TODO I think this is better than the noalias attribute + # for the input, but self_ref isn't supported in a released + # llvmlite version yet + # mod = builder.module + # domain = mod.add_metadata([], self_ref=True) + # input_scope = mod.add_metadata([domain], self_ref=True) + # output_scope = mod.add_metadata([domain], self_ref=True) + # input_scope_set = mod.add_metadata([input_scope, output_scope]) + # output_scope_set = mod.add_metadata([input_scope, output_scope]) + + inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs)) + + outputs = tuple( + extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs) + ) + + zero = ir.Constant(ir.IntType(64), 0) + + # Setup loops and initialize accumulators for outputs + # This part corresponds to opening the loops + loop_stack = [] + loops = [] + output_accumulator: List[Tuple[Optional[Any], Optional[int]]] = [ + (None, None) + ] * n_outputs + for dim, length in enumerate(iter_shape): + # Find outputs that only have accumulations left + for output in range(n_outputs): + if output_accumulator[output][0] is not None: + continue + if all(output_bc[output][dim:]): + value = outputs[output][0].type.pointee(0) + accu = cgutils.alloca_once_value(builder, value) + output_accumulator[output] = (accu, dim) + + loop = cgutils.for_range(builder, length) + loop_stack.append(loop) + loops.append(loop.__enter__()) + + # Code in the inner most loop... + idxs = [loopval.index for loopval in loops] + + # Load values from input arrays + input_vals = [] + for array_info, bc in zip(inputs, input_bc): + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe) + val = builder.load(ptr) + # val.set_metadata("alias.scope", input_scope_set) + # val.set_metadata("noalias", output_scope_set) + input_vals.append(val) + + inner_codegen = context.get_function(scalar_func, scalar_signature) + + if isinstance( + scalar_signature.args[0], (types.StarArgTuple, types.StarArgUniTuple) + ): + input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)] + output_values = inner_codegen(builder, input_vals) + + if isinstance(scalar_signature.return_type, (types.Tuple, types.UniTuple)): + output_values = cgutils.unpack_tuple(builder, output_values) + func_output_types = scalar_signature.return_type.types + else: + output_values = [output_values] + func_output_types = [scalar_signature.return_type] + + # Update output value or accumulators respectively + for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)): + if accu is not None: + load = builder.load(accu) + # load.set_metadata("alias.scope", output_scope_set) + # load.set_metadata("noalias", input_scope_set) + new_value = builder.fadd(load, value) + builder.store(new_value, accu) + # TODO belongs to noalias scope + # store.set_metadata("alias.scope", output_scope_set) + # store.set_metadata("noalias", input_scope_set) + else: + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])] + ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc) + # store = builder.store(value, ptr) + value = context.cast( + builder, value, func_output_types[i], output_types[i].dtype + ) + arrayobj.store_item(context, builder, output_types[i], value, ptr) + # store.set_metadata("alias.scope", output_scope_set) + # store.set_metadata("noalias", input_scope_set) + + # Close the loops and write accumulator values to the output arrays + for depth, loop in enumerate(loop_stack[::-1]): + for output, (accu, accu_depth) in enumerate(output_accumulator): + if accu_depth == depth: + idxs_bc = [ + zero if bc else idx for idx, bc in zip(idxs, output_bc[output]) + ] + ptr = cgutils.get_item_pointer2( + context, builder, *outputs[output], idxs_bc + ) + load = builder.load(accu) + # load.set_metadata("alias.scope", output_scope_set) + # load.set_metadata("noalias", input_scope_set) + # store = builder.store(load, ptr) + load = context.cast( + builder, load, func_output_types[output], output_types[output].dtype + ) + arrayobj.store_item(context, builder, output_types[output], load, ptr) + # store.set_metadata("alias.scope", output_scope_set) + # store.set_metadata("noalias", input_scope_set) + loop.__exit__(None, None, None) + + return diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 9871584454..33fac601a5 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -364,6 +364,7 @@ def numba_funcify_BroadcastTo(op, node, **kwargs): lambda _: 0, len(node.inputs) - 1 ) + # TODO broadcastable checks @numba_basic.numba_njit def broadcast_to(x, *shape): scalars_shape = create_zeros_tuple() diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index d72277b2f5..8cd57c6765 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -38,6 +38,9 @@ def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? + if not hasattr(op, "nfunc_spec"): + return generate_fallback_impl(op, node, **kwargs) + scalar_func_path = op.nfunc_spec[0] scalar_func_numba = None diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c26cd9aa6c..a307d29c5e 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -17,7 +17,11 @@ def idx_to_str( - array_name: str, offset: int, size: Optional[str] = None, idx_symbol: str = "i" + array_name: str, + offset: int, + size: Optional[str] = None, + idx_symbol: str = "i", + allow_scalar=False, ) -> str: if offset < 0: indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}" @@ -32,7 +36,10 @@ def idx_to_str( # compensate for this poor `Op`/rewrite design and implementation. indices = f"({indices}) % {size}" - return f"{array_name}[{indices}]" + if allow_scalar: + return f"{array_name}[{indices}]" + else: + return f"np.asarray({array_name}[{indices}])" @overload(range) @@ -115,7 +122,9 @@ def add_inner_in_expr( indexed_inner_in_str = ( storage_name if tap_offset is None - else idx_to_str(storage_name, tap_offset, size=storage_size_var) + else idx_to_str( + storage_name, tap_offset, size=storage_size_var, allow_scalar=False + ) ) inner_in_exprs.append(indexed_inner_in_str) @@ -232,7 +241,12 @@ def add_output_storage_post_proc_stmt( ) for out_tap in output_taps: inner_out_to_outer_in_stmts.append( - idx_to_str(storage_name, out_tap, size=storage_size_name) + idx_to_str( + storage_name, + out_tap, + size=storage_size_name, + allow_scalar=True, + ) ) add_output_storage_post_proc_stmt( @@ -269,7 +283,7 @@ def add_output_storage_post_proc_stmt( storage_size_name = f"{outer_in_name}_len" inner_out_to_outer_in_stmts.append( - idx_to_str(storage_name, 0, size=storage_size_name) + idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True) ) add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name) @@ -337,8 +351,8 @@ def scan({", ".join(outer_in_names)}): {indent(input_storage_block, " " * 4)} i = 0 - cond = False - while i < n_steps and not cond: + cond = np.array(False) + while i < n_steps and not cond.item(): {inner_outputs} = scan_inner_func({inner_in_args}) {indent(inner_out_post_processing_block, " " * 8)} {indent(inner_out_to_outer_out_stmts, " " * 8)} diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 7cddedbc58..3f0e35543f 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -27,9 +27,9 @@ def fgraph_convert(self, fgraph, **kwargs): return numba_funcify(fgraph, **kwargs) def jit_compile(self, fn): - import numba + from pytensor.link.numba.dispatch.basic import numba_njit - jitted_fn = numba.njit(fn) + jitted_fn = numba_njit(fn) return jitted_fn def create_thunk_inputs(self, storage_map): diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 1dbf416f24..5686951c1b 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -27,6 +27,7 @@ from pytensor.link.numba.dispatch import numba_typify from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op +from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.tensor import blas from pytensor.tensor import subtensor as at_subtensor from pytensor.tensor.elemwise import Elemwise @@ -63,6 +64,33 @@ def perform(self, node, inputs, outputs): outputs[0][0] = res +class ScalarMyMultiOut(ScalarOp): + nin = 2 + nout = 2 + + @staticmethod + def impl(a, b): + res1 = 2 * a + res2 = 2 * b + return [res1, res2] + + def make_node(self, a, b): + a = as_scalar(a) + b = as_scalar(b) + return Apply(self, [a, b], [a.type(), b.type()]) + + def perform(self, node, inputs, outputs): + res1, res2 = self.impl(inputs[0], inputs[1]) + outputs[0][0] = res1 + outputs[1][0] = res2 + + +scalar_my_multi_out = Elemwise(ScalarMyMultiOut()) +scalar_my_multi_out.ufunc = ScalarMyMultiOut.impl +scalar_my_multi_out.ufunc.nin = 2 +scalar_my_multi_out.ufunc.nout = 2 + + class MyMultiOut(Op): nin = 2 nout = 2 @@ -86,7 +114,6 @@ def perform(self, node, inputs, outputs): my_multi_out.ufunc = MyMultiOut.impl my_multi_out.ufunc.nin = 2 my_multi_out.ufunc.nout = 2 - opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) numba_mode = Mode(NumbaLinker(), opts) py_mode = Mode("py", opts) @@ -988,8 +1015,8 @@ def test_config_options_parallel(): x = at.dvector() with config.change_flags(numba__vectorize_target="parallel"): - pytensor_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode) + numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] assert numba_mul_fn.targetoptions["parallel"] is True @@ -997,8 +1024,9 @@ def test_config_options_fastmath(): x = at.dvector() with config.change_flags(numba__fastmath=True): - pytensor_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode) + print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__.keys())) + numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] assert numba_mul_fn.targetoptions["fastmath"] is True @@ -1006,16 +1034,14 @@ def test_config_options_cached(): x = at.dvector() with config.change_flags(numba__cache=True): - pytensor_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert not isinstance( - numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache - ) + pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode) + numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] + assert not isinstance(numba_mul_fn._cache, numba.core.caching.NullCache) with config.change_flags(numba__cache=False): - pytensor_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache) + pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode) + numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] + assert isinstance(numba_mul_fn._cache, numba.core.caching.NullCache) def test_scalar_return_value_conversion(): diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 4d4186e4b6..0958e90034 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -6,7 +6,7 @@ import pytensor.tensor as at import pytensor.tensor.inplace as ati import pytensor.tensor.math as aem -from pytensor import config +from pytensor import config, function from pytensor.compile.ops import deep_copy_op from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant @@ -16,7 +16,7 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from tests.link.numba.test_basic import ( compare_numba_and_py, - my_multi_out, + scalar_my_multi_out, set_test_value, ) @@ -99,8 +99,8 @@ rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX), ], - lambda x, y: my_multi_out(x, y), - NotImplementedError, + lambda x, y: scalar_my_multi_out(x, y), + None, ), ], ) @@ -117,6 +117,25 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): compare_numba_and_py(out_fg, input_vals) +def test_elemwise_speed(benchmark): + x = at.dmatrix("y") + y = at.dvector("z") + + out = np.exp(2 * x * y + y) + + rng = np.random.default_rng(42) + + x_val = rng.normal(size=(200, 500)) + y_val = rng.normal(size=500) + + func = function([x, y], out, mode="NUMBA") + func = func.vm.jit_fn + (out,) = func(x_val, y_val) + np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out) + + benchmark(func, x_val, y_val) + + @pytest.mark.parametrize( "v, new_order", [ diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index 8cf1fdc6bd..0570ef2996 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -32,6 +32,7 @@ def test_Bartlett(val): for i in g_fg.inputs if not isinstance(i, (SharedVariable, Constant)) ], + assert_fn=lambda x, y: np.testing.assert_allclose(x, y, atol=1e-15), ) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index dde20f5f19..7676b1bf40 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -97,7 +97,7 @@ def test_Clip(v, min, max): ], ) def test_Composite(inputs, input_values, scalar_fn): - composite_inputs = [aes.float64(i.name) for i in inputs] + composite_inputs = [aes.ScalarType(config.floatX)(name=i.name) for i in inputs] comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)])) out_fg = FunctionGraph(inputs, [comp_op(*inputs)]) compare_numba_and_py(out_fg, input_values)