diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index b0123a4d3c..37e0e6723e 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -3,7 +3,6 @@ """ -import logging from copy import copy from typing import Optional @@ -16,11 +15,6 @@ from pytensor.graph.fg import FunctionGraph -_logger = logging.getLogger("pytensor.compile.function.pfunc") - -__docformat__ = "restructuredtext en" - - def rebuild_collect_shared( outputs, inputs=None, diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 273aeb3dc5..3b14d5841c 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: + from pytensor.compile.mode import Mode from pytensor.link.vm import VM @@ -1391,9 +1392,16 @@ def check_unused_inputs(inputs, outputs, on_unused_input): @staticmethod def prepare_fgraph( - inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile + inputs, + outputs, + additional_outputs, + fgraph: FunctionGraph, + mode: "Mode", + profile, ): + rewriter = mode.optimizer + try: start_rewriter = time.perf_counter() @@ -1401,6 +1409,7 @@ def prepare_fgraph( rewrite_time = None with config.change_flags( + mode=mode, compute_test_value=config.compute_test_value_opt, traceback__limit=config.traceback__compile_limit, ): @@ -1440,7 +1449,7 @@ def prepare_fgraph( stacklevel=3, ) - if not hasattr(linker, "accept"): + if not hasattr(mode.linker, "accept"): raise ValueError( "'linker' parameter of FunctionMaker should be " f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}" @@ -1511,12 +1520,8 @@ def __init__( self.fgraph = fgraph - rewriter, linker = mode.optimizer, copy.copy(mode.linker) - if not no_fgraph_prep: - self.prepare_fgraph( - inputs, outputs, found_updates, fgraph, rewriter, linker, profile - ) + self.prepare_fgraph(inputs, outputs, found_updates, fgraph, mode, profile) assert len(fgraph.outputs) == len(outputs + found_updates) @@ -1528,6 +1533,8 @@ def __init__( if not spec.borrow ] + linker = copy.copy(mode.linker) + if no_borrow: self.linker = linker.accept( fgraph, diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index d1d9f6a0b5..b740a2679b 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -7,6 +7,8 @@ import warnings from typing import Optional, Tuple, Union +from typing_extensions import Literal + from pytensor.compile.function.types import Supervisor from pytensor.configdefaults import config from pytensor.graph.destroyhandler import DestroyHandler @@ -530,3 +532,26 @@ def register_mode(name, mode): if name in predefined_modes: raise ValueError(f"Mode name already taken: {name}") predefined_modes[name] = mode + + +def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax"], ...]: + """Get the compilation target language.""" + + if mode is None: + mode = get_default_mode() + + linker = mode.linker + + if isinstance(linker, NumbaLinker): + return ("numba",) + if isinstance(linker, JAXLinker): + return ("jax",) + if isinstance(linker, PerformLinker): + return ("py",) + if isinstance(linker, CLinker): + return ("c",) + + if isinstance(linker, (VMLinker, OpWiseCLinker)): + return ("c", "py") if config.cxx else ("py",) + + raise Exception(f"Unsupported Linker: {linker}") diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index dde4b4c4b1..d7362ed372 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -27,6 +27,7 @@ from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import HasInnerGraph from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.graph.type import HasDataType, HasShape from pytensor.graph.utils import MetaObject, MethodNotDefined @@ -3987,7 +3988,7 @@ def c_code(self, *args, **kwargs): complex_from_polar = ComplexFromPolar(name="complex_from_polar") -class Composite(ScalarOp): +class Composite(ScalarOp, HasInnerGraph): """ Composite is an Op that takes a graph of scalar operations and produces c code for the whole graph. Its purpose is to implement loop @@ -3999,174 +4000,6 @@ class Composite(ScalarOp): init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs") - def __str__(self): - if self.name is None: - self.init_name() - return self.name - - def make_new_inplace(self, output_types_preference=None, name=None): - """ - This op.__init__ fct don't have the same parameter as other scalar op. - This break the insert_inplace_optimizer optimization. - This fct allow fix patch this. - - """ - d = {k: getattr(self, k) for k in self.init_param} - out = self.__class__(**d) - if name: - out.name = name - else: - name = out.name - super(Composite, out).__init__(output_types_preference, name) - return out - - def init_c_code(self): - """ - Assemble the C code for this Composite Op. - - The result is assigned to `self._c_code`. - """ - from pytensor.link.c.interface import CLinkerType - - # It was already called - if hasattr(self, "_c_code"): - return - subd = dict( - chain( - ((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)), - ((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)), - ) - ) - - for var in self.fgraph.variables: - if var.owner is None: - if var not in self.fgraph.inputs: - # This is an orphan - if isinstance(var, Constant) and isinstance(var.type, CLinkerType): - subd[var] = var.type.c_literal(var.data) - else: - raise ValueError( - "All orphans in the fgraph to Composite must" - " be Constant, CLinkerType instances." - ) - elif any(i.dtype == "float16" for i in var.owner.inputs) or any( - o.dtype == "float16" for o in var.owner.outputs - ): - # flag for elemwise ops to check. - self.inner_float16 = True - - _c_code = "{\n" - self.nodenames = [ - f"%(nodename)s_subnode{int(j)}" - for j, n in enumerate(self.fgraph.toposort()) - ] - - i = 0 - for j, node in enumerate(self.fgraph.toposort()): - for output in node.outputs: - if output not in subd: - i += 1 - name = f"V%(id)s_tmp{int(i)}" - subd[output] = name - _c_code += f"{output.type.dtype_specs()[1]} {name};\n" - s = node.op.c_code( - node, - self.nodenames[j], - [subd[input] for input in node.inputs], - [subd[output] for output in node.outputs], - dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"), - ) - _c_code += s - _c_code += "\n" - _c_code += "}\n" - self._c_code = _c_code - - def init_py_impls(self): - """ - Return a list of functions that compute each output of self. - - """ - # In the case where the graph is a dag, but not a tree like: - # add(*1 -> mul(x, y), *1) - - # We have an efficient way to build the executable (we build - # and traverse each node only once). - - # But we don't have an efficient execution. We will execute - # like a tree, so nodes that have more then 1 client will be - # executed as many times as there number of clients. In the - # example above, it will calculate *1 twice. Doing otherwise - # imply making a complicated execution engine. - - # We need the fast creation of the executor as we always do it - # even if we will use the c code. The Python implementation is - # already slow, so it is not as much important to have a fast - # execution there. - - memo = {} - - def compose_impl(r): - if r in memo: - return memo[r] - if r in self.fgraph.inputs: - idx = self.fgraph.inputs.index(r) - - def f(inputs): - return inputs[idx] - - memo[r] = f - return f - elif r.owner is None: # in fgraph.orphans: - - def f(inputs): - return r.data - - memo[r] = f - return f - node = r.owner - producers = [compose_impl(input) for input in node.inputs] - - def f(inputs): - return node.op.impl(*[p(inputs) for p in producers]) - - memo[r] = f - return f - - self._impls = [compose_impl(r) for r in self.fgraph.outputs] - - def init_name(self): - """ - Return a readable string representation of self.fgraph. - - """ - rval = self.name - if rval is None: - for i, r in enumerate(self.fgraph.inputs): - r.name = f"i{int(i)}" - for i, r in enumerate(self.fgraph.outputs): - r.name = f"o{int(i)}" - io = set(self.fgraph.inputs + self.fgraph.outputs) - for i, r in enumerate(self.fgraph.variables): - if r not in io and len(self.fgraph.clients[r]) > 1: - r.name = f"t{int(i)}" - outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs]) - rval = f"Composite{{{outputs_str}}}" - self.name = rval - - def init_fgraph(self): - # The clone done by FunctionGraph is needed as we don't want - # the fgraph to be set to the variable as we need to pickle - # them for the cache of c module to work. - fgraph = FunctionGraph(self.inputs, self.outputs) - MergeOptimizer().rewrite(fgraph) - for node in fgraph.apply_nodes: - if not isinstance(node.op, ScalarOp): - raise ValueError( - "The fgraph to Composite must be exclusively" - " composed of ScalarOp instances." - ) - self.fgraph = fgraph - def __init__(self, inputs, outputs): # We need to clone the graph as sometimes its nodes already # contain a reference to an fgraph. As we want the Composite @@ -4179,6 +4012,7 @@ def __init__(self, inputs, outputs): # only 1 new Composite each time at the output. for i in inputs: assert i not in outputs # This isn't supported, use identity + if len(outputs) > 1 or not any( isinstance(var.owner.op, Composite) for var in outputs ): @@ -4210,15 +4044,112 @@ def __init__(self, inputs, outputs): self.outputs_type = tuple([output.type for output in outputs]) self.nin = len(inputs) self.nout = len(outputs) - self.init_fgraph() # self.fgraph - # Postpone the creation in case it isn't needed. - # self.init_name() # self.name - self.name = None self.prepare_node_called = set() + @property + def fn(self): + return None + + @property + def inner_inputs(self): + return self.fgraph.inputs + + @property + def inner_outputs(self): + return self.fgraph.outputs + + def __str__(self): + return self.name + + def make_new_inplace(self, output_types_preference=None, name=None): + """ + This op.__init__ fct don't have the same parameter as other scalar op. + This break the insert_inplace_optimizer optimization. + This fct allow fix patch this. + + """ + d = {k: getattr(self, k) for k in self.init_param} + out = self.__class__(**d) + if name: + out.name = name + else: + name = out.name + super(Composite, out).__init__(output_types_preference, name) + return out + + @property + def py_perform(self): + if hasattr(self, "_py_perform_fn"): + return self._py_perform_fn + + from pytensor.link.utils import fgraph_to_python + + def python_convert(op, node=None, **kwargs): + assert node is not None + + n_outs = len(node.outputs) + + if n_outs > 1: + + def _perform(*inputs, outputs=[[None]] * n_outs): + op.perform(node, inputs, outputs) + return tuple(o[0] for o in outputs) + + else: + + def _perform(*inputs, outputs=[[None]]): + op.perform(node, inputs, outputs) + return outputs[0][0] + + return _perform + + self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert) + return self._py_perform_fn + + @property + def name(self): + if hasattr(self, "_name"): + return self._name + + # TODO FIXME: Just implement pretty printing for the `Op`; don't do + # this redundant, outside work in the `Op` itself. + for i, r in enumerate(self.fgraph.inputs): + r.name = f"i{int(i)}" + for i, r in enumerate(self.fgraph.outputs): + r.name = f"o{int(i)}" + io = set(self.fgraph.inputs + self.fgraph.outputs) + for i, r in enumerate(self.fgraph.variables): + if r not in io and len(self.fgraph.clients[r]) > 1: + r.name = f"t{int(i)}" + outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs]) + rval = f"Composite{{{outputs_str}}}" + self._name = rval + return self._name + + @name.setter + def name(self, name): + self._name = name + + @property + def fgraph(self): + if hasattr(self, "_fgraph"): + return self._fgraph + + # The clone done by FunctionGraph is needed as we don't want + # the fgraph to be set to the variable as we need to pickle + # them for the cache of c module to work. + fgraph = FunctionGraph(self.inputs, self.outputs) + MergeOptimizer().rewrite(fgraph) + for node in fgraph.apply_nodes: + if not isinstance(node.op, ScalarOp): + raise TypeError( + "The fgraph to Composite must be exclusively" + " composed of ScalarOp instances." + ) + self._fgraph = fgraph + return self._fgraph + def prepare_node(self, node, storage_map, compute_map, impl): - if impl == "py": - self.init_py_impls() # self._impls if impl not in self.prepare_node_called: for n in list_of_nodes(self.inputs, self.outputs): n.op.prepare_node(n, None, None, impl) @@ -4229,7 +4160,13 @@ def clone_float32(self): new_ins, new_outs = composite_f32.apply(self.fgraph) return Composite(new_ins, new_outs) + def clone(self): + new_ins, new_outs = composite_f32.apply(self.fgraph) + return Composite(new_ins, new_outs) + def output_types(self, input_types): + # TODO FIXME: What's the intended purpose/use of this method, and why + # does it even need to be a method? if tuple(input_types) != self.inputs_type: raise TypeError( f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}." @@ -4256,8 +4193,9 @@ def make_node(self, *inputs): return node def perform(self, node, inputs, output_storage): - for storage, impl in zip(output_storage, self._impls): - storage[0] = impl(inputs) + outputs = self.py_perform(*inputs) + for storage, out_val in zip(output_storage, outputs): + storage[0] = out_val def impl(self, *inputs): output_storage = [[None] for i in range(self.nout)] @@ -4270,8 +4208,110 @@ def impl(self, *inputs): def grad(self, inputs, output_grads): raise NotImplementedError("grad is not implemented for Composite") + def __eq__(self, other): + if self is other: + return True + if ( + type(self) != type(other) + or self.nin != other.nin + or self.nout != other.nout + ): + return False + + # TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this + # object to generate the same `_c_code`? + return self.c_code_template == other.c_code_template + + def __hash__(self): + # Note that in general, the configparser settings at the time + # of code generation (__init__) affect the semantics of this Op. + # This function assumes that all relevant info about the configparser + # is embodied in _c_code. So the _c_code, rather than self.fgraph, + # is the signature of the semantics of this Op. + # _c_code is preserved through unpickling, so the Op will not change + # semantics when it is reloaded with different configparser + # settings. + # + # TODO FIXME: Doesn't the above just mean that we should be including + # the relevant "configparser settings" here? Also, why should we even + # care about the exact form of the generated C code when comparing + # `Op`s? All this smells of leaky concerns and interfaces. + return hash((type(self), self.nin, self.nout, self.c_code_template)) + + def __getstate__(self): + rval = dict(self.__dict__) + rval.pop("_c_code", None) + rval.pop("_py_perform_fn", None) + rval.pop("_fgraph", None) + rval.pop("prepare_node_called", None) + return rval + + def __setstate__(self, d): + self.__dict__.update(d) + self.prepare_node_called = set() + + @property + def c_code_template(self): + from pytensor.link.c.interface import CLinkerType + + if hasattr(self, "_c_code"): + return self._c_code + + subd = dict( + chain( + ((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)), + ((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)), + ) + ) + + for var in self.fgraph.variables: + if var.owner is None: + if var not in self.fgraph.inputs: + # This is an orphan + if isinstance(var, Constant) and isinstance(var.type, CLinkerType): + subd[var] = var.type.c_literal(var.data) + else: + raise ValueError( + "All orphans in the fgraph to Composite must" + " be Constant, CLinkerType instances." + ) + elif any(i.dtype == "float16" for i in var.owner.inputs) or any( + o.dtype == "float16" for o in var.owner.outputs + ): + # flag for elemwise ops to check. + self.inner_float16 = True + + _c_code = "{\n" + self.nodenames = [ + f"%(nodename)s_subnode{int(j)}" + for j, n in enumerate(self.fgraph.toposort()) + ] + + i = 0 + for j, node in enumerate(self.fgraph.toposort()): + for output in node.outputs: + if output not in subd: + i += 1 + name = f"V%(id)s_tmp{int(i)}" + subd[output] = name + _c_code += f"{output.type.dtype_specs()[1]} {name};\n" + s = node.op.c_code( + node, + self.nodenames[j], + [subd[input] for input in node.inputs], + [subd[output] for output in node.outputs], + dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"), + ) + _c_code += s + _c_code += "\n" + + _c_code += "}\n" + + self._c_code = _c_code + + return self._c_code + def c_code(self, node, nodename, inames, onames, sub): - self.init_c_code() d = dict( chain( @@ -4286,7 +4326,7 @@ def c_code(self, node, nodename, inames, onames, sub): # It won't generate conflicting variable name. d["id"] = "_DUMMY_ID_" - return self._c_code % d + return self.c_code_template % d def c_code_cache_version(self): rval = [3] @@ -4314,7 +4354,6 @@ def c_support_code(self, **kwargs): return "\n".join(sorted(rval)) def c_support_code_apply(self, node, name): - self.init_c_code() rval = [] for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): subnode_support_code = subnode.op.c_support_code_apply( @@ -4328,49 +4367,6 @@ def c_support_code_apply(self, node, name): # c_support_code instead of c_support_code_apply. return "\n".join(rval) - def __eq__(self, other): - if self is other: - return True - if ( - type(self) != type(other) - or self.nin != other.nin - or self.nout != other.nout - ): - return False - # see __hash__ for comment on why there is no mention of fgraph - # or module cache key here. - self.init_c_code() # self._c_code and self.nodenames - other.init_c_code() - return self._c_code == other._c_code - - def __hash__(self): - self.init_c_code() # self._c_code and self.nodenames - rval = hash((type(self), self.nin, self.nout, self._c_code)) - # Note that in general, the configparser settings at the time - # of code generation (__init__) affect the semantics of this Op. - # This function assumes that all relevant info about the configparser - # is embodied in _c_code. So the _c_code, rather than self.fgraph, - # is the signature of the semantics of this Op. - # _c_code is preserved through unpickling, so the Op will not change - # semantics when it is reloaded with different configparser - # settings. - return rval - - def __getstate__(self): - rval = dict(self.__dict__) - rval.pop("_impls", None) - rval.pop("prepare_node_called", None) - del rval["fgraph"] - return rval - - def __setstate__(self, d): - self.__dict__.update(d) - # We must call init to set fgraph and _impls again, as otherwise - # self.perform will not work. - self.prepare_node_called = set() - self.init_fgraph() - self.init_py_impls() - class Compositef32: # This is a dict of scalar op classes that need special handling diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 3ee04bb3d9..764d67f5d8 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,5 +1,5 @@ from copy import copy -from typing import List, Tuple, Union +from typing import List, Tuple import numpy as np @@ -1257,33 +1257,61 @@ class CAReduce(COp): """ - __props__: Union[ - Tuple[str], Tuple[str, str], Tuple[str, str, str], Tuple[str, str, str, str] - ] = ("scalar_op", "axis") + __props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "upcast_discrete_output") - def __init__(self, scalar_op, axis=None): + def __init__( + self, + scalar_op, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=False, + ): """ Parameters ---------- scalar_op - A binary scalar `Op` with only one output. It must be commutative - and associative. + A binary scalar `Op` with only one output. + It must be commutative and associative. axis - - The dimension along which we want to reduce - - List of dimensions that we want to reduce - - If ``None``, all dimensions are reduced + - the dimension along which we want to reduce + - list of dimensions that we want to reduce + - if ``None``, all dimensions are reduced + dtype + The dtype of the returned tensor. If ``None``, then we use the default + dtype which is the same as the input array's dtype except when + `upcast_discrete_output` is ``True`` and the following holds: + + - the input dtype is a signed integer of precision < 64 bit, in which + case we use int64 + - the input dtype is an unsigned integer of precision < 64 bit, in + which case we use uint64 + + This default dtype does _not_ depend on the value of `acc_dtype`. + This behavior is similar in spirit to that of NumPy, except that + NumPy uses the default machine integer while we always use 64 bit + integers to avoid platform-dependent behavior. + acc_dtype + The dtype of the internal accumulator. + If ``None`` (default), we use the dtype in the list below, + or the input dtype if its precision is higher: + + - for int dtypes, we use at least int64; + - for uint dtypes, we use at least uint64; + - for float dtypes, we use at least float64; + - for complex dtypes, we use at least complex128. + upcast_discrete_output + See """ if scalar_op.nin not in (-1, 2) or scalar_op.nout != 1: raise NotImplementedError( - "CAReduce only supports binary functions with a single " "output." + "CAReduce only supports binary functions with a single output." ) self.axis = None - self.ufunc_is_vectorized = False self.scalar_op = scalar_op - self.set_ufunc(scalar_op) if axis is not None: if isinstance(axis, (int, np.integer)) or ( @@ -1293,64 +1321,179 @@ def __init__(self, scalar_op, axis=None): else: self.axis = tuple(axis) - def set_ufunc(self, scalar_op): - if hasattr(scalar_op, "nfunc_spec") and hasattr(np, scalar_op.nfunc_spec[0]): - self.ufunc = getattr(np, scalar_op.nfunc_spec[0]) + self.dtype = dtype + self.acc_dtype = acc_dtype + self.upcast_discrete_output = upcast_discrete_output + + @property + def ufunc(self): + if hasattr(self, "_ufunc"): + return self._ufunc + + if hasattr(self.scalar_op, "nfunc_spec") and hasattr( + np, self.scalar_op.nfunc_spec[0] + ): + self._ufunc = getattr(np, self.scalar_op.nfunc_spec[0]) + else: + self._ufunc = np.frompyfunc( + self.scalar_op.impl, 2, 1, identity=self.scalar_op.identity + ) + + return self._ufunc + + def _output_dtype(self, idtype): + + if not self.upcast_discrete_output: + return idtype + + dtype = self.dtype + + if dtype == "OLD": + return dict( + int8="int32", + int16="int32", + int32="int64", + uint8="uint32", + uint16="uint32", + uint32="uint64", + ).get(idtype, idtype) + elif dtype is None: + # If input has a discrete dtype, upcast it to 64 + return dict( + bool="int64", + int8="int64", + int16="int64", + int32="int64", + uint8="uint64", + uint16="uint64", + uint32="uint64", + ).get(idtype, idtype) else: - self.ufunc = np.frompyfunc(scalar_op.impl, 2, 1) - self.ufunc_is_vectorized = True + # The important is that the accumulator dtype does not + # lose precision. Then, the result can be downcasted. + return dtype - def _output_dtype(self, input_dtype): - return input_dtype + def _acc_dtype(self, idtype): + acc_dtype = self.acc_dtype + if acc_dtype is None: + return dict( + bool="int64", + int8="int64", + int16="int64", + int32="int64", + uint8="uint64", + uint16="uint64", + uint32="uint64", + float16="float32", + float32="float64", + complex64="complex128", + ).get(idtype, idtype) + elif acc_dtype in continuous_dtypes and idtype in discrete_dtypes: + # Specifying a continuous accumulator for discrete input is OK + return acc_dtype + else: + # The conversion has to be considered an upcast. + upcasted_dtype = upcast(idtype, acc_dtype) + if acc_dtype != upcasted_dtype: + raise TypeError( + f"Cannot build {self} node with input dtype {idtype} " + f"and acc_dtype {acc_dtype}, as precision would be lost. " + "To correct this error, you can:\n" + " - not specify acc_dtype, or\n" + f" - use an acc_dtype at least as precise as {upcasted_dtype}.\n" + ' - specify "dtype" instead of "acc_dtype", so ' + "the reduction will be precise, but the result will " + 'be casted into "dtype" at the end.\n' + "If you are expecting the precision loss, you can " + f'use tensor.cast(..., dtype="{acc_dtype}"), on your input.' + ) + return acc_dtype def make_node(self, input): input = as_tensor_variable(input) inp_dims = input.type.ndim + inp_dtype = input.type.dtype + + # We need to redefine make_node so that, if self.dtype is None, + # we can infer what dtype should be, and create a node from an Op + # of the appropriate dtype. + dtype = self._output_dtype(inp_dtype) + acc_dtype = self._acc_dtype(inp_dtype) + + assert dtype is not None + assert acc_dtype is not None axis = self.axis - if axis is None: - axis = list(range(inp_dims)) - copy_op = any(a < 0 for a in axis) # scalar inputs are treated as 1D regarding axis in this `Op` - try: - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims)) - except np.AxisError: - raise np.AxisError(axis, ndim=inp_dims) - - # We can't call self.__class__() as there is a class that - # inherits from CAReduce that doesn't have the same signature - if copy_op: - op = copy(self) - op.set_ufunc(op.scalar_op) - assert len(axis) == len(self.axis) - op.axis = tuple(axis) + if axis is not None: + try: + axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims)) + except np.AxisError: + raise np.AxisError(axis, ndim=inp_dims) + + out_shape = tuple( + s for i, s in enumerate(input.type.shape) if i not in axis + ) else: - op = self + out_shape = () - shape = [x for i, x in enumerate(input.type.shape) if i not in axis] + if ( + (axis is not None and any(a < 0 for a in axis)) + or dtype != self.dtype + or acc_dtype != self.acc_dtype + ): + op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype) + else: + op = self - output = TensorType( - dtype=self._output_dtype(input.type.dtype), - shape=shape, - )() + output = TensorType(dtype=dtype, shape=out_shape)() return Apply(op, [input], [output]) - def __getstate__(self): - d = copy(self.__dict__) - d.pop("ufunc", None) - return d + def clone( + self, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=None, + **kwargs, + ): + if axis is None: + axis = self.axis + if dtype is None: + dtype = self.dtype + if acc_dtype is None: + acc_dtype = self.acc_dtype + if upcast_discrete_output is None: + upcast_discrete_output = self.upcast_discrete_output - def __setstate__(self, d): - self.__dict__.update(d) - self.set_ufunc(self.scalar_op) + res = type(self)( + self.scalar_op, + axis=axis, + dtype=dtype, + acc_dtype=acc_dtype, + upcast_discrete_output=None, + **kwargs, + ) + + return res def __str__(self): prefix = f"{type(self).__name__}{{{self.scalar_op}}}" + extra_params = [] + if self.axis is not None: - axes_str = ", ".join(str(x) for x in self.axis) - return f"{prefix}{{{axes_str}}}" + axis = ", ".join(str(x) for x in self.axis) + extra_params.append(f"axis=[{axis}]") + + if self.acc_dtype: + extra_params.append(f"acc_dtype={self.acc_dtype}") + + extra_params_str = ", ".join(extra_params) + + if extra_params_str: + return f"{prefix}{{{extra_params_str}}}" else: return f"{prefix}" @@ -1358,31 +1501,21 @@ def perform(self, node, inp, out): (input,) = inp (output,) = out axis = self.axis - if axis is None: - axis = list(range(input.ndim)) - if hasattr(self, "acc_dtype") and self.acc_dtype is not None: + out_dtype = node.outputs[0].type.dtype + + if self.acc_dtype is not None: acc_dtype = self.acc_dtype else: - acc_dtype = node.outputs[0].type.dtype - - variable = np.array(input, dtype=acc_dtype) - - if axis: - # Reducing functions built using np.frompyfunc() do not - # support reduction along multiple axes. Hence loop through - # each, otherwise numpy's inbuilt reduction functions - # support reduction along multiple axes directly. - if self.ufunc_is_vectorized: - to_reduce = reversed(sorted(axis)) - for dimension in to_reduce: - variable = self.ufunc.reduce(variable, dimension, dtype=acc_dtype) - else: - variable = self.ufunc.reduce(variable, axis=tuple(axis)) - output[0] = _asarray(variable, dtype=node.outputs[0].type.dtype) - else: - # Force a copy - output[0] = np.array(variable, copy=True, dtype=node.outputs[0].type.dtype) + acc_dtype = out_dtype + + # out_dtype = self.dtype if self.dtype and self.dtype != "OLD" else out_dtype + + input = np.array(input, dtype=acc_dtype) + + out = self.ufunc.reduce(input, axis=axis, dtype=acc_dtype) + + output[0] = _asarray(out, dtype=out_dtype) def infer_shape(self, fgraph, node, shapes): (ishape,) = shapes @@ -1588,176 +1721,6 @@ def c_code_cache_version_apply(self, node): return () -class CAReduceDtype(CAReduce): - """A subclass of `CAReduce` that accepts an additional output "dtype" parameter. - - It also accepts an optional `acc_dtype`, which specifies the dtype that - will be used for the accumulation. The accumulation will be done using an - array of dtype `acc_dtype`, then it will be cast into `dtype` and returned. - - If no `dtype` is provided, one will be inferred so as not to lose - too much precision. - - """ - - __props__: Union[Tuple[str, str, str], Tuple[str, str, str, str]] = ( - "scalar_op", - "axis", - "dtype", - "acc_dtype", - ) - - def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None): - """ - - Parameters - ---------- - scalar_op - A binary scalar `Op` with only one output. - It must be commutative and associative. - axis - * the dimension along which we want to reduce - * list of dimensions that we want to reduce - * if ``None``, all dimensions are reduced - dtype - The dtype of the returned tensor. If ``None``, then we use the default - dtype which is the same as the input array's dtype except when: - - * the input dtype is a signed integer of precision < 64 bit, in which - case we use int64 - * the input dtype is an unsigned integer of precision < 64 bit, in - which case we use uint64 - - This default dtype does _not_ depend on the value of `acc_dtype`. - This behavior is similar in spirit to that of NumPy, except that - NumPy uses the default machine integer while we always use 64 bit - integers to avoid platform-dependent behavior. - acc_dtype - The dtype of the internal accumulator. - If ``None`` (default), we use the dtype in the list below, - or the input dtype if its precision is higher: - - * for int dtypes, we use at least int64; - * for uint dtypes, we use at least uint64; - * for float dtypes, we use at least float64; - * for complex dtypes, we use at least complex128. - - """ - super().__init__(scalar_op, axis=axis) - self.dtype = dtype - self.acc_dtype = acc_dtype - - def __setstate__(self, d): - super().__setstate__(d) - if not hasattr(self, "dtype"): - # This is needed as old pickled will crash otherwise. - # We need to keep the old dtype behavior as the op - # could be in an apply node with a specified dtype. - self.dtype = "OLD" - - if not hasattr(self, "acc_dtype"): - # acc_dtype is not used by any external Op, so we do not - # need to keep the previous behaviour here. - self.acc_dtype = None - - def _output_dtype(self, idtype): - dtype = self.dtype - if dtype == "OLD": - return dict( - int8="int32", - int16="int32", - int32="int64", - uint8="uint32", - uint16="uint32", - uint32="uint64", - ).get(idtype, idtype) - if dtype is None: - # If input has a discrete dtype, upcast it to 64 - return dict( - bool="int64", - int8="int64", - int16="int64", - int32="int64", - uint8="uint64", - uint16="uint64", - uint32="uint64", - ).get(idtype, idtype) - else: - # The important is that the accumulator dtype does not - # lose precision. Then, the result can be downcasted. - return dtype - - def _acc_dtype(self, idtype): - acc_dtype = self.acc_dtype - if acc_dtype is None: - return dict( - bool="int64", - int8="int64", - int16="int64", - int32="int64", - uint8="uint64", - uint16="uint64", - uint32="uint64", - float16="float32", - float32="float64", - complex64="complex128", - ).get(idtype, idtype) - elif acc_dtype in continuous_dtypes and idtype in discrete_dtypes: - # Specifying a continuous accumulator for discrete input is OK - return acc_dtype - else: - # The conversion has to be considered an upcast. - upcasted_dtype = upcast(idtype, acc_dtype) - if acc_dtype != upcasted_dtype: - raise TypeError( - f"Cannot build {self} node with input dtype {idtype} " - f"and acc_dtype {acc_dtype}, as precision would be lost. " - "To correct this error, you can:\n" - " - not specify acc_dtype, or\n" - f" - use an acc_dtype at least as precise as {upcasted_dtype}.\n" - ' - specify "dtype" instead of "acc_dtype", so ' - "the reduction will be precise, but the result will " - 'be casted into "dtype" at the end.\n' - "If you are expecting the precision loss, you can " - f'use tensor.cast(..., dtype="{acc_dtype}"), on your input.' - ) - return acc_dtype - - def make_node(self, input): - # We need to redefine make_node so that, if self.dtype is None, - # we can infer what dtype should be, and create a node from an Op - # of the appropriate dtype. - input = as_tensor_variable(input) - dtype = self._output_dtype(input.dtype) - acc_dtype = self._acc_dtype(input.dtype) - - assert dtype is not None - assert acc_dtype is not None - - if dtype == self.dtype and acc_dtype == self.acc_dtype: - # Don't build another instance - op = self - else: - op = copy(self) - op.set_ufunc(self.scalar_op) - op.dtype = dtype - op.acc_dtype = acc_dtype - - assert op.acc_dtype is not None - - # TODO: Why doesn't `make_node` just take these - # automatically-determined values as arguments? - return super(CAReduceDtype, op).make_node(input) - - def __str__(self): - prefix = f"{type(self).__name__}{{{self.scalar_op}}}" - if self.axis is not None: - axis = ", ".join(str(x) for x in self.axis) - return f"{prefix}{{axis=[{axis}], acc_dtype={self.acc_dtype}}}" - else: - return f"{prefix}{{acc_dtype={self.acc_dtype}}}" - - def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): """Replace a symbol definition with an `Elemwise`-wrapped version of the corresponding scalar `Op`. diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 69101a4450..eb609851e7 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -25,13 +25,7 @@ stack, switch, ) -from pytensor.tensor.elemwise import ( - CAReduce, - CAReduceDtype, - DimShuffle, - Elemwise, - scalar_elemwise, -) +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.type import ( DenseTensorType, @@ -633,6 +627,10 @@ class Max(NonZeroCAReduce): def __init__(self, axis): super().__init__(aes.scalar_maximum, axis) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + return type(self)(axis=axis) + class Min(NonZeroCAReduce): nfunc_spec = ("min", 1, 1) @@ -640,6 +638,10 @@ class Min(NonZeroCAReduce): def __init__(self, axis): super().__init__(aes.scalar_minimum, axis) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + return type(self)(axis=axis) + def max(x, axis=None, keepdims=False): """ @@ -1530,6 +1532,10 @@ def c_code(self, node, name, inames, onames, sub): """ ) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + return type(self)(axis=axis) + # TODO: implement the grad. When done and tested, you can make this the default # version. @@ -2350,7 +2356,6 @@ class All(CAReduce): """ - __props__ = ("axis",) nfunc_spec = ("all", 1, 1) def __init__(self, axis=None): @@ -2376,6 +2381,10 @@ def grad(self, inp, grads): (x,) = inp return [x.zeros_like(config.floatX)] + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + return type(self)(axis=axis) + class Any(CAReduce): """Applies `bitwise or` to all the values of a tensor along the @@ -2383,7 +2392,6 @@ class Any(CAReduce): """ - __props__ = ("axis",) nfunc_spec = ("any", 1, 1) def __init__(self, axis=None): @@ -2409,48 +2417,31 @@ def grad(self, inp, grads): (x,) = inp return [x.zeros_like(config.floatX)] + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + return type(self)(axis=axis) -class Sum(CAReduceDtype): + +class Sum(CAReduce): """ Sums all the values of a tensor along the specified axis(es). - Equivalent to `CAReduceDtype(scalar.add, axis=axis, dtype=dtype)`, + Equivalent to `CAReduce(scalar.add, axis=axis, dtype=dtype)`, with the difference that this defines the gradient of sum wrt its tensor input. - Parameters - ---------- - axis - Axis(es) along which the tensor should be summed - (use None to sum over all axes, and a list or tuple to sum along more - than one axis). - - dtype - The dtype of the internal accumulator and returned - tensor. If None, then we use the default dtype which is the same as the - input tensor's dtype except when: - - the input dtype is a signed integer of precision < 64 bit, in - which case we use int64 - - the input dtype is an unsigned integer of precision < 64 bit, in - which case we use uint64 - This value does not depend on the value of "acc_dtype". - - acc_dtype - The dtype of the internal accumulator. - If None (default), we use the dtype in the list below, - or the input dtype if its precision is higher: - - for int dtypes, we use at least int64; - - for uint dtypes, we use at least uint64; - - for float dtypes, we use at least float64; - - for complex dtypes, we use at least complex128. - """ - __props__ = ("axis", "dtype", "acc_dtype") nfunc_spec = ("sum", 1, 1) def __init__(self, axis=None, dtype=None, acc_dtype=None): - super().__init__(aes.add, axis=axis, dtype=dtype, acc_dtype=acc_dtype) + super().__init__( + aes.add, + axis=axis, + dtype=dtype, + acc_dtype=acc_dtype, + upcast_discrete_output=True, + ) def __str__(self): name = self.__class__.__name__ @@ -2492,6 +2483,12 @@ def R_op(self, inputs, eval_points): return [None] return self(*eval_points, return_list=True) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + dtype = kwargs.get("dtype", self.dtype) + acc_dtype = kwargs.get("acc_dtype", self.acc_dtype) + return type(self)(axis=axis, dtype=dtype, acc_dtype=acc_dtype) + def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): """ @@ -2523,7 +2520,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"])) -class Prod(CAReduceDtype): +class Prod(CAReduce): """ Multiplies all the values of a tensor along the specified axis(es). @@ -2533,19 +2530,20 @@ class Prod(CAReduceDtype): """ - __props__ = ("axis", "dtype", "acc_dtype") + __props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input") + nfunc_spec = ("prod", 1, 1) def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False): - super().__init__(aes.mul, axis=axis, dtype=dtype, acc_dtype=acc_dtype) + super().__init__( + aes.mul, + axis=axis, + dtype=dtype, + acc_dtype=acc_dtype, + upcast_discrete_output=True, + ) self.no_zeros_in_input = no_zeros_in_input - def __setstate__(self, dct): - super().__setstate__(dct) - # Add default value to be able to reload old pickled objects. - if "no_zeros_in_input" not in dct: - self.no_zeros_in_input = False - def L_op(self, inp, out, grads): """ The grad of this Op could be very easy, if it is was not for the case @@ -2668,6 +2666,18 @@ def L_op(self, inp, out, grads): def c_code_cache_version(self): return (1,) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + dtype = kwargs.get("dtype", self.dtype) + acc_dtype = kwargs.get("acc_dtype", self.acc_dtype) + no_zeros_in_input = kwargs.get("no_zeros_in_input", self.no_zeros_in_input) + return type(self)( + axis=axis, + dtype=dtype, + acc_dtype=acc_dtype, + no_zeros_in_input=no_zeros_in_input, + ) + def prod( input, @@ -2736,12 +2746,15 @@ def c_code_cache_version(self): mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros") -class ProdWithoutZeros(CAReduceDtype): - - __props__ = ("axis", "dtype", "acc_dtype") - +class ProdWithoutZeros(CAReduce): def __init__(self, axis=None, dtype=None, acc_dtype=None): - super().__init__(mul_without_zeros, axis=axis, dtype=dtype, acc_dtype=acc_dtype) + super().__init__( + mul_without_zeros, + axis=axis, + dtype=dtype, + acc_dtype=acc_dtype, + upcast_discrete_output=True, + ) def grad(self, inp, grads): from pytensor.gradient import grad_not_implemented @@ -2757,6 +2770,12 @@ def grad(self, inp, grads): ) return [a_grad] + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + dtype = kwargs.get("dtype", self.dtype) + acc_dtype = kwargs.get("acc_dtype", self.acc_dtype) + return type(self)(axis=axis, dtype=dtype, acc_dtype=acc_dtype) + def any(x, axis=None, keepdims=False): out = Any(axis)(x) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index a5a86905ec..127530bf42 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -7,6 +7,7 @@ import pytensor import pytensor.scalar.basic as aes from pytensor import compile +from pytensor.compile.mode import get_target_language from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Constant, io_toposort from pytensor.graph.features import ReplaceValidate @@ -14,12 +15,13 @@ from pytensor.graph.rewriting.basic import ( GraphRewriter, copy_stack_trace, + in2out, node_rewriter, ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value -from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.shape import shape_padleft @@ -948,3 +950,82 @@ def local_useless_composite(fgraph, node): c = aes.Composite(inputs=comp.inputs, outputs=new_outputs) e = Elemwise(scalar_op=c)(*node.inputs, return_list=True) return dict(zip([node.outputs[i] for i in idx], e)) + + +@node_rewriter([CAReduce]) +def local_careduce_fusion(fgraph, node): + """Fuse a `CAReduce` applied to an `Elemwise`.""" + + (car_input,) = node.inputs + elm_node = car_input.owner + + if elm_node is None or not isinstance(elm_node.op, Elemwise): + return False + + elm_inputs = elm_node.inputs + elm_outputs = elm_node.outputs + + if len(elm_inputs) > 1 or len(elm_outputs) > 1: + # TODO: Implement the multiple inputs case + return False + + if len(fgraph.clients[elm_outputs[0]]) > 1: + return False + + # Don't form the fusion when the target language is Python + elm_scalar_op = elm_node.op.scalar_op + car_scalar_op = node.op.scalar_op + + if get_target_language() == ("py",): + return False + + try: + elm_scalar_op.c_code( + elm_node, + "test_presence_of_c_code", + ["x" for x in elm_inputs], + ["z" for z in elm_outputs], + {"fail": "%(fail)s"}, + ) + + car_scalar_op.c_code( + node, + "test_presence_of_c_code", + ["x" for x in node.inputs], + ["z" for z in node.outputs], + {"fail": "%(fail)s"}, + ) + except (NotImplementedError, MethodNotDefined): + return False + + car_axis = node.op.axis + + scalar_elm_inputs = [ + aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs + ] + elm_output = elm_scalar_op(*scalar_elm_inputs) + # This input represents the previous value in the `CAReduce` binary reduction + carried_car_input = elm_output.type() + scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)] + + fused_scalar_op = aes.Composite( + inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs + ) + + # The fused `Op` needs to look and behave like a `BinaryScalarOp` + # TODO: Generate a new `type` and make this relationship official? + fused_scalar_op.identity = car_scalar_op.identity + fused_scalar_op.nin = 2 + fused_scalar_op.nout = 1 + + new_car_op = CAReduce(fused_scalar_op, car_axis) + + return [new_car_op(*elm_inputs)] + + +compile.optdb.register( # type: ignore + "local_careduce_fusion", + in2out(local_careduce_fusion), + "fusion", + position=49, +) diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index 772d864394..c965087ea2 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -1,9 +1,20 @@ +import copy + +import pytest + from pytensor.compile.function import function -from pytensor.compile.mode import AddFeatureOptimizer, Mode +from pytensor.compile.mode import ( + AddFeatureOptimizer, + Mode, + get_default_mode, + get_target_language, +) +from pytensor.configdefaults import config from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB +from pytensor.link.basic import LocalLinker from pytensor.tensor.math import dot, tanh -from pytensor.tensor.type import matrix +from pytensor.tensor.type import matrix, vector def test_Mode_basic(): @@ -48,3 +59,86 @@ def test_including(): new_mode = mode.including("fast_compile") assert set(new_mode._optimizer.include) == {"merge", "fast_compile"} + + +class TestBunchOfModes: + def test_modes(self): + # this is a quick test after the LazyLinker branch merge + # to check that all the current modes can still be used. + linker_classes_involved = [] + + predef_modes = ["FAST_COMPILE", "FAST_RUN", "DEBUG_MODE"] + + # Linkers to use with regular Mode + if config.cxx: + linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "cvm", "cvm_nogc"] + else: + linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc"] + modes = predef_modes + [Mode(linker, "fast_run") for linker in linkers] + + for mode in modes: + x = matrix() + y = vector() + f = function([x, y], x + y, mode=mode) + # test that it runs something + f([[1, 2], [3, 4]], [5, 6]) + linker_classes_involved.append(f.maker.mode.linker.__class__) + # print 'MODE:', mode, f.maker.mode.linker, 'stop' + + # regression check: + # there should be + # - `VMLinker` + # - OpWiseCLinker (FAST_RUN) + # - PerformLinker (FAST_COMPILE) + # - DebugMode's Linker (DEBUG_MODE) + assert 4 == len(set(linker_classes_involved)) + + +class TestOldModesProblem: + def test_modes(self): + # Then, build a mode with the same linker, and a modified optimizer + default_mode = get_default_mode() + modified_mode = default_mode.including("specialize") + + # The following line used to fail, with Python 2.4, in July 2012, + # because an fgraph was associated to the default linker + copy.deepcopy(modified_mode) + + # More straightforward test + linker = get_default_mode().linker + assert not hasattr(linker, "fgraph") or linker.fgraph is None + + +def test_get_target_language(): + with config.change_flags(mode=Mode(linker="py")): + res = get_target_language() + assert res == ("py",) + + res = get_target_language(Mode(linker="py")) + assert res == ("py",) + + res = get_target_language(Mode(linker="c")) + assert res == ("c",) + + res = get_target_language(Mode(linker="c|py")) + assert res == ("c", "py") + + res = get_target_language(Mode(linker="vm")) + assert res == ("c", "py") + + with config.change_flags(cxx=""): + res = get_target_language(Mode(linker="vm")) + assert res == ("py",) + + res = get_target_language(Mode(linker="jax")) + assert res == ("jax",) + + res = get_target_language(Mode(linker="numba")) + assert res == ("numba",) + + class MyLinker(LocalLinker): + pass + + test_mode = Mode(linker=MyLinker()) + with pytest.raises(Exception): + get_target_language(test_mode) diff --git a/tests/compile/test_modes.py b/tests/compile/test_modes.py deleted file mode 100644 index 843e6f5536..0000000000 --- a/tests/compile/test_modes.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Test compilation modes -""" - -import copy - -from pytensor.compile.function import function -from pytensor.compile.mode import Mode, get_default_mode -from pytensor.configdefaults import config -from pytensor.tensor.type import matrix, vector - - -class TestBunchOfModes: - def test_modes(self): - # this is a quick test after the LazyLinker branch merge - # to check that all the current modes can still be used. - linker_classes_involved = [] - - predef_modes = ["FAST_COMPILE", "FAST_RUN", "DEBUG_MODE"] - - # Linkers to use with regular Mode - if config.cxx: - linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "cvm", "cvm_nogc"] - else: - linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc"] - modes = predef_modes + [Mode(linker, "fast_run") for linker in linkers] - - for mode in modes: - x = matrix() - y = vector() - f = function([x, y], x + y, mode=mode) - # test that it runs something - f([[1, 2], [3, 4]], [5, 6]) - linker_classes_involved.append(f.maker.mode.linker.__class__) - # print 'MODE:', mode, f.maker.mode.linker, 'stop' - - # regression check: - # there should be - # - `VMLinker` - # - OpWiseCLinker (FAST_RUN) - # - PerformLinker (FAST_COMPILE) - # - DebugMode's Linker (DEBUG_MODE) - assert 4 == len(set(linker_classes_involved)) - - -class TestOldModesProblem: - def test_modes(self): - # Then, build a mode with the same linker, and a modified optimizer - default_mode = get_default_mode() - modified_mode = default_mode.including("specialize") - - # The following line used to fail, with Python 2.4, in July 2012, - # because an fgraph was associated to the default linker - copy.deepcopy(modified_mode) - - # More straightforward test - linker = get_default_mode().linker - assert not hasattr(linker, "fgraph") or linker.fgraph is None diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index d6e27d9784..02dc2001f8 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -2,6 +2,7 @@ import pytest import pytensor +import pytensor.tensor as at import tests.unittest_tools as utt from pytensor.compile.mode import Mode from pytensor.graph.fg import FunctionGraph @@ -130,11 +131,16 @@ def test_flatten(self): def test_with_constants(self): x, y, z = floats("xyz") e = mul(add(70.0, y), true_div(x, y)) - C = Composite([x, y], [e]) - c = C.make_node(x, y) - assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0)) - # print c.c_code(['x', 'y'], ['z'], dict(id = 0)) - g = FunctionGraph([x, y], [c.out]) + comp_op = Composite([x, y], [e]) + comp_node = comp_op.make_node(x, y) + + c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0)) + assert "70.0" in c_code + + # Make sure caching of the c_code template works + assert hasattr(comp_node.op, "_c_code") + + g = FunctionGraph([x, y], [comp_node.out]) fn = make_function(DualLinker().accept(g)) assert fn(1.0, 2.0) == 36.0 @@ -174,24 +180,35 @@ def test_composite_printing(self): "*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" ) - def test_make_node_continue_graph(self): - # This is a test for a bug (now fixed) that disabled the - # local_gpu_elemwise_0 optimization and printed an - # optimization warning on the terminal. - - # We test that Composite.make_node accept as inputs Variable - # some that represent existing computation. - - si0 = pytensor.scalar.int8() - si1 = pytensor.scalar.int8() - si2 = pytensor.scalar.float32() - sout = (si0 * si1) / si2 - sop = pytensor.scalar.Composite([si0, si1, si2], [sout]) - si0 = pytensor.scalar.int8() - si1 = pytensor.scalar.int8() - si2 = pytensor.scalar.float32() - si3 = pytensor.scalar.float32() - sop.make_node(si0 * si3, si1, si2) + def test_non_scalar_error(self): + x = float32("x") + comp_op = Composite([x], [(at.zeros((2,)) + x).sum()]) + + with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"): + comp_op.fgraph + + def test_multi_out_perform(self): + from pytensor.graph.basic import Apply + from pytensor.scalar.basic import ScalarOp + + class MultiOutOp(ScalarOp): + def make_node(self, x): + return Apply(self, [x], [x.type(), x.type()]) + + def perform(self, node, inputs, outputs): + outputs[1][0] = outputs[0][0] = inputs[0] + + def c_code(self, *args): + return "dummy" + + x = float32("x") + comp_op = Composite([x], MultiOutOp()(x)) + + y, z = comp_op(x) + + fn = pytensor.function([x], [y, z], mode=Mode("py", None)) + + assert fn(1.0) == [1.0, 1.0] class TestLogical: diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 60b78b9716..4c02cf0956 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1113,6 +1113,86 @@ def test_test_values(self, test_value): f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] ) + @pytest.mark.parametrize("linker", ["cvm", "py"]) + @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) + def test_CAReduce_single_input(self, linker, axis): + """Make sure that `CAReduce` and `Elemwise` fusions work with a single input.""" + + mode = Mode(linker=linker) + mode._optimizer = mode._optimizer.including( + "local_careduce_fusion", + "canonicalize", + "inplace", + ) + + x = tensor("floatX", shape=(None, None, None), name="x") + out = exp(x).sum(axis=axis) + + out_fn = function([x], out, mode=mode) + + if linker != "py": + (out_node,) = out_fn.maker.fgraph.toposort() + assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite) + + rng = np.random.default_rng(2320) + x_val = rng.random((4, 3, 2), dtype=config.floatX) + + exp_res = np.exp(x_val).sum(axis=axis) + + out_val = out_fn(x_val) + assert out_val.shape == exp_res.shape + assert np.allclose(out_val, exp_res) + else: + out_nodes = out_fn.maker.fgraph.toposort() + assert not any( + isinstance(out_node.op.scalar_op, aes.basic.Composite) + for out_node in out_nodes + if hasattr(out_node.op, "scalar_op") + ) + + # `Elemwise`s with more than one client shouldn't be rewritten + x = tensor("floatX", shape=(None, None, None), name="x") + exp_x = exp(x) + out = exp_x.sum(axis=axis) + exp(x) + + out_fn = function([x], out, mode=mode) + out_nodes = out_fn.maker.fgraph.toposort() + assert not any( + isinstance(out_node.op.scalar_op, aes.basic.Composite) + for out_node in out_nodes + if hasattr(out_node.op, "scalar_op") + ) + + @pytest.mark.xfail(reason="Not implemented") + @pytest.mark.parametrize("linker", ["cvm", "py"]) + @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) + def test_CAReduce_multiple_inputs(self, linker, axis): + """Make sure that `CAReduce` and `Elemwise` fusions work with multiple inputs.""" + + mode = Mode(linker=linker) + mode._optimizer = mode._optimizer.including( + "local_careduce_fusion", + "canonicalize", + "inplace", + ) + + x = tensor("floatX", shape=(None, None, None), name="x") + y = tensor("floatX", shape=(None, None, None), name="y") + out = (x + y).sum(axis=axis) + + out_fn = function([x, y], out, mode=mode) + (out_node,) = out_fn.maker.fgraph.toposort() + + assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite) + + rng = np.random.default_rng(2320) + x_val = rng.random((4, 3, 2), dtype=config.floatX) + y_val = rng.random((4, 3, 2), dtype=config.floatX) + exp_res = (x_val + y_val).sum(axis=axis) + out_val = out_fn(x_val, y_val) + assert out_val.shape == exp_res.shape + assert np.allclose(out_val, exp_res) + class TimesN(aes.basic.UnaryScalarOp): """ diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index c727b44f63..7b13ba682a 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -17,7 +17,7 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import second -from pytensor.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import ShapeError from pytensor.tensor.math import all as at_all from pytensor.tensor.math import any as at_any @@ -537,24 +537,16 @@ def with_mode( for axis in reversed(sorted(tosum)): zv = np.bitwise_xor.reduce(zv, axis) else: - raise Exception( + raise NotImplementedError( f"Test for CAReduce with scalar_op {scalar_op} not implemented" ) if test_nan: - try: - assert self.type.values_eq(f(xv), zv), (f(xv), zv) - except NotImplementedError: - # GpuCAReduce don't implement all cases when size is 0 - assert xv.size == 0 + assert self.type.values_eq(f(xv), zv), (f(xv), zv) else: - try: - f_xv = f(xv) - assert f_xv.shape == zv.shape, (f_xv, zv) - utt.assert_allclose(zv, f_xv) - except NotImplementedError: - # GpuCAReduce don't implement all cases when size is 0 - assert xv.size == 0 + f_xv = f(xv) + assert f_xv.shape == zv.shape, (f_xv, zv) + utt.assert_allclose(zv, f_xv) x = self.type( dtype, shape=tuple(entry if entry == 1 else None for entry in xsh) @@ -570,11 +562,7 @@ def with_mode( scalar_op in [aes.scalar_maximum, aes.scalar_minimum] and (xsh == () or np.prod(xsh) == 0) ): - try: - assert all(f(xv) == zv.shape) - except NotImplementedError: - # GpuCAReduce don't implement all cases when size is 0 - assert xv.size == 0 + assert all(f(xv) == zv.shape) def test_perform_noopt(self): self.with_mode(Mode(linker="py", optimizer=None), aes.add, dtype="floatX") @@ -691,12 +679,12 @@ def test_str(self): op = CAReduce(aes.add, axis=None) assert str(op) == "CAReduce{add}" op = CAReduce(aes.add, axis=(1,)) - assert str(op) == "CAReduce{add}{1}" + assert str(op) == "CAReduce{add}{axis=[1]}" - op = CAReduceDtype(aes.add, axis=None, acc_dtype="float64") - assert str(op) == "CAReduceDtype{add}{acc_dtype=float64}" - op = CAReduceDtype(aes.add, axis=(1,), acc_dtype="float64") - assert str(op) == "CAReduceDtype{add}{axis=[1], acc_dtype=float64}" + op = CAReduce(aes.add, axis=None, acc_dtype="float64") + assert str(op) == "CAReduce{add}{acc_dtype=float64}" + op = CAReduce(aes.add, axis=(1,), acc_dtype="float64") + assert str(op) == "CAReduce{add}{axis=[1], acc_dtype=float64}" def test_repeated_axis(self): x = vector("x")