-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add support for cross product #5365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 58 commits
1490c16
03db734
c824e36
7ce39c7
916e661
654ad60
a6ac578
e0c1fac
7aebae7
b85e236
2b54a42
be7b2c2
4448006
af8b09c
a135e05
6f17b9b
1fadb5f
57239a4
265ef82
dd60562
a20cb86
7ce9315
d5a0ea8
ef94fa4
1a85147
2ce3dbe
53c84c2
dded720
7058166
cb57a55
e69ca81
4b2fc72
afe572d
e137350
1a26324
531a98b
2146406
0940472
a7cc565
1d1f205
9af7091
14decb3
6f73c32
72330ce
bce2f3e
1636d25
b5b97a0
f77780f
02364ca
e842c75
ed44400
4fe9737
ec05780
36c5956
cbf289c
4cfd5be
658a59f
ab5ae20
d65ca41
20eef03
274af32
f352303
0a773cb
d8da29f
54a76c1
0a2dc2e
b3592f3
06772da
cfd11f7
8451a9e
90553ed
6eed96e
d3648e5
c639aa3
4c636f5
3bea936
4fc7fcb
19e8f93
f71a6f1
d4070ab
12da913
ea062e6
ebd89e6
3c7122b
9af1198
27262e6
cc91e7c
629df59
972c7dc
3c4ace0
49967d4
6ab7d19
20a6cb6
ba3fa9c
8b192f2
a27965c
5ec65d2
b058084
f007ed5
e88ae9d
9aaee2b
5d6ecba
71fc9c1
a98b2e3
c95817b
408eb39
316b935
3b5b030
f9c5404
34b300d
cf13bf9
f2167a6
570a806
6f57ed6
52a986b
fa78e74
f2d98b6
7449cd7
70d2a4b
e6020e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,6 +41,7 @@ | |
| from .dataset import Dataset | ||
|
|
||
| T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) | ||
| T_DSorDAorVar = TypeVar("T_DSorDAorVar", Dataset, DataArray, Variable) | ||
|
|
||
| _NO_FILL_VALUE = utils.ReprObject("<no-fill-value>") | ||
| _DEFAULT_NAME = utils.ReprObject("<default-name>") | ||
|
|
@@ -1396,6 +1397,209 @@ def _get_valid_values(da, other): | |
| return corr | ||
|
|
||
|
|
||
| def cross( | ||
| a: "T_DSorDAorVar", | ||
| b: "T_DSorDAorVar", | ||
| dim: str, | ||
| ) -> "T_DSorDAorVar": | ||
| """ | ||
| Return the cross product of two (arrays of) vectors. | ||
|
|
||
| The cross product of `a` and `b` in :math:`R^3` is a vector | ||
| perpendicular to both `a` and `b`. If `a` and `b` are arrays of | ||
| vectors, and these axes can have dimensions 2 or 3. Where the | ||
Illviljan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dimension of either `a` or `b` is 2, the third component of the | ||
| input vector is assumed to be zero and the cross product calculated | ||
| accordingly. In cases where both input vectors have dimension 2, | ||
| the z-component of the cross product is returned. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| a, b : DataArray, Dataset or Variable | ||
| Components of the first and second vector(s). | ||
| dim : hashable | ||
| The dimension along which the cross product will be computed. | ||
| Must be available in both vectors. | ||
|
|
||
| Examples | ||
| -------- | ||
| Vector cross-product with 3 dimensions. | ||
|
|
||
| >>> a = xr.DataArray([1, 2, 3]) | ||
| >>> b = xr.DataArray([4, 5, 6]) | ||
| >>> xr.cross(a, b, "dim_0") | ||
Illviljan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| <xarray.DataArray (dim_0: 3)> | ||
| array([-3, 6, -3]) | ||
| Dimensions without coordinates: dim_0 | ||
|
|
||
| Vector cross-product with 2 dimensions, returns in the perpendicular | ||
| direction: | ||
|
|
||
| >>> a = xr.DataArray([1, 2]) | ||
| >>> b = xr.DataArray([4, 5]) | ||
| >>> xr.cross(a, b, "dim_0") | ||
| <xarray.DataArray ()> | ||
| array(-3) | ||
|
|
||
| Vector cross-product with 3 dimensions but zeros at the last axis | ||
| yields the same results as with 2 dimensions: | ||
|
|
||
| >>> a = xr.DataArray([1, 2, 0]) | ||
| >>> b = xr.DataArray([4, 5, 0]) | ||
| >>> xr.cross(a, b, "dim_0") | ||
| <xarray.DataArray (dim_0: 3)> | ||
| array([ 0, 0, -3]) | ||
| Dimensions without coordinates: dim_0 | ||
|
|
||
| One vector with dimension 2. | ||
|
|
||
| >>> a = xr.DataArray( | ||
| ... [1, 2], | ||
| ... dims=["cartesian"], | ||
| ... coords=dict(cartesian=(["cartesian"], ["x", "y"])), | ||
| ... ) | ||
| >>> b = xr.DataArray( | ||
| ... [4, 5, 6], | ||
| ... dims=["cartesian"], | ||
| ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), | ||
| ... ) | ||
| >>> xr.cross(a, b, "cartesian") | ||
| <xarray.DataArray (cartesian: 3)> | ||
| array([12, -6, -3]) | ||
| Coordinates: | ||
| * cartesian (cartesian) object 'x' 'y' 'z' | ||
|
|
||
| One vector with dimension 2 but coords in other positions. | ||
|
|
||
| >>> a = xr.DataArray( | ||
| ... [1, 2], | ||
| ... dims=["cartesian"], | ||
| ... coords=dict(cartesian=(["cartesian"], ["x", "z"])), | ||
| ... ) | ||
| >>> b = xr.DataArray( | ||
| ... [4, 5, 6], | ||
| ... dims=["cartesian"], | ||
| ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), | ||
| ... ) | ||
| >>> xr.cross(a, b, "cartesian") | ||
| <xarray.DataArray (cartesian: 3)> | ||
| array([-10, 2, 5]) | ||
| Coordinates: | ||
| * cartesian (cartesian) object 'x' 'y' 'z' | ||
|
|
||
| Multiple vector cross-products. Note that the direction of the | ||
| cross product vector is defined by the right-hand rule. | ||
|
|
||
| >>> a = xr.DataArray( | ||
| ... [[1, 2, 3], [4, 5, 6]], | ||
| ... dims=("time", "cartesian"), | ||
| ... coords=dict( | ||
| ... time=(["time"], [0, 1]), | ||
| ... cartesian=(["cartesian"], ["x", "y", "z"]), | ||
| ... ), | ||
| ... ) | ||
| >>> b = xr.DataArray( | ||
| ... [[4, 5, 6], [1, 2, 3]], | ||
| ... dims=("time", "cartesian"), | ||
| ... coords=dict( | ||
| ... time=(["time"], [0, 1]), | ||
| ... cartesian=(["cartesian"], ["x", "y", "z"]), | ||
| ... ), | ||
| ... ) | ||
| >>> xr.cross(a, b, "cartesian") | ||
| <xarray.DataArray (time: 2, cartesian: 3)> | ||
| array([[-3, 6, -3], | ||
| [ 3, -6, 3]]) | ||
| Coordinates: | ||
| * time (time) int64 0 1 | ||
| * cartesian (cartesian) <U1 'x' 'y' 'z' | ||
|
|
||
| See Also | ||
| -------- | ||
| numpy.cross : Corresponding numpy function | ||
| """ | ||
| from .dataarray import DataArray | ||
| from .dataset import Dataset | ||
|
|
||
| all_dims: List[Hashable] = [] | ||
| arrays: List["T_DSorDAorVar"] = [a, b] | ||
| for i, arr in enumerate(arrays): | ||
| if isinstance(arr, Dataset): | ||
| # Turn the dataset to a stacked dataarray to follow the | ||
| # normal code path. Then at the end turn it back to a | ||
| # dataset. | ||
| is_dataset = True | ||
|
||
| arrays[i] = arr = arr.to_stacked_array( | ||
| variable_dim=dim, new_dim="variable", sample_dims=arr.dims | ||
Illviljan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ).unstack("variable") | ||
| if is_duck_dask_array(arr.data): | ||
| arrays[i] = arr = arr.chunk({dim: -1}) | ||
| elif isinstance(arr, (DataArray, Variable)): | ||
| is_dataset = False | ||
| else: | ||
| raise TypeError( | ||
| "Only xr.DataArray, xr.Dataset and xr.Variable are supported, " | ||
| f"got {type(arr)}." | ||
| ) | ||
|
|
||
| # TODO: Find spatial dim default by looking for unique | ||
| # (3 or 2)-valued dim? | ||
| if dim not in arr.dims: | ||
| raise ValueError(f"Dimension {dim} not in {arr}.") | ||
|
|
||
| if not 1 <= arr.sizes[dim] <= 3: | ||
| raise ValueError( | ||
| "Incompatible dimensions for cross product,\n" | ||
| "dimension with coords must be 1, 2 or 3." | ||
| ) | ||
|
|
||
| all_dims += [d for d in arr.dims if d not in all_dims] | ||
|
|
||
| if arrays[0].sizes[dim] != arrays[1].sizes[dim]: | ||
| # Arrays have different sizes. Append zeros where the smaller | ||
| # array is missing a value, zeros will not affect np.cross: | ||
| i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 | ||
| array_small, array_large = arrays[i], arrays[1 - i] | ||
|
|
||
| if getattr(array_large, "coords", False) and getattr( | ||
| array_small, "coords", False | ||
| ): | ||
| # if all([getattr(arr, "coords", False) for arr in arrays]): | ||
| # If the arrays have coords we know which indexes to fill | ||
| # with zeros: | ||
| arrays[i] = array_small.reindex_like(array_large, fill_value=0) | ||
| elif array_small.sizes[dim] == 2: | ||
| # If the array doesn't have coords we can only infer | ||
| # that it is composite values if the size is 2: | ||
| arrays[i] = array_small.pad({dim: (0, 1)}, constant_values=0) | ||
| if is_duck_dask_array(arrays[i].data): | ||
| arrays[i] = arrays[i].chunk({dim: -1}) | ||
| else: | ||
| # Size is 1, then we do not know if the array is a constant or | ||
| # composite value: | ||
| raise ValueError( | ||
| "Incompatible dimensions for cross product,\n" | ||
| "dimension without coords must be 2 or 3." | ||
| ) | ||
|
|
||
| c = apply_ufunc( | ||
| np.cross, | ||
| *arrays, | ||
| input_core_dims=[[dim], [dim]], | ||
| output_core_dims=[[dim] if arrays[0].sizes[dim] == 3 else []], | ||
| dask="parallelized", | ||
| output_dtypes=[np.result_type(*arrays)], | ||
| ) | ||
| c = c.transpose(*[d for d in all_dims if d in c.dims]) | ||
Illviljan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if is_dataset: | ||
| c = c.stack(variable=[dim]).to_unstacked_dataset("variable") | ||
| c = c.expand_dims( | ||
| list({d: s for ds in arrays for d, s in ds.sizes.items() if s == 1}) | ||
| ) | ||
|
|
||
| return c | ||
|
|
||
|
|
||
| def dot(*arrays, dims=None, **kwargs): | ||
| """Generalized dot product for xarray objects. Like np.einsum, but | ||
| provides a simpler interface based on array dimensions. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.