diff --git a/pytensor/graph/rewriting/db.py b/pytensor/graph/rewriting/db.py index f303c1840e..645faf9911 100644 --- a/pytensor/graph/rewriting/db.py +++ b/pytensor/graph/rewriting/db.py @@ -427,7 +427,7 @@ def query( position_cutoff = tags[0].position_cutoff # The RewriteDatabaseQuery instance might contain extra rewrites which need - # to be added the the sequence of rewrites (don't alter the + # to be added to the sequence of rewrites (don't alter the # original dictionary) if len(tags[0].extra_rewrites) > 0: position_dict = position_dict.copy() diff --git a/pytensor/printing.py b/pytensor/printing.py index 8b24884944..1042a5897a 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -312,7 +312,11 @@ def debugprint( ): if hasattr(var.owner, "op"): - if isinstance(var.owner.op, HasInnerGraph) and var not in inner_graph_vars: + if ( + isinstance(var.owner.op, HasInnerGraph) + or hasattr(var.owner.op, "scalar_op") + and isinstance(var.owner.op.scalar_op, HasInnerGraph) + ) and var not in inner_graph_vars: inner_graph_vars.append(var) if print_op_info: op_information.update(op_debug_information(var.owner.op, var.owner)) @@ -355,8 +359,12 @@ def debugprint( inner_inputs = inner_fn.maker.fgraph.inputs inner_outputs = inner_fn.maker.fgraph.outputs else: - inner_inputs = ig_var.owner.op.inner_inputs - inner_outputs = ig_var.owner.op.inner_outputs + if hasattr(ig_var.owner.op, "scalar_op"): + inner_inputs = ig_var.owner.op.scalar_op.inner_inputs + inner_outputs = ig_var.owner.op.scalar_op.inner_outputs + else: + inner_inputs = ig_var.owner.op.inner_inputs + inner_outputs = ig_var.owner.op.inner_outputs outer_inputs = ig_var.owner.inputs @@ -422,8 +430,9 @@ def debugprint( if ( isinstance(getattr(out.owner, "op", None), HasInnerGraph) - and out not in inner_graph_vars - ): + or hasattr(getattr(out.owner, "op", None), "scalar_op") + and isinstance(out.owner.op.scalar_op, HasInnerGraph) + ) and out not in inner_graph_vars: inner_graph_vars.append(out) _debugprint( @@ -664,8 +673,9 @@ def get_id_str( if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"): if ( isinstance(in_var.owner.op, HasInnerGraph) - and in_var not in inner_graph_ops - ): + or hasattr(in_var.owner.op, "scalar_op") + and isinstance(in_var.owner.op.scalar_op, HasInnerGraph) + ) and in_var not in inner_graph_ops: inner_graph_ops.append(in_var) _debugprint( diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index f428f2528b..fee02684fe 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4000,7 +4000,8 @@ class Composite(ScalarOp, HasInnerGraph): init_param: Tuple[str, ...] = ("inputs", "outputs") - def __init__(self, inputs, outputs): + def __init__(self, inputs, outputs, name="Composite"): + self.name = name # We need to clone the graph as sometimes its nodes already # contain a reference to an fgraph. As we want the Composite # to be pickable, we can't have reference to fgraph. @@ -4106,30 +4107,6 @@ def _perform(*inputs, outputs=[[None]]): self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert) return self._py_perform_fn - @property - def name(self): - if hasattr(self, "_name"): - return self._name - - # TODO FIXME: Just implement pretty printing for the `Op`; don't do - # this redundant, outside work in the `Op` itself. - for i, r in enumerate(self.fgraph.inputs): - r.name = f"i{int(i)}" - for i, r in enumerate(self.fgraph.outputs): - r.name = f"o{int(i)}" - io = set(self.fgraph.inputs + self.fgraph.outputs) - for i, r in enumerate(self.fgraph.variables): - if r not in io and len(self.fgraph.clients[r]) > 1: - r.name = f"t{int(i)}" - outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs]) - rval = f"Composite{{{outputs_str}}}" - self._name = rval - return self._name - - @name.setter - def name(self, name): - self._name = name - @property def fgraph(self): if hasattr(self, "_fgraph"): @@ -4146,6 +4123,21 @@ def fgraph(self): "The fgraph to Composite must be exclusively" " composed of ScalarOp instances." ) + + # Clone identical outputs that have been merged + if len(set(fgraph.outputs)) != len(self.outputs): + old_outputs = fgraph.outputs + new_outputs = [] + for output in old_outputs: + if output not in new_outputs: + new_outputs.append(output) + else: + node = output.owner + output_idx = node.outputs.index(output) + new_output = node.clone().outputs[output_idx] + new_outputs.append(new_output) + fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False) + self._fgraph = fgraph return self._fgraph diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 6733c62a86..b24d9d0341 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1638,9 +1638,9 @@ def compute_grad_2f1(a, b, c, z, wrt): return compute_grad_2f1(a, b, c, z, wrt=wrt) - def __call__(self, a, b, c, z, wrt): + def __call__(self, a, b, c, z, wrt, **kwargs): # This allows wrt to be a keyword argument - return super().__call__(a, b, c, z, wrt) + return super().__call__(a, b, c, z, wrt, **kwargs) def c_code(self, *args, **kwargs): raise NotImplementedError() diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 4ac1ffdd33..17d4c5776b 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -652,10 +652,10 @@ def transform(r): def prepare_node(self, node, storage_map, compute_map, impl): # Postpone the ufunc building to the last minutes due to: - # - NumPy ufunc support only up to 31 inputs. + # - NumPy ufunc support only up to 32 operands (inputs and outputs) # But our c code support more. # - nfunc is reused for scipy and scipy is optional - if len(node.inputs) > 32 and self.ufunc and impl == "py": + if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py": impl = "c" if getattr(self, "nfunc_spec", None) and impl != "c": @@ -677,7 +677,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): self.nfunc = module if ( - len(node.inputs) < 32 + (len(node.inputs) + len(node.outputs)) <= 32 and (self.nfunc is None or self.scalar_op.nin != len(node.inputs)) and self.ufunc is None and impl == "py" @@ -727,28 +727,18 @@ def prepare_node(self, node, storage_map, compute_map, impl): self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl) def perform(self, node, inputs, output_storage): - if len(node.inputs) >= 32: + if (len(node.inputs) + len(node.outputs)) > 32: # Some versions of NumPy will segfault, other will raise a - # ValueError, if the number of inputs to a ufunc is 32 or more. + # ValueError, if the number of operands in an ufunc is more than 32. # In that case, the C version should be used, or Elemwise fusion # should be disabled. + # FIXME: This no longer calls the C implementation! super().perform(node, inputs, output_storage) for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))): if len(set(dim_shapes) - {1}) > 1: raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}") - # Determine the shape of outputs - out_shape = [] - for values in zip(*[input.shape for input in inputs]): - if any(v == 0 for v in values): - # All non-broadcasted dimensions should be zero - assert max(values) <= 1 - out_shape.append(0) - else: - out_shape.append(max(values)) - out_shape = tuple(out_shape) - ufunc_args = inputs ufunc_kwargs = {} # We supported in the past calling manually op.perform. diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e9952a3908..19c5eabd03 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -1,25 +1,27 @@ import sys -import time -from collections import defaultdict -from typing import Optional +from collections import defaultdict, deque +from functools import lru_cache +from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar from warnings import warn import pytensor import pytensor.scalar.basic as aes -from pytensor import compile +from pytensor import clone_replace, compile from pytensor.compile.mode import get_target_language from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Constant, io_toposort +from pytensor.graph import FunctionGraph +from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort from pytensor.graph.features import ReplaceValidate -from pytensor.graph.op import compute_test_value, get_test_value +from pytensor.graph.fg import ApplyOrOutput from pytensor.graph.rewriting.basic import ( + EquilibriumGraphRewriter, GraphRewriter, copy_stack_trace, in2out, node_rewriter, ) from pytensor.graph.rewriting.db import SequenceDB -from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError +from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -58,6 +60,14 @@ def print_profile(cls, stream, prof, level=0): for n in sorted(ndim.keys()): print(blanc, n, ndim[n], file=stream) + def candidate_input_idxs(self, node): + if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1: + # TODO: Implement specialized InplaceCompositeOptimizer with logic + # needed to correctly assign inplace for multi-output Composites + return [] + else: + return range(len(node.outputs)) + def apply(self, fgraph): r""" @@ -148,7 +158,7 @@ def apply(self, fgraph): baseline = op.inplace_pattern candidate_outputs = [ - i for i in range(len(node.outputs)) if i not in baseline + i for i in self.candidate_input_idxs(node) if i not in baseline ] # node inputs that are Constant, already destroyed, # or fgraph protected inputs and fgraph outputs can't be used as @@ -166,7 +176,7 @@ def apply(self, fgraph): ] else: baseline = [] - candidate_outputs = list(range(len(node.outputs))) + candidate_outputs = self.candidate_input_idxs(node) # node inputs that are Constant, already destroyed, # fgraph protected inputs and fgraph outputs can't be used as inplace # target. @@ -529,333 +539,492 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): return rval -def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): - r"""Create a recursive function that fuses `Elemwise` `Op`\s. - - The basic idea is that we loop through an `Elemwise` node's inputs, find - other `Elemwise` nodes, determine the scalars input types for all of the - `Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types - and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a - new "fused" `Elemwise`. +@node_rewriter([Elemwise]) +def local_add_mul_fusion(fgraph, node): + """Fuse consecutive add or mul in one such node with more inputs. - It's parameterized in order to work for `Elemwise` `Op`\s. + It is better to fuse add/mul that way then in a Composite node as + this make the inner graph of the Composite smaller. This allows to + put more computation in a Composite before hitting the max + recursion limit when pickling Composite. - Parameters - ---------- - op_class : type - `Elemwise` class (the one that we want to fuse) - max_input_fct : callable - A function that returns the maximum number of inputs that this `Elemwise` - can take. - On the CPU we limit to 32 input variables since that is the maximum - NumPy support. + This rewrite is almost useless after the AlgebraicCanonizer is used, + but it catches a few edge cases that are not canonicalized by it + """ + if not isinstance(node.op, Elemwise) or not isinstance( + node.op.scalar_op, (aes.Add, aes.Mul) + ): + return False - maker: callable - A function with the signature ``(node, *args)`` that constructs an - `op_class` instance (e.g. ``op_class(*args)``). + s_op = node.op.scalar_op.__class__ + new_inp = [] + fused = False + nb_inputs = len(node.inputs) + max_inputs = float("inf") + if hasattr(node.op, "max_inputs"): + max_inputs = node.op.max_inputs(node) + for inp in node.inputs: + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, s_op) + and + # Do not duplicate the operation. + len(fgraph.clients[inp]) == 1 + and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs + ): + new_inp.extend(inp.owner.inputs) + fused = True + else: + new_inp.append(inp) - """ - if maker is None: + # We can not compare the number of inputs as Mul and Add could have + # 0 or 1 inputs in some corner cases. + if fused: + output = node.op(*new_inp) + copy_stack_trace(node.outputs[0], output) - def maker(node, scalar_op): - return op_class(scalar_op) + # Do the recursion here to help lower the number of + # FusionOptimizer iteration. + if output.owner: + output2 = local_add_mul_fusion.transform(fgraph, output.owner) + if output2: + return output2 + return [output] - def local_fuse(fgraph, node): - r"""Fuse `Elemwise` `Op`\s in a node. - As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the - same shape. +def elemwise_max_operands_fct(node) -> int: + # `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs) + if not config.cxx: + return 32 + return 1024 - For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C - compiler do the cast. - The number of dimensions is validated at call time by PyTensor itself. +class FusionOptimizer(GraphRewriter): + """Graph optimizer that fuses consecutive Elemwise operations.""" - """ - # TODO: use broadcast flag? + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) - # TODO: don't do this rewrite as a `NodeRewriter`. - # Analyze the graph in terms of elemwise subgraphs, and then - # replace each subgraph with a Composite version. + @staticmethod + def elemwise_to_scalar(inputs, outputs): + replace_inputs = [(inp, inp.clone()) for inp in inputs] + outputs = clone_replace(outputs, replace=replace_inputs) - # TODO: use malloc and copy to transfer arguments that don't - # fit within the parameter space of 256 bytes - # - # TODO: Merge with multiple output to merge when an inputs - # have multiple clients. This can't be done with a `NodeRewriter` - - # TODO: Related: Support composites with multiple outputs - - # TODO: Use Composite to combine Elemwise and Reduce - # operations. We have to loop over the data anyway... might - # as well sum it up while we're at it (this can be trickier - # than i'm making it seound here. The data-traversal should be - # done contiguously, and the summing-up might not be easy or - # worthwhile if the summation axis doesn't line up with a - # contiguous dimension) - - if type(node.op) is not op_class: - return False - - if len(node.outputs) > 1: - # We don't support fusion for nodes with multiple outputs. - return - - inputs = [] # inputs of the new Elemwise op. - s_inputs = [] # inputs of the new scalar op used by the Composite. - # Inputs of the new scalar op that represents the current node. - s_g = [] - - # There is a hard limit of 256 bytes for the formal argument list to a - # GPU kernel function. - max_nb_input = max_input_fct(node) - # The number of inputs to the new fused op if we do not fuse more - # inputs. - new_nb_input = len(node.inputs) - # Did we fuse something? - # Needed as we can fuse unary op that don't change the number of - # inputs. - # And there is a case where the inputs are the same as the current - # node. That won't change the number of inputs of the new op. - fused = False - - for i in node.inputs: - scalar_node: Optional[Apply] = None - # Will store inputs of the fused node that are not currently inputs - # of the node we want to create (to avoid duplicating inputs). - tmp_input = [] - # Same as tmp_input, but for scalars. - tmp_scalar = [] - - # We should not check the number of inputs here - # As fusing op don't always change the number of input. - # If a variable is used as multiple into to the same node, - # we still want to fusion. So we take the set. - if ( - i.owner - and isinstance(i.owner.op, op_class) - and len({n for n, idx in fgraph.clients[i]}) == 1 - and - # Do not merge elemwise that don't have the same - # broadcastable pattern to don't redo duplicate - # computation due to broadcast. - i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable - ): - try: - tmp_s_input = [] - # we should not put duplicate input into s_inputs and inputs - for ii in i.owner.inputs: - if ii in inputs: - tmp_s_input.append(s_inputs[inputs.index(ii)]) - elif ii in tmp_input: - tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) - else: - tmp = aes.get_scalar_type(ii.type.dtype).make_variable() - - try: - tv = get_test_value(ii) - # Sometimes the original inputs have - # zero-valued shapes in some dimensions, which - # implies that this whole scalar thing doesn't - # make sense (i.e. we're asking for the scalar - # value of an entry in a zero-dimensional - # array). - # This will eventually lead to an error in the - # `compute_test_value` call below when/if - # `config.compute_test_value_opt` is enabled - # (for debugging, more or less) - tmp.tag.test_value = tv.item() - except (TestValueError, ValueError): - pass - - tmp_s_input.append(tmp) - tmp_input.append(ii) - tmp_scalar.append(tmp_s_input[-1]) - - # Use the `Op.make_node` interface in case `Op.__call__` - # has been customized - scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input) - - if config.compute_test_value_opt != "off": - # This is required because `Op.make_node` won't do it - compute_test_value(scalar_node) - - # If the scalar_op doesn't have a C implementation, we skip - # its fusion to allow fusion of the other ops - i.owner.op.scalar_op.c_code( - scalar_node, - "test_presence_of_c_code", - ["x" for x in i.owner.inputs], - ["z" for z in i.owner.outputs], - {"fail": "%(fail)s"}, - ) + inputs = [inp for _, inp in replace_inputs] + fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) + middle_inputs = [] - except (NotImplementedError, MethodNotDefined): - warn( - "Rewrite warning: " - f"The Op {i.owner.op.scalar_op} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." + scalar_inputs = [ + aes.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + ] + middle_scalar_inputs = [] + + for node in fg.toposort(): + node_scalar_inputs = [] + for inp in node.inputs: + if inp in inputs: + node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) + elif inp in middle_inputs: + node_scalar_inputs.append( + middle_scalar_inputs[middle_inputs.index(inp)] ) - scalar_node = None - - # Compute the number of inputs in case we fuse this input. - # We subtract 1 because we replace the existing input with the new - # inputs from `tmp_input`. - new_nb_input_ = new_nb_input + len(tmp_input) - 1 - - # If the new input is already an input of the current node, it was - # already counted when `new_nb_input` was initialized to - # len(node.inputs). - # This can happen when a variable is used both by the Elemwise to - # fuse and the current node. - for x in tmp_input: - if x in node.inputs: - new_nb_input_ -= 1 - - if scalar_node and (new_nb_input_ <= max_nb_input): - fused = True - new_nb_input = new_nb_input_ - inputs.extend(tmp_input) - s_inputs.extend(tmp_scalar) - s_g.extend(scalar_node.outputs) - else: - # We must support the case where the same variable appears many - # times within the inputs - if inputs.count(i) == node.inputs.count(i): - s = s_inputs[inputs.index(i)] else: - s = aes.get_scalar_type(i.type.dtype).make_variable() - if config.compute_test_value_opt != "off": - try: - v = get_test_value(i) - # See the zero-dimensional test value situation - # described above. - s.tag.test_value = v.item() - except (TestValueError, ValueError): - pass - - inputs.append(i) - s_inputs.append(s) - s_g.append(s) - - if not fused: - return False - - if new_nb_input != len(inputs) or len(s_inputs) != len(inputs): - # TODO FIXME: This shouldn't be a generic `Exception` - raise Exception( - "Something has gone wrong with the elemwise fusion rewrite; skipping." - ) - - s_new_out = node.op.scalar_op(*s_g, return_list=True) - try: - s_new_out[0].owner.op.c_code( - s_new_out[0].owner, - "test_presence_of_c_code", - ["x" for x in s_g], - ["z" for x in s_new_out], - {"fail": "%(fail)s"}, - ) - except (NotImplementedError, MethodNotDefined): - name = str(s_new_out[0].owner.op) - warn( - "Rewrite warning: " - f"The Op {name} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." - ) - return False - - # create the composite op. - composite_op = aes.Composite(s_inputs, s_new_out) - - # create the new node. - # Do not call make_node to have test_value - new_node = maker(node, composite_op)(*inputs).owner - - assert len(new_node.outputs) == 1 - assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype + new_scalar_input = aes.get_scalar_type( + inp.type.dtype + ).make_variable() + node_scalar_inputs.append(new_scalar_input) + middle_scalar_inputs.append(new_scalar_input) + middle_inputs.append(inp) + + new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) + middle_scalar_inputs.append(new_scalar_node.outputs[0]) + middle_inputs.append(node.outputs[0]) + + scalar_outputs = [ + middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs + ] + return scalar_inputs, scalar_outputs - if len(new_node.inputs) > max_nb_input: - warn( - "Loop fusion failed because the resulting node " - "would exceed the kernel argument limit." - ) - return False - - # we fuse as many that we can at the same time to make debug mode faster - # debug mode will be faster as it won't test all intermediate step. - while True: - ret = local_fuse(fgraph, new_node) - if ret is not False and ret is not None: - assert len(ret) == len(new_node.outputs) - assert len(ret) == 1 - new_node = ret[0].owner - else: - break + def apply(self, fgraph): + nb_replacement = 0 - return new_node.outputs + if fgraph.profile: + validate_before = fgraph.profile.validate_time + callbacks_before = fgraph.execute_callbacks_times.copy() + callback_before = fgraph.execute_callbacks_time - return local_fuse + max_operands = elemwise_max_operands_fct(None) + + def find_next_fuseable_subgraph( + fg: FunctionGraph, + ) -> Generator[Tuple[List[Variable], List[Variable]], None, None]: + """Find all subgraphs in a FunctionGraph that can be fused together + + Yields + ------- + List of inputs and outputs that determine subgraphs which can be fused. + This generator assumes that such subgraph is replaced by a single + Elemwise Composite before being accessed again in the next iteration. + """ + + FUSEABLE_MAPPING = DefaultDict[Variable, List[Apply]] + UNFUSEABLE_MAPPING = DefaultDict[Variable, Set[ApplyOrOutput]] + + def initialize_fuseable_mappings( + *, fg: FunctionGraph + ) -> Tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]: + @lru_cache(maxsize=None) + def elemwise_scalar_op_has_c_code(node: Apply) -> bool: + # TODO: This should not play a role in non-c backends! + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + else: + warn( + "Optimization Warning: " + f"The Op {node.op.scalar_op} does not provide a C implementation." + " As well as being potentially slow, this also disables " + "loop fusion." + ) + return False + + # Fuseable nodes have to be accessed in a deterministic manner + # to ensure the rewrite remains deterministic. + # This is not a problem from unfuseable ones, as they can never + # become part of the graph. + fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) + unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) + for out, clients in fg.clients.items(): + out_maybe_fuseable = ( + out.owner + and isinstance(out.owner.op, Elemwise) + # and not isinstance(out.owner.op.scalar_op, aes.Composite) + and len(out.owner.outputs) == 1 + and elemwise_scalar_op_has_c_code(out.owner) + ) + for client, _ in clients: + if ( + out_maybe_fuseable + and not isinstance(client, str) # "output" + and isinstance(client.op, Elemwise) + # and not isinstance(client.op.scalar_op, aes.Composite) + and len(client.outputs) == 1 + and out.type.broadcastable + == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ): + if client not in fuseable_clients[out]: + fuseable_clients[out].append(client) + else: + unfuseable_clients[out].add(client) + + return fuseable_clients, unfuseable_clients + + def find_fuseable_subgraph( + *, + fg: FunctionGraph, + visited_nodes: Set[Apply], + fuseable_clients: FUSEABLE_MAPPING, + unfuseable_clients: UNFUSEABLE_MAPPING, + ) -> Tuple[List[Variable], List[Variable]]: + + KT = TypeVar("KT") + VT = TypeVar("VT", list, set) + + def shallow_clone_defaultdict( + d: DefaultDict[KT, VT] + ) -> DefaultDict[KT, VT]: + new_dict: DefaultDict[KT, VT] = defaultdict(d.default_factory) + new_dict.update({k: v.copy() for k, v in d.items()}) + return new_dict + + def variables_depend_on( + variables, depend_on, stop_search_at=None + ) -> bool: + return any( + a in depend_on + for a in ancestors(variables, blockers=stop_search_at) + ) + toposort = fg.toposort() + for starting_node in toposort: + if starting_node in visited_nodes: + continue -def elemwise_max_input_fct(node): - # `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs. - if not config.cxx: - return 31 - return 1024 + starting_out = starting_node.outputs[0] + if not fuseable_clients.get(starting_out): + visited_nodes.add(starting_node) + continue + subgraph_inputs: List[Variable] = [] + subgraph_outputs: List[Variable] = [] + unfuseable_clients_subgraph: Set[Variable] = set() -local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) + # Shallow cloning of maps so that they can be manipulated in place + fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients) + unfuseable_clients_clone = shallow_clone_defaultdict( + unfuseable_clients + ) + fuseable_nodes_to_visit = deque([starting_node]) + + # We now try to expand as much as possible towards the potentially + # fuseable clients and ancestors to detect the largest possible + # subgraph that can be Composed together into a single `Op`. The + # largest issue to watch out is for cyclical dependencies, where + # some inputs or clients may depend on other nodes of the same + # subgraph via a path that cannot be included in the Composite + # (unfuseable) + while fuseable_nodes_to_visit: + next_node = fuseable_nodes_to_visit.popleft() + visited_nodes.add(next_node) + next_out = next_node.outputs[0] + + # If the output variable of next_node has no fuseable clients + # or has unfuseable clients, then next_node must become an output + # if it is to be fused. + must_become_output = ( + next_out not in fuseable_clients_temp + or next_out in unfuseable_clients_clone + ) -class FusionOptimizer(GraphRewriter): - """Graph rewriter that simply runs node fusion operations. + # We have backtracked to this node, and it may no longer be a viable output, + # so we remove it and check again as if we had never seen this node + if must_become_output and next_out in subgraph_outputs: + subgraph_outputs.remove(next_out) + + required_unfuseable_inputs = [ + inp + for inp in next_node.inputs + if next_node in unfuseable_clients_clone.get(inp, ()) + ] + new_required_unfuseable_inputs = [ + inp + for inp in required_unfuseable_inputs + if inp not in subgraph_inputs + ] + + must_backtrack = False + if new_required_unfuseable_inputs and subgraph_outputs: + # We need to check that any new inputs required by this node + # do not depend on other outputs of the current subgraph, + # via an unfuseable path. + if variables_depend_on( + [next_out], + depend_on=unfuseable_clients_subgraph, + stop_search_at=subgraph_outputs, + ): + must_backtrack = True + + if not must_backtrack: + implied_unfuseable_clients = { + c + for client in unfuseable_clients_clone.get(next_out, ()) + if not isinstance(client, str) # "output" + for c in client.outputs + } + + new_implied_unfuseable_clients = ( + implied_unfuseable_clients - unfuseable_clients_subgraph + ) - TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that. + if new_implied_unfuseable_clients and subgraph_inputs: + # We need to check that any inputs of the current subgraph + # do not depend on other clients of this node, + # via an unfuseable path. + if variables_depend_on( + subgraph_inputs, + depend_on=new_implied_unfuseable_clients, + ): + must_backtrack = True + + if must_backtrack: + for inp in next_node.inputs: + if ( + inp.owner in visited_nodes + # next_node could have the same input repeated + and next_node in fuseable_clients_temp[inp] + ): + fuseable_clients_temp[inp].remove(next_node) + unfuseable_clients_clone[inp].add(next_node) + # This input must become an output of the subgraph, + # because it can't be merged with next_node. + # We will revisit it to make sure this is safe. + fuseable_nodes_to_visit.appendleft(inp.owner) + + for client in fuseable_clients_temp[next_out]: + if client in visited_nodes: + fuseable_clients_temp[next_out].remove(client) + unfuseable_clients_clone[next_out].add(client) + # next_out must become an input of the subgraph. + # We will revisit any of its clients currently + # in the subgraph to make sure this is safe. + fuseable_nodes_to_visit.appendleft(client) + + # Revisit node at a later time + visited_nodes.remove(next_node) + continue + + # Adding next_node to subgraph does not result in any + # immediate dependency problems. Update subgraph + # mappings as if it next_node was part of it. + # Useless inputs will be removed by the useless Composite rewrite + for inp in new_required_unfuseable_inputs: + if inp not in subgraph_inputs: + subgraph_inputs.append(inp) + + if must_become_output: + subgraph_outputs.append(next_out) + unfuseable_clients_subgraph.update( + new_implied_unfuseable_clients + ) - """ + # Expand through unvisited fuseable ancestors + for inp in sorted( + ( + inp + for inp in next_node.inputs + if ( + inp not in required_unfuseable_inputs + and inp.owner not in visited_nodes + ) + ), + key=lambda inp: toposort.index(inp.owner), + reverse=True, + ): + fuseable_nodes_to_visit.appendleft(inp.owner) + + # Expand through unvisited fuseable clients + for next_node in sorted( + ( + node + for node in fuseable_clients_temp.get(next_out, ()) + if node not in visited_nodes + ), + key=lambda node: toposort.index(node), + ): + fuseable_nodes_to_visit.append(next_node) + + # Don't return if final subgraph is just the original Elemwise + if len(subgraph_outputs) == 1 and set( + subgraph_outputs[0].owner.inputs + ) == set(subgraph_inputs): + # Update global fuseable mappings + # No input was actually fuseable + for inp in starting_node.inputs: + if starting_node in fuseable_clients.get(inp, ()): + fuseable_clients[inp].remove(starting_node) + unfuseable_clients[inp].add(starting_node) + # No client was actually fuseable + unfuseable_clients[starting_out].update( + fuseable_clients.pop(starting_out, ()) + ) + continue - def __init__(self, node_rewriter): - super().__init__() - self.node_rewriter = node_rewriter + return subgraph_inputs, subgraph_outputs + raise ValueError + + def update_fuseable_mappings_after_fg_replace( + *, + fg: FunctionGraph, + visited_nodes: Set[Apply], + fuseable_clients: FUSEABLE_MAPPING, + unfuseable_clients: UNFUSEABLE_MAPPING, + starting_nodes: Set[Apply], + ) -> None: + # Find new composite node and dropped intermediate nodes + # by comparing the current fg.apply nodes with the cached + # original nodes + next_nodes = fg.apply_nodes + (new_composite_node,) = next_nodes - starting_nodes + dropped_nodes = starting_nodes - next_nodes + + # Remove intermediate Composite nodes from mappings + for dropped_node in dropped_nodes: + (dropped_out,) = dropped_node.outputs + fuseable_clients.pop(dropped_out, None) + unfuseable_clients.pop(dropped_out, None) + visited_nodes.remove(dropped_node) + + # Update fuseable information for subgraph inputs + for inp in subgraph_inputs: + if inp in fuseable_clients: + new_fuseable_clients = [ + client + for client in fuseable_clients[inp] + if client not in dropped_nodes + ] + if new_fuseable_clients: + fuseable_clients[inp] = new_fuseable_clients + else: + fuseable_clients.pop(inp) + unfuseable_clients[inp] = ( + unfuseable_clients[inp] - dropped_nodes + ) | {new_composite_node} + + # Update fuseable information for subgraph outputs + for out in new_composite_node.outputs: + unfuseable_clients[out] = {client for client, _ in fg.clients[out]} + + visited_nodes.add(new_composite_node) + return + + # We start by creating two maps, 1) from each node to each potentially + # fuseable client (both nodes must be single output Elemwise with same + # broadcast type) and 2) from each node to each certainly unfuseable + # client (those that don't fit into 1)) + fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) + visited_nodes: Set[Apply] = set() + while True: + starting_nodes = fg.apply_nodes.copy() + try: + subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( + fg=fg, + visited_nodes=visited_nodes, + fuseable_clients=fuseable_clients, + unfuseable_clients=unfuseable_clients, + ) + except ValueError: + return + else: + # The caller is now expected to update fg in place, + # by replacing the subgraph with a Composite Op + yield subgraph_inputs, subgraph_outputs + + # This is where we avoid repeated work by using a stateful + # generator. For large models (as in `TestFusion.test_big_fusion`) + # this can provide huge speedups + update_fuseable_mappings_after_fg_replace( + fg=fg, + visited_nodes=visited_nodes, + fuseable_clients=fuseable_clients, + unfuseable_clients=unfuseable_clients, + starting_nodes=starting_nodes, + ) - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) + for inputs, outputs in find_next_fuseable_subgraph(fgraph): + if (len(inputs) + len(outputs)) > max_operands: + warn( + "Loop fusion failed because the resulting node would exceed " + "the kernel argument limit." + ) + break - def apply(self, fgraph): - did_something = True - nb_iter = 0 - nb_replacement = 0 - nb_inconsistency_replace = 0 - time_toposort = 0 - if fgraph.profile: - validate_before = fgraph.profile.validate_time - callbacks_before = fgraph.execute_callbacks_times.copy() - callback_before = fgraph.execute_callbacks_time - while did_something: - t0 = time.perf_counter() - nodelist = list(fgraph.toposort()) - time_toposort += time.perf_counter() - t0 - nodelist.reverse() - did_something = False - for node in nodelist: - # Don't try to fuse node that have already been fused. - if node in fgraph.apply_nodes: - new_outputs = self.node_rewriter(fgraph, node) - if new_outputs: - assert len(new_outputs) == len(node.outputs) - try: - fgraph.replace_all_validate( - list(zip(node.outputs, new_outputs)), - reason=self.__class__.__name__, - ) - did_something = True - nb_replacement += 1 - except InconsistencyError: - nb_inconsistency_replace += 1 - nb_iter += 1 + scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) + composite_outputs = Elemwise(aes.Composite(scalar_inputs, scalar_outputs))( + *inputs + ) + if not isinstance(composite_outputs, list): + composite_outputs = [composite_outputs] + for old_out, composite_out in zip(outputs, composite_outputs): + if old_out.name: + composite_out.name = old_out.name + + fgraph.replace_all_validate( + list(zip(outputs, composite_outputs)), + reason=self.__class__.__name__, + ) + nb_replacement += 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -870,21 +1039,22 @@ def apply(self, fgraph): validate_time = None callback_time = None callbacks_time = {} + return ( self, - nb_iter, + 1, # nb_iter nb_replacement, - nb_inconsistency_replace, + 0, # nb_inconsintency_replace validate_time, callback_time, callbacks_time, - time_toposort, + -1, # toposort_time ) - @classmethod - def print_profile(cls, stream, prof, level=0): + @staticmethod + def print_profile(stream, prof, level=0): blanc = " " * level - print(blanc, cls.__name__, file=stream) + print(blanc, "FusionOptimizer", file=stream) print(blanc, " nb_iter", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream) @@ -901,9 +1071,16 @@ def print_profile(cls, stream, prof, level=0): if config.tensor__local_elemwise_fusion: # Must be after gpu(48.5) and before AddDestroyHandler(49.5) fuse_seqopt = SequenceDB() + fuse_seqopt.register( + "local_add_mul_fusion", + EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000), + "fast_run", + "fusion", + position=0, + ) fuse_seqopt.register( "composite_elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), + FusionOptimizer(), "fast_run", "fusion", position=1, @@ -917,35 +1094,38 @@ def print_profile(cls, stream, prof, level=0): "FusionOptimizer", position=49, ) -else: - compile.optdb.register( # type: ignore - "elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) @register_canonicalize +@register_specialize @node_rewriter([Elemwise]) def local_useless_composite(fgraph, node): - """For elemwise Composite that have multiple outputs, remove the - outputs that are not used. - - """ + """Remove inputs and outputs of Composite Ops that are not used anywhere.""" if not isinstance(node.op, Elemwise) or not isinstance( node.op.scalar_op, aes.Composite ): return comp = node.op.scalar_op - idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]] - if len(idx) < len(node.outputs): - new_outputs = [comp.outputs[i] for i in idx] - c = aes.Composite(inputs=comp.inputs, outputs=new_outputs) - e = Elemwise(scalar_op=c)(*node.inputs, return_list=True) - return dict(zip([node.outputs[i] for i in idx], e)) + used_outputs_idxs = [ + i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern] + ] + used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs] + comp_fgraph = FunctionGraph( + inputs=comp.inputs, outputs=used_inner_outputs, clone=False + ) + used_inputs_idxs = [ + i + for i, i_intern in enumerate(comp_fgraph.inputs) + if comp_fgraph.clients[i_intern] + ] + used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs] + if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len( + node.outputs + ): + used_inputs = [node.inputs[i] for i in used_inputs_idxs] + c = aes.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) + e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) + return dict(zip([node.outputs[i] for i in used_outputs_idxs], e)) @node_rewriter([CAReduce]) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 7cf495b56a..c2efda11d0 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -92,7 +92,6 @@ register_uncanonicalize, register_useless, ) -from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -2966,66 +2965,6 @@ def check_input(inputs): return [ret] -def local_add_mul_fusion(fgraph, node): - """Fuse consecutive add or mul in one such node with more inputs. - - It is better to fuse add/mul that way then in a Composite node as - this make the inner graph of the Composite smaller. This allow to - put more computation in a Composite before hitting the max - recursion limit when pickling Composite. - - """ - if not isinstance(node.op, Elemwise) or not isinstance( - node.op.scalar_op, (aes.Add, aes.Mul) - ): - return False - - s_op = node.op.scalar_op.__class__ - new_inp = [] - fused = False - nb_inputs = len(node.inputs) - max_inputs = float("inf") - if hasattr(node.op, "max_inputs"): - max_inputs = node.op.max_inputs(node) - for inp in node.inputs: - if ( - inp.owner - and isinstance(inp.owner.op, Elemwise) - and isinstance(inp.owner.op.scalar_op, s_op) - and - # Do not duplicate the operation. - len(fgraph.clients[inp]) == 1 - and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs - ): - new_inp.extend(inp.owner.inputs) - fused = True - else: - new_inp.append(inp) - - # We can not compare the number of inputs as Mul and Add could have - # 0 or 1 inputs in some corner cases. - if fused: - output = node.op(*new_inp) - copy_stack_trace(node.outputs[0], output) - - # Do the recursion here to help lower the number of - # FusionOptimizer iteration. - if output.owner: - output2 = local_add_mul_fusion(fgraph, output.owner) - if output2: - return output2 - return [output] - - -fuse_seqopt.register( - "local_add_mul_fusion", - FusionOptimizer(local_add_mul_fusion), - "fast_run", - "fusion", - position=0, -) - - def _skip_mul_1(r): if r.owner and r.owner.op == mul: not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 307dcb5572..8800294c74 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py pytensor/tensor/random/op.py pytensor/tensor/random/utils.py pytensor/tensor/rewriting/basic.py -pytensor/tensor/rewriting/elemwise.py pytensor/tensor/shape.py pytensor/tensor/slinalg.py pytensor/tensor/subtensor.py diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index c37ed77dea..d9c4548d24 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -2,7 +2,7 @@ import pytest import pytensor.tensor as at -from pytensor.compile import UnusedInputError +from pytensor.compile import UnusedInputError, get_mode from pytensor.compile.function import function, pfunc from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.io import In @@ -200,7 +200,12 @@ def test_shared_mutable(self): bval = np.arange(5) b.set_value(bval, borrow=True) bval = data_of(b) - f = pfunc([], [b_out], updates=[(b, (b_out + 3))], mode="FAST_RUN") + f = pfunc( + [], + [b_out], + updates=[(b, (b_out + 3))], + mode=get_mode("FAST_RUN").excluding("fusion"), + ) assert (f() == (np.arange(5) * 2)).all() # because of the update assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all() diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index d05b5ab95d..8fbf026e11 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -11,6 +11,7 @@ from pytensor import config, function from pytensor.compile.ops import deep_copy_op from pytensor.compile.sharedvalue import SharedVariable +from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as at_elemwise @@ -548,10 +549,25 @@ def test_logsumexp_benchmark(size, axis, benchmark): rng = np.random.default_rng(23920) X_val = rng.normal(size=size) - X_lse_fn = pytensor.function([X], X_lse, mode="JAX") + X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") # JIT compile first _ = X_lse_fn(X_val) res = benchmark(X_lse_fn, X_val) exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) np.testing.assert_array_almost_equal(res, exp_res) + + +def test_fused_elemwise_benchmark(benchmark): + rng = np.random.default_rng(123) + size = 100_000 + x = pytensor.shared(rng.normal(size=size), name="x") + mu = pytensor.shared(rng.normal(size=size), name="mu") + + logp = -((x - mu) ** 2) / 2 + grad_logp = grad(logp.sum(), x) + + func = pytensor.function([], [logp, grad_logp], mode="NUMBA") + # JIT compile first + func() + benchmark(func) diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 02dc2001f8..c27f220c06 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -156,6 +156,17 @@ def test_many_outputs(self): fn = make_function(DualLinker().accept(g)) assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] + def test_identical_outputs(self): + x, y, z = floats("xyz") + e0 = x + y + z + e1 = x + y + z + e2 = x / y + C = Composite([x, y, z], [e0, e1, e2]) + c = C.make_node(x, y, z) + g = FunctionGraph([x, y, z], c.outputs) + fn = make_function(DualLinker().accept(g)) + assert fn(1.0, 2.0, 3.0) == [6.0, 6.0, 0.5] + def test_composite_printing(self): x, y, z = floats("xyz") e0 = x + y + z @@ -172,12 +183,7 @@ def test_composite_printing(self): make_function(DualLinker().accept(g)) assert str(g) == ( - "FunctionGraph(*1 -> Composite{((i0 + i1) + i2)," - " (i0 + (i1 * i2)), (i0 / i1), " - "(i0 // 5), " - "(-i0), (i0 - i1), ((i0 ** i1) + (-i2))," - " (i0 % 3)}(x, y, z), " - "*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" + "FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" ) def test_non_scalar_error(self): diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 2a8adbbe30..7373c78cb4 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -604,31 +604,40 @@ def no_shared_fn(n, x_tm1, M): out = pytensor.function([M], out, updates=updates, mode="FAST_RUN") expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0) - |TensorConstant{20000} [id B] (n_steps) - |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0) - |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0) - | |AllocEmpty{dtype='int64'} [id E] 0 - | | |TensorConstant{20000} [id B] - | |TensorConstant{(1,) of 0} [id F] - | |ScalarConstant{1} [id G] - | [id H] (outer_in_non_seqs-0) + |TensorConstant{20000} [id B] (n_steps) + |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0) + |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0) + | |AllocEmpty{dtype='int64'} [id E] 0 + | | |TensorConstant{20000} [id B] + | |TensorConstant{(1,) of 0} [id F] + | |ScalarConstant{1} [id G] + | [id H] (outer_in_non_seqs-0) Inner graphs: forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0) - >Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0) - > |TensorConstant{0} [id J] - > |Subtensor{int64, int64, uint8} [id K] - > | |*2- [id L] -> [id H] (inner_in_non_seqs-0) - > | |ScalarFromTensor [id M] - > | | |*0- [id N] -> [id C] (inner_in_seqs-0) - > | |ScalarFromTensor [id O] - > | | |*1- [id P] -> [id D] (inner_in_sit_sot-0) - > | |ScalarConstant{0} [id Q] - > |TensorConstant{1} [id R] + >Elemwise{Composite} [id I] (inner_out_sit_sot-0) + > |TensorConstant{0} [id J] + > |Subtensor{int64, int64, uint8} [id K] + > | |*2- [id L] -> [id H] (inner_in_non_seqs-0) + > | |ScalarFromTensor [id M] + > | | |*0- [id N] -> [id C] (inner_in_seqs-0) + > | |ScalarFromTensor [id O] + > | | |*1- [id P] -> [id D] (inner_in_sit_sot-0) + > | |ScalarConstant{0} [id Q] + > |TensorConstant{1} [id R] + + Elemwise{Composite} [id I] + >Switch [id S] + > |LT [id T] + > | | [id U] + > | | [id V] + > | [id W] + > | [id U] """ output_str = debugprint(out, file="str", print_op_info=True) + print(output_str) lines = output_str.split("\n") for truth, out in zip(expected_output.split("\n"), lines): diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index f28c95037f..144fcf0eb9 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -16,7 +16,7 @@ from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph -from pytensor.printing import pprint +from pytensor.printing import debugprint, pprint from pytensor.raise_op import Assert, CheckAndRaise from pytensor.tensor.basic import ( Alloc, @@ -1105,7 +1105,7 @@ def test_elemwise_float_ops(self, op): s2 = at.switch(c, x, y) g = rewrite(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 @pytest.mark.parametrize( "op", @@ -1122,7 +1122,7 @@ def test_elemwise_int_ops(self, op): s1 = at.switch(c, a, b) s2 = at.switch(c, x, y) g = rewrite(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 @pytest.mark.parametrize("op", [add, mul]) def test_elemwise_multi_inputs(self, op): @@ -1134,7 +1134,7 @@ def test_elemwise_multi_inputs(self, op): u, v = matrices("uv") s3 = at.switch(c, u, v) g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 class TestLocalOptAlloc: diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index e8dce3e5ff..deaf92ef43 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1,26 +1,29 @@ -import contextlib - import numpy as np import pytest import pytensor -import pytensor.scalar as aes -import pytensor.tensor as at +from pytensor import In +from pytensor import scalar as aes from pytensor import shared +from pytensor import tensor as at from pytensor.compile.function import function from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config +from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.misc.safe_asarray import _asarray +from pytensor.raise_op import assert_op from pytensor.scalar.basic import Composite from pytensor.tensor.basic import MakeVector from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.math import abs as at_abs +from pytensor.tensor.math import add +from pytensor.tensor.math import all as at_all from pytensor.tensor.math import ( - add, bitwise_and, bitwise_or, cos, @@ -28,6 +31,7 @@ dot, eq, exp, + ge, int_div, invert, iround, @@ -44,7 +48,7 @@ from pytensor.tensor.math import sin, sinh, sqr, sqrt from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import tan, tanh, true_div, xor -from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift +from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape from pytensor.tensor.shape import reshape from pytensor.tensor.type import ( @@ -263,9 +267,8 @@ def test_local_useless_dimshuffle_in_reshape(): class TestFusion: rewrites = RewriteDatabaseQuery( include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", "canonicalize", + "fusion", "inplace", ], exclude=["cxx_only", "BlasOpt"], @@ -299,6 +302,29 @@ def my_init(dtype="float64", num=0): fwx = fw + fx ftanx = tan(fx) + def large_fuseable_graph(self, n): + factors = [] + sd = dscalar() + means = dvector() + + cst_05 = at.constant(0.5) + cst_m05 = at.constant(-0.5) + cst_2 = at.constant(2) + cst_m2 = at.constant(-2) + ones = at.constant(np.ones(10)) + + for i in range(n): + f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( + cst_05 * (sd**cst_m2) / np.pi + ) + factors.append(at_sum(f)) + + logp = add(*factors) + + vars = [sd, means] + dlogp = [pytensor.grad(logp, v) for v in vars] + return vars, dlogp + @pytest.mark.parametrize( "case", [ @@ -880,6 +906,7 @@ def my_init(dtype="float64", num=0): 1, fxv * np.tan(fxv) * np.tan(fxv) * fxv, "float32", + 1e-5, ), ( mul(ftanx, ftanx, fx + fy), @@ -888,6 +915,7 @@ def my_init(dtype="float64", num=0): 1, np.tan(fxv) * np.tan(fxv) * (fxv + fyv), "float32", + 1e-5, ), # 70 # Cases with different broadcast pattern. They should not # be merged as this would duplicate computation @@ -900,33 +928,115 @@ def my_init(dtype="float64", num=0): fxv * np.sin(fsv), "float32", ), + # Multiple output cases # 72 + ( + ( + # sum(logp) + at_sum(-((fx - fy) ** 2) / 2), + # grad(logp) + at.grad(at_sum(-((fx - fy) ** 2) / 2), wrt=fx), + ), + (fx, fy), + (fxv, fyv), + 3, + ( + np.sum(-((fxv - fyv) ** 2) / 2), + -(fxv - fyv), + ), + ("float32", "float32"), + ), + # Two Composite graphs that share the same input, but are split by + # a non-elemwise operation (Assert) + ( + ( + log( + ge( + assert_op( + at_abs(fx), + at_all(ge(at_abs(fx), 0)), + ), + 0, + ) + ), + ), + (fx,), + (fxv,), + 4, + (np.zeros_like(fxv),), + ("float32",), + ), + # Two subgraphs that share the same non-fuseable input, but are otherwise + # completely independent + ( + ( + true_div( + mul( + at_sum(fx + 5), # breaks fusion + exp(fx), + ), + (fx + 5), + ), + ), + (fx,), + (fxv,), + 4, + (np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),), + ("float32",), + ), + pytest.param( + ( + (sin(exp(fx)), exp(sin(fx))), + (fx,), + (fxv,), + 1, + (np.sin(np.exp(fxv)), np.exp(np.sin(fxv))), + ("float32", "float32"), + ), + marks=pytest.mark.xfail, # Not implemented yet + ), ], ) def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): """Verify that `Elemwise` fusion works.""" - g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case + if len(case) == 6: + g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case + atol = None + else: + g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype, atol = case if isinstance(out_dtype, dict): out_dtype = out_dtype[config.cast_policy] + if not isinstance(g, (tuple, list)): + g = (g,) + answer = (answer,) + out_dtype = (out_dtype,) + if self._shared is None: f = function(list(sym_inputs), g, mode=self.mode) for x in range(nb_repeat): out = f(*val_inputs) + if not isinstance(out, list): + out = (out,) else: - out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out") - assert out.dtype == g.dtype - f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode) + out = [ + self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out") + for g_, od in zip(g, out_dtype) + ] + assert all(o.dtype == g_.dtype for o, g_ in zip(out, g)) + f = function(sym_inputs, [], updates=list(zip(out, g)), mode=self.mode) for x in range(nb_repeat): f(*val_inputs) - out = out.get_value() + out = [o.get_value() for o in out] - atol = 1e-8 - if out_dtype == "float32": - atol = 1e-6 + if atol is None: + atol = 1e-8 + if any(o == "float32" for o in out_dtype): + atol = 1e-6 - assert np.allclose(out, answer * nb_repeat, atol=atol) + for o, a in zip(out, answer): + np.testing.assert_allclose(o, a * nb_repeat, atol=atol) topo = f.maker.fgraph.toposort() topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] @@ -939,13 +1049,15 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): # input of g, # check that the number of input to the Composite # Elemwise is ok - if len(set(g.owner.inputs)) == len(g.owner.inputs): - expected_len_sym_inputs = sum( - not isinstance(x, Constant) for x in topo_[0].inputs - ) - assert expected_len_sym_inputs == len(sym_inputs) + for g_ in g: + if len(set(g_.owner.inputs)) == len(g_.owner.inputs): + expected_len_sym_inputs = sum( + not isinstance(x, Constant) for x in topo_[0].inputs + ) + assert expected_len_sym_inputs == len(sym_inputs) - assert out_dtype == out.dtype + for od, o in zip(out_dtype, out): + assert od == o.dtype def test_fusion_35_inputs(self): r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" @@ -970,35 +1082,9 @@ def test_fusion_35_inputs(self): @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") def test_big_fusion(self): - # In the past, pickle of Composite generated in that case - # crashed with max recursion limit. So we were not able to - # generate C code in that case. - factors = [] - sd = dscalar() - means = dvector() - - cst_05 = at.constant(0.5) - cst_m05 = at.constant(-0.5) - cst_2 = at.constant(2) - cst_m2 = at.constant(-2) - ones = at.constant(np.ones(10)) - n = 85 - if config.mode in ["DebugMode", "DEBUG_MODE"]: - n = 10 - - for i in range(n): - f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( - cst_05 * (sd**cst_m2) / np.pi - ) - factors.append(at_sum(f)) - - logp = add(*factors) - - vars = [sd, means] - # Make sure that C compilation is used mode = Mode("cvm", self.rewrites) - dlogp = function(vars, [pytensor.grad(logp, v) for v in vars], mode=mode) + dlogp = function(*self.large_fuseable_graph(n=85), mode=mode) # Make sure something was fused assert any( @@ -1006,23 +1092,35 @@ def test_big_fusion(self): for node in dlogp.maker.fgraph.toposort() ) - def test_add_mul_fusion_inplace(self): - - rewrites = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, rewrites) + @pytest.mark.xfail(reason="Fails due to #1244") + def test_add_mul_fusion_precedence(self): + """Test that additions and multiplications are "fused together" before + a `Composite` `Op` is introduced. This fusion is done by canonicalization + """ + x, y, z = vectors("x", "y", "z") + out = log((x + y + z) / (x * y * z)) + f = pytensor.function([x, y, z], out, mode=self.mode) + # There should be a single Composite Op + nodes = f.maker.fgraph.apply_nodes + assert len(nodes) == 1 + (node,) = nodes + assert isinstance(node.op, Elemwise) + scalar_op = node.op.scalar_op + assert isinstance(scalar_op, Composite) + assert [node.op for node in scalar_op.fgraph.toposort()] == [ + # There should be a single mul + aes.mul, + # There should be a single add + aes.add, + aes.true_div, + aes.log, + ] + def test_add_mul_fusion_inplace(self): x, y, z = dmatrices("xyz") out = dot(x, y) + x + y + z - f = function([x, y, z], out, mode=mode) + + f = function([x, y, z], out, mode=self.mode) topo = [n for n in f.maker.fgraph.toposort()] assert len(topo) == 2 assert topo[-1].op.inplace_pattern @@ -1037,6 +1135,34 @@ def test_add_mul_fusion_inplace(self): np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) ) + def test_fusion_multiout_inplace(self): + x = vector("x") + + # Create Composite where inplacing the first non-constant output would corrupt the second output + xs = aes.float64("xs") + outs = ( + Elemwise(Composite([xs], [xs + 1, aes.cos(xs + 1) + xs])) + .make_node(x) + .outputs + ) + + f = pytensor.function( + [In(x, mutable=True)], + outs, + mode=self.mode.including("inplace"), + ) + (composite_node,) = f.maker.fgraph.apply_nodes + + # Destroy map must be None or the last toposorted output + destroy_map = composite_node.op.destroy_map + assert (destroy_map == {}) or ( + destroy_map == {1: [composite_node.inputs.index(x)]} + ) + + res = f([0, 1, 2]) + assert np.allclose(res[0], [1, 2, 3]) + assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2])) + @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") def test_no_c_code(self): r"""Make sure we avoid fusions for `Op`\s without C code implementations.""" @@ -1050,8 +1176,7 @@ def impl(self, x): mode = Mode(linker="cvm") mode._optimizer = mode._optimizer.including( - "local_elemwise_fusion", - "composite_elemwise_fusion", + "fusion", "canonicalize", "inplace", ) @@ -1067,51 +1192,29 @@ def impl(self, x): @pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]]) def test_test_values(self, test_value): - """Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions. - - The test values we're talking about are the ones used when C implementations - are checked. - + """Make sure that `local_elemwise_fusion_op` uses test values correctly + when they have zero dimensions. """ - - rewrites = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, rewrites) - x, y, z = dmatrices("xyz") x.tag.test_value = test_value y.tag.test_value = test_value z.tag.test_value = test_value - if test_value.size == 0: - cm = pytest.raises(ValueError) - else: - cm = contextlib.suppress() - with config.change_flags( compute_test_value="raise", compute_test_value_opt="raise" ): out = x * y + z - with cm: - f = function([x, y, z], out, mode=mode) + f = function([x, y, z], out, mode=self.mode) - if test_value.size != 0: - # Confirm that the fusion happened - assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) - assert len(f.maker.fgraph.toposort()) == 1 + # Confirm that the fusion happened + assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) + assert len(f.maker.fgraph.toposort()) == 1 - x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs - assert np.array_equal( - f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] - ) + assert np.array_equal( + f.maker.fgraph.outputs[0].tag.test_value, + np.full_like(test_value, 2.0), + ) @pytest.mark.parametrize("linker", ["cvm", "py"]) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) @@ -1193,6 +1296,81 @@ def test_CAReduce_multiple_inputs(self, linker, axis): assert out_val.shape == exp_res.shape assert np.allclose(out_val, exp_res) + def test_not_fusing_broadcasted_subgraphs(self): + """Test that broadcasted Elemwise subgraphs are not fused in a single Elemwise Composite Op. + + There are some cases in self.test_elemwise_fusion, but this test confirms that the + fused subgraphs are exactly the expected ones. + """ + xs = vector("xm") + xm = matrix("xs") + + es = log(xs + 5) + em = exp(xm * 5) + esm = es - em + + f = pytensor.function([xs, xm], esm, mode=self.mode) + apply_nodes = f.maker.fgraph.toposort() + assert len(apply_nodes) == 3 + assert isinstance(apply_nodes[0].op, DimShuffle) + # Inner Vector output Composite + assert isinstance(apply_nodes[1].op.scalar_op, Composite) + assert {node.op for node in apply_nodes[1].op.scalar_op.fgraph.apply_nodes} == { + aes.add, + aes.log, + } + # Outer Matrix output Composite + assert isinstance(apply_nodes[2].op.scalar_op, Composite) + assert {node.op for node in apply_nodes[2].op.scalar_op.fgraph.apply_nodes} == { + aes.sub, + aes.exp, + aes.mul, + } + + def test_multiple_outputs_fused_root_elemwise(self): + """Test that a root elemwise output (single layer) is reused when + there is another fused output""" + + # By default, we do not introduce Composite for single layers of Elemwise + x = at.vector("x") + out1 = at.cos(x) + f = pytensor.function([x], out1, mode=self.mode) + nodes = tuple(f.maker.fgraph.apply_nodes) + assert len(nodes) == 1 + assert isinstance(nodes[0].op.scalar_op, aes.Cos) + + # However, when it can be composed with another output, we should not + # compute that root Elemwise twice + out2 = at.log(out1) + f = pytensor.function([x], [out1, out2], mode=self.mode) + nodes = tuple(f.maker.fgraph.apply_nodes) + assert len(nodes) == 1 + assert isinstance(nodes[0].op.scalar_op, Composite) + + def test_eval_benchmark(self, benchmark): + rng = np.random.default_rng(123) + size = 100_000 + x = pytensor.shared(rng.normal(size=size), name="x") + mu = pytensor.shared(rng.normal(size=size), name="mu") + + logp = -((x - mu) ** 2) / 2 + grad_logp = grad(logp.sum(), x) + + func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN") + benchmark(func) + + @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") + def test_rewrite_benchmark(self, benchmark): + inps, outs = self.large_fuseable_graph(n=25) + fg = FunctionGraph(inps, outs) + opt = FusionOptimizer() + + def rewrite_func(): + nb_replacement = opt.apply(fg.clone())[2] + return nb_replacement + + assert benchmark(rewrite_func) == 103 + class TimesN(aes.basic.UnaryScalarOp): """ @@ -1258,22 +1436,37 @@ def test_nested_composite(self): def test_local_useless_composite(self): x = aes.float32() - c = aes.Composite([x], [x + 1, x - 1]) - X = matrix() - o = Elemwise(scalar_op=c)(X) + y = aes.float32() + z = aes.float32() + c = aes.Composite([x, y, z], [x + 1, y - 1]) + X = matrix("X") + Y = matrix("Y") + Z = matrix("Z") + o1, o2 = Elemwise(scalar_op=c)(X, Y, Z) mode = get_default_mode().including("local_useless_composite") - f = function([X], o[0], mode=mode) + f = function([X, Y, Z], [o1, o2], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].inputs) == 2 + assert len(topo[0].outputs) == 2 + res1, res2 = f([[1.0]], [[1.0]], [[np.nan]]) + utt.assert_allclose(res1, [[2.0]]) + utt.assert_allclose(res2, [[0.0]]) + + f = function([X, Y, Z], o1, mode=mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 + assert len(topo[0].inputs) == 1 assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[1.0]]), [[2.0]]) + utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]]) - f = function([X], o[1], mode=mode) + f = function([X, Y, Z], o2, mode=mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 + assert len(topo[0].inputs) == 1 assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[1.0]]), [[0.0]]) + utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) def test_local_useless_dimshuffle_makevector(): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 116c3e4ad0..a662b5b325 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -16,7 +16,7 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Constant, equal_computations +from pytensor.graph.basic import Apply, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( SequentialNodeRewriter, @@ -28,6 +28,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.misc.safe_asarray import _asarray +from pytensor.printing import debugprint from pytensor.tensor import inplace from pytensor.tensor.basic import Alloc, join, switch from pytensor.tensor.blas import Dot22, Gemv @@ -46,7 +47,6 @@ bitwise_or, bitwise_xor, conj, - cos, cosh, deg2rad, dot, @@ -59,14 +59,10 @@ ge, gt, int_div, - invert, - iround, le, log, log1mexp, log1p, - log2, - log10, lt, ) from pytensor.tensor.math import max as at_max @@ -74,11 +70,20 @@ from pytensor.tensor.math import min as at_min from pytensor.tensor.math import minimum, mul, neg, neq from pytensor.tensor.math import pow as at_pow -from pytensor.tensor.math import prod, rad2deg, reciprocal -from pytensor.tensor.math import round as at_round -from pytensor.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub +from pytensor.tensor.math import ( + prod, + rad2deg, + reciprocal, + sgn, + sigmoid, + sinh, + softplus, + sqr, + sqrt, + sub, +) from pytensor.tensor.math import sum as at_sum -from pytensor.tensor.math import tan, tanh, true_div, xor +from pytensor.tensor.math import tanh, true_div, xor from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.math import ( compute_mul, @@ -102,7 +107,6 @@ dvector, fmatrices, fmatrix, - fscalar, ftensor4, fvector, imatrices, @@ -1072,745 +1076,6 @@ def test_cast_in_mul_canonizer(): f([1], [1]) -class TestFusion: - rewrites = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - mode = Mode(get_default_mode().linker, rewrites) - _shared = staticmethod(shared) - topo_exclude = () - - def do(self, mode, shared_fn, shp, nb_repeat=1, assert_len_topo=True, slice=None): - """ - param shared_fn: if None, will use function - verify that the elemwise fusion work - Test with and without DimShuffle - """ - # TODO: disable the canonizer? - def my_init(shp, dtype="float64", num=0): - ret = np.zeros(shp, dtype=dtype) + num - return ret - - fw, fx, fy, fz = ( - tensor(dtype="float32", shape=(None,) * len(shp), name=n) for n in "wxyz" - ) - dw, dx, dy, dz = ( - tensor(dtype="float64", shape=(None,) * len(shp), name=n) for n in "wxyz" - ) - ix, iy, iz = ( - tensor(dtype="int32", shape=(None,) * len(shp), name=n) for n in "xyz" - ) - fv = fvector("v") - fs = fscalar("s") - - fwv = my_init(shp, "float32", 1) - fxv = my_init(shp, "float32", 2) - fyv = my_init(shp, "float32", 3) - fzv = my_init(shp, "float32", 4) - fvv = _asarray(np.random.random(shp[0]), dtype="float32") - fsv = np.asarray(np.random.random(), dtype="float32") - dwv = my_init(shp, "float64", 5) - ixv = _asarray(my_init(shp, num=60), dtype="int32") - iyv = _asarray(my_init(shp, num=70), dtype="int32") - izv = _asarray(my_init(shp, num=70), dtype="int32") - fwx = fw + fx - ftanx = tan(fx) - cases = [ - ( - fx + fy + fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + fzv, - "float32", - ), # 0 - ( - fx * fy * fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv * fzv, - "float32", - ), # 1 - ( - fx + fy * fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv * fzv, - "float32", - ), # 2 - ( - fx * fy + fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv + fzv, - "float32", - ), # 3 - ( - fw + fx + fy + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), # 5 - ( - ((fw + fx) + fy) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + (fx + fy)) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + (fx + fy) + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - fw + (fx + (fy + fz)), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), # 10 - ( - fw * fx * fy * fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv * fxv * fyv * fzv, - "float32", - ), - ( - fw + fx * fy * fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv * fyv * fzv, - "float32", - ), - ( - fx + fy * fz * fx, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv * fzv * fxv, - "float32", - ), - ( - fx * fy + fz + fy, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv + fzv + fyv, - "float32", - ), - ( - fx * fy * fz * fw + fx + fy + fz + fw, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fxv * fyv * fzv * fwv + fxv + fyv + fzv + fwv, - "float32", - ), # 15 - # test with constant - ( - (fw + fx) + (fy + fz) + 2.0, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - ((fw + fx) + 2.0 + fy) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - (fw + (fx + 2.0 + fy)) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - (fw + (fx + fy) + 2 + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - fw + (fx + (fy + fz) + 2.0), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), # 20 - ( - 2 + (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - # mix float32 and float64 - ( - 2 + (dw + fx) + (fy + fz), - (dw, fx, fy, fz), - (dwv, fxv, fyv, fzv), - 1, - dwv + fxv + fyv + fzv + 2, - "float64", - ), - ( - 2 + (fw + dw) + (fy + fz), - (fw, dw, fy, fz), - (fwv, dwv, fyv, fzv), - 1, - fwv + dwv + fyv + fzv + 2, - "float64", - ), - ( - 2 + (fw + fx) + (dw + fz), - (fw, fx, dw, fz), - (fwv, fxv, dwv, fzv), - 1, - fwv + fxv + dwv + fzv + 2, - "float64", - ), - ( - 2 + (fw + fx) + (fy + dw), - (fw, fx, fy, dw), - (fwv, fxv, fyv, dwv), - 1, - fwv + fxv + fyv + dwv + 2, - "float64", - ), # 25 - # test when their is other op then elemwise. - ( - (fwx.sum()) + (fwx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 4, - (fwv + fxv).sum() + fwv + fxv + fyv + fzv, - "float32", - ), - # test other elemwise op - ( - fx + fy + cos(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.cos(fzv), - "float32", - ), - ( - fx + fy + cosh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.cosh(fzv), - "float32", - ), - ( - fx + fy + abs(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.absolute(fzv), - "float32", - ), - ( - ix + iy + abs(iz), - (ix, iy, iz), - (ixv, iyv, izv), - 1, - ixv + iyv + np.absolute(izv), - "int32", - ), # 30 - ( - fx + fy + log(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log(fzv), - "float32", - ), - ( - fx + fy + log2(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log2(fzv), - "float32", - ), - ( - fx + fy + log10(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log10(fzv), - "float32", - ), - ( - fx + fy**fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv**fzv, - "float32", - ), # pow - ( - fx + fy + exp(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.exp(fzv), - "float32", - ), # 35 - ( - fx - fy - fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv - fzv, - "float32", - ), - ( - fx - (fy / fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv / fzv), - "float32", - ), - ( - fx - true_div(fy, 2), - (fx, fy), - (fxv, fyv), - 1, - fxv - (fyv / 2), - "float32", - ), - ( - fx - true_div(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv / fzv), - "float32", - ), - ( - fx - int_div(ix * 100, iy * 1000), - (fx, ix, iy), - (fxv, ixv, iyv), - 1, - fxv - ((ixv * 100) // (iyv * 1000)), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), # 40 - (fx - (fy / 2), (fx, fy), (fxv, fyv), 1, fxv - (fyv / 2), "float32"), - ( - fx - (fy % fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv % fzv), - "float32", - ), - ( - fx - (fy > fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv > fzv), - "float32", - ), - ( - fx - (fy >= fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv >= fzv), - "float32", - ), - ( - fx - (fy < fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv < fzv), - "float32", - ), # 45 - ( - fx - (fy <= fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv <= fzv), - "float32", - ), - ( - fx - eq(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv == fzv), - "float32", - ), - ( - fx - neq(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv != fzv), - "float32", - ), - ( - fx - fy + tan(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.tan(fzv), - "float32", - ), - ( - fx - fy + tanh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.tanh(fzv), - "float32", - ), # 50 - ( - fx - fy + sin(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sin(fzv), - "float32", - ), - ( - fx - fy + sinh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sinh(fzv), - "float32", - ), - ( - fx - fy + sqr(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (fzv * fzv), - "float32", - ), - ( - fx - fy + sqrt(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sqrt(fzv), - "float32", - ), - ( - fx - fy + reciprocal(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (1 / fzv), - "float32", - ), # 55 - ( - fx - fy + neg(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (-fzv), - "float32", - ), - ( - fx - fy + at_round(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.round(fzv), - "float32", - ), - ( - ix - iy + iround(fz), - (ix, iy, fz), - (ixv, iyv, fzv), - 1, - ixv - iyv + np.round(fzv), - "int64", - ), - # Bit op - ( - fx - bitwise_or(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv | izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - xor(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv ^ izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), # 60 - ( - fx - bitwise_and(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv & izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - invert(iy), - (fx, iy), - (fxv, iyv), - 1, - fxv - (~iyv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - at.cast(fy, dtype="float64"), - (fx, fy), - (fxv, fyv), - 1, - fxv - np.asarray(fyv, "float64"), - "float64", - ), - ( - at_pow(fx * fy + fz, fx * fy), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - np.power(fxv * fyv + fzv, fxv * fyv), - "float32", - ), - ( - fv + fy**fz, - (fv, fy, fz), - (fvv, fyv, fzv), - 2, - fvv + fyv**fzv, - "float32", - ), # fused with a dimshuffle #65 - ( - fv - fy + tanh(fz), - (fv, fy, fz), - (fvv, fyv, fzv), - 2, - fvv - fyv + np.tanh(fzv), - "float32", - ), # fused with a dimshuffle - # Cases where the same input is reused many times. - ( - mul(fx, fx, fx, fx), - (fx,), - (fxv,), - 1, - fxv * fxv * fxv * fxv, - "float32", - ), - ( - mul(fx, ftanx, ftanx), - (fx,), - (fxv,), - 1, - fxv * np.tan(fxv) * np.tan(fxv), - "float32", - ), - ( - mul(fx, ftanx, ftanx, fx), - (fx,), - (fxv,), - 1, - fxv * np.tan(fxv) * np.tan(fxv) * fxv, - "float32", - ), - ( - mul(ftanx, ftanx, fx + fy), - (fx, fy), - (fxv, fyv), - 1, - np.tan(fxv) * np.tan(fxv) * (fxv + fyv), - "float32", - ), # 70 - # Cases with different broadcast pattern. They should not - # be merged as this would duplicate computation - # The graph should have 2 elemwise and 1 dimshuffle - ( - fx * sin(fs), - (fx, fs), - (fxv, fsv), - 3, - fxv * np.sin(fsv), - "float32", - ), - ] - if slice: - cases = cases[slice] - times = np.zeros(len(cases)) - fail1 = [] - fail2 = [] - fail3 = [] - fail4 = [] - for ( - id, - [g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype], - ) in enumerate(cases): - if isinstance(out_dtype, dict): - out_dtype = out_dtype[config.cast_policy] - - if shared_fn is None: - f = function(list(sym_inputs), g, mode=mode) - for x in range(nb_repeat): - out = f(*val_inputs) - t1 = time.perf_counter() - else: - out = shared_fn(np.zeros(shp, dtype=out_dtype), "out") - assert out.dtype == g.dtype - f = function(sym_inputs, [], updates=[(out, g)], mode=mode) - t0 = time.perf_counter() - for x in range(nb_repeat): - f(*val_inputs) - t1 = time.perf_counter() - out = out.get_value() - - times[id] = t1 - t0 - atol = 1e-8 - if out_dtype == "float32": - atol = 1e-6 - if not np.allclose(out, answer * nb_repeat, atol=atol): - fail1.append(id) - topo = f.maker.fgraph.toposort() - topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] - if assert_len_topo: - if len(topo_) != nb_elemwise: - fail3.append((id, topo_, nb_elemwise)) - if nb_elemwise == 1: - # if no variable appears multiple times in the - # input of g, - # check that the number of input to the Composite - # Elemwise is ok - if len(set(g.owner.inputs)) == len(g.owner.inputs): - expected_len_sym_inputs = sum( - not isinstance(x, Constant) for x in topo_[0].inputs - ) - assert expected_len_sym_inputs == len(sym_inputs) - - if out_dtype != out.dtype: - fail4.append((id, out_dtype, out.dtype)) - - assert len(fail1 + fail2 + fail3 + fail4) == 0 - - return times - - def test_add_mul_fusion_inplace(self): - - rewrites_query = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, rewrites_query) - - x, y, z = dmatrices("xyz") - out = dot(x, y) + x + y + z - f = function([x, y, z], out, mode=mode) - topo = [n for n in f.maker.fgraph.toposort()] - assert len(topo) == 2 - assert topo[-1].op.inplace_pattern - - new_out = f.maker.fgraph.outputs[0] - assert isinstance(new_out.owner.op, Elemwise) - assert isinstance(new_out.owner.op.scalar_op, aes.basic.Add) - assert len(new_out.owner.inputs) == 4 - - # TODO: Do we really need to do this? - _ = f( - np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) - ) - - @utt.assertFailure_fast def test_log1p(): m = config.mode @@ -3152,7 +2417,7 @@ def test_elemwise(self): at_pow, ): g = rewrite(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 # integer Ops mats = imatrices("cabxy") c, a, b, x, y = mats @@ -3164,13 +2429,13 @@ def test_elemwise(self): bitwise_xor, ): g = rewrite(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 # add/mul with more than two inputs u, v = matrices("uv") s3 = at.switch(c, u, v) for op in (add, mul): g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 class TestLocalSumProd: diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 8395fb4a60..aa275a03a6 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -887,10 +887,9 @@ def test_basic_6(self): prog = f.maker.fgraph.toposort() assert isinstance(prog[0].op, DimShuffle) assert isinstance(prog[1].op.scalar_op, aes.Composite) # Composite{add,exp} - assert prog[2].op == add or prog[3].op == add # first subtensor - assert isinstance(prog[2].op, Subtensor) or isinstance(prog[3].op, Subtensor) - assert len(prog) == 4 + assert isinstance(prog[2].op, Subtensor) + assert len(prog) == 3 f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something def test_basic_7(self): diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 1d0bfe9b21..431c236442 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -986,7 +986,7 @@ def test_adv_sub1_idx_broadcast(self): def test_shape_i_const(self): # Each axis is treated independently by shape_i/shape operators - mode_opt = self.mode.including("fast_run") + mode_opt = self.mode data = self.shared(np.array(np.arange(5), dtype=self.dtype)) for start in [None] + [-8, -5, -1, 0, 1, 5, 8]: outs = [] @@ -1004,7 +1004,7 @@ def test_shape_i_const(self): def test_shape_i_scalar(self): # Each axis is treated independently by shape_i/shape operators - mode_opt = self.mode.including("fast_run") + mode_opt = self.mode v_data = np.array(np.arange(5), dtype=self.dtype) t_data = self.shared(v_data) diff --git a/tests/test_printing.py b/tests/test_printing.py index f5a10c8aeb..d9592dd9af 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -273,8 +273,7 @@ def test_debugprint(): s = s.getvalue() exp_res = dedent( r""" - Elemwise{Composite{(i0 + (i1 - i2))}} 4 - |A + Elemwise{Composite} 4 |InplaceDimShuffle{x,0} v={0: [0]} 3 | |CGemv{inplace} d={0: [0]} 2 | |AllocEmpty{dtype='float64'} 1 @@ -285,6 +284,16 @@ def test_debugprint(): | | | |TensorConstant{0.0} |D + |A + + Inner graphs: + + Elemwise{Composite} + >add + > | + > |sub + > | + > | """ ).lstrip()