Skip to content

Commit a1d7edc

Browse files
authored
Add complex number support to linalg.slogdet (#567)
1 parent dff1997 commit a1d7edc

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

spec/API_specification/array_api/linalg.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -386,31 +386,51 @@ def qr(x: array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> Tupl
386386
"""
387387

388388
def slogdet(x: array, /) -> Tuple[array, array]:
389-
"""
389+
r"""
390390
Returns the sign and the natural logarithm of the absolute value of the determinant of a square matrix (or a stack of square matrices) ``x``.
391391
392392
.. note::
393393
The purpose of this function is to calculate the determinant more accurately when the determinant is either very small or very large, as calling ``det`` may overflow or underflow.
394394
395+
The sign of the determinant is given by
396+
397+
.. math::
398+
\operatorname{sign}(\det x) = \begin{cases}
399+
0 & \textrm{if } \det x = 0 \\
400+
\frac{\det x}{|\det x|} & \textrm{otherwise}
401+
\end{cases}
402+
403+
where :math:`|\det x|` is the absolute value of the determinant of ``x``.
404+
405+
When ``x`` is a stack of matrices, the function must compute the sign and natural logarithm of the absolute value of the determinant for each matrix in the stack.
406+
407+
**Special Cases**
408+
409+
For real-valued floating-point operands,
410+
411+
- If the determinant is zero, the ``sign`` should be ``0`` and ``logabsdet`` should be ``-infinity``.
412+
413+
For complex floating-point operands,
414+
415+
- If the determinant is ``0 + 0j``, the ``sign`` should be ``0 + 0j`` and ``logabsdet`` should be ``-infinity + 0j``.
416+
417+
.. note::
418+
Depending on the underlying algorithm, when the determinant is zero, the returned result may differ from ``-infinity`` (or ``-infinity + 0j``). In all cases, the determinant should be equal to ``sign * exp(logabsdet)`` (although, again, the result may be subject to numerical precision errors).
419+
395420
Parameters
396421
----------
397422
x: array
398-
input array having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Should have a real-valued floating-point data type.
423+
input array having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Should have a floating-point data type.
399424
400425
Returns
401426
-------
402427
out: Tuple[array, array]
403428
a namedtuple (``sign``, ``logabsdet``) whose
404429
405-
- first element must have the field name ``sign`` and must be an array containing a number representing the sign of the determinant for each square matrix.
406-
- second element must have the field name ``logabsdet`` and must be an array containing the determinant for each square matrix.
407-
408-
For a real matrix, the sign of the determinant must be either ``1``, ``0``, or ``-1``.
430+
- first element must have the field name ``sign`` and must be an array containing a number representing the sign of the determinant for each square matrix. Must have the same data type as ``x``.
431+
- second element must have the field name ``logabsdet`` and must be an array containing the natural logarithm of the absolute value of the determinant for each square matrix. If ``x`` is real-valued, the returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. If ``x`` is complex, the returned array must have a real-valued floating-point data type having the same precision as ``x`` (e.g., if ``x`` is ``complex64``, ``logabsdet`` must have a ``float32`` data type).
409432
410-
Each returned array must have shape ``shape(x)[:-2]`` and a real-valued floating-point data type determined by :ref:`type-promotion`.
411-
412-
.. note::
413-
If a determinant is zero, then the corresponding ``sign`` should be ``0`` and ``logabsdet`` should be ``-infinity``; however, depending on the underlying algorithm, the returned result may differ. In all cases, the determinant should be equal to ``sign * exp(logsabsdet)`` (although, again, the result may be subject to numerical precision errors).
433+
Each returned array must have shape ``shape(x)[:-2]``.
414434
"""
415435

416436
def solve(x1: array, x2: array, /) -> array:

0 commit comments

Comments
 (0)