diff --git a/xarray/core/concat.py b/xarray/core/concat.py index e26c1464f2d..3145b9de71a 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import ( TYPE_CHECKING, Dict, @@ -12,6 +14,7 @@ ) import pandas as pd +from typing_extensions import Literal from . import dtypes, utils from .alignment import align @@ -24,14 +27,19 @@ from .dataarray import DataArray from .dataset import Dataset +compat_options = Literal[ + "identical", "equals", "broadcast_equals", "no_conflicts", "override" +] +concat_options = Literal["all", "minimal", "different"] + @overload def concat( objs: Iterable["Dataset"], - dim: Union[str, "DataArray", pd.Index], - data_vars: Union[str, List[str]] = "all", - coords: Union[str, List[str]] = "different", - compat: str = "equals", + dim: Hashable | "DataArray" | pd.Index, + data_vars: concat_options | List[Hashable] = "all", + coords: concat_options | List[Hashable] = "different", + compat: compat_options = "equals", positions: Optional[Iterable[int]] = None, fill_value: object = dtypes.NA, join: str = "outer", @@ -43,10 +51,10 @@ def concat( @overload def concat( objs: Iterable["DataArray"], - dim: Union[str, "DataArray", pd.Index], - data_vars: Union[str, List[str]] = "all", - coords: Union[str, List[str]] = "different", - compat: str = "equals", + dim: Hashable | "DataArray" | pd.Index, + data_vars: concat_options | List[Hashable] = "all", + coords: concat_options | List[Hashable] = "different", + compat: compat_options = "equals", positions: Optional[Iterable[int]] = None, fill_value: object = dtypes.NA, join: str = "outer", @@ -74,14 +82,14 @@ def concat( xarray objects to concatenate together. Each object is expected to consist of variables and coordinates with matching shapes except for along the concatenated dimension. - dim : str or DataArray or pandas.Index + dim : Hashable or DataArray or pandas.Index Name of the dimension to concatenate along. This can either be a new dimension name, in which case it is added along axis=0, or an existing dimension name, in which case the location of the dimension is unchanged. If dimension is provided as a DataArray or Index, its name is used as the dimension to concatenate along and the values are added as a coordinate. - data_vars : {"minimal", "different", "all"} or list of str, optional + data_vars : {"minimal", "different", "all"} or list of Hashable, optional These data variables will be concatenated together: * "minimal": Only data variables in which the dimension already appears are included. @@ -91,11 +99,11 @@ def concat( load the data payload of data variables into memory if they are not already loaded. * "all": All data variables will be concatenated. - * list of str: The listed data variables will be concatenated, in + * list of dims: The listed data variables will be concatenated, in addition to the "minimal" data variables. If objects are DataArrays, data_vars must be "all". - coords : {"minimal", "different", "all"} or list of str, optional + coords : {"minimal", "different", "all"} or list of Hashable, optional These coordinate variables will be concatenated together: * "minimal": Only coordinates in which the dimension already appears are included. @@ -106,7 +114,7 @@ def concat( loaded. * "all": All coordinate variables will be concatenated, except those corresponding to other dimensions. - * list of str: The listed coordinate variables will be concatenated, + * list of Hashable: The listed coordinate variables will be concatenated, in addition to the "minimal" coordinates. compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional String indicating how to compare non-concatenated variables of the same name for diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index e049f843bed..a8d06188844 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,4 +1,5 @@ from copy import deepcopy +from typing import List import numpy as np import pandas as pd @@ -6,6 +7,7 @@ from xarray import DataArray, Dataset, Variable, concat from xarray.core import dtypes, merge +from xarray.core.concat import compat_options, concat_options from . import ( InaccessibleArray, @@ -17,7 +19,7 @@ from .test_dataset import create_test_data -def test_concat_compat(): +def test_concat_compat() -> None: ds1 = Dataset( { "has_x_y": (("y", "x"), [[1, 2]]), @@ -50,10 +52,10 @@ def test_concat_compat(): class TestConcatDataset: @pytest.fixture - def data(self): + def data(self) -> Dataset: return create_test_data().drop_dims("dim3") - def rectify_dim_order(self, data, dataset): + def rectify_dim_order(self, data, dataset) -> Dataset: # return a new dataset with all variable dimensions transposed into # the order in which they are found in `data` return Dataset( @@ -64,11 +66,11 @@ def rectify_dim_order(self, data, dataset): @pytest.mark.parametrize("coords", ["different", "minimal"]) @pytest.mark.parametrize("dim", ["dim1", "dim2"]) - def test_concat_simple(self, data, dim, coords): + def test_concat_simple(self, data, dim, coords) -> None: datasets = [g for _, g in data.groupby(dim, squeeze=False)] assert_identical(data, concat(datasets, dim, coords=coords)) - def test_concat_merge_variables_present_in_some_datasets(self, data): + def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: # coordinates present in some datasets but not others ds1 = Dataset(data_vars={"a": ("y", [0.1])}, coords={"x": 0.1}) ds2 = Dataset(data_vars={"a": ("y", [0.2])}, coords={"z": 0.2}) @@ -84,7 +86,7 @@ def test_concat_merge_variables_present_in_some_datasets(self, data): expected = data.copy().assign(foo=data1.foo) assert_identical(expected, actual) - def test_concat_2(self, data): + def test_concat_2(self, data) -> None: dim = "dim2" datasets = [g for _, g in data.groupby(dim, squeeze=True)] concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim] @@ -93,7 +95,7 @@ def test_concat_2(self, data): @pytest.mark.parametrize("coords", ["different", "minimal", "all"]) @pytest.mark.parametrize("dim", ["dim1", "dim2"]) - def test_concat_coords_kwarg(self, data, dim, coords): + def test_concat_coords_kwarg(self, data, dim, coords) -> None: data = data.copy(deep=True) # make sure the coords argument behaves as expected data.coords["extra"] = ("dim4", np.arange(3)) @@ -107,7 +109,7 @@ def test_concat_coords_kwarg(self, data, dim, coords): else: assert_equal(data["extra"], actual["extra"]) - def test_concat(self, data): + def test_concat(self, data) -> None: split_data = [ data.isel(dim1=slice(3)), data.isel(dim1=3), @@ -115,7 +117,7 @@ def test_concat(self, data): ] assert_identical(data, concat(split_data, "dim1")) - def test_concat_dim_precedence(self, data): + def test_concat_dim_precedence(self, data) -> None: # verify that the dim argument takes precedence over # concatenating dataset variables of the same name dim = (2 * data["dim1"]).rename("dim1") @@ -124,14 +126,23 @@ def test_concat_dim_precedence(self, data): expected["dim1"] = dim assert_identical(expected, concat(datasets, dim)) + def test_concat_data_vars_typing(self) -> None: + # Testing typing, can be removed if the next function works with annotations. + data = Dataset({"foo": ("x", np.random.randn(10))}) + objs: List[Dataset] = [data.isel(x=slice(5)), data.isel(x=slice(5, None))] + actual = concat(objs, dim="x", data_vars="minimal") + assert_identical(data, actual) + def test_concat_data_vars(self): + # TODO: annotating this func fails data = Dataset({"foo": ("x", np.random.randn(10))}) - objs = [data.isel(x=slice(5)), data.isel(x=slice(5, None))] + objs: List[Dataset] = [data.isel(x=slice(5)), data.isel(x=slice(5, None))] for data_vars in ["minimal", "different", "all", [], ["foo"]]: actual = concat(objs, dim="x", data_vars=data_vars) assert_identical(data, actual) def test_concat_coords(self): + # TODO: annotating this func fails data = Dataset({"foo": ("x", np.random.randn(10))}) expected = data.assign_coords(c=("x", [0] * 5 + [1] * 5)) objs = [ @@ -146,6 +157,7 @@ def test_concat_coords(self): concat(objs, dim="x", coords=coords) def test_concat_constant_index(self): + # TODO: annotating this func fails # GH425 ds1 = Dataset({"foo": 1.5}, {"y": 1}) ds2 = Dataset({"foo": 2.5}, {"y": 1}) @@ -158,7 +170,7 @@ def test_concat_constant_index(self): # "foo" has dimension "y" so minimal should concatenate it? concat([ds1, ds2], "new_dim", data_vars="minimal") - def test_concat_size0(self): + def test_concat_size0(self) -> None: data = create_test_data() split_data = [data.isel(dim1=slice(0, 0)), data] actual = concat(split_data, "dim1") @@ -167,7 +179,7 @@ def test_concat_size0(self): actual = concat(split_data[::-1], "dim1") assert_identical(data, actual) - def test_concat_autoalign(self): + def test_concat_autoalign(self) -> None: ds1 = Dataset({"foo": DataArray([1, 2], coords=[("x", [1, 2])])}) ds2 = Dataset({"foo": DataArray([1, 2], coords=[("x", [1, 3])])}) actual = concat([ds1, ds2], "y") @@ -183,6 +195,7 @@ def test_concat_autoalign(self): assert_identical(expected, actual) def test_concat_errors(self): + # TODO: annotating this func fails data = create_test_data() split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] @@ -222,7 +235,7 @@ def test_concat_errors(self): ): concat([Dataset({"x": 0}), Dataset({}, {"x": 1})], dim="z") - def test_concat_join_kwarg(self): + def test_concat_join_kwarg(self) -> None: ds1 = Dataset({"a": (("x", "y"), [[0]])}, coords={"x": [0], "y": [0]}) ds2 = Dataset({"a": (("x", "y"), [[0]])}, coords={"x": [1], "y": [0.0001]}) @@ -258,10 +271,10 @@ def test_concat_join_kwarg(self): actual = concat( [ds1.drop_vars("x"), ds2.drop_vars("x")], join="override", dim="y" ) - expected = Dataset( + expected2 = Dataset( {"a": (("x", "y"), np.array([0, 0], ndmin=2))}, coords={"y": [0, 0.0001]} ) - assert_identical(actual, expected) + assert_identical(actual, expected2) @pytest.mark.parametrize( "combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception", @@ -389,7 +402,7 @@ def test_concat_combine_attrs_kwarg_variables( assert_identical(actual, expected) - def test_concat_promote_shape(self): + def test_concat_promote_shape(self) -> None: # mixed dims within variables objs = [Dataset({}, {"x": 0}), Dataset({"x": [1]})] actual = concat(objs, "x") @@ -427,7 +440,7 @@ def test_concat_promote_shape(self): expected = Dataset({"z": (("x", "y"), [[-1], [1]])}, {"x": [0, 1], "y": [0]}) assert_identical(actual, expected) - def test_concat_do_not_promote(self): + def test_concat_do_not_promote(self) -> None: # GH438 objs = [ Dataset({"y": ("t", [1])}, {"x": 1, "t": [0]}), @@ -444,14 +457,14 @@ def test_concat_do_not_promote(self): with pytest.raises(ValueError): concat(objs, "t", coords="minimal") - def test_concat_dim_is_variable(self): + def test_concat_dim_is_variable(self) -> None: objs = [Dataset({"x": 0}), Dataset({"x": 1})] coord = Variable("y", [3, 4]) expected = Dataset({"x": ("y", [0, 1]), "y": [3, 4]}) actual = concat(objs, coord) assert_identical(actual, expected) - def test_concat_multiindex(self): + def test_concat_multiindex(self) -> None: x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) expected = Dataset({"x": x}) actual = concat( @@ -461,7 +474,7 @@ def test_concat_multiindex(self): assert isinstance(actual.x.to_index(), pd.MultiIndex) @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}]) - def test_concat_fill_value(self, fill_value): + def test_concat_fill_value(self, fill_value) -> None: datasets = [ Dataset({"a": ("x", [2, 3]), "b": ("x", [-2, 1]), "x": [1, 2]}), Dataset({"a": ("x", [1, 2]), "b": ("x", [3, -1]), "x": [0, 1]}), @@ -487,7 +500,7 @@ def test_concat_fill_value(self, fill_value): @pytest.mark.parametrize("dtype", [str, bytes]) @pytest.mark.parametrize("dim", ["x1", "x2"]) - def test_concat_str_dtype(self, dtype, dim): + def test_concat_str_dtype(self, dtype, dim) -> None: data = np.arange(4).reshape([2, 2]) @@ -511,7 +524,7 @@ def test_concat_str_dtype(self, dtype, dim): class TestConcatDataArray: - def test_concat(self): + def test_concat(self) -> None: ds = Dataset( { "foo": (["x", "y"], np.random.random((2, 3))), @@ -538,13 +551,13 @@ def test_concat(self): stacked = concat(grouped, pd.Index(ds["x"], name="x")) assert_identical(foo, stacked) - actual = concat([foo[0], foo[1]], pd.Index([0, 1])).reset_coords(drop=True) + actual2 = concat([foo[0], foo[1]], pd.Index([0, 1])).reset_coords(drop=True) expected = foo[:2].rename({"x": "concat_dim"}) - assert_identical(expected, actual) + assert_identical(expected, actual2) - actual = concat([foo[0], foo[1]], [0, 1]).reset_coords(drop=True) + actual3 = concat([foo[0], foo[1]], [0, 1]).reset_coords(drop=True) expected = foo[:2].rename({"x": "concat_dim"}) - assert_identical(expected, actual) + assert_identical(expected, actual3) with pytest.raises(ValueError, match=r"not identical"): concat([foo, bar], dim="w", compat="identical") @@ -552,7 +565,7 @@ def test_concat(self): with pytest.raises(ValueError, match=r"not a valid argument"): concat([foo, bar], dim="w", data_vars="minimal") - def test_concat_encoding(self): + def test_concat_encoding(self) -> None: # Regression test for GH1297 ds = Dataset( { @@ -568,7 +581,7 @@ def test_concat_encoding(self): assert concat([ds, ds], dim="x").encoding == ds.encoding @requires_dask - def test_concat_lazy(self): + def test_concat_lazy(self) -> None: import dask.array as da arrays = [ @@ -583,7 +596,7 @@ def test_concat_lazy(self): assert combined.dims == ("z", "x", "y") @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) - def test_concat_fill_value(self, fill_value): + def test_concat_fill_value(self, fill_value) -> None: foo = DataArray([1, 2], coords=[("x", [1, 2])]) bar = DataArray([1, 2], coords=[("x", [1, 3])]) if fill_value == dtypes.NA: @@ -598,7 +611,7 @@ def test_concat_fill_value(self, fill_value): actual = concat((foo, bar), dim="y", fill_value=fill_value) assert_identical(actual, expected) - def test_concat_join_kwarg(self): + def test_concat_join_kwarg(self) -> None: ds1 = Dataset( {"a": (("x", "y"), [[0]])}, coords={"x": [0], "y": [0]} ).to_array() @@ -634,7 +647,7 @@ def test_concat_join_kwarg(self): actual = concat([ds1, ds2], join=join, dim="x") assert_equal(actual, expected[join].to_array()) - def test_concat_combine_attrs_kwarg(self): + def test_concat_combine_attrs_kwarg(self) -> None: da1 = DataArray([0], coords=[("x", [0])], attrs={"b": 42}) da2 = DataArray([0], coords=[("x", [1])], attrs={"b": 42, "c": 43}) @@ -660,7 +673,7 @@ def test_concat_combine_attrs_kwarg(self): @pytest.mark.parametrize("dtype", [str, bytes]) @pytest.mark.parametrize("dim", ["x1", "x2"]) - def test_concat_str_dtype(self, dtype, dim): + def test_concat_str_dtype(self, dtype, dim) -> None: data = np.arange(4).reshape([2, 2]) @@ -678,7 +691,7 @@ def test_concat_str_dtype(self, dtype, dim): assert np.issubdtype(actual.x2.dtype, dtype) - def test_concat_coord_name(self): + def test_concat_coord_name(self) -> None: da = DataArray([0], dims="a") da_concat = concat([da, da], dim=DataArray([0, 1], dims="b")) @@ -690,7 +703,7 @@ def test_concat_coord_name(self): @pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {})) @pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {})) -def test_concat_attrs_first_variable(attr1, attr2): +def test_concat_attrs_first_variable(attr1, attr2) -> None: arrs = [ DataArray([[1], [2]], dims=["x", "y"], attrs=attr1), @@ -702,6 +715,7 @@ def test_concat_attrs_first_variable(attr1, attr2): def test_concat_merge_single_non_dim_coord(): + # TODO: annotating this func fails da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) da2 = DataArray([4, 5, 6], dims="x", coords={"x": [4, 5, 6]}) @@ -722,7 +736,7 @@ def test_concat_merge_single_non_dim_coord(): concat([da1, da2, da3], dim="x") -def test_concat_preserve_coordinate_order(): +def test_concat_preserve_coordinate_order() -> None: x = np.arange(0, 5) y = np.arange(0, 10) time = np.arange(0, 4) @@ -755,7 +769,7 @@ def test_concat_preserve_coordinate_order(): assert_identical(actual.coords[act], expected.coords[exp]) -def test_concat_typing_check(): +def test_concat_typing_check() -> None: ds = Dataset({"foo": 1}, {"bar": 2}) da = Dataset({"foo": 3}, {"bar": 4}).to_array(dim="foo")