|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import warnings |
3 | 4 | from typing import TYPE_CHECKING
|
4 | 5 |
|
5 | 6 | if TYPE_CHECKING:
|
6 | 7 | from ._typing import Array, ModuleType
|
7 | 8 |
|
8 |
| -__all__ = ["atleast_nd", "expand_dims", "kron"] |
| 9 | +__all__ = ["atleast_nd", "cov", "expand_dims", "kron"] |
9 | 10 |
|
10 | 11 |
|
11 | 12 | def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
|
@@ -48,6 +49,117 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
|
48 | 49 | return x
|
49 | 50 |
|
50 | 51 |
|
| 52 | +def cov(m: Array, /, *, xp: ModuleType) -> Array: |
| 53 | + """ |
| 54 | + Estimate a covariance matrix. |
| 55 | +
|
| 56 | + Covariance indicates the level to which two variables vary together. |
| 57 | + If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`, |
| 58 | + then the covariance matrix element :math:`C_{ij}` is the covariance of |
| 59 | + :math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance |
| 60 | + of :math:`x_i`. |
| 61 | +
|
| 62 | + This provides a subset of the functionality of ``numpy.cov``. |
| 63 | +
|
| 64 | + Parameters |
| 65 | + ---------- |
| 66 | + m : array |
| 67 | + A 1-D or 2-D array containing multiple variables and observations. |
| 68 | + Each row of `m` represents a variable, and each column a single |
| 69 | + observation of all those variables. |
| 70 | + xp : array_namespace |
| 71 | + The standard-compatible namespace for `m`. |
| 72 | +
|
| 73 | + Returns |
| 74 | + ------- |
| 75 | + res : array |
| 76 | + The covariance matrix of the variables. |
| 77 | +
|
| 78 | + Examples |
| 79 | + -------- |
| 80 | + >>> import array_api_strict as xp |
| 81 | + >>> import array_api_extra as xpx |
| 82 | +
|
| 83 | + Consider two variables, :math:`x_0` and :math:`x_1`, which |
| 84 | + correlate perfectly, but in opposite directions: |
| 85 | +
|
| 86 | + >>> x = xp.asarray([[0, 2], [1, 1], [2, 0]]).T |
| 87 | + >>> x |
| 88 | + Array([[0, 1, 2], |
| 89 | + [2, 1, 0]], dtype=array_api_strict.int64) |
| 90 | +
|
| 91 | + Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance |
| 92 | + matrix shows this clearly: |
| 93 | +
|
| 94 | + >>> xpx.cov(x, xp=xp) |
| 95 | + Array([[ 1., -1.], |
| 96 | + [-1., 1.]], dtype=array_api_strict.float64) |
| 97 | +
|
| 98 | +
|
| 99 | + Note that element :math:`C_{0,1}`, which shows the correlation between |
| 100 | + :math:`x_0` and :math:`x_1`, is negative. |
| 101 | +
|
| 102 | + Further, note how `x` and `y` are combined: |
| 103 | +
|
| 104 | + >>> x = xp.asarray([-2.1, -1, 4.3]) |
| 105 | + >>> y = xp.asarray([3, 1.1, 0.12]) |
| 106 | + >>> X = xp.stack((x, y), axis=0) |
| 107 | + >>> xpx.cov(X, xp=xp) |
| 108 | + Array([[11.71 , -4.286 ], |
| 109 | + [-4.286 , 2.14413333]], dtype=array_api_strict.float64) |
| 110 | +
|
| 111 | + >>> xpx.cov(x, xp=xp) |
| 112 | + Array(11.71, dtype=array_api_strict.float64) |
| 113 | +
|
| 114 | + >>> xpx.cov(y, xp=xp) |
| 115 | + Array(2.14413333, dtype=array_api_strict.float64) |
| 116 | +
|
| 117 | + """ |
| 118 | + m = xp.asarray(m, copy=True) |
| 119 | + dtype = ( |
| 120 | + xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64) |
| 121 | + ) |
| 122 | + |
| 123 | + m = atleast_nd(m, ndim=2, xp=xp) |
| 124 | + m = xp.astype(m, dtype) |
| 125 | + |
| 126 | + avg = _mean(m, axis=1, xp=xp) |
| 127 | + fact = m.shape[1] - 1 |
| 128 | + |
| 129 | + if fact <= 0: |
| 130 | + warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) |
| 131 | + fact = 0.0 |
| 132 | + |
| 133 | + m -= avg[:, None] |
| 134 | + m_transpose = m.T |
| 135 | + if xp.isdtype(m_transpose.dtype, "complex floating"): |
| 136 | + m_transpose = xp.conj(m_transpose) |
| 137 | + c = m @ m_transpose |
| 138 | + c /= fact |
| 139 | + axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) |
| 140 | + return xp.squeeze(c, axis=axes) |
| 141 | + |
| 142 | + |
| 143 | +def _mean( |
| 144 | + x: Array, |
| 145 | + /, |
| 146 | + *, |
| 147 | + axis: int | tuple[int, ...] | None = None, |
| 148 | + keepdims: bool = False, |
| 149 | + xp: ModuleType, |
| 150 | +) -> Array: |
| 151 | + """ |
| 152 | + Complex mean, https://github.com/data-apis/array-api/issues/846. |
| 153 | + """ |
| 154 | + if xp.isdtype(x.dtype, "complex floating"): |
| 155 | + x_real = xp.real(x) |
| 156 | + x_imag = xp.imag(x) |
| 157 | + mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims) |
| 158 | + mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims) |
| 159 | + return mean_real + (mean_imag * xp.asarray(1j)) |
| 160 | + return xp.mean(x, axis=axis, keepdims=keepdims) |
| 161 | + |
| 162 | + |
51 | 163 | def expand_dims(
|
52 | 164 | a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
|
53 | 165 | ) -> Array:
|
|
0 commit comments