diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 008bf48add..73d13cfa6b 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -432,6 +432,12 @@ def c_literal(self, data): return None if self.dtype == "bool": return "1" if data else "0" + if data == np.inf: + return "INFINITY" + if data == -np.inf: + return "-INFINITY" + if np.isnan(data): + return "NAN" return str(data) def c_declare(self, name, sub, check_input=True): diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 217693d0ae..f3fde2995e 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -69,7 +69,7 @@ get_slice_elements, set_subtensor, ) -from pytensor.tensor.var import TensorConstant, get_unique_value +from pytensor.tensor.var import TensorConstant, get_unique_constant_value list_opt_slice = [ @@ -136,7 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): node_inp = node.inputs[idx + 1] if ( isinstance(node_inp, TensorConstant) - and get_unique_value(node_inp) is not None + and get_unique_constant_value(node_inp) is not None ): try: # This works if input is a constant that has all entries diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index d62a81b946..8090d6d6a8 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -62,7 +62,11 @@ uint_dtypes, values_eq_approx_always_true, ) -from pytensor.tensor.var import TensorConstant, TensorVariable, get_unique_value +from pytensor.tensor.var import ( + TensorConstant, + TensorVariable, + get_unique_constant_value, +) if TYPE_CHECKING: @@ -323,7 +327,7 @@ def get_underlying_scalar_constant_value( raise NotScalarConstantError() if isinstance(v, Constant): - unique_value = get_unique_value(v) + unique_value = get_unique_constant_value(v) if unique_value is not None: data = unique_value else: diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e42543f9e6..64a37cb340 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -33,9 +33,13 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import exp -from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize +from pytensor.tensor.rewriting.basic import ( + broadcast_like, + register_canonicalize, + register_specialize, +) from pytensor.tensor.shape import shape_padleft -from pytensor.tensor.var import TensorConstant +from pytensor.tensor.var import TensorConstant, get_unique_constant_value class InplaceElemwiseOptimizer(GraphRewriter): @@ -1203,6 +1207,49 @@ def local_careduce_fusion(fgraph, node): return [new_car_op(*elm_inputs)] +@node_rewriter([Elemwise]) +def local_inline_composite_constants(fgraph, node): + """Inline scalar constants in Composite graphs.""" + composite_op = node.op.scalar_op + + if not isinstance(composite_op, aes.Composite): + return None + + new_outer_inputs = [] + new_inner_inputs = [] + inner_replacements = {} + for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs): + # Complex variables don't have a `c_literal` that can be inlined + if "complex" not in outer_inp.type.dtype: + unique_value = get_unique_constant_value(outer_inp) + if unique_value is not None: + inner_replacements[inner_inp] = aes.constant( + unique_value, dtype=inner_inp.dtype + ) + continue + new_outer_inputs.append(outer_inp) + new_inner_inputs.append(inner_inp) + + if not inner_replacements: + return None + + new_inner_outs = clone_replace( + composite_op.fgraph.outputs, replace=inner_replacements + ) + new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs) + new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs + + # Some of the inlined constants were broadcasting the output shape + if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable: + new_outputs = [ + broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph) + for new_out in new_outputs + ] + + copy_stack_trace(node.outputs, new_outputs) + return new_outputs + + # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) fuse_seqopt = SequenceDB() compile.optdb.register( @@ -1243,6 +1290,13 @@ def local_careduce_fusion(fgraph, node): "fusion", position=10, ) +fuse_seqopt.register( + "local_inline_composite_constants", + in2out(local_inline_composite_constants), + "fast_run", + "fusion", + position=20, +) def _rebuild_partial_2f1grad_loop(node, wrt): diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 584b470c94..2f79bf20fa 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -101,7 +101,7 @@ values_eq_approx_remove_inf_nan, values_eq_approx_remove_nan, ) -from pytensor.tensor.var import TensorConstant, get_unique_value +from pytensor.tensor.var import TensorConstant, get_unique_constant_value def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): @@ -133,7 +133,7 @@ def get_constant(v): """ if isinstance(v, Constant): - unique_value = get_unique_value(v) + unique_value = get_unique_constant_value(v) if unique_value is not None: data = unique_value else: @@ -1135,10 +1135,12 @@ def same(x, y): if new.type.dtype != out.type.dtype: new = cast(new, out.type.dtype) - if new.type != out.type: + if new.type.broadcastable != out.type.broadcastable: new = fill_chain(new, node.inputs)[0] - if new.type == out.type: + if (new.type.dtype == out.type.dtype) and ( + new.type.broadcastable == out.type.broadcastable + ): new.tag.values_eq_approx = values_eq_approx_remove_inf_nan copy_stack_trace(out, new) return [new] diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 7b9c80c96c..817c6d0d63 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -1026,6 +1026,52 @@ def local_Shape_of_SpecifyShape(fgraph, node): return [stack(shape).astype(np.int64)] +@register_canonicalize +@register_specialize +@node_rewriter([SpecifyShape]) +def local_specify_shape_lift(fgraph, node): + """Lift SpecifyShape of Elemwise towards the inputs.""" + inp, *shape = node.inputs + if inp.owner and isinstance(inp.owner.op, Elemwise): + if len(inp.owner.outputs) != 1: + return None + + elem_inps = inp.owner.inputs + if len(elem_inps) == 1: + new_elem_inps = [specify_shape(elem_inps[0], shape)] + else: + # Rewrite does not support case where specify_shape provides new broadcastable information, + # As that may require a specify_shape for each input + out_broadcastable = node.outputs[0].type.broadcastable + if out_broadcastable != inp.type.broadcastable: + return None + + # All non-broadcastable dimensions of inputs must match the non-broadcastbale specify_shape dims + # We look for a sufficient input to assign all the specify_shape dims + # We could consider distributing the SpecifyShape across multiple inputs, when none is sufficient + + nonbcast_dims = { + i + for i, (dim, bcast) in enumerate(zip(shape, out_broadcastable)) + if (not bcast and not NoneConst.equals(dim)) + } + new_elem_inps = elem_inps.copy() + for i, elem_inp in enumerate(elem_inps): + if all( + bcast_dim is False + for dim, bcast_dim in enumerate(elem_inp.type.broadcastable) + if dim in nonbcast_dims + ): + new_elem_inps[i] = specify_shape(elem_inp, shape) + break + else: # no-break, no sufficient candidate found + return None + + new_out = inp.owner.op.make_node(*new_elem_inps).outputs + copy_stack_trace(node.outputs, new_out) + return new_out + + @register_useless @register_canonicalize @node_rewriter([Shape_i]) diff --git a/pytensor/tensor/var.py b/pytensor/tensor/var.py index 61f8402d90..ead65047b1 100644 --- a/pytensor/tensor/var.py +++ b/pytensor/tensor/var.py @@ -986,7 +986,7 @@ def no_nan(self): return self._no_nan -def get_unique_value(x: TensorVariable) -> Optional[Number]: +def get_unique_constant_value(x: TensorVariable) -> Optional[Number]: """Return the unique value of a tensor, if there is one""" if isinstance(x, Constant): data = x.data diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index a2f1f81a12..f14cbdd067 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -128,21 +128,33 @@ def test_flatten(self): # We don't flatten that case. assert isinstance(CC.outputs[0].owner.op, Composite) - def test_with_constants(self): + @pytest.mark.parametrize("literal_value", (70.0, -np.inf, np.float32("nan"))) + def test_with_constants(self, literal_value): x, y, z = floats("xyz") - e = mul(add(70.0, y), true_div(x, y)) + e = mul(add(literal_value, y), true_div(x, y)) comp_op = Composite([x, y], [e]) comp_node = comp_op.make_node(x, y) c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0)) - assert "70.0" in c_code + assert constant(literal_value).type.c_literal(literal_value) in c_code # Make sure caching of the c_code template works assert hasattr(comp_node.op, "_c_code") g = FunctionGraph([x, y], [comp_node.out]) - fn = make_function(DualLinker().accept(g)) - assert fn(1.0, 2.0) == 36.0 + + # Default checker does not allow `nan` + def checker(x, y): + np.testing.assert_equal(x, y) + + fn = make_function(DualLinker(checker=checker).accept(g)) + + test_x = 1.0 + test_y = 2.0 + np.testing.assert_equal( + fn(test_x, test_y), + (literal_value + test_y) * (test_x / test_y), + ) def test_many_outputs(self): x, y, z = floats("xyz") diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index ead6575eb3..ac4d293f16 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1461,6 +1461,32 @@ def test_local_useless_composite_outputs(): utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) +@pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)]) +@pytest.mark.parametrize("op, np_op", [(at.pow, np.power), (at.add, np.add)]) +def test_local_inline_composite_constants(op, np_op, const_shape): + const = np.full(shape=const_shape, fill_value=2.5).astype(config.floatX) + x = vector("x") + y = vector("y") + out = at.exp(op(x, const)) + y + + fn = pytensor.function( + [x, y], out, mode=get_default_mode().including("specialize", "fusion") + ) + # There should be a single Composite after optimization + [node] = [ + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise) + ] + assert isinstance(node.op.scalar_op, Composite) + assert len(node.inputs) == 2 # x and y, but not const + + x_test_value = np.arange(5).astype(config.floatX) + y_test_value = np.ones(5).astype(config.floatX) + np.testing.assert_allclose( + fn(x_test_value, y_test_value), + np.exp(np_op(x_test_value, const)) + y_test_value, + ) + + def test_local_useless_dimshuffle_makevector(): a = scalar() x = MakeVector(config.floatX)(a) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 185641e4fb..e2ed504e79 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -30,7 +30,7 @@ from pytensor.misc.safe_asarray import _asarray from pytensor.printing import debugprint from pytensor.tensor import inplace -from pytensor.tensor.basic import Alloc, join, switch +from pytensor.tensor.basic import Alloc, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -96,7 +96,7 @@ perform_sigm_times_exp, simplify_mul, ) -from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape +from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape from pytensor.tensor.type import ( TensorType, cmatrix, @@ -979,6 +979,28 @@ def test_mismatching_types(self): # No rewrite was applied assert z_rewritten is z + def test_shape_specified_by_constant(self): + x = vector("x") + const = np.full(shape=(5,), fill_value=2.0).astype(config.floatX) + out = x * const + + new_out = rewrite_graph( + out, custom_rewrite=in2out(local_mul_canonizer, name="test") + ) + expected_out = np.array([2.0]).astype(config.floatX) * specify_shape(x, (5,)) + assert equal_computations([new_out], [expected_out]) + + def test_broadcasted_by_constant(self): + x = vector("x") + const = np.full(shape=(3, 5), fill_value=2.0).astype(config.floatX) + out = x * const + + new_out = rewrite_graph( + out, custom_rewrite=in2out(local_mul_canonizer, name="test") + ) + expected_out = second(const, np.array([[2.0]], dtype=config.floatX) * x) + assert equal_computations([new_out], [expected_out]) + def test_local_merge_abs(): x, y, z = matrices("xyz") diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 8c8c8a7baa..0bc6638a7a 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -491,6 +491,14 @@ def test_local_Shape_of_SpecifyShape_partial(s1): assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes) +def test_local_specify_shape_lift(): + x = vector("x") + out = specify_shape([1.0] + x, shape=(5,)) + + new_out = rewrite_graph(out) + assert equal_computations([new_out], [[1.0] + specify_shape(x, shape=(5,))]) + + def test_local_Shape_i_ground(): x = tensor(dtype=np.float64, shape=(None, 2)) s = Shape_i(1)(x)