Skip to content

Commit b79b93b

Browse files
authored
Add initial linear algebra function specifications (#20)
* Add cross signature * Add det specification * Add diagonal specification * Add inv specification * Add norm specification * Add outer product specification * Add outer specification * Add trace specification * Add transpose * Update index * Fix type annotation * Update norm behavior for multi-dimensional arrays * Support all of NumPy's norms Further support for supporting all of NumPy's norms comes from pending updates to Torch (see pytorch/pytorch#42749). * Switch order * Split into separate tables * Escape markup * Escape markup * Add matrix_transpose interface Interface inspired by TensorFlow (see https://www.tensorflow.org/api_docs/python/tf/linalg/matrix_transpose) and Torch (see https://pytorch.org/docs/stable/generated/torch.transpose.html). Allows transposing a stack of matrices. * Rename parameters * Remove matrix_transpose signature
1 parent 84a4804 commit b79b93b

File tree

3 files changed

+275
-0
lines changed

3 files changed

+275
-0
lines changed

spec/API_specification/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ API specification
1414
out_keyword
1515
elementwise_functions
1616
statistical_functions
17+
linear_algebra_functions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Linear Algebra Functions
2+
3+
> Array API specification for linear algebra functions.
4+
5+
A conforming implementation of the array API standard must provide and support the following functions adhering to the following conventions.
6+
7+
- Positional parameters must be [positional-only](https://www.python.org/dev/peps/pep-0570/) parameters. Positional-only parameters have no externally-usable name. When a function accepting positional-only parameters is called, positional arguments are mapped to these parameters based solely on their order.
8+
- Optional parameters must be [keyword-only](https://www.python.org/dev/peps/pep-3102/) arguments.
9+
- Broadcasting semantics must follow the semantics defined in :ref:`broadcasting`.
10+
- Unless stated otherwise, functions must support the data types defined in :ref:`data-types`.
11+
- Unless stated otherwise, functions must adhere to the type promotion rules defined in :ref:`type-promotion`.
12+
- Unless stated otherwise, floating-point operations must adhere to IEEE 754-2019.
13+
14+
<!-- NOTE: please keep the functions in alphabetical order -->
15+
16+
### <a name="cross" href="#cross">#</a> cross(x1, x2, /, *, axis=-1)
17+
18+
Returns the cross product of 3-element vectors. If `x1` and `x2` are multi-dimensional arrays (i.e., both have a rank greater than `1`), then the cross-product of each pair of corresponding 3-element vectors is independently computed.
19+
20+
#### Parameters
21+
22+
- **x1**: _&lt;array&gt;_
23+
24+
- first input array.
25+
26+
- **x2**: _&lt;array&gt;_
27+
28+
- second input array. Must have the same shape as `x1`.
29+
30+
- **axis**: _int_
31+
32+
- the axis (dimension) of `x1` and `x2` containing the vectors for which to compute the cross product. If set to `-1`, the function computes the cross product for vectors defined by the last axis (dimension). Default: `-1`.
33+
34+
#### Returns
35+
36+
- **out**: _&lt;array&gt;_
37+
38+
- an array containing the cross products.
39+
40+
### <a name="det" href="#det">#</a> det(x, /)
41+
42+
Returns the determinant of a square matrix (or stack of square matrices) `x`.
43+
44+
#### Parameters
45+
46+
- **a**: _&lt;array&gt;_
47+
48+
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices.
49+
50+
#### Returns
51+
52+
- **out**: _&lt;array&gt;_
53+
54+
- if `x` is a two-dimensional array, a zero-dimensional array containing the determinant; otherwise, a non-zero dimensional array containing the determinant for each square matrix.
55+
56+
### <a name="diagonal" href="#diagonal">#</a> diagonal(x, /, *, axis1=0, axis2=1, offset=0)
57+
58+
Returns the specified diagonals. If `x` has more than two dimensions, then the axes (dimensions) specified by `axis1` and `axis2` are used to determine the two-dimensional sub-arrays from which to return diagonals.
59+
60+
#### Parameters
61+
62+
- **x**: _&lt;array&gt;_
63+
64+
- input array. Must have at least `2` dimensions.
65+
66+
- **axis1**: _int_
67+
68+
- first axis (dimension) with respect to which to take diagonal. Default: `0`.
69+
70+
- **axis2**: _int_
71+
72+
- second axis (dimension) with respect to which to take diagonal. Default: `1`.
73+
74+
- **offset**: _int_
75+
76+
- offset specifying the off-diagonal relative to the main diagonal.
77+
78+
- `offset = 0`: the main diagonal.
79+
- `offset > 0`: off-diagonal above the main diagonal.
80+
- `offset < 0`: off-diagonal below the main diagonal.
81+
82+
Default: `0`.
83+
84+
#### Returns
85+
86+
- **out**: _&lt;array&gt;_
87+
88+
- if `x` is a two-dimensional array, a one-dimensional array containing the diagonal; otherwise, a multi-dimensional array containing the diagonals and whose shape is determined by removing `axis1` and `axis2` and appending a dimension equal to the size of the resulting diagonals. Must have the same data type as `x`.
89+
90+
### <a name="inv" href="#inv">#</a> inv(x, /)
91+
92+
Computes the multiplicative inverse of a square matrix (or stack of square matrices) `x`.
93+
94+
#### Parameters
95+
96+
- **x**: _&lt;array&gt;_
97+
98+
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices.
99+
100+
#### Returns
101+
102+
- **out**: _&lt;array&gt;_
103+
104+
- an array containing the multiplicative inverses. Must have the same data type and shape as `x`.
105+
106+
### <a name="norm" href="#norm">#</a> norm(x, /, *, axis=None, keepdims=False, ord=None)
107+
108+
Computes the matrix or vector norm of `x`.
109+
110+
#### Parameters
111+
112+
- **x**: _&lt;array&gt;_
113+
114+
- input array.
115+
116+
- **axis**: _Optional\[ Union\[ int, Tuple\[ int, int ] ] ]_
117+
118+
- If an integer, `axis` specifies the axis (dimension) along which to compute vector norms.
119+
120+
If a 2-tuple, `axis` specifies the axes (dimensions) defining two-dimensional matrices for which to compute matrix norms.
121+
122+
If `None`,
123+
124+
- if `x` is one-dimensional, the function computes the vector norm.
125+
- if `x` is two-dimensional, the function computes the matrix norm.
126+
- if `x` has more than two dimensions, the function computes the vector norm over all array values (i.e., equivalent to computing the vector norm of a flattened array).
127+
128+
Negative indices must be supported. Default: `None`.
129+
130+
- **keepdims**: _bool_
131+
132+
- If `True`, the axes (dimensions) specified by `axis` must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see :ref:`broadcasting`). Otherwise, if `False`, the axes (dimensions) specified by `axis` must not be included in the result. Default: `False`.
133+
134+
- **ord**: _Optional\[ int, float, Literal\[ inf, -inf, 'fro', 'nuc' ] ]_
135+
136+
- order of the norm. The following mathematical norms must be supported:
137+
138+
| ord | matrix | vector |
139+
| ---------------- | ------------------------------- | -------------------------- |
140+
| 'fro' | 'fro' | - |
141+
| 'nuc' | 'nuc' | - |
142+
| 1 | max(sum(abs(x), axis=0)) | L1-norm (Manhattan) |
143+
| 2 | largest singular value | L2-norm (Euclidean) |
144+
| inf | max(sum(abs(x), axis=1)) | infinity norm |
145+
| (int,float >= 1) | - | p-norm |
146+
147+
The following non-mathematical "norms" must be supported:
148+
149+
| ord | matrix | vector |
150+
| ---------------- | ------------------------------- | ------------------------------ |
151+
| 0 | - | sum(a != 0) |
152+
| -1 | min(sum(abs(x), axis=0)) | 1./sum(1./abs(a)) |
153+
| -2 | smallest singular value | 1./sqrt(sum(1./abs(a)\*\*2)) |
154+
| -inf | min(sum(abs(x), axis=1)) | min(abs(a)) |
155+
| (int,float < 1) | - | sum(abs(a)\*\*ord)\*\*(1./ord) |
156+
157+
When `ord` is `None`, the following norms must be the default norms:
158+
159+
| ord | matrix | vector |
160+
| ---------------- | ------------------------------- | -------------------------- |
161+
| None | 'fro' | L2-norm (Euclidean) |
162+
163+
where `fro` corresponds to the **Frobenius norm**, `nuc` corresponds to the **nuclear norm**, and `-` indicates that the norm is **not** supported.
164+
165+
For matrices,
166+
167+
- if `ord=1`, the norm corresponds to the induced matrix norm where `p=1` (i.e., the maximum absolute value column sum).
168+
- if `ord=2`, the norm corresponds to the induced matrix norm where `p=inf` (i.e., the maximum absolute value row sum).
169+
- if `ord=inf`, the norm corresponds to the induced matrix norm where `p=2` (i.e., the largest singular value).
170+
171+
If `None`,
172+
173+
- if matrix (or matrices), the function computes the Frobenius norm.
174+
- if vector (or vectors), the function computes the L2-norm (Euclidean norm).
175+
176+
Default: `None`.
177+
178+
#### Returns
179+
180+
- **out**: _&lt;array&gt;_
181+
182+
- an array containing the norms. Must have the same data type as `x`. If `axis` is `None`, the output array is a zero-dimensional array containing a vector norm. If `axis` is a scalar value (`int` or `float`), the output array has a rank which is one less than the rank of `x`. If `axis` is a 2-tuple, the output array has a rank which is two less than the rank of `x`.
183+
184+
### <a name="outer" href="#outer">#</a> outer(x1, x2, /)
185+
186+
Computes the outer product of two vectors `x1` and `x2`.
187+
188+
#### Parameters
189+
190+
- **x1**: _&lt;array&gt;_
191+
192+
- first one-dimensional input array of size `N`.
193+
194+
- **x2**: _&lt;array&gt;_
195+
196+
- second one-dimensional input array of size `M`.
197+
198+
#### Returns
199+
200+
- **out**: _&lt;array&gt;_
201+
202+
- a two-dimensional array containing the outer product and whose shape is `NxM`.
203+
204+
### <a name="trace" href="#trace">#</a> trace(x, /, *, axis1=0, axis2=1, offset=0)
205+
206+
Returns the sum along the specified diagonals. If `x` has more than two dimensions, then the axes (dimensions) specified by `axis1` and `axis2` are used to determine the two-dimensional sub-arrays for which to compute the trace.
207+
208+
#### Parameters
209+
210+
- **x**: _&lt;array&gt;_
211+
212+
- input array. Must have at least `2` dimensions.
213+
214+
- **axis1**: _int_
215+
216+
- first axis (dimension) with respect to which to compute the trace. Default: `0`.
217+
218+
- **axis2**: _int_
219+
220+
- second axis (dimension) with respect to which to compute the trace. Default: `1`.
221+
222+
- **offset**: _int_
223+
224+
- offset specifying the off-diagonal relative to the main diagonal.
225+
226+
- `offset = 0`: the main diagonal.
227+
- `offset > 0`: off-diagonal above the main diagonal.
228+
- `offset < 0`: off-diagonal below the main diagonal.
229+
230+
Default: `0`.
231+
232+
#### Returns
233+
234+
- **out**: _&lt;array&gt;_
235+
236+
- if `x` is a two-dimensional array, a zero-dimensional array containing the trace; otherwise, a multi-dimensional array containing the traces.
237+
238+
The shape of a multi-dimensional output array is determined by removing `axis1` and `axis2` and storing the traces in the last array dimension. For example, if `x` has rank `k` and shape `(I, J, K, ..., L, M, N)` and `axis1=-2` and `axis1=-1`, then a multi-dimensional output array has rank `k-2` and shape `(I, J, K, ..., L)` where
239+
240+
```text
241+
out[i, j, k, ..., l] = trace(a[i, j, k, ..., l, :, :])
242+
```
243+
244+
### <a name="transpose" href="#transpose">#</a> transpose(x, /, *, axes=None)
245+
246+
Transposes (or permutes the axes (dimensions)) of an array `x`.
247+
248+
#### Parameters
249+
250+
- **x**: _&lt;array&gt;_
251+
252+
- input array.
253+
254+
- **axes**: _Optional\[ Tuple\[ int, ... ] ]_
255+
256+
- tuple containing a permutation of `(0, 1, ..., N-1)` where `N` is the number of axes (dimensions) of `x`. If `None`, the axes (dimensions) are permuted in reverse order (i.e., equivalent to setting `axes=(N-1, ..., 1, 0)`). Default: `None`.
257+
258+
#### Returns
259+
260+
- **out**: _&lt;array&gt;_
261+
262+
- an array containing the transpose. Must have the same data type as `x`.

spec/purpose_and_scope.md

+12
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ For the purposes of this specification, the following terms and definitions appl
5050

5151
a (usually fixed-size) multidimensional container of items of the same type and size.
5252

53+
### axis
54+
55+
an array dimension.
56+
5357
### broadcast
5458

5559
automatic (implicit) expansion of array dimensions to be of equal sizes without copying array data for the purpose of making arrays with different shapes have compatible shapes for element-wise operations.
@@ -62,6 +66,10 @@ two arrays whose dimensions are compatible (i.e., where the size of each dimensi
6266

6367
an operation performed element-by-element, in which individual array elements are considered in isolation and independently of other elements within the same array.
6468

69+
### matrix
70+
71+
a two-dimensional array.
72+
6573
### rank
6674

6775
number of array dimensions (not to be confused with the number of linearly independent columns of a matrix).
@@ -74,6 +82,10 @@ a tuple of `N` non-negative integers that specify the sizes of each dimension an
7482

7583
a dimension whose size is one.
7684

85+
### vector
86+
87+
a one-dimensional array.
88+
7789
* * *
7890

7991
## Normative References

0 commit comments

Comments
 (0)