diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cd202fe3ed..da16acf1a4 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -34,6 +34,7 @@ MatrixPinv, SLogDet, det, + eig, inv, kron, pinv, @@ -1013,3 +1014,104 @@ def slogdet_specialization(fgraph, node): k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() } return replacements + + +@register_canonicalize +@register_stabilize +@node_rewriter([eig]) +def rewrite_eig_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that for any identity matrix, all the eigenvalues are 1 and the eigenvectors are the standard basis. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check whether input to Eig is Eye and the 1's are on main diagonal + potential_eye = node.inputs[0] + if not ( + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0 + ): + return None + + eigval_rewritten = pt.ones(potential_eye.shape[-1]) + eigvec_rewritten = pt.eye(potential_eye.shape[-1]) + + return [eigval_rewritten, eigvec_rewritten] + + +@register_canonicalize +@register_stabilize +@node_rewriter([eig]) +def rewrite_eig_diag(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the standard basis. + + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices + that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to + make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, + vector or a matrix. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + inputs = node.inputs[0] + + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + eigval_rewritten = pt.diag(inputs) + eigvec_rewritten = pt.eye(inputs.shape[-1]) + return [eigval_rewritten, eigvec_rewritten] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + eye_input, non_eye_input = eye_input, non_eye_inputs[0] + # eigval_rewritten = pt.diag(non_eye_input) + eigvec_rewritten = eye_input + + # Checking if original x was scalar/vector/matrix + if non_eye_input.type.broadcastable[-2:] == (True, True): + # For scalar + eigval_rewritten = pt.full( + (eye_input.shape[0],), non_eye_input.squeeze(axis=(-1, -2)) + ) + elif non_eye_input.type.broadcastable[-2:] == (False, False): + # For Matrix + eigval_rewritten = pt.diag(non_eye_input) + else: + # For vector + eigval_rewritten = non_eye_input.squeeze() + + return [eigval_rewritten, eigvec_rewritten] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c9b9afff19..fa9c5f84e6 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -18,6 +18,7 @@ from pytensor.tensor.nlinalg import ( SVD, Det, + Eig, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -996,3 +997,107 @@ def test_slogdet_specialization(): f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, SLogDet) for node in nodes) + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (1, 7), (7, 1), (7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix"], +) +def test_eig_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + y = pt.eye(7) * x + + # Calculating eigval and eigvec using pt.linalg.eig + eigval, eigvec = pt.linalg.eig(y) + + # REWRITE TEST + f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + assert not any( + isinstance(node.op, Eig) or isinstance(getattr(node.op, "core_op", None), Eig) + for node in nodes + ) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + + x_test_matrix = np.eye(7) * x_test + eigval, eigvec = np.linalg.eig(x_test_matrix) + rewritten_eigval, rewritten_eigvec = f_rewritten(x_test) + + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + eigvec, + rewritten_eigvec, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_eig_eye(): + x = pt.eye(10) + eigval, eigvec = pt.linalg.eig(x) + + # REWRITE TEST + f_rewritten = function([], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Eig) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.eye(10) + eigval, eigvec = np.linalg.eig(x_test) + rewritten_eigval, rewritten_eigvec = f_rewritten() + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + eigvec, + rewritten_eigvec, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_eig_diag(): + x = pt.tensor("x", shape=(None,)) + x_diag = pt.diag(x) + eigval, eigvec = pt.linalg.eig(x_diag) + + # REWRITE TEST + f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Eig) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.random.rand(7).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + eigval, eigvec = np.linalg.eig(x_test_matrix) + rewritten_eigval, rewritten_eigvec = f_rewritten(x_test) + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + eigvec, + rewritten_eigvec, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + )