From 2b6fa678a2925e709ca6ef823cb269ec4b05c546 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Mon, 12 Dec 2022 11:55:42 -0800 Subject: [PATCH 1/2] Add complex number support to `linalg.solve` --- spec/API_specification/array_api/linalg.py | 23 ++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index b3595e1fa..9ae950f39 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -373,23 +373,34 @@ def slogdet(x: array, /) -> Tuple[array, array]: """ def solve(x1: array, x2: array, /) -> array: - """ - Returns the solution to the system of linear equations represented by the well-determined (i.e., full rank) linear matrix equation ``AX = B``. + r""" + Returns the solution of a square system of linear equations with a unique solution. + + Let ``x1`` equal :math:`A` and ``x2`` equal :math:`B`. If the promoted data type of ``x1`` and ``x2`` is real-valued, let :math:`\mathbb{K}` be the set of real numbers :math:`\mathbb{R}`, and, if the promoted data type of ``x1`` and ``x2`` is complex-valued, let :math:`\mathbb{K}` be the set of complex numbers :math:`\mathbb{C}`. + + This function computes the solution :math:`X \in\ \mathbb{K}^{m \times k}` of the **linear system** associated to :math:`A \in\ \mathbb{K}^{m \times m}` and :math:`B \in\ \mathbb{K}^{m \times k}` and is defined as + + .. math:: + AX = B + + This system of linear equations has a unique solution if and only if :math:`A` is invertible. .. note:: - Whether an array library explicitly checks whether an input array is full rank is implementation-defined. + Whether an array library explicitly checks whether ``x1`` is invertible is implementation-defined. + + When ``x`` is a stack of matrices, the function must compute a solution for each matrix in the stack. Parameters ---------- x1: array - coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Must be of full rank (i.e., all rows or, equivalently, columns must be linearly independent). Should have a real-valued floating-point data type. + coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Must be of full rank (i.e., all rows or, equivalently, columns must be linearly independent). Should have a floating-point data type. x2: array - ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-1]`` must be compatible with ``shape(x1)[:-1]`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. + ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-1]`` must be compatible with ``shape(x1)[:-1]`` (see :ref:`broadcasting`). Should have a floating-point data type. Returns ------- out: array - an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a real-valued floating-point data type determined by :ref:`type-promotion`. + an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`. """ def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]: From dfa0dc6435fed06b5505022c35ea72ed868e2c71 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Mon, 12 Dec 2022 11:58:16 -0800 Subject: [PATCH 2/2] Update copy --- spec/API_specification/array_api/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index 9ae950f39..ecea0f4e7 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -388,7 +388,7 @@ def solve(x1: array, x2: array, /) -> array: .. note:: Whether an array library explicitly checks whether ``x1`` is invertible is implementation-defined. - When ``x`` is a stack of matrices, the function must compute a solution for each matrix in the stack. + When ``x1`` and/or ``x2`` is a stack of matrices, the function must compute a solution for each matrix in the stack. Parameters ----------