diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d8e75234bd7..2c17dea0a03 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -624,6 +624,10 @@ Enhancements Bug fixes ~~~~~~~~~ +- Fixed unnecessary conversion of dimensions' dtype to numpy.object + during concat operation. By + `Maciek Swat `_. + - Attributes were being retained by default for some resampling operations when they should not. With the ``keep_attrs=False`` option, they will no longer be retained by default. This may be backwards-incompatible diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 5770ec2cd50..f3df0929605 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -349,6 +349,14 @@ def any_not_full_slices(indexers): def var_indexers(var, indexers): return tuple(indexers.get(d, slice(None)) for d in var.dims) + def is_composite_dtype(dtype): + """ + Detects composite dtype dtype e.g. dtype of the structured numpy array + :param dtype: numpy.dtype + :return: bool + """ + return dtype.fields is not None + # create variables for the new dataset reindexed = OrderedDict() @@ -358,7 +366,15 @@ def var_indexers(var, indexers): args = (var.attrs, var.encoding) else: args = () - reindexed[dim] = IndexVariable((dim,), indexers[dim], *args) + + idx_var = IndexVariable((dim,), indexers[dim], *args) + + # GH1434 + # ensures that dtype of numpy structured arrays is preserved + if len(args) and is_composite_dtype(var.dtype) \ + and idx_var.dtype != var.dtype: + idx_var.data = idx_var.data.astype(var.dtype) + reindexed[dim] = idx_var for name, var in iteritems(variables): if name not in indexers: @@ -402,7 +418,13 @@ def var_indexers(var, indexers): # we neither created a new ndarray nor used fancy indexing new_var = var.copy(deep=copy) + # GH1434 + # ensures that dtype of numpy structured arrays is preserved + if is_composite_dtype(var.dtype) \ + and new_var.dtype != indexes._variables[name].dtype: + new_var = new_var.astype(indexes._variables[name].dtype) reindexed[name] = new_var + return reindexed diff --git a/xarray/core/combine.py b/xarray/core/combine.py index d139151064b..6477867a952 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -207,8 +207,8 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): dim, coord = _calc_concat_dim_coord(dim) datasets = [as_dataset(ds) for ds in datasets] - datasets = align(*datasets, join='outer', copy=False, exclude=[dim]) + datasets = align(*datasets, join='outer', copy=False, exclude=[dim]) concat_over = _calc_concat_over(datasets, dim, data_vars, coords) def insert_result_variable(k, v): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 34b86275374..b5b3833b8de 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1211,11 +1211,41 @@ def __getitem__(self, key): def __setitem__(self, key, value): raise TypeError('%s values cannot be modified' % type(self).__name__) + + @classmethod + def _concat_pandas(cls, variables, positions=None): + """ + Concatenates variables. This is generic function that + handles cases where numpy.concatenate will not work + + Parameters + ---------- + :param variables: list of variables to concatenate + + Returns + ------- + Concatenated variables + """ + indexes = [v._data.array for v in variables] + + if not indexes: + data = [] + else: + + data = indexes[0].append(indexes[1:]) + + if positions is not None: + indices = nputils.inverse_permutation( + np.concatenate(positions)) + data = data.take(indices) + + return data + @classmethod def concat(cls, variables, dim='concat_dim', positions=None, shortcut=False): - """Specialized version of Variable.concat for IndexVariable objects. - + """Specialized version of Variable.concat for + IndexVariable objects. This exists because we want to avoid converting Index objects to NumPy arrays, if possible. """ @@ -1223,23 +1253,20 @@ def concat(cls, variables, dim='concat_dim', positions=None, dim, = dim.dims variables = list(variables) + first_var = variables[0] if any(not isinstance(v, cls) for v in variables): raise TypeError('IndexVariable.concat requires that all input ' 'variables be IndexVariable objects') - indexes = [v._data.array for v in variables] - - if not indexes: - data = [] + # GH1434 + # Fixes bug: "xr.concat loses coordinate dtype + # information with recarrays in 0.9" + if any(var.dtype == np.object for var in variables): + data = cls._concat_pandas(variables, positions) else: - data = indexes[0].append(indexes[1:]) - - if positions is not None: - indices = nputils.inverse_permutation( - np.concatenate(positions)) - data = data.take(indices) + data = np.concatenate([v.data for v in variables]) attrs = OrderedDict(first_var.attrs) if not shortcut: diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 89b0badee4d..5f3bca0c836 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -101,7 +101,7 @@ try: _SKIP_FLAKY = not pytest.config.getoption("--run-flaky") _SKIP_NETWORK_TESTS = not pytest.config.getoption("--run-network-tests") -except ValueError: +except (ValueError, AttributeError): # Can't get config from pytest, e.g., because xarray is installed instead # of being run from a development version (and hence conftests.py is not # available). Don't run flaky tests. diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 7813378277a..ed79a796eed 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -75,6 +75,50 @@ def rectify_dim_order(dataset): expected['dim1'] = dim self.assertDatasetIdentical(expected, concat(datasets, dim)) + def test_concat_dtype_preservation(self): + """ + This test checks whether concatennation of two DataArrays + along the axis whose dimension is numpy structured array + preserves dtype of the numpy structured array + """ + + p1 = np.array([('A', 180), ('B', 150), ('C', 200)], + dtype=[('name', '|S256'), ('height', int)]) + p2 = np.array([('D', 170), ('E', 250), ('F', 150)], + dtype=[('name', '|S256'), ('height', int)]) + + data = np.arange(50, 80, 1, dtype=np.float) + + dims = ['measurement', 'participant'] + + da1 = DataArray( + data.reshape(10, 3), + coords={ + 'measurement': np.arange(10), + 'participant': p1, + }, + dims=dims + ) + + da2 = DataArray( + data.reshape(10, 3), + coords={ + 'measurement': np.arange(10), + 'participant': p2, + }, + dims=dims + ) + + combined_1 = concat([da1, da2], dim='participant') + + assert combined_1.participant.dtype == da1.participant.dtype + assert combined_1.measurement.dtype == da1.measurement.dtype + + combined_2 = concat([da1, da2], dim='measurement') + + assert combined_2.participant.dtype == da1.participant.dtype + assert combined_2.measurement.dtype == da1.measurement.dtype + def test_concat_data_vars(self): data = Dataset({'foo': ('x', np.random.randn(10))}) objs = [data.isel(x=slice(5)), data.isel(x=slice(5, None))]