Skip to content

Commit 03b6ba1

Browse files
authored
fix concat with variable or dataarray as dim (#6387)
Propagate attrs.
1 parent 073512e commit 03b6ba1

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

xarray/core/concat.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def _dataset_concat(
429429
"""
430430
Concatenate a sequence of datasets along a new or existing dimension
431431
"""
432+
from .dataarray import DataArray
432433
from .dataset import Dataset
433434

434435
datasets = list(datasets)
@@ -438,6 +439,13 @@ def _dataset_concat(
438439
"The elements in the input list need to be either all 'Dataset's or all 'DataArray's"
439440
)
440441

442+
if isinstance(dim, DataArray):
443+
dim_var = dim.variable
444+
elif isinstance(dim, Variable):
445+
dim_var = dim
446+
else:
447+
dim_var = None
448+
441449
dim, index = _calc_concat_dim_index(dim)
442450

443451
# Make sure we're working on a copy (we'll be loading variables)
@@ -582,7 +590,11 @@ def get_indexes(name):
582590

583591
if index is not None:
584592
# add concat index / coordinate last to ensure that its in the final Dataset
585-
result[dim] = index.create_variables()[dim]
593+
if dim_var is not None:
594+
index_vars = index.create_variables({dim: dim_var})
595+
else:
596+
index_vars = index.create_variables()
597+
result[dim] = index_vars[dim]
586598
result_indexes[dim] = index
587599

588600
# TODO: add indexes at Dataset creation (when it is supported)

xarray/tests/test_concat.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,15 @@ def test_concat_do_not_promote(self) -> None:
459459

460460
def test_concat_dim_is_variable(self) -> None:
461461
objs = [Dataset({"x": 0}), Dataset({"x": 1})]
462-
coord = Variable("y", [3, 4])
463-
expected = Dataset({"x": ("y", [0, 1]), "y": [3, 4]})
462+
coord = Variable("y", [3, 4], attrs={"foo": "bar"})
463+
expected = Dataset({"x": ("y", [0, 1]), "y": coord})
464+
actual = concat(objs, coord)
465+
assert_identical(actual, expected)
466+
467+
def test_concat_dim_is_dataarray(self) -> None:
468+
objs = [Dataset({"x": 0}), Dataset({"x": 1})]
469+
coord = DataArray([3, 4], dims="y", attrs={"foo": "bar"})
470+
expected = Dataset({"x": ("y", [0, 1]), "y": coord})
464471
actual = concat(objs, coord)
465472
assert_identical(actual, expected)
466473

0 commit comments

Comments
 (0)