Skip to content

Commit 548103c

Browse files
authored
Merge pull request #10 from lucascolley/cov
2 parents be06b63 + 2edc2ce commit 548103c

File tree

4 files changed

+154
-5
lines changed

4 files changed

+154
-5
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
:toctree: generated
88
99
atleast_nd
10+
cov
1011
expand_dims
1112
kron
1213
```

src/array_api_extra/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd, expand_dims, kron
3+
from ._funcs import atleast_nd, cov, expand_dims, kron
44

55
__version__ = "0.1.2.dev0"
66

7-
__all__ = ["__version__", "atleast_nd", "expand_dims", "kron"]
7+
__all__ = ["__version__", "atleast_nd", "cov", "expand_dims", "kron"]

src/array_api_extra/_funcs.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3+
import warnings
34
from typing import TYPE_CHECKING
45

56
if TYPE_CHECKING:
67
from ._typing import Array, ModuleType
78

8-
__all__ = ["atleast_nd", "expand_dims", "kron"]
9+
__all__ = ["atleast_nd", "cov", "expand_dims", "kron"]
910

1011

1112
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
@@ -48,6 +49,117 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
4849
return x
4950

5051

52+
def cov(m: Array, /, *, xp: ModuleType) -> Array:
53+
"""
54+
Estimate a covariance matrix.
55+
56+
Covariance indicates the level to which two variables vary together.
57+
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
58+
then the covariance matrix element :math:`C_{ij}` is the covariance of
59+
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
60+
of :math:`x_i`.
61+
62+
This provides a subset of the functionality of ``numpy.cov``.
63+
64+
Parameters
65+
----------
66+
m : array
67+
A 1-D or 2-D array containing multiple variables and observations.
68+
Each row of `m` represents a variable, and each column a single
69+
observation of all those variables.
70+
xp : array_namespace
71+
The standard-compatible namespace for `m`.
72+
73+
Returns
74+
-------
75+
res : array
76+
The covariance matrix of the variables.
77+
78+
Examples
79+
--------
80+
>>> import array_api_strict as xp
81+
>>> import array_api_extra as xpx
82+
83+
Consider two variables, :math:`x_0` and :math:`x_1`, which
84+
correlate perfectly, but in opposite directions:
85+
86+
>>> x = xp.asarray([[0, 2], [1, 1], [2, 0]]).T
87+
>>> x
88+
Array([[0, 1, 2],
89+
[2, 1, 0]], dtype=array_api_strict.int64)
90+
91+
Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance
92+
matrix shows this clearly:
93+
94+
>>> xpx.cov(x, xp=xp)
95+
Array([[ 1., -1.],
96+
[-1., 1.]], dtype=array_api_strict.float64)
97+
98+
99+
Note that element :math:`C_{0,1}`, which shows the correlation between
100+
:math:`x_0` and :math:`x_1`, is negative.
101+
102+
Further, note how `x` and `y` are combined:
103+
104+
>>> x = xp.asarray([-2.1, -1, 4.3])
105+
>>> y = xp.asarray([3, 1.1, 0.12])
106+
>>> X = xp.stack((x, y), axis=0)
107+
>>> xpx.cov(X, xp=xp)
108+
Array([[11.71 , -4.286 ],
109+
[-4.286 , 2.14413333]], dtype=array_api_strict.float64)
110+
111+
>>> xpx.cov(x, xp=xp)
112+
Array(11.71, dtype=array_api_strict.float64)
113+
114+
>>> xpx.cov(y, xp=xp)
115+
Array(2.14413333, dtype=array_api_strict.float64)
116+
117+
"""
118+
m = xp.asarray(m, copy=True)
119+
dtype = (
120+
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)
121+
)
122+
123+
m = atleast_nd(m, ndim=2, xp=xp)
124+
m = xp.astype(m, dtype)
125+
126+
avg = _mean(m, axis=1, xp=xp)
127+
fact = m.shape[1] - 1
128+
129+
if fact <= 0:
130+
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
131+
fact = 0.0
132+
133+
m -= avg[:, None]
134+
m_transpose = m.T
135+
if xp.isdtype(m_transpose.dtype, "complex floating"):
136+
m_transpose = xp.conj(m_transpose)
137+
c = m @ m_transpose
138+
c /= fact
139+
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
140+
return xp.squeeze(c, axis=axes)
141+
142+
143+
def _mean(
144+
x: Array,
145+
/,
146+
*,
147+
axis: int | tuple[int, ...] | None = None,
148+
keepdims: bool = False,
149+
xp: ModuleType,
150+
) -> Array:
151+
"""
152+
Complex mean, https://github.com/data-apis/array-api/issues/846.
153+
"""
154+
if xp.isdtype(x.dtype, "complex floating"):
155+
x_real = xp.real(x)
156+
x_imag = xp.imag(x)
157+
mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
158+
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
159+
return mean_real + (mean_imag * xp.asarray(1j))
160+
return xp.mean(x, axis=axis, keepdims=keepdims)
161+
162+
51163
def expand_dims(
52164
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
53165
) -> Array:

tests/test_funcs.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from __future__ import annotations
22

33
import contextlib
4+
import warnings
45
from typing import TYPE_CHECKING, Any
56

67
# array-api-strict#6
78
import array_api_strict as xp # type: ignore[import-untyped]
89
import pytest
9-
from numpy.testing import assert_array_equal, assert_equal
10+
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
1011

11-
from array_api_extra import atleast_nd, expand_dims, kron
12+
from array_api_extra import atleast_nd, cov, expand_dims, kron
1213

1314
if TYPE_CHECKING:
1415
Array = Any # To be changed to a Protocol later (see array-api#589)
@@ -76,6 +77,41 @@ def test_5D(self):
7677
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))
7778

7879

80+
class TestCov:
81+
def test_basic(self):
82+
assert_allclose(
83+
cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T, xp=xp),
84+
xp.asarray([[1.0, -1.0], [-1.0, 1.0]]),
85+
)
86+
87+
def test_complex(self):
88+
x = xp.asarray([[1, 2, 3], [1j, 2j, 3j]])
89+
res = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]])
90+
assert_allclose(cov(x, xp=xp), res)
91+
92+
def test_empty(self):
93+
with warnings.catch_warnings(record=True):
94+
warnings.simplefilter("always", RuntimeWarning)
95+
assert_array_equal(cov(xp.asarray([]), xp=xp), xp.nan)
96+
assert_array_equal(
97+
cov(xp.reshape(xp.asarray([]), (0, 2)), xp=xp),
98+
xp.reshape(xp.asarray([]), (0, 0)),
99+
)
100+
assert_array_equal(
101+
cov(xp.reshape(xp.asarray([]), (2, 0)), xp=xp),
102+
xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]]),
103+
)
104+
105+
def test_combination(self):
106+
x = xp.asarray([-2.1, -1, 4.3])
107+
y = xp.asarray([3, 1.1, 0.12])
108+
X = xp.stack((x, y), axis=0)
109+
desired = xp.asarray([[11.71, -4.286], [-4.286, 2.144133]])
110+
assert_allclose(cov(X, xp=xp), desired, rtol=1e-6)
111+
assert_allclose(cov(x, xp=xp), xp.asarray(11.71))
112+
assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)
113+
114+
79115
class TestKron:
80116
def test_basic(self):
81117
# Using 0-dimensional array

0 commit comments

Comments
 (0)