diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 30d9084449..3c98834c94 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -5,13 +5,15 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( + PatternNodeRewriter, copy_stack_trace, node_rewriter, ) -from pytensor.tensor.basic import TensorVariable, diagonal +from pytensor.scalar.basic import Mul +from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import ( SVD, @@ -39,6 +41,7 @@ solve, solve_triangular, ) +from pytensor.tensor.subtensor import advanced_set_subtensor logger = logging.getLogger(__name__) @@ -384,6 +387,104 @@ def local_lift_through_linalg( raise NotImplementedError # pragma: no cover +def _find_diag_from_eye_mul(potential_mul_input): + # Check if the op is Elemwise and mul + if not ( + potential_mul_input.owner is not None + and isinstance(potential_mul_input.owner.op, Elemwise) + and isinstance(potential_mul_input.owner.op.scalar_op, Mul) + ): + return None + + # Find whether any of the inputs to mul is Eye + inputs_to_mul = potential_mul_input.owner.inputs + eye_input = [ + mul_input + for mul_input in inputs_to_mul + if mul_input.owner and isinstance(mul_input.owner.op, Eye) + ] + + # Check if 1's are being put on the main diagonal only (k = 0) + if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0: + return None + + # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite + if eye_input and eye_input[0].broadcastable[-2:] != (False, False): + return None + + # Get all non Eye inputs (scalars/matrices/vectors) + non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) + return eye_input, non_eye_inputs + + +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") +@node_rewriter([det]) +def rewrite_det_diag_from_eye_mul(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. + + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + potential_mul_input = node.inputs[0] + eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input) + if eye_non_eye_inputs is None: + return None + eye_input, non_eye_inputs = eye_non_eye_inputs + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] + + # Checking if original x was scalar/vector/matrix + if useful_non_eye.type.broadcastable[-2:] == (True, True): + # For scalar + det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0]) + elif useful_non_eye.type.broadcastable[-2:] == (False, False): + # For Matrix + det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1) + else: + # For vector + det_val = useful_non_eye.prod(axis=(-1, -2)) + det_val = det_val.astype(node.outputs[0].type.dtype) + return [det_val] + + +arange = ARange("int64") +det_diag_from_diag = PatternNodeRewriter( + ( + det, + ( + advanced_set_subtensor, + (alloc, 0, "sh1", "sh2"), + "x", + (arange, 0, "stop", 1), + (arange, 0, "stop", 1), + ), + ), + (prod, "x"), + name="det_diag_from_diag", + allow_multiple_clients=True, +) +register_canonicalize(det_diag_from_diag) +register_stabilize(det_diag_from_diag) +register_specialize(det_diag_from_diag) + + @register_canonicalize @register_stabilize @register_specialize diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 523742e356..d59e3cc88f 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -394,6 +394,95 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) +@pytest.mark.parametrize( + "shape", + [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], +) +def test_det_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + y = pt.eye(7) * x + # Calculating determinant value using pt.linalg.det + z_det = pt.linalg.det(y) + + # REWRITE TEST + f_rewritten = function([x], z_det, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Det) for node in nodes) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + det_val = np.linalg.det(x_test_matrix) + rewritten_val = f_rewritten(x_test) + + assert_allclose( + det_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_det_diag_from_diag(): + x = pt.tensor("x", shape=(None,)) + x_diag = pt.diag(x) + y = pt.linalg.det(x_diag) + + # REWRITE TEST + f_rewritten = function([x], y, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Det) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.random.rand(7).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + det_val = np.linalg.det(x_test_matrix) + rewritten_val = f_rewritten(x_test) + + assert_allclose( + det_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_dont_apply_det_diag_rewrite_for_1_1(): + x = pt.matrix("x") + x_diag = pt.eye(1, 1) * x + y = pt.linalg.det(x_diag) + f_rewritten = function([x], y, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + assert any(isinstance(node.op, Det) for node in nodes) + + # Numeric Value test + x_test = np.random.normal(size=(3, 3)).astype(config.floatX) + x_test_matrix = np.eye(1, 1) * x_test + det_val = np.linalg.det(x_test_matrix) + rewritten_val = f_rewritten(x_test) + assert_allclose( + det_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_det_diag_incorrect_for_rectangle_eye(): + x = pt.matrix("x") + x_diag = pt.eye(7, 5) * x + with pytest.raises(ValueError, match="Determinant not defined"): + pt.linalg.det(x_diag) + + def test_svd_uv_merge(): a = matrix("a") s_1 = svd(a, full_matrices=False, compute_uv=False)