From 8e875d774bf0e9b9acf3bcbb658efbb783678e9e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 02:31:16 +0100 Subject: [PATCH 01/16] Copy `block_diag` and support functions from `pymc.math` --- pytensor/tensor/basic.py | 19 ++++++ pytensor/tensor/slinalg.py | 112 ++++++++++++++++++++++++++++++++++- tests/tensor/test_slinalg.py | 12 ++++ 3 files changed, 142 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 207fd4909a..4b043a6471 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -4269,6 +4269,25 @@ def take_along_axis(arr, indices, axis=0): return arr[_make_along_axis_idx(arr.shape, indices, axis)] +def ix_(*args): + """ + PyTensor np.ix_ analog + + See numpy.lib.index_tricks.ix_ for reference + """ + out = [] + nd = len(args) + for k, new in enumerate(args): + if new is None: + out.append(slice(None)) + new = as_tensor(new) + if new.ndim != 1: + raise ValueError("Cross index must be 1 dimensional") + new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1)) + out.append(new) + return tuple(out) + + __all__ = [ "take_along_axis", "expand_dims", diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index f96dec5a35..7f724c0363 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,3 +1,4 @@ +import functools as ft import logging import typing import warnings @@ -23,7 +24,6 @@ if TYPE_CHECKING: from pytensor.tensor import TensorLike - logger = logging.getLogger(__name__) @@ -908,6 +908,115 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: ) +def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: + return ft.reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) + + +class BlockDiagonalMatrix(Op): + __props__ = ("sparse", "format") + + def __init__(self, sparse=False, format="csr"): + if format not in ("csr", "csc"): + raise ValueError(f"format must be one of: 'csr', 'csc', got {format}") + self.sparse = sparse + self.format = format + + def make_node(self, *matrices): + if not matrices: + raise ValueError("no matrices to allocate") + dtype = largest_common_dtype(matrices) + matrices = list(map(pt.as_tensor, matrices)) + + if any(mat.type.ndim != 2 for mat in matrices): + raise TypeError("all data arguments must be matrices") + if self.sparse: + out_type = pytensor.sparse.matrix(self.format, dtype=dtype) + else: + out_type = pytensor.tensor.matrix(dtype=dtype) + return Apply(self, matrices, [out_type]) + + def perform(self, node, inputs, output_storage, params=None): + dtype = largest_common_dtype(inputs) + if self.sparse: + output_storage[0][0] = scipy.sparse.block_diag(inputs, self.format, dtype) + else: + output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) + + def grad(self, inputs, gout): + shapes = pt.stack([i.shape for i in inputs]) + index_end = shapes.cumsum(0) + index_begin = index_end - shapes + slices = [ + ptb.ix_( + pt.arange(index_begin[i, 0], index_end[i, 0]), + pt.arange(index_begin[i, 1], index_end[i, 1]), + ) + for i in range(len(inputs)) + ] + return [gout[0][slc] for slc in slices] + + def infer_shape(self, fgraph, nodes, shapes): + first, second = zip(*shapes) + return [(pt.add(*first), pt.add(*second))] + + +def block_diagonal( + matrices: typing.Sequence[TensorVariable], + sparse: bool = False, + format: Literal["csr", "csc"] = "csr", +): + """ + Construct a block diagonal matrix from a sequence of input matrices. + + Parameters + ---------- + matrices: sequence of tensors + Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the + block diagonal matrix will be formed along the first axis of the matrices. + sparse : bool, optional + If True, the function returns a sparse matrix in the specified format. Default is True. + format: str, optional + The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False. + + Returns + ------- + out: tensor or sparse matrix tensor + The block diagonal matrix formed from the input matrices. If `sparse` is True, the output will be a symbolic + sparse matrix in the specified format. Otherwise, a symbolic tensor will be returned. + + Examples + -------- + Create a block diagonal matrix from two 2x2 matrices: + + ..code-block:: python + + import numpy as np + from pytensor.tensor.slinalg import block_diagonal + + matrices = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] + matrices = [pt.as_tensor_variable(mat) for mat in matrices] + result = block_diagonal(matrices) + + print(result) + >>> Out: array([[1, 2, 0, 0], + >>> [3, 4, 0, 0], + >>> [0, 0, 5, 6], + >>> [0, 0, 7, 8]]) + + Create a sparse block diagonal matrix from two sparse 2x2 matrices: + + ..code-block:: python + + matrices_sparse = [csr_matrix([[1, 2], [3, 4]]), csr_matrix([[5, 6], [7, 8]])] + result_sparse = block_diagonal(matrices_sparse, sparse=True) + + The resulting sparse block diagonal matrix `result_sparse` is in CSR format. + """ + if len(matrices) == 1: # graph optimization + return matrices[0] + return BlockDiagonalMatrix(sparse=sparse, format=format)(*matrices) + + __all__ = [ "cholesky", "solve", @@ -918,4 +1027,5 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: "solve_continuous_lyapunov", "solve_discrete_are", "solve_triangular", + "block_diagonal", ] diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 504d848140..e8510eb751 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -15,6 +15,7 @@ Solve, SolveBase, SolveTriangular, + block_diagonal, cho_solve, cholesky, eigvalsh, @@ -661,3 +662,14 @@ def test_solve_discrete_are_grad(): rng=rng, abs_tol=atol, ) + + +def test_block_diagonal(): + matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])] + result = block_diagonal(matrices) + np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(*matrices)) + + result = block_diagonal(matrices, format="csr", sparse=True) + sp_result = scipy.sparse.block_diag(matrices, format="csr") + assert type(result.eval()) == type(sp_result) + np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray()) From 77b733da5306430288ac4a8cd057f3fe09dbf850 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Sat, 6 Jan 2024 12:19:22 +0100 Subject: [PATCH 02/16] Evaluate output in sphinx code example Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 7f724c0363..4ef01592cb 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -997,7 +997,7 @@ def block_diagonal( matrices = [pt.as_tensor_variable(mat) for mat in matrices] result = block_diagonal(matrices) - print(result) + print(result.eval()) >>> Out: array([[1, 2, 0, 0], >>> [3, 4, 0, 0], >>> [0, 0, 5, 6], From a1dba8e054ae33791aae8b9ae0d6333e1ded5fee Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Sat, 6 Jan 2024 12:19:49 +0100 Subject: [PATCH 03/16] Test type equivalence with `isinstance` instead of `==` Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- tests/tensor/test_slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index e8510eb751..658c39d787 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -671,5 +671,5 @@ def test_block_diagonal(): result = block_diagonal(matrices, format="csr", sparse=True) sp_result = scipy.sparse.block_diag(matrices, format="csr") - assert type(result.eval()) == type(sp_result) + assert isinstance(result.eval()), type(sp_result)) np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray()) From d809f1ce6981f870cf0daeac0f80cf192fac664e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 12:35:14 +0100 Subject: [PATCH 04/16] Typo in test function --- tests/tensor/test_slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 658c39d787..e8a5686a78 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -671,5 +671,5 @@ def test_block_diagonal(): result = block_diagonal(matrices, format="csr", sparse=True) sp_result = scipy.sparse.block_diag(matrices, format="csr") - assert isinstance(result.eval()), type(sp_result)) + assert isinstance(result.eval(), type(sp_result)) np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray()) From fd26b746c5078c54ea24bcbe7367e2834716fa8b Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 19:12:14 +0100 Subject: [PATCH 05/16] Split `block_diag` into sparse and dense version Closely follow scipy function signature for `block_diag` --- pytensor/sparse/basic.py | 87 ++++++++++++++++++++++++--- pytensor/tensor/slinalg.py | 110 +++++++++++++++-------------------- tests/sparse/test_basic.py | 13 +++++ tests/tensor/test_slinalg.py | 9 +-- 4 files changed, 141 insertions(+), 78 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 363400416f..0abe9b1f7a 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -7,6 +7,7 @@ TODO: Automatic methods for determining best sparse format? """ +from typing import Literal from warnings import warn import numpy as np @@ -47,6 +48,7 @@ trunc, ) from pytensor.tensor.shape import shape, specify_broadcastable +from pytensor.tensor.slinalg import BaseBlockDiagonal, largest_common_dtype from pytensor.tensor.type import TensorType from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes @@ -60,7 +62,6 @@ sparse_formats = ["csc", "csr"] - """ Types of sparse matrices to use for testing. @@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs): as_sparse = as_sparse_variable - as_sparse_or_tensor_variable = as_symbolic @@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes): return r def __str__(self): - return f"{self.__class__.__name__ }{{axis={self.axis}}}" + return f"{self.__class__.__name__}{{axis={self.axis}}}" def sp_sum(x, axis=None, sparse_grad=False): @@ -2775,19 +2775,14 @@ def comparison(self, x, y): greater_equal_s_d = GreaterEqualSD() - eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d) - neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d) - lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d) - gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d) - le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d) ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d) @@ -2992,7 +2987,7 @@ def __str__(self): l = [] if self.inplace: l.append("inplace") - return f"{self.__class__.__name__ }{{{', '.join(l)}}}" + return f"{self.__class__.__name__}{{{', '.join(l)}}}" def make_node(self, x): """ @@ -3291,6 +3286,7 @@ class TrueDot(Op): # Simplify code by splitting into DotSS and DotSD. __props__ = () + # The grad_preserves_dense attribute doesn't change the # execution behavior. To let the optimizer merge nodes with # different values of this attribute we shouldn't compare it @@ -4260,3 +4256,76 @@ def grad(self, inputs, grads): construct_sparse_from_list = ConstructSparseFromList() + + +class SparseBlockDiagonalMatrix(BaseBlockDiagonal): + def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None): + if not matrices: + raise ValueError("no matrices to allocate") + dtype = largest_common_dtype(matrices) + matrices = list(map(pytensor.tensor.as_tensor, matrices)) + + if any(mat.type.ndim != 2 for mat in matrices): + raise TypeError("all data arguments must be matrices") + + out_type = matrix(format=format, dtype=dtype, name=name) + return Apply(self, matrices, [out_type]) + + def perform(self, node, inputs, output_storage, params=None): + format = node.outputs[0].type.format + dtype = largest_common_dtype(inputs) + output_storage[0][0] = scipy.sparse.block_diag(inputs, format=format).astype( + dtype + ) + + +_sparse_block_diagonal = SparseBlockDiagonalMatrix() + + +def block_diag( + *matrices: TensorVariable, format: Literal["csc", "csr"] = "csc", name=None +): + r""" + Construct a block diagonal matrix from a sequence of input matrices. + + Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal: + + [[A, 0, 0], + [0, B, 0], + [0, 0, C]] + + Parameters + ---------- + A, B, C ... : tensors + Input sparse matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, + and the block diagonal matrix will be formed using the right-most two dimensions of each input matrix. + format: str, optional + The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False. + name: str, optional + Name of the output tensor. + + Returns + ------- + out: sparse matrix tensor + Symbolic sparse matrix in the specified format. + + Examples + -------- + Create a sparse block diagonal matrix from two sparse 2x2 matrices: + + ..code-block:: python + import numpy as np + from pytensor.sparse import block_diag + from scipy.sparse import csr_matrix + + A = csr_matrix([[1, 2], [3, 4]]) + B = csr_matrix([[5, 6], [7, 8]]) + result_sparse = block_diag(A, B, format='csr', name='X') + print(result_sparse.eval()) + + The resulting sparse block diagonal matrix `result_sparse` is in CSR format. + """ + if len(matrices) == 1: + return matrices + + return _sparse_block_diagonal(*matrices, format=format, name=name) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 4ef01592cb..9379d01ea9 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -912,16 +912,31 @@ def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: return ft.reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) -class BlockDiagonalMatrix(Op): - __props__ = ("sparse", "format") +def block_diag_grad(inputs, gout): + shapes = pt.stack([i.shape for i in inputs]) + index_end = shapes.cumsum(0) + index_begin = index_end - shapes + slices = [ + ptb.ix_( + pt.arange(index_begin[i, 0], index_end[i, 0]), + pt.arange(index_begin[i, 1], index_end[i, 1]), + ) + for i in range(len(inputs)) + ] + return [gout[0][slc] for slc in slices] + + +class BaseBlockDiagonal(Op): + def grad(self, inputs, gout): + return block_diag_grad(inputs, gout) + + def infer_shape(self, fgraph, nodes, shapes): + first, second = zip(*shapes) + return [(pt.add(*first), pt.add(*second))] - def __init__(self, sparse=False, format="csr"): - if format not in ("csr", "csc"): - raise ValueError(f"format must be one of: 'csr', 'csc', got {format}") - self.sparse = sparse - self.format = format - def make_node(self, *matrices): +class BlockDiagonalMatrix(BaseBlockDiagonal): + def make_node(self, *matrices, name=None): if not matrices: raise ValueError("no matrices to allocate") dtype = largest_common_dtype(matrices) @@ -929,60 +944,40 @@ def make_node(self, *matrices): if any(mat.type.ndim != 2 for mat in matrices): raise TypeError("all data arguments must be matrices") - if self.sparse: - out_type = pytensor.sparse.matrix(self.format, dtype=dtype) - else: - out_type = pytensor.tensor.matrix(dtype=dtype) + + out_type = pytensor.tensor.matrix(dtype=dtype, name=name) return Apply(self, matrices, [out_type]) def perform(self, node, inputs, output_storage, params=None): dtype = largest_common_dtype(inputs) - if self.sparse: - output_storage[0][0] = scipy.sparse.block_diag(inputs, self.format, dtype) - else: - output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) + output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) - def grad(self, inputs, gout): - shapes = pt.stack([i.shape for i in inputs]) - index_end = shapes.cumsum(0) - index_begin = index_end - shapes - slices = [ - ptb.ix_( - pt.arange(index_begin[i, 0], index_end[i, 0]), - pt.arange(index_begin[i, 1], index_end[i, 1]), - ) - for i in range(len(inputs)) - ] - return [gout[0][slc] for slc in slices] - def infer_shape(self, fgraph, nodes, shapes): - first, second = zip(*shapes) - return [(pt.add(*first), pt.add(*second))] +_block_diagonal_matrix = BlockDiagonalMatrix() -def block_diagonal( - matrices: typing.Sequence[TensorVariable], - sparse: bool = False, - format: Literal["csr", "csc"] = "csr", -): +def block_diag(*matrices: TensorVariable, name=None): """ - Construct a block diagonal matrix from a sequence of input matrices. + Construct a block diagonal matrix from a sequence of input tensors. + + Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal: + + [[A, 0, 0], + [0, B, 0], + [0, 0, C]] Parameters ---------- - matrices: sequence of tensors + A, B, C ... : tensors Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the - block diagonal matrix will be formed along the first axis of the matrices. - sparse : bool, optional - If True, the function returns a sparse matrix in the specified format. Default is True. - format: str, optional - The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False. + block diagonal matrix will be formed using the right-most two dimensions of each input matrix. + name: str, optional + Name of the block diagonal matrix. Returns ------- - out: tensor or sparse matrix tensor - The block diagonal matrix formed from the input matrices. If `sparse` is True, the output will be a symbolic - sparse matrix in the specified format. Otherwise, a symbolic tensor will be returned. + out: tensor + The block diagonal matrix formed from the input matrices. Examples -------- @@ -991,30 +986,21 @@ def block_diagonal( ..code-block:: python import numpy as np - from pytensor.tensor.slinalg import block_diagonal + from pytensor.tensor.slinalg import block_diag - matrices = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])] - matrices = [pt.as_tensor_variable(mat) for mat in matrices] - result = block_diagonal(matrices) + A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]])) + B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]])) + result = block_diagonal(A, B, name='X') print(result.eval()) >>> Out: array([[1, 2, 0, 0], >>> [3, 4, 0, 0], >>> [0, 0, 5, 6], >>> [0, 0, 7, 8]]) - - Create a sparse block diagonal matrix from two sparse 2x2 matrices: - - ..code-block:: python - - matrices_sparse = [csr_matrix([[1, 2], [3, 4]]), csr_matrix([[5, 6], [7, 8]])] - result_sparse = block_diagonal(matrices_sparse, sparse=True) - - The resulting sparse block diagonal matrix `result_sparse` is in CSR format. """ if len(matrices) == 1: # graph optimization - return matrices[0] - return BlockDiagonalMatrix(sparse=sparse, format=format)(*matrices) + return matrices + return _block_diagonal_matrix(*matrices, name=name) __all__ = [ @@ -1027,5 +1013,5 @@ def block_diagonal( "solve_continuous_lyapunov", "solve_discrete_are", "solve_triangular", - "block_diagonal", + "block_diag", ] diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 16fd5fef04..0b4f25987c 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -51,6 +51,7 @@ add_s_s_data, as_sparse_or_tensor_variable, as_sparse_variable, + block_diag, cast, clean, construct_sparse_from_list, @@ -3389,3 +3390,15 @@ def _helper(x, y): ) class TestSharedOptions: pass + + +@pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"]) +def test_block_diagonal(format): + from scipy.sparse import block_diag as scipy_block_diag + + matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])] + result = block_diag(*matrices, format=format, name="X") + sp_result = scipy_block_diag(matrices, format=format) + + assert isinstance(result.eval(), type(sp_result)) + np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray()) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index e8a5686a78..674ecf541f 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -15,7 +15,7 @@ Solve, SolveBase, SolveTriangular, - block_diagonal, + block_diag, cho_solve, cholesky, eigvalsh, @@ -666,10 +666,5 @@ def test_solve_discrete_are_grad(): def test_block_diagonal(): matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])] - result = block_diagonal(matrices) + result = block_diag(*matrices) np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(*matrices)) - - result = block_diagonal(matrices, format="csr", sparse=True) - sp_result = scipy.sparse.block_diag(matrices, format="csr") - assert isinstance(result.eval(), type(sp_result)) - np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray()) From 62bab9d2e6ab26e1cb50514924588dde0b111b38 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 19:26:28 +0100 Subject: [PATCH 06/16] Use `as_sparse_or_tensor_variable` in `SparseBlockDiagonalMatrix` to allow sparse matrix inputs to `pytensor.sparse.block_diag` --- pytensor/sparse/basic.py | 21 +++++++++++++++------ tests/sparse/test_basic.py | 9 +++++---- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 0abe9b1f7a..68eeec8096 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -4263,7 +4263,7 @@ def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None) if not matrices: raise ValueError("no matrices to allocate") dtype = largest_common_dtype(matrices) - matrices = list(map(pytensor.tensor.as_tensor, matrices)) + matrices = list(map(as_sparse_or_tensor_variable, matrices)) if any(mat.type.ndim != 2 for mat in matrices): raise TypeError("all data arguments must be matrices") @@ -4273,7 +4273,7 @@ def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None) def perform(self, node, inputs, output_storage, params=None): format = node.outputs[0].type.format - dtype = largest_common_dtype(inputs) + dtype = node.outputs[0].type.dtype output_storage[0][0] = scipy.sparse.block_diag(inputs, format=format).astype( dtype ) @@ -4296,9 +4296,12 @@ def block_diag( Parameters ---------- - A, B, C ... : tensors - Input sparse matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, + A, B, C ... : tensors or array-like + Inputs to form the block diagonal matrix. Each input should have the same number of dimensions, and the block diagonal matrix will be formed using the right-most two dimensions of each input matrix. + + Note that the input matrices need not be sparse themselves, and will be automatically converted to the + requested format if they are not. format: str, optional The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False. name: str, optional @@ -4321,9 +4324,15 @@ def block_diag( A = csr_matrix([[1, 2], [3, 4]]) B = csr_matrix([[5, 6], [7, 8]]) result_sparse = block_diag(A, B, format='csr', name='X') - print(result_sparse.eval()) - The resulting sparse block diagonal matrix `result_sparse` is in CSR format. + print(result_sparse) + >>> SparseVariable{csr,int32} + + print(result_sparse.toarray().eval()) + >>> array([[1, 2, 0, 0], + >>> [3, 4, 0, 0], + >>> [0, 0, 5, 6], + >>> [0, 0, 7, 8]]) """ if len(matrices) == 1: return matrices diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 0b4f25987c..cb94b328ef 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -3394,11 +3394,12 @@ class TestSharedOptions: @pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"]) def test_block_diagonal(format): - from scipy.sparse import block_diag as scipy_block_diag + from scipy import sparse as sp_sparse - matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])] - result = block_diag(*matrices, format=format, name="X") - sp_result = scipy_block_diag(matrices, format=format) + A = sp_sparse.csr_matrix([[1, 2], [3, 4]]) + B = sp_sparse.csr_matrix([[5, 6], [7, 8]]) + result = block_diag(A, B, format=format, name="X") + sp_result = sp_sparse.block_diag([A, B], format=format) assert isinstance(result.eval(), type(sp_result)) np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray()) From 1de23db16f96f2963bf3fbf8b403332afb8f1e43 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 19:29:38 +0100 Subject: [PATCH 07/16] Test sparse and dense inputs to `pytensor.sparse.block_diag` --- tests/sparse/test_basic.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index cb94b328ef..2a5bce9ecb 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -3393,11 +3393,14 @@ class TestSharedOptions: @pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"]) -def test_block_diagonal(format): +@pytest.mark.parametrize("sparse_input", [True, False], ids=["sparse", "dense"]) +def test_block_diagonal(format, sparse_input): from scipy import sparse as sp_sparse - A = sp_sparse.csr_matrix([[1, 2], [3, 4]]) - B = sp_sparse.csr_matrix([[5, 6], [7, 8]]) + f_array = sp_sparse.csr_matrix if sparse_input else np.array + A = f_array([[1, 2], [3, 4]]).astype(config.floatX) + B = f_array([[5, 6], [7, 8]]).astype(config.floatX) + result = block_diag(A, B, format=format, name="X") sp_result = sp_sparse.block_diag([A, B], format=format) From 382c50b1cc1f96ae19eebe4c4f7ff19a6a5776bf Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 20:29:55 +0100 Subject: [PATCH 08/16] Add numba overload for `pytensor.tensor.slinalg.block_diag` --- pytensor/link/numba/dispatch/slinalg.py | 24 +++++++++++++++++++- pytensor/tensor/slinalg.py | 2 +- tests/link/numba/test_slinalg.py | 29 ++++++++++++++++++++++++- tests/tensor/test_slinalg.py | 7 +++--- 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index ad8065defd..14bae08bb6 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -9,7 +9,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import SolveTriangular +from pytensor.tensor.slinalg import BlockDiagonalMatrix, SolveTriangular _PTR = ctypes.POINTER @@ -273,3 +273,25 @@ def solve_triangular(a, b): return res return solve_triangular + + +@numba_funcify.register(BlockDiagonalMatrix) +def numba_funcify_BlockDiagonalMatrix(op, node, **kwargs): + dtype = node.outputs[0].dtype + + # TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case. + @numba_basic.numba_njit(inline="never") + def block_diag(*arrs): + shapes = np.array([a.shape for a in arrs], dtype=dtype) + out_shape = [int(s) for s in np.sum(shapes, axis=0)] + out = np.zeros((out_shape[0], out_shape[1])) + + r, c = 0, 0 + for arr, shape in zip(arrs, shapes): + rr, cc = shape + out[r : r + rr, c : c + cc] = arr + r += rr + c += cc + return out + + return block_diag diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 9379d01ea9..a55d15789b 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -949,7 +949,7 @@ def make_node(self, *matrices, name=None): return Apply(self, matrices, [out_type]) def perform(self, node, inputs, output_storage, params=None): - dtype = largest_common_dtype(inputs) + dtype = node.outputs[0].type.dtype output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 75e016f1e0..d3757a1454 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from scipy import linalg import pytensor import pytensor.tensor as pt @@ -10,7 +11,6 @@ numba = pytest.importorskip("numba") - ATOL = 0 if config.floatX.endswith("64") else 1e-6 RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 rng = np.random.default_rng(42849) @@ -102,3 +102,30 @@ def test_solve_triangular_raises_on_nan_inf(value): ValueError, match=re.escape("Non-numeric values (nan or inf) returned ") ): f(A_tri, b) + + +def test_block_diag(): + A = pt.matrix("A") + B = pt.matrix("B") + C = pt.matrix("C") + D = pt.matrix("D") + X = pt.linalg.block_diag(A, B, C, D) + f = pytensor.function([A, B, C, D], X, mode="NUMBA") + + A_val = np.random.normal(size=(5, 5)) + B_val = np.random.normal(size=(3, 3)) + C_val = np.random.normal(size=(2, 2)) + D_val = np.random.normal(size=(4, 4)) + + X_val = f(A_val, B_val, C_val, D_val) + np.testing.assert_allclose( + np.block([[A_val, np.zeros((5, 3))], [np.zeros((3, 5)), B_val]]), X_val[:8, :8] + ) + np.testing.assert_allclose( + np.block([[C_val, np.zeros((2, 4))], [np.zeros((4, 2)), D_val]]), X_val[8:, 8:] + ) + np.testing.assert_allclose(np.zeros((8, 6)), X_val[:8, 8:]) + np.testing.assert_allclose(np.zeros((6, 8)), X_val[8:, :8]) + + X_sp = linalg.block_diag(A_val, B_val, C_val, D_val) + np.testing.assert_allclose(X_val, X_sp) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 674ecf541f..7fc1dadd8f 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -665,6 +665,7 @@ def test_solve_discrete_are_grad(): def test_block_diagonal(): - matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])] - result = block_diag(*matrices) - np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(*matrices)) + A = np.array([[1.0, 2.0], [3.0, 4.0]]) + B = np.array([[5.0, 6.0], [7.0, 8.0]]) + result = block_diag(A, B, name="X") + np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B)) From 491111b6a79f2b7452dba4999e9de48943ac2b11 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 6 Jan 2024 21:02:50 +0100 Subject: [PATCH 09/16] add jax overload for `pytensor.tensor.slinalg.block_diag` --- pytensor/link/jax/dispatch/slinalg.py | 15 ++++++++++++++- tests/link/jax/test_slinalg.py | 20 ++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 4481e442f9..4d2df2d27a 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -1,7 +1,12 @@ import jax from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular +from pytensor.tensor.slinalg import ( + BlockDiagonalMatrix, + Cholesky, + Solve, + SolveTriangular, +) @jax_funcify.register(Cholesky) @@ -45,3 +50,11 @@ def solve_triangular(A, b): ) return solve_triangular + + +@jax_funcify.register(BlockDiagonalMatrix) +def jax_funcify_BlockDiagonalMatrix(op, **kwargs): + def block_diag(*inputs): + return jax.scipy.linalg.block_diag(*inputs) + + return block_diag diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 4ae9531f9b..169ebef2a8 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -129,3 +129,23 @@ def test_jax_SolveTriangular(trans, lower, check_finite): np.arange(10).astype(config.floatX), ], ) + + +def test_jax_block_diag(): + A = matrix("A") + B = matrix("B") + C = matrix("C") + D = matrix("D") + + out = pt_slinalg.block_diag(A, B, C, D) + + out_fg = FunctionGraph([A, B, C, D], [out]) + compare_jax_and_py( + out_fg, + [ + np.random.normal(size=(5, 5)).astype(config.floatX), + np.random.normal(size=(3, 3)).astype(config.floatX), + np.random.normal(size=(2, 2)).astype(config.floatX), + np.random.normal(size=(4, 4)).astype(config.floatX), + ], + ) From bb2bd363b01a862ca48e232013a8f4a74fe3470b Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 22:39:46 +0100 Subject: [PATCH 10/16] Move stand-alone `block_diag_grad` function into `grad` method --- pytensor/tensor/slinalg.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index a55d15789b..9783b81718 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -912,23 +912,19 @@ def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: return ft.reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) -def block_diag_grad(inputs, gout): - shapes = pt.stack([i.shape for i in inputs]) - index_end = shapes.cumsum(0) - index_begin = index_end - shapes - slices = [ - ptb.ix_( - pt.arange(index_begin[i, 0], index_end[i, 0]), - pt.arange(index_begin[i, 1], index_end[i, 1]), - ) - for i in range(len(inputs)) - ] - return [gout[0][slc] for slc in slices] - - class BaseBlockDiagonal(Op): def grad(self, inputs, gout): - return block_diag_grad(inputs, gout) + shapes = pt.stack([i.shape for i in inputs]) + index_end = shapes.cumsum(0) + index_begin = index_end - shapes + slices = [ + ptb.ix_( + pt.arange(index_begin[i, 0], index_end[i, 0]), + pt.arange(index_begin[i, 1], index_end[i, 1]), + ) + for i in range(len(inputs)) + ] + return [gout[0][slc] for slc in slices] def infer_shape(self, fgraph, nodes, shapes): first, second = zip(*shapes) @@ -936,6 +932,10 @@ def infer_shape(self, fgraph, nodes, shapes): class BlockDiagonalMatrix(BaseBlockDiagonal): + def __init__(self, n_inputs): + input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)]) + self.gufunc_signature = f"{input_sig}->(m,n)" + def make_node(self, *matrices, name=None): if not matrices: raise ValueError("no matrices to allocate") @@ -953,9 +953,6 @@ def perform(self, node, inputs, output_storage, params=None): output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) -_block_diagonal_matrix = BlockDiagonalMatrix() - - def block_diag(*matrices: TensorVariable, name=None): """ Construct a block diagonal matrix from a sequence of input tensors. @@ -1000,6 +997,8 @@ def block_diag(*matrices: TensorVariable, name=None): """ if len(matrices) == 1: # graph optimization return matrices + + _block_diagonal_matrix = BlockDiagonalMatrix(n_inputs=len(matrices)) return _block_diagonal_matrix(*matrices, name=name) From 26bf96dd87a27e8a80199121cbf23e93f81b86eb Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 22:43:37 +0100 Subject: [PATCH 11/16] Add `format` prop to `SparseBlockDiagonalMatrix` --- pytensor/sparse/basic.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 68eeec8096..1cd36c0b8c 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -4259,7 +4259,12 @@ def grad(self, inputs, grads): class SparseBlockDiagonalMatrix(BaseBlockDiagonal): - def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None): + __props__ = ("format",) + + def __init__(self, format: Literal["csc", "csr"] = "csc"): + self.format = format + + def make_node(self, *matrices, name=None): if not matrices: raise ValueError("no matrices to allocate") dtype = largest_common_dtype(matrices) @@ -4268,18 +4273,14 @@ def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None) if any(mat.type.ndim != 2 for mat in matrices): raise TypeError("all data arguments must be matrices") - out_type = matrix(format=format, dtype=dtype, name=name) + out_type = matrix(format=self.format, dtype=dtype, name=name) return Apply(self, matrices, [out_type]) def perform(self, node, inputs, output_storage, params=None): - format = node.outputs[0].type.format dtype = node.outputs[0].type.dtype - output_storage[0][0] = scipy.sparse.block_diag(inputs, format=format).astype( - dtype - ) - - -_sparse_block_diagonal = SparseBlockDiagonalMatrix() + output_storage[0][0] = scipy.sparse.block_diag( + inputs, format=self.format + ).astype(dtype) def block_diag( @@ -4337,4 +4338,5 @@ def block_diag( if len(matrices) == 1: return matrices - return _sparse_block_diagonal(*matrices, format=format, name=name) + _sparse_block_diagonal = SparseBlockDiagonalMatrix(format=format) + return _sparse_block_diagonal(*matrices, name=name) From dd70db92f5caa7b371551466ab60d780e28ee5df Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 22:57:40 +0100 Subject: [PATCH 12/16] Use `compare_numba_and_py` in `numba\test_slinalg.py::test_block_diag` --- pytensor/link/numba/dispatch/slinalg.py | 4 ++-- tests/link/numba/test_slinalg.py | 18 +++--------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 14bae08bb6..79d1558ac3 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -282,9 +282,9 @@ def numba_funcify_BlockDiagonalMatrix(op, node, **kwargs): # TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case. @numba_basic.numba_njit(inline="never") def block_diag(*arrs): - shapes = np.array([a.shape for a in arrs], dtype=dtype) + shapes = np.array([a.shape for a in arrs], dtype="int") out_shape = [int(s) for s in np.sum(shapes, axis=0)] - out = np.zeros((out_shape[0], out_shape[1])) + out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype) r, c = 0, 0 for arr, shape in zip(arrs, shapes): diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index d3757a1454..33ec1a529c 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -2,11 +2,11 @@ import numpy as np import pytest -from scipy import linalg import pytensor import pytensor.tensor as pt from pytensor import config +from tests.link.numba.test_basic import compare_numba_and_py numba = pytest.importorskip("numba") @@ -110,22 +110,10 @@ def test_block_diag(): C = pt.matrix("C") D = pt.matrix("D") X = pt.linalg.block_diag(A, B, C, D) - f = pytensor.function([A, B, C, D], X, mode="NUMBA") A_val = np.random.normal(size=(5, 5)) B_val = np.random.normal(size=(3, 3)) C_val = np.random.normal(size=(2, 2)) D_val = np.random.normal(size=(4, 4)) - - X_val = f(A_val, B_val, C_val, D_val) - np.testing.assert_allclose( - np.block([[A_val, np.zeros((5, 3))], [np.zeros((3, 5)), B_val]]), X_val[:8, :8] - ) - np.testing.assert_allclose( - np.block([[C_val, np.zeros((2, 4))], [np.zeros((4, 2)), D_val]]), X_val[8:, 8:] - ) - np.testing.assert_allclose(np.zeros((8, 6)), X_val[:8, 8:]) - np.testing.assert_allclose(np.zeros((6, 8)), X_val[8:, :8]) - - X_sp = linalg.block_diag(A_val, B_val, C_val, D_val) - np.testing.assert_allclose(X_val, X_sp) + out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X]) + compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val]) From d32bf9f89d27c9ab951c5e3d48e9f5c76933eb31 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 23:47:42 +0100 Subject: [PATCH 13/16] Add support for Blockwise to `slinalg.block_diag` --- pytensor/sparse/basic.py | 13 ++++++++----- pytensor/tensor/slinalg.py | 18 ++++++++++-------- tests/link/jax/test_slinalg.py | 16 +++++++++++++++- tests/tensor/test_slinalg.py | 17 ++++++++++++++++- 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 1cd36c0b8c..700be281c8 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -4261,10 +4261,11 @@ def grad(self, inputs, grads): class SparseBlockDiagonalMatrix(BaseBlockDiagonal): __props__ = ("format",) - def __init__(self, format: Literal["csc", "csr"] = "csc"): + def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"): + super().__init__(n_inputs) self.format = format - def make_node(self, *matrices, name=None): + def make_node(self, *matrices): if not matrices: raise ValueError("no matrices to allocate") dtype = largest_common_dtype(matrices) @@ -4273,7 +4274,7 @@ def make_node(self, *matrices, name=None): if any(mat.type.ndim != 2 for mat in matrices): raise TypeError("all data arguments must be matrices") - out_type = matrix(format=self.format, dtype=dtype, name=name) + out_type = matrix(format=self.format, dtype=dtype) return Apply(self, matrices, [out_type]) def perform(self, node, inputs, output_storage, params=None): @@ -4338,5 +4339,7 @@ def block_diag( if len(matrices) == 1: return matrices - _sparse_block_diagonal = SparseBlockDiagonalMatrix(format=format) - return _sparse_block_diagonal(*matrices, name=name) + _sparse_block_diagonal = SparseBlockDiagonalMatrix( + n_inputs=len(matrices), format=format + ) + return _sparse_block_diagonal(*matrices) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 9783b81718..d31fee58d8 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -913,6 +913,12 @@ def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: class BaseBlockDiagonal(Op): + __props__ = ("gufunc_signature",) + + def __init__(self, n_inputs): + input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)]) + self.gufunc_signature = f"{input_sig}->(m,n)" + def grad(self, inputs, gout): shapes = pt.stack([i.shape for i in inputs]) index_end = shapes.cumsum(0) @@ -932,11 +938,7 @@ def infer_shape(self, fgraph, nodes, shapes): class BlockDiagonalMatrix(BaseBlockDiagonal): - def __init__(self, n_inputs): - input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)]) - self.gufunc_signature = f"{input_sig}->(m,n)" - - def make_node(self, *matrices, name=None): + def make_node(self, *matrices): if not matrices: raise ValueError("no matrices to allocate") dtype = largest_common_dtype(matrices) @@ -945,7 +947,7 @@ def make_node(self, *matrices, name=None): if any(mat.type.ndim != 2 for mat in matrices): raise TypeError("all data arguments must be matrices") - out_type = pytensor.tensor.matrix(dtype=dtype, name=name) + out_type = pytensor.tensor.matrix(dtype=dtype) return Apply(self, matrices, [out_type]) def perform(self, node, inputs, output_storage, params=None): @@ -998,8 +1000,8 @@ def block_diag(*matrices: TensorVariable, name=None): if len(matrices) == 1: # graph optimization return matrices - _block_diagonal_matrix = BlockDiagonalMatrix(n_inputs=len(matrices)) - return _block_diagonal_matrix(*matrices, name=name) + _block_diagonal_matrix = Blockwise(BlockDiagonalMatrix(n_inputs=len(matrices))) + return _block_diagonal_matrix(*matrices) __all__ = [ diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 169ebef2a8..53e154facc 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -138,8 +138,8 @@ def test_jax_block_diag(): D = matrix("D") out = pt_slinalg.block_diag(A, B, C, D) - out_fg = FunctionGraph([A, B, C, D], [out]) + compare_jax_and_py( out_fg, [ @@ -149,3 +149,17 @@ def test_jax_block_diag(): np.random.normal(size=(4, 4)).astype(config.floatX), ], ) + + +def test_jax_block_diag_blockwise(): + A = pt.tensor3("A") + B = pt.tensor3("B") + out = pt_slinalg.block_diag(A, B) + out_fg = FunctionGraph([A, B], [out]) + compare_jax_and_py( + out_fg, + [ + np.random.normal(size=(5, 5, 5)).astype(config.floatX), + np.random.normal(size=(5, 3, 3)).astype(config.floatX), + ], + ) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 7fc1dadd8f..d13bb24741 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -667,5 +667,20 @@ def test_solve_discrete_are_grad(): def test_block_diagonal(): A = np.array([[1.0, 2.0], [3.0, 4.0]]) B = np.array([[5.0, 6.0], [7.0, 8.0]]) - result = block_diag(A, B, name="X") + result = block_diag(A, B) np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B)) + + +def test_block_diagonal_blockwise(): + batch_size = 5 + A = np.random.normal(size=(batch_size, 2, 2)).astype(config.floatX) + B = np.random.normal(size=(batch_size, 4, 4)).astype(config.floatX) + result = block_diag(A, B).eval() + assert result.shape == (batch_size, 6, 6) + for i in range(batch_size): + np.testing.assert_allclose( + result[i], + scipy.linalg.block_diag(A[i], B[i]), + atol=1e-4 if config.floatX == "float32" else 1e-8, + rtol=1e-4 if config.floatX == "float32" else 1e-8, + ) From 747ed1d2a4bf9e9d9d6c59380895aa0364e25ec4 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 7 Jan 2024 01:10:08 +0100 Subject: [PATCH 14/16] Add gradient test Remove `Matrix` from `BlockDiagonal` and `SparseBlockDiagonal` `Op` names Correct errors in docstrings Move input validation to a shared class method --- pytensor/link/jax/dispatch/slinalg.py | 9 +---- pytensor/link/numba/dispatch/slinalg.py | 6 +-- pytensor/sparse/basic.py | 34 +++++++---------- pytensor/tensor/slinalg.py | 49 +++++++++++++------------ tests/sparse/test_basic.py | 2 +- tests/tensor/test_slinalg.py | 13 +++++++ 6 files changed, 58 insertions(+), 55 deletions(-) diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 4d2df2d27a..73ddadc2a0 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -1,12 +1,7 @@ import jax from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.slinalg import ( - BlockDiagonalMatrix, - Cholesky, - Solve, - SolveTriangular, -) +from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular @jax_funcify.register(Cholesky) @@ -52,7 +47,7 @@ def solve_triangular(A, b): return solve_triangular -@jax_funcify.register(BlockDiagonalMatrix) +@jax_funcify.register(BlockDiagonal) def jax_funcify_BlockDiagonalMatrix(op, **kwargs): def block_diag(*inputs): return jax.scipy.linalg.block_diag(*inputs) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 79d1558ac3..a5ac0c6348 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -9,7 +9,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import BlockDiagonalMatrix, SolveTriangular +from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular _PTR = ctypes.POINTER @@ -275,8 +275,8 @@ def solve_triangular(a, b): return solve_triangular -@numba_funcify.register(BlockDiagonalMatrix) -def numba_funcify_BlockDiagonalMatrix(op, node, **kwargs): +@numba_funcify.register(BlockDiagonal) +def numba_funcify_BlockDiagonal(op, node, **kwargs): dtype = node.outputs[0].dtype # TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case. diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 700be281c8..94cbdd22f4 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -48,7 +48,7 @@ trunc, ) from pytensor.tensor.shape import shape, specify_broadcastable -from pytensor.tensor.slinalg import BaseBlockDiagonal, largest_common_dtype +from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype from pytensor.tensor.type import TensorType from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes @@ -4258,7 +4258,7 @@ def grad(self, inputs, grads): construct_sparse_from_list = ConstructSparseFromList() -class SparseBlockDiagonalMatrix(BaseBlockDiagonal): +class SparseBlockDiagonal(BaseBlockDiagonal): __props__ = ("format",) def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"): @@ -4266,15 +4266,12 @@ def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"): self.format = format def make_node(self, *matrices): - if not matrices: - raise ValueError("no matrices to allocate") - dtype = largest_common_dtype(matrices) - matrices = list(map(as_sparse_or_tensor_variable, matrices)) - - if any(mat.type.ndim != 2 for mat in matrices): - raise TypeError("all data arguments must be matrices") - + matrices = self._validate_and_prepare_inputs( + matrices, as_sparse_or_tensor_variable + ) + dtype = _largest_common_dtype(matrices) out_type = matrix(format=self.format, dtype=dtype) + return Apply(self, matrices, [out_type]) def perform(self, node, inputs, output_storage, params=None): @@ -4284,9 +4281,7 @@ def perform(self, node, inputs, output_storage, params=None): ).astype(dtype) -def block_diag( - *matrices: TensorVariable, format: Literal["csc", "csr"] = "csc", name=None -): +def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"): r""" Construct a block diagonal matrix from a sequence of input matrices. @@ -4298,16 +4293,15 @@ def block_diag( Parameters ---------- - A, B, C ... : tensors or array-like - Inputs to form the block diagonal matrix. Each input should have the same number of dimensions, - and the block diagonal matrix will be formed using the right-most two dimensions of each input matrix. + A, B, C ... : tensors + Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all + inputs should have at least 2 dimensins. Note that the input matrices need not be sparse themselves, and will be automatically converted to the requested format if they are not. + format: str, optional The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False. - name: str, optional - Name of the output tensor. Returns ------- @@ -4339,7 +4333,5 @@ def block_diag( if len(matrices) == 1: return matrices - _sparse_block_diagonal = SparseBlockDiagonalMatrix( - n_inputs=len(matrices), format=format - ) + _sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format) return _sparse_block_diagonal(*matrices) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index d31fee58d8..bfbc523f0d 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,7 +1,7 @@ -import functools as ft import logging import typing import warnings +from functools import reduce from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np @@ -908,17 +908,21 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: ) -def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: - return ft.reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) +def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: + return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) class BaseBlockDiagonal(Op): - __props__ = ("gufunc_signature",) + __props__ = ("gufunc_signature", "n_inputs") def __init__(self, n_inputs): input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)]) self.gufunc_signature = f"{input_sig}->(m,n)" + if n_inputs == 0: + raise ValueError("n_inputs must be greater than 0") + self.n_inputs = n_inputs + def grad(self, inputs, gout): shapes = pt.stack([i.shape for i in inputs]) index_end = shapes.cumsum(0) @@ -936,17 +940,21 @@ def infer_shape(self, fgraph, nodes, shapes): first, second = zip(*shapes) return [(pt.add(*first), pt.add(*second))] - -class BlockDiagonalMatrix(BaseBlockDiagonal): - def make_node(self, *matrices): - if not matrices: - raise ValueError("no matrices to allocate") - dtype = largest_common_dtype(matrices) - matrices = list(map(pt.as_tensor, matrices)) - + def _validate_and_prepare_inputs(self, matrices, as_tensor_func): + if len(matrices) != self.n_inputs: + raise ValueError( + f"Expected {self.n_inputs} matri{'ces' if self.n_inputs > 1 else 'x'}, got {len(matrices)}" + ) + matrices = list(map(as_tensor_func, matrices)) if any(mat.type.ndim != 2 for mat in matrices): - raise TypeError("all data arguments must be matrices") + raise TypeError("All inputs must have dimension 2") + return matrices + +class BlockDiagonal(BaseBlockDiagonal): + def make_node(self, *matrices): + matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor) + dtype = _largest_common_dtype(matrices) out_type = pytensor.tensor.matrix(dtype=dtype) return Apply(self, matrices, [out_type]) @@ -955,7 +963,7 @@ def perform(self, node, inputs, output_storage, params=None): output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) -def block_diag(*matrices: TensorVariable, name=None): +def block_diag(*matrices: TensorVariable): """ Construct a block diagonal matrix from a sequence of input tensors. @@ -968,10 +976,8 @@ def block_diag(*matrices: TensorVariable, name=None): Parameters ---------- A, B, C ... : tensors - Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the - block diagonal matrix will be formed using the right-most two dimensions of each input matrix. - name: str, optional - Name of the block diagonal matrix. + Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all + inputs should have at least 2 dimensins. Returns ------- @@ -985,7 +991,7 @@ def block_diag(*matrices: TensorVariable, name=None): ..code-block:: python import numpy as np - from pytensor.tensor.slinalg import block_diag + from pytensor.tensor.linalg import block_diag A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]])) B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]])) @@ -997,10 +1003,7 @@ def block_diag(*matrices: TensorVariable, name=None): >>> [0, 0, 5, 6], >>> [0, 0, 7, 8]]) """ - if len(matrices) == 1: # graph optimization - return matrices - - _block_diagonal_matrix = Blockwise(BlockDiagonalMatrix(n_inputs=len(matrices))) + _block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices))) return _block_diagonal_matrix(*matrices) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 2a5bce9ecb..faea6d0bf9 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -3401,7 +3401,7 @@ def test_block_diagonal(format, sparse_input): A = f_array([[1, 2], [3, 4]]).astype(config.floatX) B = f_array([[5, 6], [7, 8]]).astype(config.floatX) - result = block_diag(A, B, format=format, name="X") + result = block_diag(A, B, format=format) sp_result = sp_sparse.block_diag([A, B], format=format) assert isinstance(result.eval(), type(sp_result)) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index d13bb24741..232a991c7d 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -671,6 +671,13 @@ def test_block_diagonal(): np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B)) +def test_block_diagonal_grad(): + A = np.array([[1.0, 2.0], [3.0, 4.0]]) + B = np.array([[5.0, 6.0], [7.0, 8.0]]) + + utt.verify_grad(block_diag, pt=[A, B], rng=np.random.default_rng()) + + def test_block_diagonal_blockwise(): batch_size = 5 A = np.random.normal(size=(batch_size, 2, 2)).astype(config.floatX) @@ -684,3 +691,9 @@ def test_block_diagonal_blockwise(): atol=1e-4 if config.floatX == "float32" else 1e-8, rtol=1e-4 if config.floatX == "float32" else 1e-8, ) + + # Test broadcasting + A = np.random.normal(size=(10, batch_size, 2, 2)).astype(config.floatX) + B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX) + result = block_diag(A, B).eval() + assert result.shape == (10, batch_size, 6, 6) From 2daca2b44b87b75e8f2f20550e44c118b98c1849 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Sun, 7 Jan 2024 01:24:11 +0100 Subject: [PATCH 15/16] Remove `gufunc_signature` from `__props__` Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index bfbc523f0d..8c8490ab63 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -913,7 +913,7 @@ def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: class BaseBlockDiagonal(Op): - __props__ = ("gufunc_signature", "n_inputs") + __props__ = ("n_inputs",) def __init__(self, n_inputs): input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)]) From a9893b820d782ec67e7662702fb13365cbef2bbf Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 7 Jan 2024 01:31:58 +0100 Subject: [PATCH 16/16] Implement correct `__props__` for subclasses of `BaseBlockMatrix` --- pytensor/sparse/basic.py | 5 ++++- pytensor/tensor/slinalg.py | 2 ++ tests/sparse/test_basic.py | 2 ++ tests/tensor/test_slinalg.py | 2 ++ 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 94cbdd22f4..96105adc5c 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -4259,7 +4259,10 @@ def grad(self, inputs, grads): class SparseBlockDiagonal(BaseBlockDiagonal): - __props__ = ("format",) + __props__ = ( + "n_inputs", + "format", + ) def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"): super().__init__(n_inputs) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 8c8490ab63..aae80fb578 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -952,6 +952,8 @@ def _validate_and_prepare_inputs(self, matrices, as_tensor_func): class BlockDiagonal(BaseBlockDiagonal): + __props__ = ("n_inputs",) + def make_node(self, *matrices): matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor) dtype = _largest_common_dtype(matrices) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index faea6d0bf9..590c76e008 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -3402,6 +3402,8 @@ def test_block_diagonal(format, sparse_input): B = f_array([[5, 6], [7, 8]]).astype(config.floatX) result = block_diag(A, B, format=format) + assert result.owner.op._props_dict() == {"n_inputs": 2, "format": format} + sp_result = sp_sparse.block_diag([A, B], format=format) assert isinstance(result.eval(), type(sp_result)) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 232a991c7d..a2cc3c52e8 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -668,6 +668,8 @@ def test_block_diagonal(): A = np.array([[1.0, 2.0], [3.0, 4.0]]) B = np.array([[5.0, 6.0], [7.0, 8.0]]) result = block_diag(A, B) + assert result.owner.op.core_op._props_dict() == {"n_inputs": 2} + np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B))