diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index fddded525a..f732848afc 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -11,4 +11,5 @@ import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.subtensor +import pytensor.link.pytorch.dispatch.slinalg # isort: on diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py new file mode 100644 index 0000000000..af3ac8efaf --- /dev/null +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -0,0 +1,88 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.slinalg import ( + BlockDiagonal, + Cholesky, + Eigvalsh, + Solve, + SolveTriangular, +) + + +@pytorch_funcify.register(Eigvalsh) +def pytorch_funcify_Eigvalsh(op, **kwargs): + if op.lower: + UPLO = "L" + else: + UPLO = "U" + + def eigvalsh(a, b): + if b is not None: + raise NotImplementedError( + "torch.linalg.eigvalsh does not support generalized eigenvector problems (b != None)" + ) + return torch.linalg.eigvalsh(a, UPLO=UPLO) + + return eigvalsh + + +@pytorch_funcify.register(Cholesky) +def pytorch_funcify_Cholesky(op, **kwargs): + upper = not op.lower + + def cholesky(a): + return torch.linalg.cholesky(a, upper=upper) + + return cholesky + + +@pytorch_funcify.register(Solve) +def pytorch_funcify_Solve(op, **kwargs): + def solve(a, b): + return torch.linalg.solve(a, b) + + return solve + + +@pytorch_funcify.register(SolveTriangular) +def pytorch_funcify_SolveTriangular(op, **kwargs): + if op.check_finite: + raise NotImplementedError( + "Option check_finite is not implemented in torch.linalg.solve_triangular" + ) + + upper = not op.lower + unit_diagonal = op.unit_diagonal + trans = op.trans + + def solve_triangular(A, b): + if trans in [1, "T"]: + A_p = A.T + elif trans in [2, "C"]: + A_p = A.conj().T + else: + A_p = A + + b_p = b + if b.ndim == 1: + b_p = b[:, None] + + res = torch.linalg.solve_triangular( + A_p, b_p, upper=upper, unitriangular=unit_diagonal + ) + + if b.ndim == 1 and res.shape[1] == 1: + return res.flatten() + + return res + + return solve_triangular + + +@pytorch_funcify.register(BlockDiagonal) +def pytorch_funcify_BlockDiagonalMatrix(op, **kwargs): + def block_diag(*inputs): + return torch.block_diag(*inputs) + + return block_diag diff --git a/tests/link/pytorch/test_slinalg.py b/tests/link/pytorch/test_slinalg.py new file mode 100644 index 0000000000..3055bf809d --- /dev/null +++ b/tests/link/pytorch/test_slinalg.py @@ -0,0 +1,129 @@ +import numpy as np +import pytest + +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import slinalg as pt_slinalg +from pytensor.tensor.type import matrix, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.mark.parametrize("lower", [False, True]) +def test_pytorch_eigvalsh(lower): + A = matrix("A") + B = matrix("B") + + out = pt_slinalg.eigvalsh(A, B, lower=lower) + out_fg = FunctionGraph([A, B], [out]) + + with pytest.raises(NotImplementedError): + compare_pytorch_and_py( + out_fg, + [ + np.array( + [[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]] + ).astype(config.floatX), + np.array( + [[10, 0, 1, 3], [0, 12, 7, 8], [1, 7, 14, 2], [3, 8, 2, 16]] + ).astype(config.floatX), + ], + ) + compare_pytorch_and_py( + out_fg, + [ + np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype( + config.floatX + ), + None, + ], + ) + + +def test_pytorch_cholesky(): + rng = np.random.default_rng(28494) + + x = matrix("x") + + out = pt_slinalg.cholesky(x) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + out = pt_slinalg.cholesky(x, lower=False) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + +def test_pytorch_solve(): + x = matrix("x") + b = vector("b") + + out = pt_slinalg.solve(x, b) + out_fg = FunctionGraph([x, b], [out]) + compare_pytorch_and_py( + out_fg, + [ + np.eye(10).astype(config.floatX), + np.arange(10).astype(config.floatX), + ], + ) + + +@pytest.mark.parametrize( + "check_finite", + (False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))), +) +@pytest.mark.parametrize("lower", [False, True]) +@pytest.mark.parametrize("trans", [0, 1, 2, "S"]) +def test_pytorch_SolveTriangular(trans, lower, check_finite): + x = matrix("x") + b = vector("b") + + out = pt_slinalg.solve_triangular( + x, + b, + trans=trans, + lower=lower, + check_finite=check_finite, + ) + out_fg = FunctionGraph([x, b], [out]) + compare_pytorch_and_py( + out_fg, + [ + np.eye(10).astype(config.floatX), + np.arange(10).astype(config.floatX), + ], + ) + + +def test_pytorch_block_diag(): + A = matrix("A") + B = matrix("B") + C = matrix("C") + D = matrix("D") + + out = pt_slinalg.block_diag(A, B, C, D) + out_fg = FunctionGraph([A, B, C, D], [out]) + + compare_pytorch_and_py( + out_fg, + [ + np.random.normal(size=(5, 5)).astype(config.floatX), + np.random.normal(size=(3, 3)).astype(config.floatX), + np.random.normal(size=(2, 2)).astype(config.floatX), + np.random.normal(size=(4, 4)).astype(config.floatX), + ], + )