diff --git a/.gitattributes b/.gitattributes index a52f4ca283a..daa5b82874e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ # reduce the number of merge conflicts doc/whats-new.rst merge=union +xarray/_version.py export-subst diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md deleted file mode 100644 index ce8c4d00c3f..00000000000 --- a/.github/CONTRIBUTING.md +++ /dev/null @@ -1,42 +0,0 @@ -# Contributing to xarray - -## Usage questions - -The best places to submit questions about how to use xarray are -[Stack Overflow](https://stackoverflow.com/questions/tagged/python-xarray) and -the [xarray Google group](https://groups.google.com/forum/#!forum/xarray). - -## Reporting issues - -When reporting issues please include as much detail as possible about your -operating system, xarray version and python version. Whenever possible, please -also include a brief, self-contained code example that demonstrates the problem. - -## Contributing code - -Thanks for your interest in contributing code to xarray! - -- If you are new to Git or Github, please take a minute to read through a few tutorials - on [Git](https://git-scm.com/docs/gittutorial) and [GitHub](https://guides.github.com/). -- The basic workflow for contributing to xarray is: - 1. [Fork](https://help.github.com/articles/fork-a-repo/) the xarray repository - 2. [Clone](https://help.github.com/articles/cloning-a-repository/) the xarray repository to create a local copy on your computer: - ``` - git clone git@github.com:${user}/xarray.git - cd xarray - ``` - 3. Create a branch for your changes - ``` - git checkout -b name-of-your-branch - ``` - 4. Make change to your local copy of the xarray repository - 5. Commit those changes - ``` - git add file1 file2 file3 - git commit -m 'a descriptive commit message' - ``` - 6. Push your updated branch to your fork - ``` - git push origin name-of-your-branch - ``` - 7. [Open a pull request](https://help.github.com/articles/creating-a-pull-request/) to the pydata/xarray repository. diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index d10e857c4ed..c7236b8159a 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,5 +1,8 @@ #### Code Sample, a copy-pastable example if possible +A "Minimal, Complete and Verifiable Example" will make it much easier for maintainers to help you: +http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports + ```python # Your code here diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5e9aa06f507..d1c79953a9b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,4 +1,3 @@ - [ ] Closes #xxxx (remove if there is no corresponding issue, which should only be the case for minor changes) - [ ] Tests added (for all bug fixes or enhancements) - - [ ] Tests passed (for all non-documentation changes) - [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API (remove if this change should not be visible to users, e.g., if it is an internal clean-up, or if this is part of a larger project that will be documented later) diff --git a/.gitignore b/.gitignore index 490eb49f9d4..2a016bb9228 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ *.py[cod] __pycache__ +# example caches from Hypothesis +.hypothesis/ + # temp files from docs build doc/auto_gallery doc/example.nc @@ -33,6 +36,11 @@ pip-log.txt nosetests.xml .cache .ropeproject/ +.tags* +.testmon* +.tmontmp/ +.pytest_cache +dask-worker-space/ # asv environments .asv @@ -45,10 +53,11 @@ nosetests.xml .project .pydevproject -# PyCharm and Vim +# IDEs .idea *.swp .DS_Store +.vscode/ # xarray specific doc/_build diff --git a/.pep8speaks.yml b/.pep8speaks.yml new file mode 100644 index 00000000000..aedce6e44eb --- /dev/null +++ b/.pep8speaks.yml @@ -0,0 +1,12 @@ +# File : .pep8speaks.yml + +scanner: + diff_only: True # If True, errors caused by only the patch are shown + +pycodestyle: + max-line-length: 79 + ignore: # Errors and warnings to ignore + - E402, # module level import not at top of file + - E731, # do not assign a lambda expression, use a def + - W503 # line break before binary operator + - W504 # line break after binary operator diff --git a/.travis.yml b/.travis.yml index 068ea3cc788..defb37ec8aa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ # Based on http://conda.pydata.org/docs/travis.html -language: python +language: minimal sudo: false # use container based build notifications: email: false @@ -10,74 +10,48 @@ branches: matrix: fast_finish: true include: - - python: 2.7 - env: CONDA_ENV=py27-min - - python: 2.7 - env: CONDA_ENV=py27-cdat+iris+pynio - - python: 3.4 - env: CONDA_ENV=py34 - - python: 3.5 - env: CONDA_ENV=py35 - - python: 3.6 - env: CONDA_ENV=py36 - - python: 3.6 - env: + - env: CONDA_ENV=py27-min + - env: CONDA_ENV=py27-cdat+iris+pynio + - env: CONDA_ENV=py35 + - env: CONDA_ENV=py36 + - env: CONDA_ENV=py37 + - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - python: 3.6 - env: CONDA_ENV=py36-netcdf4-dev + - env: CONDA_ENV=py36-netcdf4-dev addons: apt_packages: - libhdf5-serial-dev - netcdf-bin - libnetcdf-dev - - python: 3.6 - env: CONDA_ENV=py36-dask-dev - - python: 3.6 - env: CONDA_ENV=py36-pandas-dev - - python: 3.6 - env: CONDA_ENV=py36-bottleneck-dev - - python: 3.6 - env: CONDA_ENV=py36-condaforge-rc - - python: 3.6 - env: CONDA_ENV=py36-pynio-dev - - python: 3.6 - env: CONDA_ENV=py36-rasterio1.0alpha - - python: 3.6 - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=py36-dask-dev + - env: CONDA_ENV=py36-pandas-dev + - env: CONDA_ENV=py36-bottleneck-dev + - env: CONDA_ENV=py36-condaforge-rc + - env: CONDA_ENV=py36-pynio-dev + - env: CONDA_ENV=py36-rasterio-0.36 + - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=docs + - env: CONDA_ENV=py36-hypothesis + allow_failures: - - python: 3.6 - env: + - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - python: 3.6 - env: CONDA_ENV=py36-netcdf4-dev + - env: CONDA_ENV=py36-netcdf4-dev addons: apt_packages: - libhdf5-serial-dev - netcdf-bin - libnetcdf-dev - - python: 3.6 - env: CONDA_ENV=py36-dask-dev - - python: 3.6 - env: CONDA_ENV=py36-pandas-dev - - python: 3.6 - env: CONDA_ENV=py36-bottleneck-dev - - python: 3.6 - env: CONDA_ENV=py36-condaforge-rc - - python: 3.6 - env: CONDA_ENV=py36-pynio-dev - - python: 3.6 - env: CONDA_ENV=py36-rasterio1.0alpha - - python: 3.6 - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=py36-pandas-dev + - env: CONDA_ENV=py36-bottleneck-dev + - env: CONDA_ENV=py36-condaforge-rc + - env: CONDA_ENV=py36-pynio-dev + - env: CONDA_ENV=py36-zarr-dev before_install: - - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then - wget http://repo.continuum.io/miniconda/Miniconda-3.16.0-Linux-x86_64.sh -O miniconda.sh; - else - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh; - fi + - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh; - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" - hash -r @@ -86,16 +60,28 @@ before_install: - conda info -a install: - - conda env create --file ci/requirements-$CONDA_ENV.yml + - if [[ "$CONDA_ENV" == "docs" ]]; then + conda env create -n test_env --file doc/environment.yml; + else + conda env create -n test_env --file ci/requirements-$CONDA_ENV.yml; + fi - source activate test_env - conda list - - python setup.py install + - pip install --no-deps -e . - python xarray/util/print_versions.py script: - - flake8 -j auto xarray + - which python + - python --version - python -OO -c "import xarray" - - py.test xarray --cov=xarray --cov-config ci/.coveragerc --cov-report term-missing --verbose $EXTRA_FLAGS + - if [[ "$CONDA_ENV" == "docs" ]]; then + conda install -c conda-forge sphinx sphinx_rtd_theme sphinx-gallery numpydoc; + sphinx-build -n -j auto -b html -d _build/doctrees doc _build/html; + elif [[ "$CONDA_ENV" == "py36-hypothesis" ]]; then + pytest properties ; + else + py.test xarray --cov=xarray --cov-config ci/.coveragerc --cov-report term-missing --verbose $EXTRA_FLAGS; + fi after_success: - coveralls diff --git a/HOW_TO_RELEASE b/HOW_TO_RELEASE index f1fee59e177..80f37e672a5 100644 --- a/HOW_TO_RELEASE +++ b/HOW_TO_RELEASE @@ -7,21 +7,21 @@ Time required: about an hour. 2. Look over whats-new.rst and the docs. Make sure "What's New" is complete (check the date!) and add a brief summary note describing the release at the top. - 3. Update the version in setup.py and switch to `ISRELEASED = True`. - 4. If you have any doubts, run the full test suite one final time! + 3. If you have any doubts, run the full test suite one final time! py.test - 5. On the master branch, commit the release in git: + 4. On the master branch, commit the release in git: git commit -a -m 'Release v0.X.Y' - 6. Tag the release: + 5. Tag the release: git tag -a v0.X.Y -m 'v0.X.Y' - 7. Build source and binary wheels for pypi: + 6. Build source and binary wheels for pypi: + git clean -xdf # this deletes all uncommited changes! python setup.py bdist_wheel sdist - 8. Use twine to register and upload the release on pypi. Be careful, you can't + 7. Use twine to register and upload the release on pypi. Be careful, you can't take this back! twine upload dist/xarray-0.X.Y* You will need to be listed as a package owner at https://pypi.python.org/pypi/xarray for this to work. - 9. Push your changes to master: + 8. Push your changes to master: git push upstream master git push upstream --tags 9. Update the stable branch (used by ReadTheDocs) and switch back to master: @@ -32,25 +32,18 @@ Time required: about an hour. It's OK to force push to 'stable' if necessary. We also update the stable branch with `git cherrypick` for documentation only fixes that apply the current released version. -10. Revert ISRELEASED in setup.py back to False. Don't change the version - number: in normal development, we keep the version number in setup.py as the - last released version. -11. Add a section for the next release (v.X.(Y+1)) to doc/whats-new.rst. -12. Commit your changes and push to master again: +10. Add a section for the next release (v.X.(Y+1)) to doc/whats-new.rst. +11. Commit your changes and push to master again: git commit -a -m 'Revert to dev version' git push upstream master You're done pushing to master! -13. Issue the release on GitHub. Click on "Draft a new release" at - https://github.com/pydata/xarray/releases and paste in the latest from - whats-new.rst. -14. Update the docs. Login to https://readthedocs.org/projects/xray/versions/ +12. Issue the release on GitHub. Click on "Draft a new release" at + https://github.com/pydata/xarray/releases. Type in the version number, but + don't bother to describe it -- we maintain that on the docs instead. +13. Update the docs. Login to https://readthedocs.org/projects/xray/versions/ and switch your new release tag (at the bottom) from "Inactive" to "Active". It should now build automatically. -15. Update conda-forge. Clone https://github.com/conda-forge/xarray-feedstock - and update the version number and sha256 in meta.yaml. (On OS X, you can - calculate sha256 with `shasum -a 256 xarray-0.X.Y.tar.gz`). Submit a pull - request (and merge it, once CI passes). -16. Issue the release announcement! For bug fix releases, I usually only email +14. Issue the release announcement! For bug fix releases, I usually only email xarray@googlegroups.com. For major/feature releases, I will email a broader list (no more than once every 3-6 months): pydata@googlegroups.com, xarray@googlegroups.com, diff --git a/MANIFEST.in b/MANIFEST.in index a49c49cd396..a006660e5fb 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,3 +4,5 @@ recursive-include doc * prune doc/_build prune doc/generated global-exclude .DS_Store +include versioneer.py +include xarray/_version.py diff --git a/README.rst b/README.rst index 8e77f55ccbb..0ac71d33954 100644 --- a/README.rst +++ b/README.rst @@ -7,12 +7,16 @@ xarray: N-D labeled arrays and datasets :target: https://ci.appveyor.com/project/shoyer/xray .. image:: https://coveralls.io/repos/pydata/xarray/badge.svg :target: https://coveralls.io/r/pydata/xarray +.. image:: https://readthedocs.org/projects/xray/badge/?version=latest + :target: http://xarray.pydata.org/ .. image:: https://img.shields.io/pypi/v/xarray.svg :target: https://pypi.python.org/pypi/xarray/ .. image:: https://zenodo.org/badge/13221727.svg :target: https://zenodo.org/badge/latestdoi/13221727 .. image:: http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat - :target: https://tomaugspurger.github.io/asv-collection/xarray/ + :target: http://pandas.pydata.org/speed/xarray/ +.. image:: https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A + :target: http://numfocus.org **xarray** (formerly **xray**) is an open source project and Python package that aims to bring the labeled data power of pandas_ to the physical sciences, by providing @@ -84,6 +88,11 @@ Documentation The official documentation is hosted on ReadTheDocs at http://xarray.pydata.org/ +Contributing +------------ + +You can find information about contributing to xarray at our `Contributing page `_. + Get in touch ------------ @@ -96,20 +105,36 @@ Get in touch .. _mailing list: https://groups.google.com/forum/#!forum/xarray .. _on GitHub: http://github.com/pydata/xarray +NumFOCUS +-------- + +.. image:: https://numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png + :scale: 25 % + :target: https://numfocus.org/ + +Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated +to supporting the open source scientific computing community. If you like +Xarray and want to support our mission, please consider making a donation_ +to support our efforts. + +.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU= + History ------- xarray is an evolution of an internal tool developed at `The Climate Corporation`__. It was originally written by Climate Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in -May 2014. The project was renamed from "xray" in January 2016. +May 2014. The project was renamed from "xray" in January 2016. Xarray became a +fiscally sponsored project of NumFOCUS_ in August 2018. __ http://climate.com/ +.. _NumFOCUS: https://numfocus.org License ------- -Copyright 2014-2017, xarray Developers +Copyright 2014-2018, xarray Developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index a2878a7bf50..e3933b400e6 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -63,7 +63,8 @@ "netcdf4": [""], "scipy": [""], "bottleneck": ["", null], - "dask": ["", null], + "dask": [""], + "distributed": [""], }, diff --git a/asv_bench/benchmarks/__init__.py b/asv_bench/benchmarks/__init__.py index e2f49e6ab48..997fdfd0db0 100644 --- a/asv_bench/benchmarks/__init__.py +++ b/asv_bench/benchmarks/__init__.py @@ -2,16 +2,23 @@ from __future__ import division from __future__ import print_function import itertools -import random import numpy as np _counter = itertools.count() +def parameterized(names, params): + def decorator(func): + func.param_names = names + func.params = params + return func + return decorator + + def requires_dask(): try: - import dask + import dask # noqa except ImportError: raise NotImplementedError diff --git a/asv_bench/benchmarks/dataarray_missing.py b/asv_bench/benchmarks/dataarray_missing.py index c6aa8f428bd..29a9e78f82c 100644 --- a/asv_bench/benchmarks/dataarray_missing.py +++ b/asv_bench/benchmarks/dataarray_missing.py @@ -1,23 +1,21 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import pandas as pd -try: - import dask -except ImportError: - pass - import xarray as xr from . import randn, requires_dask +try: + import dask # noqa +except ImportError: + pass + def make_bench_data(shape, frac_nan, chunks): vals = randn(shape, frac_nan) coords = {'time': pd.date_range('2000-01-01', freq='D', - periods=shape[0])} + periods=shape[0])} da = xr.DataArray(vals, dims=('time', 'x', 'y'), coords=coords) if chunks is not None: diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index d7766d99a3d..3e070e1355b 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -1,19 +1,22 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +import os import numpy as np import pandas as pd +import xarray as xr + +from . import randint, randn, requires_dask + try: import dask import dask.multiprocessing except ImportError: pass -import xarray as xr -from . import randn, requires_dask +os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' class IOSingleNetCDF(object): @@ -73,6 +76,15 @@ def make_ds(self): self.ds.attrs = {'history': 'created for xarray benchmarking'} + self.oinds = {'time': randint(0, self.nt, 120), + 'lon': randint(0, self.nx, 20), + 'lat': randint(0, self.ny, 10)} + self.vinds = {'time': xr.DataArray(randint(0, self.nt, 120), + dims='x'), + 'lon': xr.DataArray(randint(0, self.nx, 120), + dims='x'), + 'lat': slice(3, 20)} + class IOWriteSingleNetCDF3(IOSingleNetCDF): def setup(self): @@ -100,6 +112,14 @@ def setup(self): def time_load_dataset_netcdf4(self): xr.open_dataset(self.filepath, engine='netcdf4').load() + def time_orthogonal_indexing(self): + ds = xr.open_dataset(self.filepath, engine='netcdf4') + ds = ds.isel(**self.oinds).load() + + def time_vectorized_indexing(self): + ds = xr.open_dataset(self.filepath, engine='netcdf4') + ds = ds.isel(**self.vinds).load() + class IOReadSingleNetCDF3(IOReadSingleNetCDF4): def setup(self): @@ -113,6 +133,14 @@ def setup(self): def time_load_dataset_scipy(self): xr.open_dataset(self.filepath, engine='scipy').load() + def time_orthogonal_indexing(self): + ds = xr.open_dataset(self.filepath, engine='scipy') + ds = ds.isel(**self.oinds).load() + + def time_vectorized_indexing(self): + ds = xr.open_dataset(self.filepath, engine='scipy') + ds = ds.isel(**self.vinds).load() + class IOReadSingleNetCDF4Dask(IOSingleNetCDF): def setup(self): @@ -129,8 +157,18 @@ def time_load_dataset_netcdf4_with_block_chunks(self): xr.open_dataset(self.filepath, engine='netcdf4', chunks=self.block_chunks).load() + def time_load_dataset_netcdf4_with_block_chunks_oindexing(self): + ds = xr.open_dataset(self.filepath, engine='netcdf4', + chunks=self.block_chunks) + ds = ds.isel(**self.oinds).load() + + def time_load_dataset_netcdf4_with_block_chunks_vindexing(self): + ds = xr.open_dataset(self.filepath, engine='netcdf4', + chunks=self.block_chunks) + ds = ds.isel(**self.vinds).load() + def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_dataset(self.filepath, engine='netcdf4', chunks=self.block_chunks).load() @@ -139,7 +177,7 @@ def time_load_dataset_netcdf4_with_time_chunks(self): chunks=self.time_chunks).load() def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_dataset(self.filepath, engine='netcdf4', chunks=self.time_chunks).load() @@ -156,12 +194,22 @@ def setup(self): self.ds.to_netcdf(self.filepath, format=self.format) def time_load_dataset_scipy_with_block_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_dataset(self.filepath, engine='scipy', chunks=self.block_chunks).load() + def time_load_dataset_scipy_with_block_chunks_oindexing(self): + ds = xr.open_dataset(self.filepath, engine='scipy', + chunks=self.block_chunks) + ds = ds.isel(**self.oinds).load() + + def time_load_dataset_scipy_with_block_chunks_vindexing(self): + ds = xr.open_dataset(self.filepath, engine='scipy', + chunks=self.block_chunks) + ds = ds.isel(**self.vinds).load() + def time_load_dataset_scipy_with_time_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_dataset(self.filepath, engine='scipy', chunks=self.time_chunks).load() @@ -301,7 +349,7 @@ def time_load_dataset_netcdf4_with_block_chunks(self): chunks=self.block_chunks).load() def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='netcdf4', chunks=self.block_chunks).load() @@ -310,7 +358,7 @@ def time_load_dataset_netcdf4_with_time_chunks(self): chunks=self.time_chunks).load() def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='netcdf4', chunks=self.time_chunks).load() @@ -319,7 +367,7 @@ def time_open_dataset_netcdf4_with_block_chunks(self): chunks=self.block_chunks) def time_open_dataset_netcdf4_with_block_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='netcdf4', chunks=self.block_chunks) @@ -328,7 +376,7 @@ def time_open_dataset_netcdf4_with_time_chunks(self): chunks=self.time_chunks) def time_open_dataset_netcdf4_with_time_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='netcdf4', chunks=self.time_chunks) @@ -344,21 +392,57 @@ def setup(self): format=self.format) def time_load_dataset_scipy_with_block_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.block_chunks).load() def time_load_dataset_scipy_with_time_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.time_chunks).load() def time_open_dataset_scipy_with_block_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.block_chunks) def time_open_dataset_scipy_with_time_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.time_chunks) + + +def create_delayed_write(): + import dask.array as da + vals = da.random.random(300, chunks=(1,)) + ds = xr.Dataset({'vals': (['a'], vals)}) + return ds.to_netcdf('file.nc', engine='netcdf4', compute=False) + + +class IOWriteNetCDFDask(object): + timeout = 60 + repeat = 1 + number = 5 + + def setup(self): + requires_dask() + self.write = create_delayed_write() + + def time_write(self): + self.write.compute() + + +class IOWriteNetCDFDaskDistributed(object): + def setup(self): + try: + import distributed + except ImportError: + raise NotImplementedError + self.client = distributed.Client() + self.write = create_delayed_write() + + def cleanup(self): + self.client.shutdown() + + def time_write(self): + self.write.compute() diff --git a/asv_bench/benchmarks/indexing.py b/asv_bench/benchmarks/indexing.py index e9a85115a49..54262b12a19 100644 --- a/asv_bench/benchmarks/indexing.py +++ b/asv_bench/benchmarks/indexing.py @@ -1,13 +1,11 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import numpy as np import pandas as pd -import xarray as xr -from . import randn, randint, requires_dask +import xarray as xr +from . import randint, randn, requires_dask nx = 3000 ny = 2000 @@ -29,7 +27,7 @@ outer_indexes = { '1d': {'x': randint(0, nx, 400)}, - '2d': {'x': randint(0, nx, 500), 'y': randint(0, ny, 400)}, + '2d': {'x': randint(0, nx, 500), 'y': randint(0, ny, 400)}, '2d-1scalar': {'x': randint(0, nx, 100), 'y': 1, 't': randint(0, nt, 400)} } diff --git a/asv_bench/benchmarks/interp.py b/asv_bench/benchmarks/interp.py new file mode 100644 index 00000000000..edec6df34dd --- /dev/null +++ b/asv_bench/benchmarks/interp.py @@ -0,0 +1,54 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +import pandas as pd + +import xarray as xr + +from . import parameterized, randn, requires_dask + +nx = 3000 +long_nx = 30000000 +ny = 2000 +nt = 1000 +window = 20 + +randn_xy = randn((nx, ny), frac_nan=0.1) +randn_xt = randn((nx, nt)) +randn_t = randn((nt, )) +randn_long = randn((long_nx, ), frac_nan=0.1) + + +new_x_short = np.linspace(0.3 * nx, 0.7 * nx, 100) +new_x_long = np.linspace(0.3 * nx, 0.7 * nx, 1000) +new_y_long = np.linspace(0.1, 0.9, 1000) + + +class Interpolation(object): + def setup(self, *args, **kwargs): + self.ds = xr.Dataset( + {'var1': (('x', 'y'), randn_xy), + 'var2': (('x', 't'), randn_xt), + 'var3': (('t', ), randn_t)}, + coords={'x': np.arange(nx), + 'y': np.linspace(0, 1, ny), + 't': pd.date_range('1970-01-01', periods=nt, freq='D'), + 'x_coords': ('x', np.linspace(1.1, 2.1, nx))}) + + @parameterized(['method', 'is_short'], + (['linear', 'cubic'], [True, False])) + def time_interpolation(self, method, is_short): + new_x = new_x_short if is_short else new_x_long + self.ds.interp(x=new_x, method=method).load() + + @parameterized(['method'], + (['linear', 'nearest'])) + def time_interpolation_2d(self, method): + self.ds.interp(x=new_x_long, y=new_y_long, method=method).load() + + +class InterpolationDask(Interpolation): + def setup(self, *args, **kwargs): + requires_dask() + super(InterpolationDask, self).setup(**kwargs) + self.ds = self.ds.chunk({'t': 50}) diff --git a/asv_bench/benchmarks/reindexing.py b/asv_bench/benchmarks/reindexing.py new file mode 100644 index 00000000000..28e14d52e89 --- /dev/null +++ b/asv_bench/benchmarks/reindexing.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +import xarray as xr + +from . import requires_dask + + +class Reindex(object): + def setup(self): + data = np.random.RandomState(0).randn(1000, 100, 100) + self.ds = xr.Dataset({'temperature': (('time', 'x', 'y'), data)}, + coords={'time': np.arange(1000), + 'x': np.arange(100), + 'y': np.arange(100)}) + + def time_1d_coarse(self): + self.ds.reindex(time=np.arange(0, 1000, 5)).load() + + def time_1d_fine_all_found(self): + self.ds.reindex(time=np.arange(0, 1000, 0.5), method='nearest').load() + + def time_1d_fine_some_missing(self): + self.ds.reindex(time=np.arange(0, 1000, 0.5), method='nearest', + tolerance=0.1).load() + + def time_2d_coarse(self): + self.ds.reindex(x=np.arange(0, 100, 2), y=np.arange(0, 100, 2)).load() + + def time_2d_fine_all_found(self): + self.ds.reindex(x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5), + method='nearest').load() + + def time_2d_fine_some_missing(self): + self.ds.reindex(x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5), + method='nearest', tolerance=0.1).load() + + +class ReindexDask(Reindex): + def setup(self): + requires_dask() + super(ReindexDask, self).setup() + self.ds = self.ds.chunk({'time': 100}) diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py new file mode 100644 index 00000000000..5ba7406f6e0 --- /dev/null +++ b/asv_bench/benchmarks/rolling.py @@ -0,0 +1,68 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +import pandas as pd + +import xarray as xr + +from . import parameterized, randn, requires_dask + +nx = 3000 +long_nx = 30000000 +ny = 2000 +nt = 1000 +window = 20 + +randn_xy = randn((nx, ny), frac_nan=0.1) +randn_xt = randn((nx, nt)) +randn_t = randn((nt, )) +randn_long = randn((long_nx, ), frac_nan=0.1) + + +class Rolling(object): + def setup(self, *args, **kwargs): + self.ds = xr.Dataset( + {'var1': (('x', 'y'), randn_xy), + 'var2': (('x', 't'), randn_xt), + 'var3': (('t', ), randn_t)}, + coords={'x': np.arange(nx), + 'y': np.linspace(0, 1, ny), + 't': pd.date_range('1970-01-01', periods=nt, freq='D'), + 'x_coords': ('x', np.linspace(1.1, 2.1, nx))}) + self.da_long = xr.DataArray(randn_long, dims='x', + coords={'x': np.arange(long_nx) * 0.1}) + + @parameterized(['func', 'center'], + (['mean', 'count'], [True, False])) + def time_rolling(self, func, center): + getattr(self.ds.rolling(x=window, center=center), func)().load() + + @parameterized(['func', 'pandas'], + (['mean', 'count'], [True, False])) + def time_rolling_long(self, func, pandas): + if pandas: + se = self.da_long.to_series() + getattr(se.rolling(window=window), func)() + else: + getattr(self.da_long.rolling(x=window), func)().load() + + @parameterized(['window_', 'min_periods'], + ([20, 40], [5, None])) + def time_rolling_np(self, window_, min_periods): + self.ds.rolling(x=window_, center=False, + min_periods=min_periods).reduce( + getattr(np, 'nanmean')).load() + + @parameterized(['center', 'stride'], + ([True, False], [1, 200])) + def time_rolling_construct(self, center, stride): + self.ds.rolling(x=window, center=center).construct( + 'window_dim', stride=stride).mean(dim='window_dim').load() + + +class RollingDask(Rolling): + def setup(self, *args, **kwargs): + requires_dask() + super(RollingDask, self).setup(**kwargs) + self.ds = self.ds.chunk({'x': 100, 'y': 50, 't': 50}) + self.da_long = self.da_long.chunk({'x': 10000}) diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py new file mode 100644 index 00000000000..54436b422e9 --- /dev/null +++ b/asv_bench/benchmarks/unstacking.py @@ -0,0 +1,26 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +import xarray as xr + +from . import requires_dask + + +class Unstacking(object): + def setup(self): + data = np.random.RandomState(0).randn(1, 1000, 500) + self.ds = xr.DataArray(data).stack(flat_dim=['dim_1', 'dim_2']) + + def time_unstack_fast(self): + self.ds.unstack('flat_dim') + + def time_unstack_slow(self): + self.ds[:, ::-1].unstack('flat_dim') + + +class UnstackingDask(Unstacking): + def setup(self, *args, **kwargs): + requires_dask() + super(UnstackingDask, self).setup(**kwargs) + self.ds = self.ds.chunk({'flat_dim': 50}) diff --git a/ci/requirements-py27-cdat+iris+pynio.yml b/ci/requirements-py27-cdat+iris+pynio.yml index 6c7e3b87318..116e323d517 100644 --- a/ci/requirements-py27-cdat+iris+pynio.yml +++ b/ci/requirements-py27-cdat+iris+pynio.yml @@ -4,6 +4,7 @@ channels: dependencies: - python=2.7 - cdat-lite + - cftime - cyordereddict - dask - distributed diff --git a/ci/requirements-py27-min.yml b/ci/requirements-py27-min.yml index 50f6724ec51..118b629271e 100644 --- a/ci/requirements-py27-min.yml +++ b/ci/requirements-py27-min.yml @@ -4,8 +4,8 @@ dependencies: - pytest - flake8 - mock - - numpy==1.11 - - pandas==0.18.0 + - numpy=1.12 + - pandas=0.19 - pip: - coveralls - pytest-cov diff --git a/ci/requirements-py27-windows.yml b/ci/requirements-py27-windows.yml index a39b24b887c..967b7c584b9 100644 --- a/ci/requirements-py27-windows.yml +++ b/ci/requirements-py27-windows.yml @@ -8,7 +8,6 @@ dependencies: - h5py - h5netcdf - matplotlib - - netcdf4 - pathlib2 - pytest - flake8 @@ -20,3 +19,5 @@ dependencies: - toolz - rasterio - zarr + - pip: + - netcdf4 diff --git a/ci/requirements-py34.yml b/ci/requirements-py34.yml deleted file mode 100644 index ba79e00bb12..00000000000 --- a/ci/requirements-py34.yml +++ /dev/null @@ -1,10 +0,0 @@ -name: test_env -dependencies: - - python=3.4 - - bottleneck - - flake8 - - pandas - - pip: - - coveralls - - pytest-cov - - pytest diff --git a/ci/requirements-py35.yml b/ci/requirements-py35.yml index 6f9ae2490b9..9615aeba9aa 100644 --- a/ci/requirements-py35.yml +++ b/ci/requirements-py35.yml @@ -3,11 +3,11 @@ channels: - conda-forge dependencies: - python=3.5 - - dask - - distributed + - cftime + - dask=0.16 - h5py - h5netcdf - - matplotlib + - matplotlib=1.5 - netcdf4 - pytest - flake8 diff --git a/ci/requirements-py36-bottleneck-dev.yml b/ci/requirements-py36-bottleneck-dev.yml index 571a2e1294f..b8619658929 100644 --- a/ci/requirements-py36-bottleneck-dev.yml +++ b/ci/requirements-py36-bottleneck-dev.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - python=3.6 + - cftime - dask - distributed - h5py diff --git a/ci/requirements-py36-condaforge-rc.yml b/ci/requirements-py36-condaforge-rc.yml index 6519d0d0f47..8436d4e3e83 100644 --- a/ci/requirements-py36-condaforge-rc.yml +++ b/ci/requirements-py36-condaforge-rc.yml @@ -4,6 +4,7 @@ channels: - conda-forge dependencies: - python=3.6 + - cftime - dask - distributed - h5py diff --git a/ci/requirements-py36-dask-dev.yml b/ci/requirements-py36-dask-dev.yml index ae359a13356..e580aaf3889 100644 --- a/ci/requirements-py36-dask-dev.yml +++ b/ci/requirements-py36-dask-dev.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - python=3.6 + - cftime - h5py - h5netcdf - matplotlib @@ -11,9 +12,13 @@ dependencies: - flake8 - numpy - pandas - - seaborn - scipy + - seaborn - toolz + - rasterio + - bottleneck + - zarr + - pseudonetcdf>=3.0.1 - pip: - coveralls - pytest-cov diff --git a/ci/requirements-py36-hypothesis.yml b/ci/requirements-py36-hypothesis.yml new file mode 100644 index 00000000000..29f4ae33538 --- /dev/null +++ b/ci/requirements-py36-hypothesis.yml @@ -0,0 +1,27 @@ +name: test_env +channels: + - conda-forge +dependencies: + - python=3.6 + - dask + - distributed + - h5py + - h5netcdf + - matplotlib + - netcdf4 + - pytest + - flake8 + - numpy + - pandas + - scipy + - seaborn + - toolz + - rasterio + - bottleneck + - zarr + - pip: + - coveralls + - pytest-cov + - pydap + - lxml + - hypothesis diff --git a/ci/requirements-py36-netcdf4-dev.yml b/ci/requirements-py36-netcdf4-dev.yml index 2daa02756bb..a473ceb5b0a 100644 --- a/ci/requirements-py36-netcdf4-dev.yml +++ b/ci/requirements-py36-netcdf4-dev.yml @@ -19,3 +19,4 @@ dependencies: - coveralls - pytest-cov - git+https://github.com/Unidata/netcdf4-python.git + - git+https://github.com/Unidata/cftime.git diff --git a/ci/requirements-py36-pandas-dev.yml b/ci/requirements-py36-pandas-dev.yml index fe4eb226204..1f1acabcae9 100644 --- a/ci/requirements-py36-pandas-dev.yml +++ b/ci/requirements-py36-pandas-dev.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - python=3.6 + - cftime - cython - dask - distributed diff --git a/ci/requirements-py36-pynio-dev.yml b/ci/requirements-py36-pynio-dev.yml index e19c6537c68..2caaa8affe5 100644 --- a/ci/requirements-py36-pynio-dev.yml +++ b/ci/requirements-py36-pynio-dev.yml @@ -1,9 +1,10 @@ name: test_env channels: - conda-forge - - ncar + - conda-forge/label/dev dependencies: - python=3.6 + - cftime - dask - distributed - h5py diff --git a/ci/requirements-py36-rasterio1.0alpha.yml b/ci/requirements-py36-rasterio-0.36.yml similarity index 86% rename from ci/requirements-py36-rasterio1.0alpha.yml rename to ci/requirements-py36-rasterio-0.36.yml index 3c32ebb0e43..5c724e1b981 100644 --- a/ci/requirements-py36-rasterio1.0alpha.yml +++ b/ci/requirements-py36-rasterio-0.36.yml @@ -1,9 +1,9 @@ name: test_env channels: - conda-forge - - conda-forge/label/dev dependencies: - python=3.6 + - cftime - dask - distributed - h5py @@ -16,7 +16,7 @@ dependencies: - scipy - seaborn - toolz - - rasterio>=1.* + - rasterio=0.36.0 - bottleneck - pip: - coveralls diff --git a/ci/requirements-py36-windows.yml b/ci/requirements-py36-windows.yml index ea366bd04f7..62f08318087 100644 --- a/ci/requirements-py36-windows.yml +++ b/ci/requirements-py36-windows.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - python=3.6 + - cftime - dask - distributed - h5py @@ -17,3 +18,4 @@ dependencies: - toolz - rasterio - zarr + diff --git a/ci/requirements-py36-zarr-dev.yml b/ci/requirements-py36-zarr-dev.yml index 9be522882c5..7fbce63aa81 100644 --- a/ci/requirements-py36-zarr-dev.yml +++ b/ci/requirements-py36-zarr-dev.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - python=3.6 + - cftime - dask - distributed - matplotlib diff --git a/ci/requirements-py36.yml b/ci/requirements-py36.yml index cc02d6e92bf..321f3087ea2 100644 --- a/ci/requirements-py36.yml +++ b/ci/requirements-py36.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - python=3.6 + - cftime - dask - distributed - h5py @@ -19,8 +20,11 @@ dependencies: - rasterio - bottleneck - zarr + - pseudonetcdf>=3.0.1 + - eccodes - pip: - coveralls - pytest-cov - pydap - lxml + - cfgrib>=0.9.2 diff --git a/ci/requirements-py37.yml b/ci/requirements-py37.yml new file mode 100644 index 00000000000..6292c4c5eb6 --- /dev/null +++ b/ci/requirements-py37.yml @@ -0,0 +1,30 @@ +name: test_env +channels: + - conda-forge +dependencies: + - python=3.7 + - cftime + - dask + - distributed + - h5py + - h5netcdf + - matplotlib + - netcdf4 + - pytest + - flake8 + - numpy + - pandas + - scipy + - seaborn + - toolz + - rasterio + - bottleneck + - zarr + - pseudonetcdf>=3.0.1 + - eccodes + - pip: + - coveralls + - pytest-cov + - pydap + - lxml + - cfgrib>=0.9.2 \ No newline at end of file diff --git a/doc/README.rst b/doc/README.rst new file mode 100644 index 00000000000..af7bc96092c --- /dev/null +++ b/doc/README.rst @@ -0,0 +1,4 @@ +xarray +------ + +You can find information about building the docs at our `Contributing page `_. diff --git a/doc/_static/advanced_selection_interpolation.svg b/doc/_static/advanced_selection_interpolation.svg new file mode 100644 index 00000000000..096563a604f --- /dev/null +++ b/doc/_static/advanced_selection_interpolation.svg @@ -0,0 +1,731 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + y + x + + + + + z + + + + + + + + + + + + + + + + + + + + + + + + + + + + y + x + + + + + z + + + + + + + + + Advanced indexing + Advanced interpolation + + + + diff --git a/doc/_static/ci.png b/doc/_static/ci.png new file mode 100644 index 00000000000..f535b594454 Binary files /dev/null and b/doc/_static/ci.png differ diff --git a/doc/_static/dataset-diagram-square-logo.png b/doc/_static/dataset-diagram-square-logo.png new file mode 100644 index 00000000000..d1eeda092c4 Binary files /dev/null and b/doc/_static/dataset-diagram-square-logo.png differ diff --git a/doc/_static/dataset-diagram-square-logo.tex b/doc/_static/dataset-diagram-square-logo.tex new file mode 100644 index 00000000000..0a784770b50 --- /dev/null +++ b/doc/_static/dataset-diagram-square-logo.tex @@ -0,0 +1,277 @@ +\documentclass[class=minimal,border=0pt,convert={size=600,outext=.png}]{standalone} +% \documentclass[class=minimal,border=0pt]{standalone} +\usepackage[scaled]{helvet} +\renewcommand*\familydefault{\sfdefault} + +% =========================================================================== +% The code below (used to define the \tikzcuboid command) is copied, +% unmodified, from a tex.stackexchange.com answer by the user "Tom Bombadil": +% http://tex.stackexchange.com/a/29882/8335 +% +% It is licensed under the Creative Commons Attribution-ShareAlike 3.0 +% Unported license: http://creativecommons.org/licenses/by-sa/3.0/ +% =========================================================================== + +\usepackage[usenames,dvipsnames]{color} +\usepackage{tikz} +\usepackage{keyval} +\usepackage{ifthen} + +%==================================== +%emphasize vertices --> switch and emph style (e.g. thick,black) +%==================================== +\makeatletter +% Standard Values for Parameters +\newcommand{\tikzcuboid@shiftx}{0} +\newcommand{\tikzcuboid@shifty}{0} +\newcommand{\tikzcuboid@dimx}{3} +\newcommand{\tikzcuboid@dimy}{3} +\newcommand{\tikzcuboid@dimz}{3} +\newcommand{\tikzcuboid@scale}{1} +\newcommand{\tikzcuboid@densityx}{1} +\newcommand{\tikzcuboid@densityy}{1} +\newcommand{\tikzcuboid@densityz}{1} +\newcommand{\tikzcuboid@rotation}{0} +\newcommand{\tikzcuboid@anglex}{0} +\newcommand{\tikzcuboid@angley}{90} +\newcommand{\tikzcuboid@anglez}{225} +\newcommand{\tikzcuboid@scalex}{1} +\newcommand{\tikzcuboid@scaley}{1} +\newcommand{\tikzcuboid@scalez}{sqrt(0.5)} +\newcommand{\tikzcuboid@linefront}{black} +\newcommand{\tikzcuboid@linetop}{black} +\newcommand{\tikzcuboid@lineright}{black} +\newcommand{\tikzcuboid@fillfront}{white} +\newcommand{\tikzcuboid@filltop}{white} +\newcommand{\tikzcuboid@fillright}{white} +\newcommand{\tikzcuboid@shaded}{N} +\newcommand{\tikzcuboid@shadecolor}{black} +\newcommand{\tikzcuboid@shadeperc}{25} +\newcommand{\tikzcuboid@emphedge}{N} +\newcommand{\tikzcuboid@emphstyle}{thick} + +% Definition of Keys +\define@key{tikzcuboid}{shiftx}[\tikzcuboid@shiftx]{\renewcommand{\tikzcuboid@shiftx}{#1}} +\define@key{tikzcuboid}{shifty}[\tikzcuboid@shifty]{\renewcommand{\tikzcuboid@shifty}{#1}} +\define@key{tikzcuboid}{dimx}[\tikzcuboid@dimx]{\renewcommand{\tikzcuboid@dimx}{#1}} +\define@key{tikzcuboid}{dimy}[\tikzcuboid@dimy]{\renewcommand{\tikzcuboid@dimy}{#1}} +\define@key{tikzcuboid}{dimz}[\tikzcuboid@dimz]{\renewcommand{\tikzcuboid@dimz}{#1}} +\define@key{tikzcuboid}{scale}[\tikzcuboid@scale]{\renewcommand{\tikzcuboid@scale}{#1}} +\define@key{tikzcuboid}{densityx}[\tikzcuboid@densityx]{\renewcommand{\tikzcuboid@densityx}{#1}} +\define@key{tikzcuboid}{densityy}[\tikzcuboid@densityy]{\renewcommand{\tikzcuboid@densityy}{#1}} +\define@key{tikzcuboid}{densityz}[\tikzcuboid@densityz]{\renewcommand{\tikzcuboid@densityz}{#1}} +\define@key{tikzcuboid}{rotation}[\tikzcuboid@rotation]{\renewcommand{\tikzcuboid@rotation}{#1}} +\define@key{tikzcuboid}{anglex}[\tikzcuboid@anglex]{\renewcommand{\tikzcuboid@anglex}{#1}} +\define@key{tikzcuboid}{angley}[\tikzcuboid@angley]{\renewcommand{\tikzcuboid@angley}{#1}} +\define@key{tikzcuboid}{anglez}[\tikzcuboid@anglez]{\renewcommand{\tikzcuboid@anglez}{#1}} +\define@key{tikzcuboid}{scalex}[\tikzcuboid@scalex]{\renewcommand{\tikzcuboid@scalex}{#1}} +\define@key{tikzcuboid}{scaley}[\tikzcuboid@scaley]{\renewcommand{\tikzcuboid@scaley}{#1}} +\define@key{tikzcuboid}{scalez}[\tikzcuboid@scalez]{\renewcommand{\tikzcuboid@scalez}{#1}} +\define@key{tikzcuboid}{linefront}[\tikzcuboid@linefront]{\renewcommand{\tikzcuboid@linefront}{#1}} +\define@key{tikzcuboid}{linetop}[\tikzcuboid@linetop]{\renewcommand{\tikzcuboid@linetop}{#1}} +\define@key{tikzcuboid}{lineright}[\tikzcuboid@lineright]{\renewcommand{\tikzcuboid@lineright}{#1}} +\define@key{tikzcuboid}{fillfront}[\tikzcuboid@fillfront]{\renewcommand{\tikzcuboid@fillfront}{#1}} +\define@key{tikzcuboid}{filltop}[\tikzcuboid@filltop]{\renewcommand{\tikzcuboid@filltop}{#1}} +\define@key{tikzcuboid}{fillright}[\tikzcuboid@fillright]{\renewcommand{\tikzcuboid@fillright}{#1}} +\define@key{tikzcuboid}{shaded}[\tikzcuboid@shaded]{\renewcommand{\tikzcuboid@shaded}{#1}} +\define@key{tikzcuboid}{shadecolor}[\tikzcuboid@shadecolor]{\renewcommand{\tikzcuboid@shadecolor}{#1}} +\define@key{tikzcuboid}{shadeperc}[\tikzcuboid@shadeperc]{\renewcommand{\tikzcuboid@shadeperc}{#1}} +\define@key{tikzcuboid}{emphedge}[\tikzcuboid@emphedge]{\renewcommand{\tikzcuboid@emphedge}{#1}} +\define@key{tikzcuboid}{emphstyle}[\tikzcuboid@emphstyle]{\renewcommand{\tikzcuboid@emphstyle}{#1}} +% Commands +\newcommand{\tikzcuboid}[1]{ + \setkeys{tikzcuboid}{#1} % Process Keys passed to command + \pgfmathsetmacro{\vectorxx}{\tikzcuboid@scalex*cos(\tikzcuboid@anglex)} + \pgfmathsetmacro{\vectorxy}{\tikzcuboid@scalex*sin(\tikzcuboid@anglex)} + \pgfmathsetmacro{\vectoryx}{\tikzcuboid@scaley*cos(\tikzcuboid@angley)} + \pgfmathsetmacro{\vectoryy}{\tikzcuboid@scaley*sin(\tikzcuboid@angley)} + \pgfmathsetmacro{\vectorzx}{\tikzcuboid@scalez*cos(\tikzcuboid@anglez)} + \pgfmathsetmacro{\vectorzy}{\tikzcuboid@scalez*sin(\tikzcuboid@anglez)} + \begin{scope}[xshift=\tikzcuboid@shiftx, yshift=\tikzcuboid@shifty, scale=\tikzcuboid@scale, rotate=\tikzcuboid@rotation, x={(\vectorxx,\vectorxy)}, y={(\vectoryx,\vectoryy)}, z={(\vectorzx,\vectorzy)}] + \pgfmathsetmacro{\steppingx}{1/\tikzcuboid@densityx} + \pgfmathsetmacro{\steppingy}{1/\tikzcuboid@densityy} + \pgfmathsetmacro{\steppingz}{1/\tikzcuboid@densityz} + \newcommand{\dimx}{\tikzcuboid@dimx} + \newcommand{\dimy}{\tikzcuboid@dimy} + \newcommand{\dimz}{\tikzcuboid@dimz} + \pgfmathsetmacro{\secondx}{2*\steppingx} + \pgfmathsetmacro{\secondy}{2*\steppingy} + \pgfmathsetmacro{\secondz}{2*\steppingz} + \foreach \x in {\steppingx,\secondx,...,\dimx} + { \foreach \y in {\steppingy,\secondy,...,\dimy} + { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} + \pgfmathsetmacro{\lowy}{(\y-\steppingy)} + \filldraw[fill=\tikzcuboid@fillfront,draw=\tikzcuboid@linefront] (\lowx,\lowy,\dimz) -- (\lowx,\y,\dimz) -- (\x,\y,\dimz) -- (\x,\lowy,\dimz) -- cycle; + + } + } + \foreach \x in {\steppingx,\secondx,...,\dimx} + { \foreach \z in {\steppingz,\secondz,...,\dimz} + { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} + \pgfmathsetmacro{\lowz}{(\z-\steppingz)} + \filldraw[fill=\tikzcuboid@filltop,draw=\tikzcuboid@linetop] (\lowx,\dimy,\lowz) -- (\lowx,\dimy,\z) -- (\x,\dimy,\z) -- (\x,\dimy,\lowz) -- cycle; + } + } + \foreach \y in {\steppingy,\secondy,...,\dimy} + { \foreach \z in {\steppingz,\secondz,...,\dimz} + { \pgfmathsetmacro{\lowy}{(\y-\steppingy)} + \pgfmathsetmacro{\lowz}{(\z-\steppingz)} + \filldraw[fill=\tikzcuboid@fillright,draw=\tikzcuboid@lineright] (\dimx,\lowy,\lowz) -- (\dimx,\lowy,\z) -- (\dimx,\y,\z) -- (\dimx,\y,\lowz) -- cycle; + } + } + \ifthenelse{\equal{\tikzcuboid@emphedge}{Y}}% + {\draw[\tikzcuboid@emphstyle](0,\dimy,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (0,\dimy,\dimz) -- cycle;% + \draw[\tikzcuboid@emphstyle] (0,0,\dimz) -- (0,\dimy,\dimz) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% + \draw[\tikzcuboid@emphstyle](\dimx,0,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% + }% + {} + \end{scope} +} + +\makeatother + +\begin{document} + +\begin{tikzpicture} + \tikzcuboid{% + shiftx=21cm,% + shifty=8cm,% + scale=1.00,% + rotation=0,% + densityx=2,% + densityy=2,% + densityz=2,% + dimx=4,% + dimy=3,% + dimz=3,% + linefront=purple!75!black,% + linetop=purple!50!black,% + lineright=purple!25!black,% + fillfront=purple!25!white,% + filltop=purple!50!white,% + fillright=purple!75!white,% + emphedge=Y,% + emphstyle=ultra thick, + } + \tikzcuboid{% + shiftx=21cm,% + shifty=11.6cm,% + scale=1.00,% + rotation=0,% + densityx=2,% + densityy=2,% + densityz=2,% + dimx=4,% + dimy=3,% + dimz=3,% + linefront=teal!75!black,% + linetop=teal!50!black,% + lineright=teal!25!black,% + fillfront=teal!25!white,% + filltop=teal!50!white,% + fillright=teal!75!white,% + emphedge=Y,% + emphstyle=ultra thick, + } + \tikzcuboid{% + shiftx=26.8cm,% + shifty=8cm,% + scale=1.00,% + rotation=0,% + densityx=10000,% + densityy=2,% + densityz=2,% + dimx=0,% + dimy=3,% + dimz=3,% + linefront=orange!75!black,% + linetop=orange!50!black,% + lineright=orange!25!black,% + fillfront=orange!25!white,% + filltop=orange!50!white,% + fillright=orange!100!white,% + emphedge=Y,% + emphstyle=ultra thick, + } + \tikzcuboid{% + shiftx=28.6cm,% + shifty=8cm,% + scale=1.00,% + rotation=0,% + densityx=10000,% + densityy=2,% + densityz=2,% + dimx=0,% + dimy=3,% + dimz=3,% + linefront=purple!75!black,% + linetop=purple!50!black,% + lineright=purple!25!black,% + fillfront=purple!25!white,% + filltop=purple!50!white,% + fillright=red!75!white,% + emphedge=Y,% + emphstyle=ultra thick, + } + % \tikzcuboid{% + % shiftx=27.1cm,% + % shifty=10.1cm,% + % scale=1.00,% + % rotation=0,% + % densityx=100,% + % densityy=2,% + % densityz=100,% + % dimx=0,% + % dimy=3,% + % dimz=0,% + % emphedge=Y,% + % emphstyle=ultra thick, + % } + % \tikzcuboid{% + % shiftx=27.1cm,% + % shifty=10.1cm,% + % scale=1.00,% + % rotation=180,% + % densityx=100,% + % densityy=100,% + % densityz=2,% + % dimx=0,% + % dimy=0,% + % dimz=3,% + % emphedge=Y,% + % emphstyle=ultra thick, + % } + \tikzcuboid{% + shiftx=26.8cm,% + shifty=11.4cm,% + scale=1.00,% + rotation=0,% + densityx=100,% + densityy=2,% + densityz=100,% + dimx=0,% + dimy=3,% + dimz=0,% + emphedge=Y,% + emphstyle=ultra thick, + } + \tikzcuboid{% + shiftx=25.3cm,% + shifty=12.9cm,% + scale=1.00,% + rotation=180,% + densityx=100,% + densityy=100,% + densityz=2,% + dimx=0,% + dimy=0,% + dimz=3,% + emphedge=Y,% + emphstyle=ultra thick, + } + % \fill (27.1,10.1) circle[radius=2pt]; + \node [font=\fontsize{130}{100}\fontfamily{phv}\selectfont, anchor=east, text width=2cm, align=right, color=white!50!black] at (19.8,4.4) {\textbf{\emph{x}}}; + \node [font=\fontsize{130}{100}\fontfamily{phv}\selectfont, anchor=west, text width=10cm, align=left] at (20.3,4) {{array}}; +\end{tikzpicture} + +\end{document} diff --git a/doc/_static/favicon.ico b/doc/_static/favicon.ico new file mode 100644 index 00000000000..a1536e3ef76 Binary files /dev/null and b/doc/_static/favicon.ico differ diff --git a/doc/_static/numfocus_logo.png b/doc/_static/numfocus_logo.png new file mode 100644 index 00000000000..af3c84209e0 Binary files /dev/null and b/doc/_static/numfocus_logo.png differ diff --git a/doc/_static/style.css b/doc/_static/style.css new file mode 100644 index 00000000000..7257d57db66 --- /dev/null +++ b/doc/_static/style.css @@ -0,0 +1,18 @@ +@import url("theme.css"); + +.wy-side-nav-search>a img.logo, +.wy-side-nav-search .wy-dropdown>a img.logo { + width: 12rem +} + +.wy-side-nav-search { + background-color: #eee; +} + +.wy-side-nav-search>div.version { + display: none; +} + +.wy-nav-top { + background-color: #555; +} diff --git a/doc/_templates/layout.html b/doc/_templates/layout.html new file mode 100644 index 00000000000..4c57ba83056 --- /dev/null +++ b/doc/_templates/layout.html @@ -0,0 +1,2 @@ +{% extends "!layout.html" %} +{% set css_files = css_files + ["_static/style.css"] %} diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index b8fbfbc288f..4b2fed8be37 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -9,6 +9,9 @@ auto_combine + Dataset.nbytes + Dataset.chunks + Dataset.all Dataset.any Dataset.argmax @@ -22,13 +25,6 @@ Dataset.std Dataset.var - Dataset.isnull - Dataset.notnull - Dataset.count - Dataset.dropna - Dataset.fillna - Dataset.where - core.groupby.DatasetGroupBy.assign core.groupby.DatasetGroupBy.assign_coords core.groupby.DatasetGroupBy.first @@ -43,15 +39,18 @@ Dataset.imag Dataset.round Dataset.real - Dataset.T Dataset.cumsum Dataset.cumprod Dataset.rank DataArray.ndim + DataArray.nbytes DataArray.shape DataArray.size DataArray.dtype + DataArray.nbytes + DataArray.chunks + DataArray.astype DataArray.item @@ -68,13 +67,6 @@ DataArray.std DataArray.var - DataArray.isnull - DataArray.notnull - DataArray.count - DataArray.dropna - DataArray.fillna - DataArray.where - core.groupby.DataArrayGroupBy.assign_coords core.groupby.DataArrayGroupBy.first core.groupby.DataArrayGroupBy.last @@ -158,3 +150,6 @@ plot.FacetGrid.set_titles plot.FacetGrid.set_ticks plot.FacetGrid.map + + CFTimeIndex.shift + CFTimeIndex.to_datetimeindex diff --git a/doc/api.rst b/doc/api.rst index 10386fe3a9b..662ef567710 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -24,6 +24,7 @@ Top-level functions full_like zeros_like ones_like + dot Dataset ======= @@ -51,6 +52,8 @@ Attributes Dataset.encoding Dataset.indexes Dataset.get_index + Dataset.chunks + Dataset.nbytes Dictionary interface -------------------- @@ -107,12 +110,32 @@ Indexing Dataset.isel Dataset.sel Dataset.squeeze + Dataset.interp + Dataset.interp_like Dataset.reindex Dataset.reindex_like Dataset.set_index Dataset.reset_index Dataset.reorder_levels +Missing value handling +---------------------- + +.. autosummary:: + :toctree: generated/ + + Dataset.isnull + Dataset.notnull + Dataset.combine_first + Dataset.count + Dataset.dropna + Dataset.fillna + Dataset.ffill + Dataset.bfill + Dataset.interpolate_na + Dataset.where + Dataset.isin + Computation ----------- @@ -127,6 +150,7 @@ Computation Dataset.resample Dataset.diff Dataset.quantile + Dataset.differentiate **Aggregation**: :py:attr:`~Dataset.all` @@ -142,18 +166,8 @@ Computation :py:attr:`~Dataset.std` :py:attr:`~Dataset.var` -**Missing values**: -:py:attr:`~Dataset.isnull` -:py:attr:`~Dataset.notnull` -:py:attr:`~Dataset.count` -:py:attr:`~Dataset.dropna` -:py:attr:`~Dataset.fillna` -:py:attr:`~Dataset.ffill` -:py:attr:`~Dataset.bfill` -:py:attr:`~Dataset.interpolate_na` -:py:attr:`~Dataset.where` - **ndarray methods**: +:py:attr:`~Dataset.astype` :py:attr:`~Dataset.argsort` :py:attr:`~Dataset.clip` :py:attr:`~Dataset.conj` @@ -216,6 +230,8 @@ Attributes :py:attr:`~DataArray.shape` :py:attr:`~DataArray.size` :py:attr:`~DataArray.dtype` +:py:attr:`~DataArray.nbytes` +:py:attr:`~DataArray.chunks` DataArray contents ------------------ @@ -250,12 +266,32 @@ Indexing DataArray.isel DataArray.sel DataArray.squeeze + DataArray.interp + DataArray.interp_like DataArray.reindex DataArray.reindex_like DataArray.set_index DataArray.reset_index DataArray.reorder_levels +Missing value handling +---------------------- + +.. autosummary:: + :toctree: generated/ + + DataArray.isnull + DataArray.notnull + DataArray.combine_first + DataArray.count + DataArray.dropna + DataArray.fillna + DataArray.ffill + DataArray.bfill + DataArray.interpolate_na + DataArray.where + DataArray.isin + Comparisons ----------- @@ -276,11 +312,13 @@ Computation DataArray.groupby DataArray.groupby_bins DataArray.rolling + DataArray.dt DataArray.resample DataArray.get_axis_num DataArray.diff DataArray.dot DataArray.quantile + DataArray.differentiate **Aggregation**: :py:attr:`~DataArray.all` @@ -296,17 +334,6 @@ Computation :py:attr:`~DataArray.std` :py:attr:`~DataArray.var` -**Missing values**: -:py:attr:`~DataArray.isnull` -:py:attr:`~DataArray.notnull` -:py:attr:`~DataArray.count` -:py:attr:`~DataArray.dropna` -:py:attr:`~DataArray.fillna` -:py:attr:`~DataArray.ffill` -:py:attr:`~DataArray.bfill` -:py:attr:`~DataArray.interpolate_na` -:py:attr:`~DataArray.where` - **ndarray methods**: :py:attr:`~DataArray.argsort` :py:attr:`~DataArray.clip` @@ -347,6 +374,13 @@ Reshaping and reorganizing Universal functions =================== +.. warning:: + + With recent versions of numpy, dask and xarray, NumPy ufuncs are now + supported directly on all xarray and dask objects. This obliviates the need + for the ``xarray.ufuncs`` module, which should not be used for new code + unless compatibility with versions of NumPy prior to v1.13 is required. + This functions are copied from NumPy, but extended to work on NumPy arrays, dask arrays and all xarray objects. You can find them in the ``xarray.ufuncs`` module: @@ -462,17 +496,81 @@ DataArray methods DataArray.from_series DataArray.from_cdms2 DataArray.from_dict + DataArray.close DataArray.compute DataArray.persist DataArray.load DataArray.chunk +GroupBy objects +=============== + +.. autosummary:: + :toctree: generated/ + + core.groupby.DataArrayGroupBy + core.groupby.DataArrayGroupBy.apply + core.groupby.DataArrayGroupBy.reduce + core.groupby.DatasetGroupBy + core.groupby.DatasetGroupBy.apply + core.groupby.DatasetGroupBy.reduce + +Rolling objects +=============== + +.. autosummary:: + :toctree: generated/ + + core.rolling.DataArrayRolling + core.rolling.DataArrayRolling.construct + core.rolling.DataArrayRolling.reduce + core.rolling.DatasetRolling + core.rolling.DatasetRolling.construct + core.rolling.DatasetRolling.reduce + +Resample objects +================ + +Resample objects also implement the GroupBy interface +(methods like ``apply()``, ``reduce()``, ``mean()``, ``sum()``, etc.). + +.. autosummary:: + :toctree: generated/ + + core.resample.DataArrayResample + core.resample.DataArrayResample.asfreq + core.resample.DataArrayResample.backfill + core.resample.DataArrayResample.interpolate + core.resample.DataArrayResample.nearest + core.resample.DataArrayResample.pad + core.resample.DatasetResample + core.resample.DatasetResample.asfreq + core.resample.DatasetResample.backfill + core.resample.DatasetResample.interpolate + core.resample.DatasetResample.nearest + core.resample.DatasetResample.pad + +Custom Indexes +============== +.. autosummary:: + :toctree: generated/ + + CFTimeIndex + +Creating custom indexes +----------------------- +.. autosummary:: + :toctree: generated/ + + cftime_range + Plotting ======== .. autosummary:: :toctree: generated/ + DataArray.plot plot.plot plot.contourf plot.contour @@ -507,6 +605,8 @@ Advanced API .. autosummary:: :toctree: generated/ + Dataset.variables + DataArray.variable Variable IndexVariable as_variable @@ -524,3 +624,6 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: backends.H5NetCDFStore backends.PydapDataStore backends.ScipyDataStore + backends.FileManager + backends.CachingFileManager + backends.DummyFileManager diff --git a/doc/computation.rst b/doc/computation.rst index 420b97923d7..759c87a6cc7 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -71,8 +71,8 @@ methods for working with missing data from pandas: x.count() x.dropna(dim='x') x.fillna(-1) - x.ffill() - x.bfill() + x.ffill('x') + x.bfill('x') Like pandas, xarray uses the float value ``np.nan`` (not-a-number) to represent missing values. @@ -158,13 +158,11 @@ Aggregation and summary methods can be applied directly to the ``Rolling`` objec r.mean() r.reduce(np.std) -Note that rolling window aggregations are much faster (both asymptotically and -because they avoid a loop in Python) when bottleneck_ is installed. Otherwise, -we fall back to a slower, pure Python implementation. +Note that rolling window aggregations are faster when bottleneck_ is installed. .. _bottleneck: https://github.com/kwgoodman/bottleneck/ -Finally, we can manually iterate through ``Rolling`` objects: +We can also manually iterate through ``Rolling`` objects: .. ipython:: python @@ -172,6 +170,61 @@ Finally, we can manually iterate through ``Rolling`` objects: for label, arr_window in r: # arr_window is a view of x +Finally, the rolling object has a ``construct`` method which returns a +view of the original ``DataArray`` with the windowed dimension in +the last position. +You can use this for more advanced rolling operations such as strided rolling, +windowed rolling, convolution, short-time FFT etc. + +.. ipython:: python + + # rolling with 2-point stride + rolling_da = r.construct('window_dim', stride=2) + rolling_da + rolling_da.mean('window_dim', skipna=False) + +Because the ``DataArray`` given by ``r.construct('window_dim')`` is a view +of the original array, it is memory efficient. +You can also use ``construct`` to compute a weighted rolling sum: + +.. ipython:: python + + weight = xr.DataArray([0.25, 0.5, 0.25], dims=['window']) + arr.rolling(y=3).construct('window').dot(weight) + +.. note:: + numpy's Nan-aggregation functions such as ``nansum`` copy the original array. + In xarray, we internally use these functions in our aggregation methods + (such as ``.sum()``) if ``skipna`` argument is not specified or set to True. + This means ``rolling_da.mean('window_dim')`` is memory inefficient. + To avoid this, use ``skipna=False`` as the above example. + + +Computation using Coordinates +============================= + +Xarray objects have some handy methods for the computation with their +coordinates. :py:meth:`~xarray.DataArray.differentiate` computes derivatives by +central finite differences using their coordinates, + +.. ipython:: python + + a = xr.DataArray([0, 1, 2, 3], dims=['x'], coords=[[0.1, 0.11, 0.2, 0.3]]) + a + a.differentiate('x') + +This method can be used also for multidimensional arrays, + +.. ipython:: python + + a = xr.DataArray(np.arange(8).reshape(4, 2), dims=['x', 'y'], + coords={'x': [0.1, 0.11, 0.2, 0.3]}) + a.differentiate('x') + +.. note:: + This method is limited to simple cartesian geometry. Differentiation along + multidimensional coordinate is not supported. + .. _compute.broadcasting: Broadcasting by dimension name @@ -319,21 +372,15 @@ Datasets support most of the same methods found on data arrays: ds.mean(dim='x') abs(ds) -Unfortunately, we currently do not support NumPy ufuncs for datasets [1]_. -:py:meth:`~xarray.Dataset.apply` works around this -limitation, by applying the given function to each variable in the dataset: +Datasets also support NumPy ufuncs (requires NumPy v1.13 or newer), or +alternatively you can use :py:meth:`~xarray.Dataset.apply` to apply a function +to each variable in a dataset: .. ipython:: python + np.sin(ds) ds.apply(np.sin) -You can also use the wrapped functions in the ``xarray.ufuncs`` module: - -.. ipython:: python - - import xarray.ufuncs as xu - xu.sin(ds) - Datasets also use looping over variables for *broadcasting* in binary arithmetic. You can do arithmetic between any ``DataArray`` and a dataset: @@ -351,10 +398,6 @@ Arithmetic between two datasets matches data variables of the same name: Similarly to index based alignment, the result has the intersection of all matching data variables. -.. [1] This was previously due to a limitation of NumPy, but with NumPy 1.13 - we should be able to support this by leveraging ``__array_ufunc__`` - (:issue:`1617`). - .. _comput.wrapping-custom: Wrapping custom computation diff --git a/doc/conf.py b/doc/conf.py index eb71c926375..897c0443054 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -11,19 +11,22 @@ # # All configuration values have a default; values that are commented out # serve to show the default. -from __future__ import print_function -from __future__ import division -from __future__ import absolute_import +from __future__ import absolute_import, division, print_function -import sys -import os import datetime import importlib +import os +import sys + +import xarray + +allowed_failures = set() print("python exec:", sys.executable) print("sys.path:", sys.path) for name in ('numpy scipy pandas matplotlib dask IPython seaborn ' - 'cartopy netCDF4 rasterio zarr').split(): + 'cartopy netCDF4 rasterio zarr iris flake8 ' + 'sphinx_gallery cftime').split(): try: module = importlib.import_module(name) if name == 'matplotlib': @@ -32,8 +35,16 @@ print("%s: %s, %s" % (name, module.__version__, fname)) except ImportError: print("no %s" % name) + # neither rasterio nor cartopy should be hard requirements for + # the doc build. + if name == 'rasterio': + allowed_failures.update(['gallery/plot_rasterio_rgb.py', + 'gallery/plot_rasterio.py']) + elif name == 'cartopy': + allowed_failures.update(['gallery/plot_cartopy_facetgrid.py', + 'gallery/plot_rasterio_rgb.py', + 'gallery/plot_rasterio.py']) -import xarray print("xarray: %s, %s" % (xarray.__version__, xarray.__file__)) # -- General configuration ------------------------------------------------ @@ -62,7 +73,8 @@ sphinx_gallery_conf = {'examples_dirs': 'gallery', 'gallery_dirs': 'auto_gallery', - 'backreferences_dir': False + 'backreferences_dir': False, + 'expected_failing_examples': list(allowed_failures) } autosummary_generate = True @@ -91,7 +103,7 @@ # built documents. # # The short X.Y version. -version = xarray.version.short_version +version = xarray.__version__.split('+')[0] # The full version, including alpha/beta/rc tags. release = xarray.__version__ @@ -138,22 +150,14 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. - -# on_rtd is whether we are on readthedocs.org, this line of code grabbed from -# docs.readthedocs.org -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' - -if not on_rtd: # only import and set the theme if we're building docs locally - import sphinx_rtd_theme - html_theme = 'sphinx_rtd_theme' - html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] - -# otherwise, readthedocs.org uses their theme by default, so no need to specify it +html_theme = 'sphinx_rtd_theme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +html_theme_options = { + 'logo_only': True, +} # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] @@ -167,12 +171,12 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +html_logo = "_static/dataset-diagram-logo.png" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +html_favicon = '_static/favicon.ico' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/doc/contributing.rst b/doc/contributing.rst new file mode 100644 index 00000000000..ceba81d9319 --- /dev/null +++ b/doc/contributing.rst @@ -0,0 +1,820 @@ +.. _contributing: + +********************** +Contributing to xarray +********************** + +.. contents:: Table of contents: + :local: + +.. note:: + + Large parts of this document came from the `Pandas Contributing + Guide `_. + +Where to start? +=============== + +All contributions, bug reports, bug fixes, documentation improvements, +enhancements, and ideas are welcome. + +If you are brand new to *xarray* or open-source development, we recommend going +through the `GitHub "issues" tab `_ +to find issues that interest you. There are a number of issues listed under +`Documentation `_ +and `good first issue +`_ +where you could start out. Once you've found an interesting issue, you can +return here to get your development environment setup. + +Feel free to ask questions on the `mailing list +`_. + +.. _contributing.bug_reports: + +Bug reports and enhancement requests +==================================== + +Bug reports are an important part of making *xarray* more stable. Having a complete bug +report will allow others to reproduce the bug and provide insight into fixing. See +`this stackoverflow article `_ for tips on +writing a good bug report. + +Trying the bug-producing code out on the *master* branch is often a worthwhile exercise +to confirm the bug still exists. It is also worth searching existing bug reports and +pull requests to see if the issue has already been reported and/or fixed. + +Bug reports must: + +#. Include a short, self-contained Python snippet reproducing the problem. + You can format the code nicely by using `GitHub Flavored Markdown + `_:: + + ```python + >>> from xarray import Dataset + >>> df = Dataset(...) + ... + ``` + +#. Include the full version string of *xarray* and its dependencies. You can use the + built in function:: + + >>> import xarray as xr + >>> xr.show_versions() + +#. Explain why the current behavior is wrong/not desired and what you expect instead. + +The issue will then show up to the *xarray* community and be open to comments/ideas +from others. + +.. _contributing.github: + +Working with the code +===================== + +Now that you have an issue you want to fix, enhancement to add, or documentation +to improve, you need to learn how to work with GitHub and the *xarray* code base. + +.. _contributing.version_control: + +Version control, Git, and GitHub +-------------------------------- + +To the new user, working with Git is one of the more daunting aspects of contributing +to *xarray*. It can very quickly become overwhelming, but sticking to the guidelines +below will help keep the process straightforward and mostly trouble free. As always, +if you are having difficulties please feel free to ask for help. + +The code is hosted on `GitHub `_. To +contribute you will need to sign up for a `free GitHub account +`_. We use `Git `_ for +version control to allow many people to work together on the project. + +Some great resources for learning Git: + +* the `GitHub help pages `_. +* the `NumPy's documentation `_. +* Matthew Brett's `Pydagogue `_. + +Getting started with Git +------------------------ + +`GitHub has instructions `__ for installing git, +setting up your SSH key, and configuring git. All these steps need to be completed before +you can work seamlessly between your local repository and GitHub. + +.. _contributing.forking: + +Forking +------- + +You will need your own fork to work on the code. Go to the `xarray project +page `_ and hit the ``Fork`` button. You will +want to clone your fork to your machine:: + + git clone https://github.com/your-user-name/xarray.git + cd xarray + git remote add upstream https://github.com/pydata/xarray.git + +This creates the directory `xarray` and connects your repository to +the upstream (main project) *xarray* repository. + +.. _contributing.dev_env: + +Creating a development environment +---------------------------------- + +To test out code changes, you'll need to build *xarray* from source, which +requires a Python environment. If you're making documentation changes, you can +skip to :ref:`contributing.documentation` but you won't be able to build the +documentation locally before pushing your changes. + +.. _contributiong.dev_python: + +Creating a Python Environment +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Before starting any development, you'll need to create an isolated xarray +development environment: + +- Install either `Anaconda `_ or `miniconda + `_ +- Make sure your conda is up to date (``conda update conda``) +- Make sure that you have :ref:`cloned the repository ` +- ``cd`` to the *xarray* source directory + +We'll now kick off a two-step process: + +1. Install the build dependencies +2. Build and install xarray + +.. code-block:: none + + # Create and activate the build environment + conda env create -f ci/requirements-py36.yml + conda activate test_env + + # or with older versions of Anaconda: + source activate test_env + + # Build and install xarray + pip install -e . + +At this point you should be able to import *xarray* from your locally built version:: + + $ python # start an interpreter + >>> import xarray + >>> xarray.__version__ + '0.10.0+dev46.g015daca' + +This will create the new environment, and not touch any of your existing environments, +nor any existing Python installation. + +To view your environments:: + + conda info -e + +To return to your root environment:: + + conda deactivate + +See the full conda docs `here `__. + +Creating a branch +----------------- + +You want your master branch to reflect only production-ready code, so create a +feature branch for making your changes. For example:: + + git branch shiny-new-feature + git checkout shiny-new-feature + +The above can be simplified to:: + + git checkout -b shiny-new-feature + +This changes your working directory to the shiny-new-feature branch. Keep any +changes in this branch specific to one bug or feature so it is clear +what the branch brings to *xarray*. You can have many "shiny-new-features" +and switch in between them using the ``git checkout`` command. + +To update this branch, you need to retrieve the changes from the master branch:: + + git fetch upstream + git rebase upstream/master + +This will replay your commits on top of the latest *xarray* git master. If this +leads to merge conflicts, you must resolve these before submitting your pull +request. If you have uncommitted changes, you will need to ``git stash`` them +prior to updating. This will effectively store your changes and they can be +reapplied after updating. + +.. _contributing.documentation: + +Contributing to the documentation +================================= + +If you're not the developer type, contributing to the documentation is still of +huge value. You don't even have to be an expert on *xarray* to do so! In fact, +there are sections of the docs that are worse off after being written by +experts. If something in the docs doesn't make sense to you, updating the +relevant section after you figure it out is a great way to ensure it will help +the next person. + +.. contents:: Documentation: + :local: + + +About the *xarray* documentation +-------------------------------- + +The documentation is written in **reStructuredText**, which is almost like writing +in plain English, and built using `Sphinx `__. The +Sphinx Documentation has an excellent `introduction to reST +`__. Review the Sphinx docs to perform more +complex changes to the documentation as well. + +Some other important things to know about the docs: + +- The *xarray* documentation consists of two parts: the docstrings in the code + itself and the docs in this folder ``xarray/doc/``. + + The docstrings are meant to provide a clear explanation of the usage of the + individual functions, while the documentation in this folder consists of + tutorial-like overviews per topic together with some other information + (what's new, installation, etc). + +- The docstrings follow the **Numpy Docstring Standard**, which is used widely + in the Scientific Python community. This standard specifies the format of + the different sections of the docstring. See `this document + `_ + for a detailed explanation, or look at some of the existing functions to + extend it in a similar manner. + +- The tutorials make heavy use of the `ipython directive + `_ sphinx extension. + This directive lets you put code in the documentation which will be run + during the doc build. For example:: + + .. ipython:: python + + x = 2 + x**3 + + will be rendered as:: + + In [1]: x = 2 + + In [2]: x**3 + Out[2]: 8 + + Almost all code examples in the docs are run (and the output saved) during the + doc build. This approach means that code examples will always be up to date, + but it does make the doc building a bit more complex. + +- Our API documentation in ``doc/api.rst`` houses the auto-generated + documentation from the docstrings. For classes, there are a few subtleties + around controlling which methods and attributes have pages auto-generated. + + Every method should be included in a ``toctree`` in ``api.rst``, else Sphinx + will emit a warning. + + +How to build the *xarray* documentation +--------------------------------------- + +Requirements +~~~~~~~~~~~~ + +First, you need to have a development environment to be able to build xarray +(see the docs on :ref:`creating a development environment above `). + +Building the documentation +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In your development environment, install ``sphinx``, ``sphinx_rtd_theme``, +``sphinx-gallery`` and ``numpydoc``:: + + conda install -c conda-forge sphinx sphinx_rtd_theme sphinx-gallery numpydoc + +Navigate to your local ``xarray/doc/`` directory in the console and run:: + + make html + +Then you can find the HTML output in the folder ``xarray/doc/_build/html/``. + +The first time you build the docs, it will take quite a while because it has to run +all the code examples and build all the generated docstring pages. In subsequent +evocations, sphinx will try to only build the pages that have been modified. + +If you want to do a full clean build, do:: + + make clean + make html + +.. _contributing.code: + +Contributing to the code base +============================= + +.. contents:: Code Base: + :local: + +Code standards +-------------- + +Writing good code is not just about what you write. It is also about *how* you +write it. During :ref:`Continuous Integration ` testing, several +tools will be run to check your code for stylistic errors. +Generating any warnings will cause the test to fail. +Thus, good style is a requirement for submitting code to *xarray*. + +In addition, because a lot of people use our library, it is important that we +do not make sudden changes to the code that could have the potential to break +a lot of user code as a result, that is, we need it to be as *backwards compatible* +as possible to avoid mass breakages. + +Python (PEP8) +~~~~~~~~~~~~~ + +*xarray* uses the `PEP8 `_ standard. +There are several tools to ensure you abide by this standard. Here are *some* of +the more common ``PEP8`` issues: + + - we restrict line-length to 79 characters to promote readability + - passing arguments should have spaces after commas, e.g. ``foo(arg1, arg2, kw1='bar')`` + +:ref:`Continuous Integration ` will run +the `flake8 `_ tool +and report any stylistic errors in your code. Therefore, it is helpful before +submitting code to run the check yourself:: + + flake8 + +If you install `isort `_ and +`flake8-isort `_, this will also show +any errors from incorrectly sorted imports. These aren't currently enforced in +CI. To automatically sort imports, you can run:: + + isort -y + + +Backwards Compatibility +~~~~~~~~~~~~~~~~~~~~~~~ + +Please try to maintain backward compatibility. *xarray* has growing number of users with +lots of existing code, so don't break it if at all possible. If you think breakage is +required, clearly state why as part of the pull request. Also, be careful when changing +method signatures and add deprecation warnings where needed. Also, add the deprecated +sphinx directive to the deprecated functions or methods. + +.. _contributing.ci: + +Testing With Continuous Integration +----------------------------------- + +The *xarray* test suite will run automatically on `Travis-CI `__, +and `Appveyor `__, continuous integration services, once +your pull request is submitted. However, if you wish to run the test suite on a +branch prior to submitting the pull request, then the continuous integration +services need to be hooked to your GitHub repository. Instructions are here +for `Travis-CI `__, and +`Appveyor `__. + +A pull-request will be considered for merging when you have an all 'green' build. If any +tests are failing, then you will get a red 'X', where you can click through to see the +individual failed tests. This is an example of a green build. + +.. image:: _static/ci.png + +.. note:: + + Each time you push to your PR branch, a new run of the tests will be triggered on the CI. + Appveyor will auto-cancel any non-currently-running tests for that same pull-request. + You can also enable the auto-cancel feature for `Travis-CI here + `__. + +.. _contributing.tdd: + + +Test-driven development/code writing +------------------------------------ + +*xarray* is serious about testing and strongly encourages contributors to embrace +`test-driven development (TDD) `_. +This development process "relies on the repetition of a very short development cycle: +first the developer writes an (initially failing) automated test case that defines a desired +improvement or new function, then produces the minimum amount of code to pass that test." +So, before actually writing any code, you should write your tests. Often the test can be +taken from the original GitHub issue. However, it is always worth considering additional +use cases and writing corresponding tests. + +Adding tests is one of the most common requests after code is pushed to *xarray*. Therefore, +it is worth getting in the habit of writing tests ahead of time so this is never an issue. + +Like many packages, *xarray* uses `pytest +`_ and the convenient +extensions in `numpy.testing +`_. + +Writing tests +~~~~~~~~~~~~~ + +All tests should go into the ``tests`` subdirectory of the specific package. +This folder contains many current examples of tests, and we suggest looking to these for +inspiration. If your test requires working with files or +network connectivity, there is more information on the `testing page +`_ of the wiki. + +The ``xarray.testing`` module has many special ``assert`` functions that +make it easier to make statements about whether DataArray or Dataset objects are +equivalent. The easiest way to verify that your code is correct is to +explicitly construct the result you expect, then compare the actual result to +the expected correct result:: + + def test_constructor_from_0d(self): + expected = Dataset({None: ([], 0)})[None] + actual = DataArray(0) + assert_identical(expected, actual) + +Transitioning to ``pytest`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +*xarray* existing test structure is *mostly* classed based, meaning that you will +typically find tests wrapped in a class. + +.. code-block:: python + + class TestReallyCoolFeature(object): + .... + +Going forward, we are moving to a more *functional* style using the +`pytest `__ framework, which offers a richer +testing framework that will facilitate testing and developing. Thus, instead of +writing test classes, we will write test functions like this: + +.. code-block:: python + + def test_really_cool_feature(): + .... + +Using ``pytest`` +~~~~~~~~~~~~~~~~ + +Here is an example of a self-contained set of tests that illustrate multiple +features that we like to use. + +- functional style: tests are like ``test_*`` and *only* take arguments that are either + fixtures or parameters +- ``pytest.mark`` can be used to set metadata on test functions, e.g. ``skip`` or ``xfail``. +- using ``parametrize``: allow testing of multiple cases +- to set a mark on a parameter, ``pytest.param(..., marks=...)`` syntax should be used +- ``fixture``, code for object construction, on a per-test basis +- using bare ``assert`` for scalars and truth-testing +- ``tm.assert_series_equal`` (and its counter part ``tm.assert_frame_equal``), for xarray + object comparisons. +- the typical pattern of constructing an ``expected`` and comparing versus the ``result`` + +We would name this file ``test_cool_feature.py`` and put in an appropriate place in the +``xarray/tests/`` structure. + +.. TODO: confirm that this actually works + +.. code-block:: python + + import pytest + import numpy as np + import xarray as xr + from xarray.testing import assert_equal + + + @pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) + def test_dtypes(dtype): + assert str(np.dtype(dtype)) == dtype + + + @pytest.mark.parametrize('dtype', ['float32', + pytest.param('int16', marks=pytest.mark.skip), + pytest.param('int32', marks=pytest.mark.xfail( + reason='to show how it works'))]) + def test_mark(dtype): + assert str(np.dtype(dtype)) == 'float32' + + + @pytest.fixture + def dataarray(): + return xr.DataArray([1, 2, 3]) + + + @pytest.fixture(params=['int8', 'int16', 'int32', 'int64']) + def dtype(request): + return request.param + + + def test_series(dataarray, dtype): + result = dataarray.astype(dtype) + assert result.dtype == dtype + + expected = xr.DataArray(np.array([1, 2, 3], dtype=dtype)) + assert_equal(result, expected) + + + +A test run of this yields + +.. code-block:: shell + + ((xarray) $ pytest test_cool_feature.py -v + =============================== test session starts ================================ + platform darwin -- Python 3.6.4, pytest-3.2.1, py-1.4.34, pluggy-0.4.0 -- + cachedir: ../../.cache + plugins: cov-2.5.1, hypothesis-3.23.0 + collected 11 items + + test_cool_feature.py::test_dtypes[int8] PASSED + test_cool_feature.py::test_dtypes[int16] PASSED + test_cool_feature.py::test_dtypes[int32] PASSED + test_cool_feature.py::test_dtypes[int64] PASSED + test_cool_feature.py::test_mark[float32] PASSED + test_cool_feature.py::test_mark[int16] SKIPPED + test_cool_feature.py::test_mark[int32] xfail + test_cool_feature.py::test_series[int8] PASSED + test_cool_feature.py::test_series[int16] PASSED + test_cool_feature.py::test_series[int32] PASSED + test_cool_feature.py::test_series[int64] PASSED + + ================== 9 passed, 1 skipped, 1 xfailed in 1.83 seconds ================== + +Tests that we have ``parametrized`` are now accessible via the test name, for +example we could run these with ``-k int8`` to sub-select *only* those tests +which match ``int8``. + + +.. code-block:: shell + + ((xarray) bash-3.2$ pytest test_cool_feature.py -v -k int8 + =========================== test session starts =========================== + platform darwin -- Python 3.6.2, pytest-3.2.1, py-1.4.31, pluggy-0.4.0 + collected 11 items + + test_cool_feature.py::test_dtypes[int8] PASSED + test_cool_feature.py::test_series[int8] PASSED + + +Running the test suite +---------------------- + +The tests can then be run directly inside your Git clone (without having to +install *xarray*) by typing:: + + pytest xarray + +The tests suite is exhaustive and takes a few minutes. Often it is +worth running only a subset of tests first around your changes before running the +entire suite. + +The easiest way to do this is with:: + + pytest xarray/path/to/test.py -k regex_matching_test_name + +Or with one of the following constructs:: + + pytest xarray/tests/[test-module].py + pytest xarray/tests/[test-module].py::[TestClass] + pytest xarray/tests/[test-module].py::[TestClass]::[test_method] + +Using `pytest-xdist `_, one can +speed up local testing on multicore machines. To use this feature, you will +need to install `pytest-xdist` via:: + + pip install pytest-xdist + + +Then, run pytest with the optional -n argument: + + pytest xarray -n 4 + +This can significantly reduce the time it takes to locally run tests before +submitting a pull request. + +For more, see the `pytest `_ documentation. + +Running the performance test suite +---------------------------------- + +Performance matters and it is worth considering whether your code has introduced +performance regressions. *xarray* is starting to write a suite of benchmarking tests +using `asv `__ +to enable easy monitoring of the performance of critical *xarray* operations. +These benchmarks are all found in the ``xarray/asv_bench`` directory. asv +supports both python2 and python3. + +To use all features of asv, you will need either ``conda`` or +``virtualenv``. For more details please check the `asv installation +webpage `_. + +To install asv:: + + pip install git+https://github.com/spacetelescope/asv + +If you need to run a benchmark, change your directory to ``asv_bench/`` and run:: + + asv continuous -f 1.1 upstream/master HEAD + +You can replace ``HEAD`` with the name of the branch you are working on, +and report benchmarks that changed by more than 10%. +The command uses ``conda`` by default for creating the benchmark +environments. If you want to use virtualenv instead, write:: + + asv continuous -f 1.1 -E virtualenv upstream/master HEAD + +The ``-E virtualenv`` option should be added to all ``asv`` commands +that run benchmarks. The default value is defined in ``asv.conf.json``. + +Running the full benchmark suite can take up to one hour and use up a few GBs of RAM. +Usually it is sufficient to paste only a subset of the results into the pull +request to show that the committed changes do not cause unexpected performance +regressions. You can run specific benchmarks using the ``-b`` flag, which +takes a regular expression. For example, this will only run tests from a +``xarray/asv_bench/benchmarks/groupby.py`` file:: + + asv continuous -f 1.1 upstream/master HEAD -b ^groupby + +If you want to only run a specific group of tests from a file, you can do it +using ``.`` as a separator. For example:: + + asv continuous -f 1.1 upstream/master HEAD -b groupby.GroupByMethods + +will only run the ``GroupByMethods`` benchmark defined in ``groupby.py``. + +You can also run the benchmark suite using the version of *xarray* +already installed in your current Python environment. This can be +useful if you do not have ``virtualenv`` or ``conda``, or are using the +``setup.py develop`` approach discussed above; for the in-place build +you need to set ``PYTHONPATH``, e.g. +``PYTHONPATH="$PWD/.." asv [remaining arguments]``. +You can run benchmarks using an existing Python +environment by:: + + asv run -e -E existing + +or, to use a specific Python interpreter,:: + + asv run -e -E existing:python3.5 + +This will display stderr from the benchmarks, and use your local +``python`` that comes from your ``$PATH``. + +Information on how to write a benchmark and how to use asv can be found in the +`asv documentation `_. + +The *xarray* benchmarking suite is run remotely and the results are +available `here `_. + +Documenting your code +--------------------- + +Changes should be reflected in the release notes located in ``doc/whats-new.rst``. +This file contains an ongoing change log for each release. Add an entry to this file to +document your fix, enhancement or (unavoidable) breaking change. Make sure to include the +GitHub issue number when adding your entry (using ``:issue:`1234```, where ``1234`` is the +issue/pull request number). + +If your code is an enhancement, it is most likely necessary to add usage +examples to the existing documentation. This can be done following the section +regarding documentation :ref:`above `. + +Contributing your changes to *xarray* +===================================== + +Committing your code +-------------------- + +Keep style fixes to a separate commit to make your pull request more readable. + +Once you've made changes, you can see them by typing:: + + git status + +If you have created a new file, it is not being tracked by git. Add it by typing:: + + git add path/to/file-to-be-added.py + +Doing 'git status' again should give something like:: + + # On branch shiny-new-feature + # + # modified: /relative/path/to/file-you-added.py + # + +Finally, commit your changes to your local repository with an explanatory message. +*Xarray* uses a convention for commit message prefixes and layout. Here are +some common prefixes along with general guidelines for when to use them: + + * ``ENH``: Enhancement, new functionality + * ``BUG``: Bug fix + * ``DOC``: Additions/updates to documentation + * ``TST``: Additions/updates to tests + * ``BLD``: Updates to the build process/scripts + * ``PERF``: Performance improvement + * ``CLN``: Code cleanup + +The following defines how a commit message should be structured: + + * A subject line with `< 72` chars. + * One blank line. + * Optionally, a commit message body. + +Please reference the relevant GitHub issues in your commit message using ``GH1234`` or +``#1234``. Either style is fine, but the former is generally preferred. + +Now you can commit your changes in your local repository:: + + git commit -m + +Pushing your changes +-------------------- + +When you want your changes to appear publicly on your GitHub page, push your +forked feature branch's commits:: + + git push origin shiny-new-feature + +Here ``origin`` is the default name given to your remote repository on GitHub. +You can see the remote repositories:: + + git remote -v + +If you added the upstream repository as described above you will see something +like:: + + origin git@github.com:yourname/xarray.git (fetch) + origin git@github.com:yourname/xarray.git (push) + upstream git://github.com/pydata/xarray.git (fetch) + upstream git://github.com/pydata/xarray.git (push) + +Now your code is on GitHub, but it is not yet a part of the *xarray* project. For that to +happen, a pull request needs to be submitted on GitHub. + +Review your code +---------------- + +When you're ready to ask for a code review, file a pull request. Before you do, once +again make sure that you have followed all the guidelines outlined in this document +regarding code style, tests, performance tests, and documentation. You should also +double check your branch changes against the branch it was based on: + +#. Navigate to your repository on GitHub -- https://github.com/your-user-name/xarray +#. Click on ``Branches`` +#. Click on the ``Compare`` button for your feature branch +#. Select the ``base`` and ``compare`` branches, if necessary. This will be ``master`` and + ``shiny-new-feature``, respectively. + +Finally, make the pull request +------------------------------ + +If everything looks good, you are ready to make a pull request. A pull request is how +code from a local repository becomes available to the GitHub community and can be looked +at and eventually merged into the master version. This pull request and its associated +changes will eventually be committed to the master branch and available in the next +release. To submit a pull request: + +#. Navigate to your repository on GitHub +#. Click on the ``Pull Request`` button +#. You can then click on ``Commits`` and ``Files Changed`` to make sure everything looks + okay one last time +#. Write a description of your changes in the ``Preview Discussion`` tab +#. Click ``Send Pull Request``. + +This request then goes to the repository maintainers, and they will review +the code. If you need to make more changes, you can make them in +your branch, add them to a new commit, push them to GitHub, and the pull request +will be automatically updated. Pushing them to GitHub again is done by:: + + git push origin shiny-new-feature + +This will automatically update your pull request with the latest code and restart the +:ref:`Continuous Integration ` tests. + + +Delete your merged branch (optional) +------------------------------------ + +Once your feature branch is accepted into upstream, you'll probably want to get rid of +the branch. First, merge upstream master into your branch so git knows it is safe to +delete your branch:: + + git fetch upstream + git checkout master + git merge upstream/master + +Then you can do:: + + git branch -d shiny-new-feature + +Make sure you use a lower-case ``-d``, or else git won't warn you if your feature +branch has not actually been merged. + +The branch will still exist on GitHub, so to delete it there do:: + + git push origin --delete shiny-new-feature diff --git a/doc/dask.rst b/doc/dask.rst index 65ebd643e1e..672450065cb 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -13,7 +13,7 @@ dependency in a future version of xarray. For a full example of how to use xarray's dask integration, read the `blog post introducing xarray and dask`_. -.. _blog post introducing xarray and dask: https://www.anaconda.com/blog/developer-blog/xray-dask-out-core-labeled-arrays-python/ +.. _blog post introducing xarray and dask: http://stephanhoyer.com/2015/06/11/xray-dask-out-of-core-labeled-arrays/ What is a dask array? --------------------- @@ -49,7 +49,7 @@ argument to :py:func:`~xarray.open_dataset` or using the :py:func:`~xarray.open_mfdataset` function. .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd @@ -100,6 +100,29 @@ Once you've manipulated a dask array, you can still write a dataset too big to fit into memory back to disk by using :py:meth:`~xarray.Dataset.to_netcdf` in the usual way. +.. ipython:: python + + ds.to_netcdf('manipulated-example-data.nc') + +By setting the ``compute`` argument to ``False``, :py:meth:`~xarray.Dataset.to_netcdf` +will return a dask delayed object that can be computed later. + +.. ipython:: python + + from dask.diagnostics import ProgressBar + # or distributed.progress when using the distributed scheduler + delayed_obj = ds.to_netcdf('manipulated-example-data.nc', compute=False) + with ProgressBar(): + results = delayed_obj.compute() + +.. note:: + + When using dask's distributed scheduler to write NETCDF4 files, + it may be necessary to set the environment variable `HDF5_USE_FILE_LOCKING=FALSE` + to avoid competing locks within the HDF5 SWMR file locking scheme. Note that + writing netCDF files with dask's distributed scheduler is only supported for + the `netcdf4` backend. + A dataset can also be converted to a dask DataFrame using :py:meth:`~xarray.Dataset.to_dask_dataframe`. .. ipython:: python diff --git a/doc/data-structures.rst b/doc/data-structures.rst index 10d83ca448f..618ccccff3e 100644 --- a/doc/data-structures.rst +++ b/doc/data-structures.rst @@ -408,13 +408,6 @@ operations keep around coordinates: list(ds[['x']]) list(ds.drop('temperature')) -If a dimension name is given as an argument to ``drop``, it also drops all -variables that use that dimension: - -.. ipython:: python - - list(ds.drop('time')) - As an alternate to dictionary-like modifications, you can use :py:meth:`~xarray.Dataset.assign` and :py:meth:`~xarray.Dataset.assign_coords`. These methods return a new dataset with additional (or replaced) or values: diff --git a/doc/environment.yml b/doc/environment.yml index b14fba351c1..bd134a7656f 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -1,20 +1,23 @@ name: xarray-docs channels: - conda-forge - - defaults dependencies: - - python=3.5 - - numpy=1.11.2 - - pandas=0.21.0 - - numpydoc=0.6.0 - - matplotlib=2.0.0 - - seaborn=0.8 - - dask=0.16.0 - - ipython=5.1.0 - - sphinx=1.5 - - netCDF4=1.3.1 - - cartopy=0.15.1 - - rasterio=0.36.0 - - sphinx-gallery - - zarr - - iris + - python=3.6 + - numpy=1.14.5 + - pandas=0.23.3 + - scipy=1.1.0 + - matplotlib=2.2.2 + - seaborn=0.9.0 + - dask=0.18.2 + - ipython=6.4.0 + - netCDF4=1.4.0 + - cartopy=0.16.0 + - rasterio=1.0.1 + - zarr=2.2.0 + - iris=2.1.0 + - flake8=3.5.0 + - cftime=1.0.0 + - bottleneck=1.2 + - sphinx=1.7.6 + - numpydoc=0.8.0 + - sphinx-gallery=0.2.0 diff --git a/doc/examples/_code/accessor_example.py b/doc/examples/_code/accessor_example.py index 1c846b38687..a11ebf9329b 100644 --- a/doc/examples/_code/accessor_example.py +++ b/doc/examples/_code/accessor_example.py @@ -1,5 +1,6 @@ import xarray as xr + @xr.register_dataset_accessor('geo') class GeoAccessor(object): def __init__(self, xarray_obj): diff --git a/doc/examples/_code/weather_data_setup.py b/doc/examples/_code/weather_data_setup.py index a6190ad3cfe..89470542d5a 100644 --- a/doc/examples/_code/weather_data_setup.py +++ b/doc/examples/_code/weather_data_setup.py @@ -1,7 +1,8 @@ -import xarray as xr import numpy as np import pandas as pd -import seaborn as sns # pandas aware plotting library +import seaborn as sns # pandas aware plotting library + +import xarray as xr np.random.seed(123) diff --git a/doc/examples/multidimensional-coords.rst b/doc/examples/multidimensional-coords.rst index a54e6058921..eed818ba064 100644 --- a/doc/examples/multidimensional-coords.rst +++ b/doc/examples/multidimensional-coords.rst @@ -3,7 +3,7 @@ Working with Multidimensional Coordinates ========================================= -Author: `Ryan Abernathey `__ +Author: `Ryan Abernathey `__ Many datasets have *physical coordinates* which differ from their *logical coordinates*. Xarray provides several ways to plot and analyze diff --git a/doc/faq.rst b/doc/faq.rst index 68670d0f5a4..44bc021024b 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -1,3 +1,5 @@ +.. _faq: + Frequently Asked Questions ========================== @@ -117,7 +119,8 @@ conventions`_. (An exception is serialization to and from netCDF files.) An implication of this choice is that we do not propagate ``attrs`` through most operations unless explicitly flagged (some methods have a ``keep_attrs`` -option). Similarly, xarray does not check for conflicts between ``attrs`` when +option, and there is a global flag for setting this to be always True or +False). Similarly, xarray does not check for conflicts between ``attrs`` when combining arrays and datasets, unless explicitly requested with the option ``compat='identical'``. The guiding principle is that metadata should not be allowed to get in the way. @@ -129,8 +132,8 @@ What other netCDF related Python libraries should I know about? `netCDF4-python`__ provides a lower level interface for working with netCDF and OpenDAP datasets in Python. We use netCDF4-python internally in xarray, and have contributed a number of improvements and fixes upstream. xarray -does not yet support all of netCDF4-python's features, such as writing to -netCDF groups or modifying files on-disk. +does not yet support all of netCDF4-python's features, such as modifying files +on-disk. __ https://github.com/Unidata/netcdf4-python @@ -153,9 +156,15 @@ __ http://drclimate.wordpress.com/2014/01/02/a-beginners-guide-to-scripting-with We think the design decisions we have made for xarray (namely, basing it on pandas) make it a faster and more flexible data analysis tool. That said, Iris -and CDAT have some great domain specific functionality, and we would love to -have support for converting their native objects to and from xarray (see -:issue:`37` and :issue:`133`) +and CDAT have some great domain specific functionality, and xarray includes +methods for converting back and forth between xarray and these libraries. See +:py:meth:`~xarray.DataArray.to_iris` and :py:meth:`~xarray.DataArray.to_cdms2` +for more details. + +What other projects leverage xarray? +------------------------------------ + +See section :ref:`related-projects`. How should I cite xarray? ------------------------- @@ -199,5 +208,5 @@ would certainly appreciate it. We recommend two citations. month = aug, year = 2016, doi = {10.5281/zenodo.59499}, - url = {http://dx.doi.org/10.5281/zenodo.59499} + url = {https://doi.org/10.5281/zenodo.59499} } diff --git a/doc/gallery/README.txt b/doc/gallery/README.txt index 242c4f7dc91..b17f803696b 100644 --- a/doc/gallery/README.txt +++ b/doc/gallery/README.txt @@ -1,5 +1,5 @@ .. _recipes: -Recipes +Gallery ======= diff --git a/doc/gallery/plot_cartopy_facetgrid.py b/doc/gallery/plot_cartopy_facetgrid.py index 525ae7054b0..3eded115263 100644 --- a/doc/gallery/plot_cartopy_facetgrid.py +++ b/doc/gallery/plot_cartopy_facetgrid.py @@ -12,12 +12,15 @@ For more details see `this discussion`_ on github. .. _this discussion: https://github.com/pydata/xarray/issues/1397#issuecomment-299190567 -""" +""" # noqa + +from __future__ import division -import xarray as xr import cartopy.crs as ccrs import matplotlib.pyplot as plt +import xarray as xr + # Load the data ds = xr.tutorial.load_dataset('air_temperature') air = ds.air.isel(time=[0, 724]) - 273.15 @@ -27,7 +30,7 @@ p = air.plot(transform=ccrs.PlateCarree(), # the data's projection col='time', col_wrap=1, # multiplot settings - aspect=ds.dims['lon']/ds.dims['lat'], # for a sensible figsize + aspect=ds.dims['lon'] / ds.dims['lat'], # for a sensible figsize subplot_kws={'projection': map_proj}) # the plot's projection # We have to set the map's options on all four axes diff --git a/doc/gallery/plot_colorbar_center.py b/doc/gallery/plot_colorbar_center.py index 00c25af50d4..4818b737632 100644 --- a/doc/gallery/plot_colorbar_center.py +++ b/doc/gallery/plot_colorbar_center.py @@ -8,33 +8,34 @@ """ -import xarray as xr import matplotlib.pyplot as plt +import xarray as xr + # Load the data ds = xr.tutorial.load_dataset('air_temperature') air = ds.air.isel(time=0) -f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6)) +f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6)) # The first plot (in kelvins) chooses "viridis" and uses the data's min/max -air.plot(ax=ax1, cbar_kwargs={'label':'K'}) +air.plot(ax=ax1, cbar_kwargs={'label': 'K'}) ax1.set_title('Kelvins: default') ax2.set_xlabel('') # The second plot (in celsius) now chooses "BuRd" and centers min/max around 0 airc = air - 273.15 -airc.plot(ax=ax2, cbar_kwargs={'label':'°C'}) +airc.plot(ax=ax2, cbar_kwargs={'label': '°C'}) ax2.set_title('Celsius: default') ax2.set_xlabel('') ax2.set_ylabel('') # The center doesn't have to be 0 -air.plot(ax=ax3, center=273.15, cbar_kwargs={'label':'K'}) +air.plot(ax=ax3, center=273.15, cbar_kwargs={'label': 'K'}) ax3.set_title('Kelvins: center=273.15') # Or it can be ignored -airc.plot(ax=ax4, center=False, cbar_kwargs={'label':'°C'}) +airc.plot(ax=ax4, center=False, cbar_kwargs={'label': '°C'}) ax4.set_title('Celsius: center=False') ax4.set_ylabel('') diff --git a/doc/gallery/plot_control_colorbar.py b/doc/gallery/plot_control_colorbar.py new file mode 100644 index 00000000000..5802a57cf31 --- /dev/null +++ b/doc/gallery/plot_control_colorbar.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +""" +=========================== +Control the plot's colorbar +=========================== + +Use ``cbar_kwargs`` keyword to specify the number of ticks. +The ``spacing`` kwarg can be used to draw proportional ticks. +""" +import matplotlib.pyplot as plt + +import xarray as xr + +# Load the data +air_temp = xr.tutorial.load_dataset('air_temperature') +air2d = air_temp.air.isel(time=500) + +# Prepare the figure +f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 4)) + +# Irregular levels to illustrate the use of a proportional colorbar +levels = [245, 250, 255, 260, 265, 270, 275, 280, 285, 290, 310, 340] + +# Plot data +air2d.plot(ax=ax1, levels=levels) +air2d.plot(ax=ax2, levels=levels, cbar_kwargs={'ticks': levels}) +air2d.plot(ax=ax3, levels=levels, cbar_kwargs={'ticks': levels, + 'spacing': 'proportional'}) + +# Show plots +plt.tight_layout() +plt.show() diff --git a/doc/gallery/plot_lines_from_2d.py b/doc/gallery/plot_lines_from_2d.py index 1e5875ea70e..93d7770238e 100644 --- a/doc/gallery/plot_lines_from_2d.py +++ b/doc/gallery/plot_lines_from_2d.py @@ -12,9 +12,10 @@ """ -import xarray as xr import matplotlib.pyplot as plt +import xarray as xr + # Load the data ds = xr.tutorial.load_dataset('air_temperature') air = ds.air - 273.15 # to celsius diff --git a/doc/gallery/plot_rasterio.py b/doc/gallery/plot_rasterio.py index 2ec58b884eb..98801990af3 100644 --- a/doc/gallery/plot_rasterio.py +++ b/doc/gallery/plot_rasterio.py @@ -13,17 +13,18 @@ These new coordinates might be handy for plotting and indexing, but it should be kept in mind that a grid which is regular in projection coordinates will likely be irregular in lon/lat. It is often recommended to work in the data's -original map projection. +original map projection (see :ref:`recipes.rasterio_rgb`). """ import os import urllib.request -import numpy as np -import xarray as xr + import cartopy.crs as ccrs import matplotlib.pyplot as plt +import numpy as np from rasterio.warp import transform +import xarray as xr # Download the file from rasterio's repository url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif' @@ -44,10 +45,13 @@ da.coords['lon'] = (('y', 'x'), lon) da.coords['lat'] = (('y', 'x'), lat) +# Compute a greyscale out of the rgb image +greyscale = da.mean(dim='band') + # Plot on a map ax = plt.subplot(projection=ccrs.PlateCarree()) -da.plot.imshow(ax=ax, x='lon', y='lat', rgb='band', - transform=ccrs.PlateCarree()) +greyscale.plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree(), + cmap='Greys_r', add_colorbar=False) ax.coastlines('10m', color='r') plt.show() diff --git a/doc/gallery/plot_rasterio_rgb.py b/doc/gallery/plot_rasterio_rgb.py new file mode 100644 index 00000000000..2733bf149e5 --- /dev/null +++ b/doc/gallery/plot_rasterio_rgb.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +""" +.. _recipes.rasterio_rgb: + +============================ +imshow() and map projections +============================ + +Using rasterio's projection information for more accurate plots. + +This example extends :ref:`recipes.rasterio` and plots the image in the +original map projection instead of relying on pcolormesh and a map +transformation. +""" + +import os +import urllib.request + +import cartopy.crs as ccrs +import matplotlib.pyplot as plt + +import xarray as xr + +# Download the file from rasterio's repository +url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif' +urllib.request.urlretrieve(url, 'RGB.byte.tif') + +# Read the data +da = xr.open_rasterio('RGB.byte.tif') + +# The data is in UTM projection. We have to set it manually until +# https://github.com/SciTools/cartopy/issues/813 is implemented +crs = ccrs.UTM('18N') + +# Plot on a map +ax = plt.subplot(projection=crs) +da.plot.imshow(ax=ax, rgb='band', transform=crs) +ax.coastlines('10m', color='r') +plt.show() + +# Delete the file +os.remove('RGB.byte.tif') diff --git a/doc/groupby.rst b/doc/groupby.rst index 4851cbe5dcc..6e42dbbc9f0 100644 --- a/doc/groupby.rst +++ b/doc/groupby.rst @@ -207,3 +207,12 @@ may be desirable: .. ipython:: python da.groupby_bins('lon', [0,45,50]).sum() + +These methods group by `lon` values. It is also possible to groupby each +cell in a grid, regardless of value, by stacking multiple dimensions, +applying your function, and then unstacking the result: + +.. ipython:: python + + stacked = da.stack(gridcell=['ny', 'nx']) + stacked.groupby('gridcell').sum().unstack('gridcell') diff --git a/doc/index.rst b/doc/index.rst index 607dce2ed50..45897f4bccb 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,12 +1,5 @@ - -.. image:: _static/dataset-diagram-logo.png - :width: 300 px - :align: center - -| - -N-D labeled arrays and datasets in Python -========================================= +xarray: N-D labeled arrays and datasets in Python +================================================= **xarray** (formerly **xray**) is an open source project and Python package that aims to bring the labeled data power of pandas_ to the physical sciences, @@ -18,12 +11,6 @@ pandas excels. Our approach adopts the `Common Data Model`_ for self- describing scientific data in widespread use in the Earth sciences: ``xarray.Dataset`` is an in-memory representation of a netCDF file. -.. note:: - - xray is now xarray! See :ref:`the v0.7.0 release notes` - for more details. The preferred URL for these docs is now - http://xarray.pydata.org. - .. _pandas: http://pandas.pydata.org .. _Common Data Model: http://www.unidata.ucar.edu/software/thredds/current/netcdf-java/CDM .. _netCDF: http://www.unidata.ucar.edu/software/netcdf @@ -32,16 +19,46 @@ describing scientific data in widespread use in the Earth sciences: Documentation ------------- +**Getting Started** + +* :doc:`why-xarray` +* :doc:`faq` +* :doc:`examples` +* :doc:`installing` + .. toctree:: :maxdepth: 1 + :hidden: + :caption: Getting Started - whats-new why-xarray faq examples installing + +**User Guide** + +* :doc:`data-structures` +* :doc:`indexing` +* :doc:`interpolation` +* :doc:`computation` +* :doc:`groupby` +* :doc:`reshaping` +* :doc:`combining` +* :doc:`time-series` +* :doc:`pandas` +* :doc:`io` +* :doc:`dask` +* :doc:`plotting` + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: User Guide + data-structures indexing + interpolation computation groupby reshaping @@ -51,8 +68,27 @@ Documentation io dask plotting + +**Help & reference** + +* :doc:`whats-new` +* :doc:`api` +* :doc:`internals` +* :doc:`roadmap` +* :doc:`contributing` +* :doc:`related-projects` + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Help & reference + + whats-new api internals + roadmap + contributing + related-projects See also -------- @@ -84,12 +120,20 @@ Get in touch .. _mailing list: https://groups.google.com/forum/#!forum/xarray .. _on GitHub: http://github.com/pydata/xarray -License -------- +NumFOCUS +-------- -xarray is available under the open source `Apache License`__. +.. image:: _static/numfocus_logo.png + :scale: 50 % + :target: https://numfocus.org/ + +Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated +to supporting the open source scientific computing community. If you like +Xarray and want to support our mission, please consider making a donation_ +to support our efforts. + +.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU= -__ http://www.apache.org/licenses/LICENSE-2.0.html History ------- @@ -97,6 +141,15 @@ History xarray is an evolution of an internal tool developed at `The Climate Corporation`__. It was originally written by Climate Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in -May 2014. The project was renamed from "xray" in January 2016. +May 2014. The project was renamed from "xray" in January 2016. Xarray became a +fiscally sponsored project of NumFOCUS_ in August 2018. __ http://climate.com/ +.. _NumFOCUS: https://numfocus.org + +License +------- + +xarray is available under the open source `Apache License`__. + +__ http://www.apache.org/licenses/LICENSE-2.0.html diff --git a/doc/indexing.rst b/doc/indexing.rst index 6b01471ecfb..3878d983cf6 100644 --- a/doc/indexing.rst +++ b/doc/indexing.rst @@ -35,15 +35,15 @@ below and summarized in this table: +------------------+--------------+---------------------------------+--------------------------------+ | Dimension lookup | Index lookup | ``DataArray`` syntax | ``Dataset`` syntax | +==================+==============+=================================+================================+ -| Positional | By integer | ``arr[:, 0]`` | *not available* | +| Positional | By integer | ``da[:, 0]`` | *not available* | +------------------+--------------+---------------------------------+--------------------------------+ -| Positional | By label | ``arr.loc[:, 'IA']`` | *not available* | +| Positional | By label | ``da.loc[:, 'IA']`` | *not available* | +------------------+--------------+---------------------------------+--------------------------------+ -| By name | By integer | ``arr.isel(space=0)`` or |br| | ``ds.isel(space=0)`` or |br| | -| | | ``arr[dict(space=0)]`` | ``ds[dict(space=0)]`` | +| By name | By integer | ``da.isel(space=0)`` or |br| | ``ds.isel(space=0)`` or |br| | +| | | ``da[dict(space=0)]`` | ``ds[dict(space=0)]`` | +------------------+--------------+---------------------------------+--------------------------------+ -| By name | By label | ``arr.sel(space='IA')`` or |br| | ``ds.sel(space='IA')`` or |br| | -| | | ``arr.loc[dict(space='IA')]`` | ``ds.loc[dict(space='IA')]`` | +| By name | By label | ``da.sel(space='IA')`` or |br| | ``ds.sel(space='IA')`` or |br| | +| | | ``da.loc[dict(space='IA')]`` | ``ds.loc[dict(space='IA')]`` | +------------------+--------------+---------------------------------+--------------------------------+ More advanced indexing is also possible for all the methods by @@ -60,19 +60,19 @@ DataArray: .. ipython:: python - arr = xr.DataArray(np.random.rand(4, 3), - [('time', pd.date_range('2000-01-01', periods=4)), - ('space', ['IA', 'IL', 'IN'])]) - arr[:2] - arr[0, 0] - arr[:, [2, 1]] + da = xr.DataArray(np.random.rand(4, 3), + [('time', pd.date_range('2000-01-01', periods=4)), + ('space', ['IA', 'IL', 'IN'])]) + da[:2] + da[0, 0] + da[:, [2, 1]] Attributes are persisted in all indexing operations. .. warning:: Positional indexing deviates from the NumPy when indexing with multiple - arrays like ``arr[[0, 1], [0, 1]]``, as described in + arrays like ``da[[0, 1], [0, 1]]``, as described in :ref:`vectorized_indexing`. xarray also supports label-based indexing, just like pandas. Because @@ -81,7 +81,7 @@ fast. To do label based indexing, use the :py:attr:`~xarray.DataArray.loc` attri .. ipython:: python - arr.loc['2000-01-01':'2000-01-02', 'IA'] + da.loc['2000-01-01':'2000-01-02', 'IA'] In this example, the selected is a subpart of the array in the range '2000-01-01':'2000-01-02' along the first coordinate `time` @@ -98,8 +98,8 @@ Setting values with label based indexing is also supported: .. ipython:: python - arr.loc['2000-01-01', ['IL', 'IN']] = -10 - arr + da.loc['2000-01-01', ['IL', 'IN']] = -10 + da Indexing with dimension names @@ -114,10 +114,10 @@ use them explicitly to slice data. There are two ways to do this: .. ipython:: python # index by integer array indices - arr[dict(space=0, time=slice(None, 2))] + da[dict(space=0, time=slice(None, 2))] # index by dimension coordinate labels - arr.loc[dict(time=slice('2000-01-01', '2000-01-02'))] + da.loc[dict(time=slice('2000-01-01', '2000-01-02'))] 2. Use the :py:meth:`~xarray.DataArray.sel` and :py:meth:`~xarray.DataArray.isel` convenience methods: @@ -125,10 +125,10 @@ use them explicitly to slice data. There are two ways to do this: .. ipython:: python # index by integer array indices - arr.isel(space=0, time=slice(None, 2)) + da.isel(space=0, time=slice(None, 2)) # index by dimension coordinate labels - arr.sel(time=slice('2000-01-01', '2000-01-02')) + da.sel(time=slice('2000-01-01', '2000-01-02')) The arguments to these methods can be any objects that could index the array along the dimension given by the keyword, e.g., labels for an individual value, @@ -138,7 +138,7 @@ Python :py:func:`slice` objects or 1-dimensional arrays. We would love to be able to do indexing with labeled dimension names inside brackets, but unfortunately, Python `does yet not support`__ indexing with - keyword arguments like ``arr[space=0]`` + keyword arguments like ``da[space=0]`` __ http://legacy.python.org/dev/peps/pep-0472/ @@ -156,16 +156,16 @@ enabling nearest neighbor (inexact) lookups by use of the methods ``'pad'``, .. ipython:: python - data = xr.DataArray([1, 2, 3], [('x', [0, 1, 2])]) - data.sel(x=[1.1, 1.9], method='nearest') - data.sel(x=0.1, method='backfill') - data.reindex(x=[0.5, 1, 1.5, 2, 2.5], method='pad') + da = xr.DataArray([1, 2, 3], [('x', [0, 1, 2])]) + da.sel(x=[1.1, 1.9], method='nearest') + da.sel(x=0.1, method='backfill') + da.reindex(x=[0.5, 1, 1.5, 2, 2.5], method='pad') Tolerance limits the maximum distance for valid matches with an inexact lookup: .. ipython:: python - data.reindex(x=[1.1, 1.5], method='nearest', tolerance=0.2) + da.reindex(x=[1.1, 1.5], method='nearest', tolerance=0.2) The method parameter is not yet supported if any of the arguments to ``.sel()`` is a ``slice`` object: @@ -173,7 +173,7 @@ to ``.sel()`` is a ``slice`` object: .. ipython:: :verbatim: - In [1]: data.sel(x=slice(1, 3), method='nearest') + In [1]: da.sel(x=slice(1, 3), method='nearest') NotImplementedError However, you don't need to use ``method`` to do inexact slicing. Slicing @@ -182,15 +182,23 @@ labels are monotonic increasing: .. ipython:: python - data.sel(x=slice(0.9, 3.1)) + da.sel(x=slice(0.9, 3.1)) Indexing axes with monotonic decreasing labels also works, as long as the ``slice`` or ``.loc`` arguments are also decreasing: .. ipython:: python - reversed_data = data[::-1] - reversed_data.loc[3.1:0.9] + reversed_da = da[::-1] + reversed_da.loc[3.1:0.9] + + +.. note:: + + If you want to interpolate along coordinates rather than looking up the + nearest neighbors, use :py:meth:`~xarray.Dataset.interp` and + :py:meth:`~xarray.Dataset.interp_like`. + See :ref:`interpolation ` for the details. Dataset indexing @@ -201,7 +209,10 @@ simultaneously, returning a new dataset: .. ipython:: python - ds = arr.to_dataset(name='foo') + da = xr.DataArray(np.random.rand(4, 3), + [('time', pd.date_range('2000-01-01', periods=4)), + ('space', ['IA', 'IL', 'IN'])]) + ds = da.to_dataset(name='foo') ds.isel(space=[0], time=[0]) ds.sel(time='2000-01-01') @@ -243,8 +254,8 @@ xarray, use :py:meth:`~xarray.DataArray.where`: .. ipython:: python - arr2 = xr.DataArray(np.arange(16).reshape(4, 4), dims=['x', 'y']) - arr2.where(arr2.x + arr2.y < 4) + da = xr.DataArray(np.arange(16).reshape(4, 4), dims=['x', 'y']) + da.where(da.x + da.y < 4) This is particularly useful for ragged indexing of multi-dimensional data, e.g., to apply a 2D mask to an image. Note that ``where`` follows all the @@ -254,7 +265,7 @@ usual xarray broadcasting and alignment rules for binary operations (e.g., .. ipython:: python - arr2.where(arr2.y < 2) + da.where(da.y < 2) By default ``where`` maintains the original size of the data. For cases where the selected data size is much smaller than the original data, @@ -263,8 +274,33 @@ elements that are fully masked: .. ipython:: python - arr2.where(arr2.y < 2, drop=True) + da.where(da.y < 2, drop=True) + +.. _selecting values with isin: + +Selecting values with ``isin`` +------------------------------ + +To check whether elements of an xarray object contain a single object, you can +compare with the equality operator ``==`` (e.g., ``arr == 3``). To check +multiple values, use :py:meth:`~xarray.DataArray.isin`: + +.. ipython:: python + + da = xr.DataArray([1, 2, 3, 4, 5], dims=['x']) + da.isin([2, 4]) + +:py:meth:`~xarray.DataArray.isin` works particularly well with +:py:meth:`~xarray.DataArray.where` to support indexing by arrays that are not +already labels of an array: + +.. ipython:: python + + lookup = xr.DataArray([-1, -2, -3, -4, -5], dims=['x']) + da.where(lookup.isin([-2, -4]), drop=True) +However, some caution is in order: when done repeatedly, this type of indexing +is significantly slower than using :py:meth:`~xarray.DataArray.sel`. .. _vectorized_indexing: @@ -339,8 +375,8 @@ These methods may and also be applied to ``Dataset`` objects .. ipython:: python - ds2 = da.to_dataset(name='bar') - ds2.isel(x=xr.DataArray([0, 1, 2], dims=['points'])) + ds = da.to_dataset(name='bar') + ds.isel(x=xr.DataArray([0, 1, 2], dims=['points'])) .. tip:: @@ -370,7 +406,37 @@ These methods may and also be applied to ``Dataset`` objects Assigning values with indexing ------------------------------ -Vectorized indexing can be used to assign values to xarray object. +To select and assign values to a portion of a :py:meth:`~xarray.DataArray` you +can use indexing with ``.loc`` : + +.. ipython:: python + + ds = xr.tutorial.open_dataset('air_temperature') + + #add an empty 2D dataarray + ds['empty']= xr.full_like(ds.air.mean('time'),fill_value=0) + + #modify one grid point using loc() + ds['empty'].loc[dict(lon=260, lat=30)] = 100 + + #modify a 2D region using loc() + lc = ds.coords['lon'] + la = ds.coords['lat'] + ds['empty'].loc[dict(lon=lc[(lc>220)&(lc<260)], lat=la[(la>20)&(la<60)])] = 100 + +or :py:meth:`~xarray.where`: + +.. ipython:: python + + #modify one grid point using xr.where() + ds['empty'] = xr.where((ds.coords['lat']==20)&(ds.coords['lon']==260), 100, ds['empty']) + + #or modify a 2D region using xr.where() + mask = (ds.coords['lat']>20)&(ds.coords['lat']<60)&(ds.coords['lon']>220)&(ds.coords['lon']<260) + ds['empty'] = xr.where(mask, 100, ds['empty']) + + +Vectorized indexing can also be used to assign values to xarray object. .. ipython:: python @@ -421,7 +487,7 @@ __ https://docs.scipy.org/doc/numpy/user/basics.indexing.html#assigning-values-t or ``sel``:: # DO NOT do this - arr.isel(space=0) = 0 + da.isel(space=0) = 0 Assigning values with the chained indexing using ``.sel`` or ``.isel`` fails silently. @@ -452,7 +518,7 @@ where three elements at ``(ix, iy) = ((0, 0), (1, 1), (6, 0))`` are selected and mapped along a new dimension ``z``. If you want to add a coordinate to the new dimension ``z``, -you can supply a :py:meth:`~xarray.DataArray` with a coordinate, +you can supply a :py:class:`~xarray.DataArray` with a coordinate, .. ipython:: python @@ -465,10 +531,13 @@ method: .. ipython:: python + da = xr.DataArray(np.random.rand(4, 3), + [('time', pd.date_range('2000-01-01', periods=4)), + ('space', ['IA', 'IL', 'IN'])]) times = xr.DataArray(pd.to_datetime(['2000-01-03', '2000-01-02', '2000-01-01']), dims='new_time') - arr.sel(space=xr.DataArray(['IA', 'IL', 'IN'], dims=['new_time']), - time=times) + da.sel(space=xr.DataArray(['IA', 'IL', 'IN'], dims=['new_time']), + time=times) .. _align and reindex: @@ -490,15 +559,15 @@ To reindex a particular dimension, use :py:meth:`~xarray.DataArray.reindex`: .. ipython:: python - arr.reindex(space=['IA', 'CA']) + da.reindex(space=['IA', 'CA']) The :py:meth:`~xarray.DataArray.reindex_like` method is a useful shortcut. To demonstrate, we will make a subset DataArray with new values: .. ipython:: python - foo = arr.rename('foo') - baz = (10 * arr[:2, :2]).rename('baz') + foo = da.rename('foo') + baz = (10 * da[:2, :2]).rename('baz') baz Reindexing ``foo`` with ``baz`` selects out the first two values along each @@ -545,8 +614,8 @@ integer-based indexing as a fallback for dimensions without a coordinate label: .. ipython:: python - array = xr.DataArray([1, 2, 3], dims='x') - array.sel(x=[0, -1]) + da = xr.DataArray([1, 2, 3], dims='x') + da.sel(x=[0, -1]) Alignment between xarray objects where one or both do not have coordinate labels succeeds only if all dimensions of the same name have the same length. @@ -555,7 +624,7 @@ Otherwise, it raises an informative error: .. ipython:: :verbatim: - In [62]: xr.align(array, array[:2]) + In [62]: xr.align(da, da[:2]) ValueError: arguments without labels along dimension 'x' cannot be aligned because they have different dimension sizes: {2, 3} Underlying Indexes @@ -567,9 +636,12 @@ through the :py:attr:`~xarray.DataArray.indexes` attribute. .. ipython:: python - arr - arr.indexes - arr.indexes['time'] + da = xr.DataArray(np.random.rand(4, 3), + [('time', pd.date_range('2000-01-01', periods=4)), + ('space', ['IA', 'IL', 'IN'])]) + da + da.indexes + da.indexes['time'] Use :py:meth:`~xarray.DataArray.get_index` to get an index for a dimension, falling back to a default :py:class:`pandas.RangeIndex` if it has no coordinate @@ -577,8 +649,9 @@ labels: .. ipython:: python - array - array.get_index('x') + da = xr.DataArray([1, 2, 3], dims='x') + da + da.get_index('x') .. _copies_vs_views: diff --git a/doc/installing.rst b/doc/installing.rst index e9fd9885b31..64751eea637 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -6,9 +6,9 @@ Installation Required dependencies --------------------- -- Python 2.7, 3.4, 3.5, or 3.6 -- `numpy `__ (1.11 or later) -- `pandas `__ (0.18.0 or later) +- Python 2.7 [1]_, 3.5, 3.6, or 3.7 +- `numpy `__ (1.12 or later) +- `pandas `__ (0.19.2 or later) Optional dependencies --------------------- @@ -25,10 +25,23 @@ For netCDF and IO - `pynio `__: for reading GRIB and other geoscience specific file formats - `zarr `__: for chunked, compressed, N-dimensional arrays. +- `cftime `__: recommended if you + want to encode/decode datetimes for non-standard calendars or dates before + year 1678 or after year 2262. +- `PseudoNetCDF `__: recommended + for accessing CAMx, GEOS-Chem (bpch), NOAA ARL files, ICARTT files + (ffi1001) and many other. +- `rasterio `__: for reading GeoTiffs and + other gridded raster datasets. +- `iris `__: for conversion to and from iris' + Cube objects +- `cfgrib `__: for reading GRIB files via the + *ECMWF ecCodes* library. For accelerating xarray ~~~~~~~~~~~~~~~~~~~~~~~ +- `scipy `__: necessary to enable the interpolation features for xarray objects - `bottleneck `__: speeds up NaN-skipping and rolling window aggregations by a large factor (1.1 or later) @@ -38,13 +51,14 @@ For accelerating xarray For parallel computing ~~~~~~~~~~~~~~~~~~~~~~ -- `dask.array `__ (0.9.0 or later): required for +- `dask.array `__ (0.16 or later): required for :ref:`dask`. For plotting ~~~~~~~~~~~~ - `matplotlib `__: required for :ref:`plotting` + (1.5 or later) - `cartopy `__: recommended for :ref:`plot-maps` - `seaborn `__: for better @@ -62,9 +76,9 @@ with its recommended dependencies using the conda command line tool:: .. _conda: http://conda.io/ -We recommend using the community maintained `conda-forge `__ channel if you need difficult\-to\-build dependencies such as cartopy or pynio:: +We recommend using the community maintained `conda-forge `__ channel if you need difficult\-to\-build dependencies such as cartopy, pynio or PseudoNetCDF:: - $ conda install -c conda-forge xarray cartopy pynio + $ conda install -c conda-forge xarray cartopy pynio pseudonetcdf New releases may also appear in conda-forge before being updated in the default channel. @@ -78,6 +92,7 @@ Testing ------- To run the test suite after installing xarray, first install (via pypi or conda) + - `py.test `__: Simple unit testing library - `mock `__: additional testing library required for python version 2 @@ -92,7 +107,18 @@ A fixed-point performance monitoring of (a part of) our codes can be seen on `this page `__. To run these benchmark tests in a local machine, first install + - `airspeed-velocity `__: a tool for benchmarking Python packages over their lifetime. and run ``asv run # this will install some conda environments in ./.asv/envs`` + +.. [1] Xarray plans to drop support for python 2.7 at the end of 2018. This + means that new releases of xarray published after this date will only be + installable on python 3+ environments, but older versions of xarray will + always be available to python 2.7 users. For more information see the + following references: + + - `Xarray Github issue discussing dropping Python 2 `__ + - `Python 3 Statement `__ + - `Tips on porting to Python 3 `__ diff --git a/doc/internals.rst b/doc/internals.rst index e5e14896472..170e2d0b0cc 100644 --- a/doc/internals.rst +++ b/doc/internals.rst @@ -130,20 +130,3 @@ To help users keep things straight, please `let us know `_ if you plan to write a new accessor for an open source library. In the future, we will maintain a list of accessors and the libraries that implement them on this page. - -Here are several existing libraries that build functionality upon xarray. -They may be useful points of reference for your work: - -- `xgcm `_: General Circulation Model - Postprocessing. Uses subclassing and custom xarray backends. -- `PyGDX `_: Python 3 package for - accessing data stored in GAMS Data eXchange (GDX) files. Also uses a custom - subclass. -- `windspharm `_: Spherical - harmonic wind analysis in Python. -- `eofs `_: EOF analysis in Python. -- `salem `_: Adds geolocalised subsetting, - masking, and plotting operations to xarray's data structures via accessors. - -.. TODO: consider adding references to these projects somewhere more prominent -.. in the documentation? maybe the FAQ page? diff --git a/doc/interpolation.rst b/doc/interpolation.rst new file mode 100644 index 00000000000..71e88079676 --- /dev/null +++ b/doc/interpolation.rst @@ -0,0 +1,304 @@ +.. _interp: + +Interpolating data +================== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + np.random.seed(123456) + +xarray offers flexible interpolation routines, which have a similar interface +to our :ref:`indexing `. + +.. note:: + + ``interp`` requires `scipy` installed. + + +Scalar and 1-dimensional interpolation +-------------------------------------- + +Interpolating a :py:class:`~xarray.DataArray` works mostly like labeled +indexing of a :py:class:`~xarray.DataArray`, + +.. ipython:: python + + da = xr.DataArray(np.sin(0.3 * np.arange(12).reshape(4, 3)), + [('time', np.arange(4)), + ('space', [0.1, 0.2, 0.3])]) + # label lookup + da.sel(time=3) + + # interpolation + da.interp(time=2.5) + + +Similar to the indexing, :py:meth:`~xarray.DataArray.interp` also accepts an +array-like, which gives the interpolated result as an array. + +.. ipython:: python + + # label lookup + da.sel(time=[2, 3]) + + # interpolation + da.interp(time=[2.5, 3.5]) + +To interpolate data with a :py:func:`numpy.datetime64` coordinate you can pass a string. + +.. ipython:: python + + da_dt64 = xr.DataArray([1, 3], + [('time', pd.date_range('1/1/2000', '1/3/2000', periods=2))]) + da_dt64.interp(time='2000-01-02') + +The interpolated data can be merged into the original :py:class:`~xarray.DataArray` +by specifing the time periods required. + +.. ipython:: python + + da_dt64.interp(time=pd.date_range('1/1/2000', '1/3/2000', periods=3)) + +Interpolation of data indexed by a :py:class:`~xarray.CFTimeIndex` is also +allowed. See :ref:`CFTimeIndex` for examples. + +.. note:: + + Currently, our interpolation only works for regular grids. + Therefore, similarly to :py:meth:`~xarray.DataArray.sel`, + only 1D coordinates along a dimension can be used as the + original coordinate to be interpolated. + + +Multi-dimensional Interpolation +------------------------------- + +Like :py:meth:`~xarray.DataArray.sel`, :py:meth:`~xarray.DataArray.interp` +accepts multiple coordinates. In this case, multidimensional interpolation +is carried out. + +.. ipython:: python + + # label lookup + da.sel(time=2, space=0.1) + + # interpolation + da.interp(time=2.5, space=0.15) + +Array-like coordinates are also accepted: + +.. ipython:: python + + # label lookup + da.sel(time=[2, 3], space=[0.1, 0.2]) + + # interpolation + da.interp(time=[1.5, 2.5], space=[0.15, 0.25]) + + +:py:meth:`~xarray.DataArray.interp_like` method is a useful shortcut. This +method interpolates an xarray object onto the coordinates of another xarray +object. For example, if we want to compute the difference between +two :py:class:`~xarray.DataArray` s (``da`` and ``other``) staying on slightly +different coordinates, + +.. ipython:: python + + other = xr.DataArray(np.sin(0.4 * np.arange(9).reshape(3, 3)), + [('time', [0.9, 1.9, 2.9]), + ('space', [0.15, 0.25, 0.35])]) + +it might be a good idea to first interpolate ``da`` so that it will stay on the +same coordinates of ``other``, and then subtract it. +:py:meth:`~xarray.DataArray.interp_like` can be used for such a case, + +.. ipython:: python + + # interpolate da along other's coordinates + interpolated = da.interp_like(other) + interpolated + +It is now possible to safely compute the difference ``other - interpolated``. + + +Interpolation methods +--------------------- + +We use :py:func:`scipy.interpolate.interp1d` for 1-dimensional interpolation and +:py:func:`scipy.interpolate.interpn` for multi-dimensional interpolation. + +The interpolation method can be specified by the optional ``method`` argument. + +.. ipython:: python + + da = xr.DataArray(np.sin(np.linspace(0, 2 * np.pi, 10)), dims='x', + coords={'x': np.linspace(0, 1, 10)}) + + da.plot.line('o', label='original') + da.interp(x=np.linspace(0, 1, 100)).plot.line(label='linear (default)') + da.interp(x=np.linspace(0, 1, 100), method='cubic').plot.line(label='cubic') + @savefig interpolation_sample1.png width=4in + plt.legend() + +Additional keyword arguments can be passed to scipy's functions. + +.. ipython:: python + + # fill 0 for the outside of the original coordinates. + da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={'fill_value': 0.0}) + # extrapolation + da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={'fill_value': 'extrapolate'}) + + +Advanced Interpolation +---------------------- + +:py:meth:`~xarray.DataArray.interp` accepts :py:class:`~xarray.DataArray` +as similar to :py:meth:`~xarray.DataArray.sel`, which enables us more advanced interpolation. +Based on the dimension of the new coordinate passed to :py:meth:`~xarray.DataArray.interp`, the dimension of the result are determined. + +For example, if you want to interpolate a two dimensional array along a particular dimension, as illustrated below, +you can pass two 1-dimensional :py:class:`~xarray.DataArray` s with +a common dimension as new coordinate. + +.. image:: _static/advanced_selection_interpolation.svg + :height: 200px + :width: 400 px + :alt: advanced indexing and interpolation + :align: center + +For example: + +.. ipython:: python + + da = xr.DataArray(np.sin(0.3 * np.arange(20).reshape(5, 4)), + [('x', np.arange(5)), + ('y', [0.1, 0.2, 0.3, 0.4])]) + # advanced indexing + x = xr.DataArray([0, 2, 4], dims='z') + y = xr.DataArray([0.1, 0.2, 0.3], dims='z') + da.sel(x=x, y=y) + + # advanced interpolation + x = xr.DataArray([0.5, 1.5, 2.5], dims='z') + y = xr.DataArray([0.15, 0.25, 0.35], dims='z') + da.interp(x=x, y=y) + +where values on the original coordinates +``(x, y) = ((0.5, 0.15), (1.5, 0.25), (2.5, 0.35))`` are obtained by the +2-dimensional interpolation and mapped along a new dimension ``z``. + +If you want to add a coordinate to the new dimension ``z``, you can supply +:py:class:`~xarray.DataArray` s with a coordinate, + +.. ipython:: python + + x = xr.DataArray([0.5, 1.5, 2.5], dims='z', coords={'z': ['a', 'b','c']}) + y = xr.DataArray([0.15, 0.25, 0.35], dims='z', + coords={'z': ['a', 'b','c']}) + da.interp(x=x, y=y) + +For the details of the advanced indexing, +see :ref:`more advanced indexing `. + + +Interpolating arrays with NaN +----------------------------- + +Our :py:meth:`~xarray.DataArray.interp` works with arrays with NaN +the same way that +`scipy.interpolate.interp1d `_ and +`scipy.interpolate.interpn `_ do. +``linear`` and ``nearest`` methods return arrays including NaN, +while other methods such as ``cubic`` or ``quadratic`` return all NaN arrays. + +.. ipython:: python + + da = xr.DataArray([0, 2, np.nan, 3, 3.25], dims='x', + coords={'x': range(5)}) + da.interp(x=[0.5, 1.5, 2.5]) + da.interp(x=[0.5, 1.5, 2.5], method='cubic') + +To avoid this, you can drop NaN by :py:meth:`~xarray.DataArray.dropna`, and +then make the interpolation + +.. ipython:: python + + dropped = da.dropna('x') + dropped + dropped.interp(x=[0.5, 1.5, 2.5], method='cubic') + +If NaNs are distributed rondomly in your multidimensional array, +dropping all the columns containing more than one NaNs by +:py:meth:`~xarray.DataArray.dropna` may lose a significant amount of information. +In such a case, you can fill NaN by :py:meth:`~xarray.DataArray.interpolate_na`, +which is similar to :py:meth:`pandas.Series.interpolate`. + +.. ipython:: python + + filled = da.interpolate_na(dim='x') + filled + +This fills NaN by interpolating along the specified dimension. +After filling NaNs, you can interpolate: + +.. ipython:: python + + filled.interp(x=[0.5, 1.5, 2.5], method='cubic') + +For the details of :py:meth:`~xarray.DataArray.interpolate_na`, +see :ref:`Missing values `. + + +Example +------- + +Let's see how :py:meth:`~xarray.DataArray.interp` works on real data. + +.. ipython:: python + + # Raw data + ds = xr.tutorial.open_dataset('air_temperature').isel(time=0) + fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) + ds.air.plot(ax=axes[0]) + axes[0].set_title('Raw data') + + # Interpolated data + new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims['lon'] * 4) + new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims['lat'] * 4) + dsi = ds.interp(lat=new_lat, lon=new_lon) + dsi.air.plot(ax=axes[1]) + @savefig interpolation_sample3.png width=8in + axes[1].set_title('Interpolated data') + +Our advanced interpolation can be used to remap the data to the new coordinate. +Consider the new coordinates x and z on the two dimensional plane. +The remapping can be done as follows + +.. ipython:: python + + # new coordinate + x = np.linspace(240, 300, 100) + z = np.linspace(20, 70, 100) + # relation between new and original coordinates + lat = xr.DataArray(z, dims=['z'], coords={'z': z}) + lon = xr.DataArray((x[:, np.newaxis]-270)/np.cos(z*np.pi/180)+270, + dims=['x', 'z'], coords={'x': x, 'z': z}) + + fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) + ds.air.plot(ax=axes[0]) + # draw the new coordinate on the original coordinates. + for idx in [0, 33, 66, 99]: + axes[0].plot(lon.isel(x=idx), lat, '--k') + for idx in [0, 33, 66, 99]: + axes[0].plot(*xr.broadcast(lon.isel(z=idx), lat.isel(z=idx)), '--k') + axes[0].set_title('Raw data') + + dsi = ds.interp(lon=lon, lat=lat) + dsi.air.plot(ax=axes[1]) + @savefig interpolation_sample4.png width=8in + axes[1].set_title('Remapped data') diff --git a/doc/io.rst b/doc/io.rst index c177496f6f2..e841e665308 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -534,7 +534,7 @@ longitudes and latitudes. considered as being experimental. Please report any bug you may find on xarray's github repository. -.. _rasterio: https://mapbox.github.io/rasterio/ +.. _rasterio: https://rasterio.readthedocs.io/en/latest/ .. _test files: https://github.com/mapbox/rasterio/blob/master/tests/data/RGB.byte.tif .. _pyproj: https://github.com/jswhit/pyproj @@ -603,7 +603,7 @@ pass to xarray:: # write to the bucket ds.to_zarr(store=gcsmap) # read it back - ds_gcs = xr.open_zarr(gcsmap, mode='r') + ds_gcs = xr.open_zarr(gcsmap) .. _Zarr: http://zarr.readthedocs.io/ .. _Amazon S3: https://aws.amazon.com/s3/ @@ -635,6 +635,28 @@ For example: Not all native zarr compression and filtering options have been tested with xarray. +.. _io.cfgrib: + +GRIB format via cfgrib +---------------------- + +xarray supports reading GRIB files via ECMWF cfgrib_ python driver and ecCodes_ +C-library, if they are installed. To open a GRIB file supply ``engine='cfgrib'`` +to :py:func:`~xarray.open_dataset`: + +.. ipython:: + :verbatim: + + In [1]: ds_grib = xr.open_dataset('example.grib', engine='cfgrib') + +We recommend installing ecCodes via conda:: + + conda install -c conda-forge eccodes + pip install cfgrib + +.. _cfgrib: https://github.com/ecmwf/cfgrib +.. _ecCodes: https://confluence.ecmwf.int/display/ECC/ecCodes+Home + .. _io.pynio: Formats supported by PyNIO @@ -650,7 +672,26 @@ We recommend installing PyNIO via conda:: .. _PyNIO: https://www.pyngl.ucar.edu/Nio.shtml -.. _combining multiple files: +.. _io.PseudoNetCDF: + +Formats supported by PseudoNetCDF +--------------------------------- + +xarray can also read CAMx, BPCH, ARL PACKED BIT, and many other file +formats supported by PseudoNetCDF_, if PseudoNetCDF is installed. +PseudoNetCDF can also provide Climate Forecasting Conventions to +CMAQ files. In addition, PseudoNetCDF can automatically register custom +readers that subclass PseudoNetCDF.PseudoNetCDFFile. PseudoNetCDF can +identify readers heuristically, or format can be specified via a key in +`backend_kwargs`. + +To use PseudoNetCDF to read such files, supply +``engine='pseudonetcdf'`` to :py:func:`~xarray.open_dataset`. + +Add ``backend_kwargs={'format': ''}`` where `` +options are listed on the PseudoNetCDF page. + +.. _PseudoNetCDF: http://github.com/barronh/PseudoNetCDF Formats supported by Pandas @@ -662,6 +703,8 @@ exporting your objects to pandas and using its broad range of `IO tools`_. .. _IO tools: http://pandas.pydata.org/pandas-docs/stable/io.html +.. _combining multiple files: + Combining multiple files ------------------------ @@ -672,9 +715,9 @@ files into a single Dataset by making use of :py:func:`~xarray.concat`. .. note:: - Version 0.5 includes support for manipulating datasets that - don't fit into memory with dask_. If you have dask installed, you can open - multiple files simultaneously using :py:func:`~xarray.open_mfdataset`:: + Xarray includes support for manipulating datasets that don't fit into memory + with dask_. If you have dask installed, you can open multiple files + simultaneously using :py:func:`~xarray.open_mfdataset`:: xr.open_mfdataset('my/files/*.nc') diff --git a/doc/plotting.rst b/doc/plotting.rst index 2b816a24563..f8ba82febb0 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -60,12 +60,19 @@ For these examples we'll use the North American air temperature dataset. .. ipython:: python - airtemps = xr.tutorial.load_dataset('air_temperature') + airtemps = xr.tutorial.open_dataset('air_temperature') airtemps # Convert to celsius air = airtemps.air - 273.15 + # copy attributes to get nice figure labels and change Kelvin to Celsius + air.attrs = airtemps.air.attrs + air.attrs['units'] = 'deg C' + +.. note:: + Until :issue:`1614` is solved, you might need to copy over the metadata in ``attrs`` to get informative figure labels (as was done above). + One Dimension ------------- @@ -73,7 +80,7 @@ One Dimension Simple Example ~~~~~~~~~~~~~~ -xarray uses the coordinate name to label the x axis. +The simplest way to make a plot is to call the :py:func:`xarray.DataArray.plot()` method. .. ipython:: python @@ -82,6 +89,12 @@ xarray uses the coordinate name to label the x axis. @savefig plotting_1d_simple.png width=4in air1d.plot() +xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) to label the axes. The names ``long_name``, ``standard_name`` and ``units`` are copied from the `CF-conventions spec `_. When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``. + +.. ipython:: python + + air1d.attrs + Additional Arguments ~~~~~~~~~~~~~~~~~~~~~ @@ -195,7 +208,64 @@ It is required to explicitly specify either 2. ``hue``: the dimension you want to represent by multiple lines. Thus, we could have made the previous plot by specifying ``hue='lat'`` instead of ``x='time'``. -If required, the automatic legend can be turned off using ``add_legend=False``. +If required, the automatic legend can be turned off using ``add_legend=False``. Alternatively, +``hue`` can be passed directly to :py:func:`xarray.plot` as `air.isel(lon=10, lat=[19,21,22]).plot(hue='lat')`. + + +Dimension along y-axis +~~~~~~~~~~~~~~~~~~~~~~ + +It is also possible to make line plots such that the data are on the x-axis and a dimension is on the y-axis. This can be done by specifying the appropriate ``y`` keyword argument. + +.. ipython:: python + + @savefig plotting_example_xy_kwarg.png + air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon') + +Step plots +~~~~~~~~~~ + +As an alternative, also a step plot similar to matplotlib's ``plt.step`` can be +made using 1D data. + +.. ipython:: python + + @savefig plotting_example_step.png width=4in + air1d[:20].plot.step(where='mid') + +The argument ``where`` defines where the steps should be placed, options are +``'pre'`` (default), ``'post'``, and ``'mid'``. This is particularly handy +when plotting data grouped with :py:func:`xarray.Dataset.groupby_bins`. + +.. ipython:: python + + air_grp = air.mean(['time','lon']).groupby_bins('lat',[0,23.5,66.5,90]) + air_mean = air_grp.mean() + air_std = air_grp.std() + air_mean.plot.step() + (air_mean + air_std).plot.step(ls=':') + (air_mean - air_std).plot.step(ls=':') + plt.ylim(-20,30) + @savefig plotting_example_step_groupby.png width=4in + plt.title('Zonal mean temperature') + +In this case, the actual boundaries of the bins are used and the ``where`` argument +is ignored. + + +Other axes kwargs +----------------- + + +The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction. + +.. ipython:: python + + @savefig plotting_example_xincrease_yincrease_kwarg.png + air.isel(time=10, lon=[10, 11]).plot.line(y='lat', hue='lon', xincrease=False, yincrease=False) + +In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, ``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively. + Two Dimensions -------------- @@ -416,9 +486,16 @@ arguments to the xarray plotting methods/functions. This returns a .. ipython:: python - @savefig plot_facet_dataarray.png height=12in + @savefig plot_facet_dataarray.png g_simple = t.plot(x='lon', y='lat', col='time', col_wrap=3) +Faceting also works for line plots. + +.. ipython:: python + + @savefig plot_facet_dataarray_line.png + g_simple_line = t.isel(lat=slice(0,None,4)).plot(x='lon', hue='lat', col='time', col_wrap=3) + 4 dimensional ~~~~~~~~~~~~~ @@ -434,7 +511,7 @@ one were much hotter. # This is a 4d array t4d.coords - @savefig plot_facet_4d.png height=12in + @savefig plot_facet_4d.png t4d.plot(x='lon', y='lat', col='time', row='fourth_dim') Other features @@ -448,9 +525,10 @@ Faceted plotting supports other arguments common to xarray 2d plots. hasoutliers[0, 0, 0] = -100 hasoutliers[-1, -1, -1] = 400 - @savefig plot_facet_robust.png height=12in + @savefig plot_facet_robust.png g = hasoutliers.plot.pcolormesh('lon', 'lat', col='time', col_wrap=3, - robust=True, cmap='viridis') + robust=True, cmap='viridis', + cbar_kwargs={'label': 'this has outliers'}) FacetGrid Objects ~~~~~~~~~~~~~~~~~ @@ -489,7 +567,7 @@ they have been plotted. bottomright = g.axes[-1, -1] bottomright.annotate('bottom right', (240, 40)) - @savefig plot_facet_iterator.png height=12in + @savefig plot_facet_iterator.png plt.show() TODO: add an example of using the ``map`` method to plot dataset variables @@ -507,7 +585,7 @@ This script will plot the air temperature on a map. .. ipython:: python import cartopy.crs as ccrs - air = xr.tutorial.load_dataset('air_temperature').air + air = xr.tutorial.open_dataset('air_temperature').air ax = plt.axes(projection=ccrs.Orthographic(-80, 35)) air.isel(time=0).plot.contourf(ax=ax, transform=ccrs.PlateCarree()); @savefig plotting_maps_cartopy.png width=100% @@ -657,3 +735,12 @@ You can however decide to infer the cell boundaries and use the outside the xarray framework. .. _cell boundaries: http://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#cell-boundaries + +One can also make line plots with multidimensional coordinates. In this case, ``hue`` must be a dimension name, not a coordinate name. + +.. ipython:: python + + f, ax = plt.subplots(2, 1) + da.plot.line(x='lon', hue='y', ax=ax[0]); + @savefig plotting_example_2d_hue_xy.png + da.plot.line(x='lon', hue='x', ax=ax[1]); diff --git a/doc/related-projects.rst b/doc/related-projects.rst new file mode 100644 index 00000000000..cf89c715bc7 --- /dev/null +++ b/doc/related-projects.rst @@ -0,0 +1,69 @@ +.. _related-projects: + +Xarray related projects +----------------------- + +Here below is a list of several existing libraries that build +functionality upon xarray. See also section :ref:`internals` for more +details on how to build xarray extensions. + +Geosciences +~~~~~~~~~~~ + +- `aospy `_: Automated analysis and management of gridded climate data. +- `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meterology data +- `marc_analysis `_: Analysis package for CESM/MARC experiments and output. +- `MPAS-Analysis `_: Analysis for simulations produced with Model for Prediction Across Scales (MPAS) components and the Accelerated Climate Model for Energy (ACME). +- `OGGM `_: Open Global Glacier Model +- `Oocgcm `_: Analysis of large gridded geophysical datasets +- `Open Data Cube `_: Analysis toolkit of continental scale Earth Observation data from satellites. +- `Pangaea: `_: xarray extension for gridded land surface & weather model output). +- `Pangeo `_: A community effort for big data geoscience in the cloud. +- `PyGDX `_: Python 3 package for + accessing data stored in GAMS Data eXchange (GDX) files. Also uses a custom + subclass. +- `Regionmask `_: plotting and creation of masks of spatial regions +- `salem `_: Adds geolocalised subsetting, masking, and plotting operations to xarray's data structures via accessors. +- `SatPy `_ : Library for reading and manipulating meteorological remote sensing data and writing it to various image and data file formats. +- `Spyfit `_: FTIR spectroscopy of the atmosphere +- `windspharm `_: Spherical + harmonic wind analysis in Python. +- `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. +- `xarray-simlab `_: xarray extension for computer model simulations. +- `xarray-topo `_: xarray extension for topographic analysis and modelling. +- `xbpch `_: xarray interface for bpch files. +- `xESMF `_: Universal Regridder for Geospatial Data. +- `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids. +- `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures. +- `xshape `_: Tools for working with shapefiles, topographies, and polygons in xarray. + +Machine Learning +~~~~~~~~~~~~~~~~ +- `cesium `_: machine learning for time series analysis +- `Elm `_: Parallel machine learning on xarray data structures +- `sklearn-xarray (1) `_: Combines scikit-learn and xarray (1). +- `sklearn-xarray (2) `_: Combines scikit-learn and xarray (2). + +Extend xarray capabilities +~~~~~~~~~~~~~~~~~~~~~~~~~~ +- `Collocate `_: Collocate xarray trajectories in arbitrary physical dimensions +- `eofs `_: EOF analysis in Python. +- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). +- `xrft `_: Fourier transforms for xarray data. +- `xr-scipy `_: A lightweight scipy wrapper for xarray. +- `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. +- `xskillscore `_: Metrics for verifying forecasts. +- `xyzpy `_: Easily generate high dimensional data, including parallelization. + +Visualization +~~~~~~~~~~~~~ +- `Datashader `_, `geoviews `_, `holoviews `_, : visualization packages for large data. +- `hvplot `_ : A high-level plotting API for the PyData ecosystem built on HoloViews. +- `psyplot `_: Interactive data visualization with python. + +Other +~~~~~ +- `ptsa `_: EEG Time Series Analysis +- `pycalphad `_: Computational Thermodynamics in Python + +More projects can be found at the `"xarray" Github topic `_. diff --git a/doc/roadmap.rst b/doc/roadmap.rst new file mode 100644 index 00000000000..34d203c3f48 --- /dev/null +++ b/doc/roadmap.rst @@ -0,0 +1,227 @@ +.. _roadmap: + +Development roadmap +=================== + +Authors: Stephan Hoyer, Joe Hamman and xarray developers + +Date: July 24, 2018 + +Xarray is an open source Python library for labeled multidimensional +arrays and datasets. + +Our philosophy +-------------- + +Why has xarray been successful? In our opinion: + +- Xarray does a great job of solving **specific use-cases** for + multidimensional data analysis: + + - The dominant use-case for xarray is for analysis of gridded + dataset in the geosciences, e.g., as part of the + `Pangeo `__ project. + - Xarray is also used more broadly in the physical sciences, where + we've found the needs for analyzing multidimensional datasets are + remarkably consistent (e.g., see + `SunPy `__ and + `PlasmaPy `__). + - Finally, xarray is used in a variety of other domains, including + finance, `probabilistic + programming `__ and + genomics. + +- Xarray is also a **domain agnostic** solution: + + - We focus on providing a flexible set of functionality related + labeled multidimensional arrays, rather than solving particular + problems. + - This facilitates collaboration between users with different needs, + and helps us attract a broad community of contributers. + - Importantly, this retains flexibility, for use cases that don't + fit particularly well into existing frameworks. + +- Xarray **integrates well** with other libraries in the scientific + Python stack. + + - We leverage first-class external libraries for core features of + xarray (e.g., NumPy for ndarrays, pandas for indexing, dask for + parallel computing) + - We expose our internal abstractions to users (e.g., + ``apply_ufunc()``), which facilitates extending xarray in various + ways. + +Together, these features have made xarray a first-class choice for +labeled multidimensional arrays in Python. + +We want to double-down on xarray's strengths by making it an even more +flexible and powerful tool for multidimensional data analysis. We want +to continue to engage xarray's core geoscience users, and to also reach +out to new domains to learn from other successful data models like those +of `yt `__ or the `OLAP +cube `__. + +Specific needs +-------------- + +The user community has voiced a number specific needs related to how +xarray interfaces with domain specific problems. Xarray may not solve +all of these issues directly, but these areas provide opportunities for +xarray to provide better, more extensible, interfaces. Some examples of +these common needs are: + +- Non-regular grids (e.g., staggered and unstructured meshes). +- Physical units. +- Lazily computed arrays (e.g., for coordinate systems). +- New file-formats. + +Technical vision +---------------- + +We think the right approach to extending xarray's user community and the +usefulness of the project is to focus on improving key interfaces that +can be used externally to meet domain-specific needs. + +We can generalize the community's needs into three main catagories: + +- More flexible grids/indexing. +- More flexible arrays/computing. +- More flexible storage backends. + +Each of these are detailed further in the subsections below. + +Flexible indexes +~~~~~~~~~~~~~~~~ + +Xarray currently keeps track of indexes associated with coordinates by +storing them in the form of a ``pandas.Index`` in special +``xarray.IndexVariable`` objects. + +The limitations of this model became clear with the addition of +``pandas.MultiIndex`` support in xarray 0.9, where a single index +corresponds to multiple xarray variables. MultiIndex support is highly +useful, but xarray now has numerous special cases to check for +MultiIndex levels. + +A cleaner model would be to elevate ``indexes`` to an explicit part of +xarray's data model, e.g., as attributes on the ``Dataset`` and +``DataArray`` classes. Indexes would need to be propagated along with +coordinates in xarray operations, but will no longer would need to have +a one-to-one correspondance with coordinate variables. Instead, an index +should be able to refer to multiple (possibly multidimensional) +coordinates that define it. See `GH +1603 `__ for full details + +Specific tasks: + +- Add an ``indexes`` attribute to ``xarray.Dataset`` and + ``xarray.Dataset``, as dictionaries that map from coordinate names to + xarray index objects. +- Use the new index interface to write wrappers for ``pandas.Index``, + ``pandas.MultiIndex`` and ``scipy.spatial.KDTree``. +- Expose the interface externally to allow third-party libraries to + implement custom indexing routines, e.g., for geospatial look-ups on + the surface of the Earth. + +In addition to the new features it directly enables, this clean up will +allow xarray to more easily implement some long-awaited features that +build upon indexing, such as groupby operations with multiple variables. + +Flexible arrays +~~~~~~~~~~~~~~~ + +Xarray currently supports wrapping multidimensional arrays defined by +NumPy, dask and to a limited-extent pandas. It would be nice to have +interfaces that allow xarray to wrap alternative N-D array +implementations, e.g.: + +- Arrays holding physical units. +- Lazily computed arrays. +- Other ndarray objects, e.g., sparse, xnd, xtensor. + +Our strategy has been to pursue upstream improvements in NumPy (see +`NEP-22 `__) +for supporting a complete duck-typing interface using with NumPy's +higher level array API. Improvements in NumPy's support for custom data +types would also be highly useful for xarray users. + +By pursuing these improvements in NumPy we hope to extend the benefits +to the full scientific Python community, and avoid tight coupling +between xarray and specific third-party libraries (e.g., for +implementing untis). This will allow xarray to maintain its domain +agnostic strengths. + +We expect that we may eventually add some minimal interfaces in xarray +for features that we delegate to external array libraries (e.g., for +getting units and changing units). If we do add these features, we +expect them to be thin wrappers, with core functionality implemented by +third-party libraries. + +Flexible storage +~~~~~~~~~~~~~~~~ + +The xarray backends module has grown in size and complexity. Much of +this growth has been "organic" and mostly to support incremental +additions to the supported backends. This has left us with a fragile +internal API that is difficult for even experienced xarray developers to +use. Moreover, the lack of a public facing API for building xarray +backends means that users can not easily build backend interface for +xarray in third-party libraries. + +The idea of refactoring the backends API and exposing it to users was +originally proposed in `GH +1970 `__. The idea would +be to develop a well tested and generic backend base class and +associated utilities for external use. Specific tasks for this +development would include: + +- Exposing an abstract backend for writing new storage systems. +- Exposing utilities for features like automatic closing of files, + LRU-caching and explicit/lazy indexing. +- Possibly moving some infrequently used backends to third-party + packages. + +Engaging more users +------------------- + +Like many open-source projects, the documentation of xarray has grown +together with the library's features. While we think that the xarray +documentation is comprehensive already, we aknowledge that the adoption +of xarray might be slowed down because of the substantial time +investment required to learn its working principles. In particular, +non-computer scientists or users less familiar with the pydata ecosystem +might find it difficult to learn xarray and realize how xarray can help +them in their daily work. + +In order to lower this adoption barrier, we propose to: + +- Develop entry-level tutorials for users with different backgrounds. For + example, we would like to develop tutorials for users with or without + previous knowledge of pandas, numpy, netCDF, etc. These tutorials may be + built as part of xarray's documentation or included in a seperate repository + to enable interactive use (e.g. mybinder.org). +- Document typical user workflows in a dedicated website, following the example + of `dask-stories + `__. +- Write a basic glossary that defines terms that might not be familiar to all + (e.g. "lazy", "labeled", "serialization", "indexing", "backend"). + +Administrative +-------------- + +Current core developers +~~~~~~~~~~~~~~~~~~~~~~~ + +- Stephan Hoyer +- Ryan Abernathey +- Joe Hamman +- Benoit Bovy +- Fabien Maussion +- Keisuke Fujii +- Maximilian Roos + +NumFOCUS +~~~~~~~~ + +On July 16, 2018, Joe and Stephan submitted xarray's fiscal sponsorship +application to NumFOCUS. diff --git a/doc/time-series.rst b/doc/time-series.rst index bdf8b1e7f81..c225c246a8c 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -47,6 +47,17 @@ attribute like ``'days since 2000-01-01'``). .. _CF conventions: http://cfconventions.org +.. note:: + + When decoding/encoding datetimes for non-standard calendars or for dates + before year 1678 or after year 2262, xarray uses the `cftime`_ library. + It was previously packaged with the ``netcdf4-python`` package under the + name ``netcdftime`` but is now distributed separately. ``cftime`` is an + :ref:`optional dependency` of xarray. + +.. _cftime: https://unidata.github.io/cftime + + You can manual decode arrays in this form by passing a dataset to :py:func:`~xarray.decode_cf`: @@ -59,7 +70,12 @@ You can manual decode arrays in this form by passing a dataset to One unfortunate limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262. When a netCDF file contains dates outside of these bounds, dates will be -returned as arrays of ``netcdftime.datetime`` objects. +returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`~xarray.CFTimeIndex` +will be used for indexing. :py:class:`~xarray.CFTimeIndex` enables a subset of +the indexing functionality of a :py:class:`pandas.DatetimeIndex` and is only +fully compatible with the standalone version of ``cftime`` (not the version +packaged with earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more +information. Datetime indexing ----------------- @@ -129,6 +145,14 @@ the first letters of the corresponding months. You can use these shortcuts with both Datasets and DataArray coordinates. +In addition, xarray supports rounding operations ``floor``, ``ceil``, and ``round``. These operations require that you supply a `rounding frequency as a string argument.`__ + +__ http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases + +.. ipython:: python + + ds['time'].dt.floor('D') + .. _resampling: Resampling and grouped operations @@ -175,16 +199,145 @@ and ``interpolate``. ``interpolate`` extends ``scipy.interpolate.interp1d`` and supports all of its schemes. All of these resampling operations work on both Dataset and DataArray objects with an arbitrary number of dimensions. +For more examples of using grouped operations on a time dimension, see +:ref:`toy weather data`. + + +.. _CFTimeIndex: + +Non-standard calendars and dates outside the Timestamp-valid range +------------------------------------------------------------------ + +Through the standalone ``cftime`` library and a custom subclass of +:py:class:`pandas.Index`, xarray supports a subset of the indexing +functionality enabled through the standard :py:class:`pandas.DatetimeIndex` for +dates from non-standard calendars commonly used in climate science or dates +using a standard calendar, but outside the `Timestamp-valid range`_ +(approximately between years 1678 and 2262). + .. note:: - The ``resample`` api was updated in version 0.10.0 to reflect similar - updates in pandas ``resample`` api to be more groupby-like. Older style - calls to ``resample`` will still be supported for a short period: + As of xarray version 0.11, by default, :py:class:`cftime.datetime` objects + will be used to represent times (either in indexes, as a + :py:class:`~xarray.CFTimeIndex`, or in data arrays with dtype object) if + any of the following are true: - .. ipython:: python + - The dates are from a non-standard calendar + - Any dates are outside the Timestamp-valid range. - ds.resample('6H', dim='time', how='mean') + Otherwise pandas-compatible dates from a standard calendar will be + represented with the ``np.datetime64[ns]`` data type, enabling the use of a + :py:class:`pandas.DatetimeIndex` or arrays with dtype ``np.datetime64[ns]`` + and their full set of associated features. +For example, you can create a DataArray indexed by a time +coordinate with dates from a no-leap calendar and a +:py:class:`~xarray.CFTimeIndex` will automatically be used: -For more examples of using grouped operations on a time dimension, see -:ref:`toy weather data`. +.. ipython:: python + + from itertools import product + from cftime import DatetimeNoLeap + + dates = [DatetimeNoLeap(year, month, 1) for year, month in + product(range(1, 3), range(1, 13))] + da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') + +xarray also includes a :py:func:`~xarray.cftime_range` function, which enables +creating a :py:class:`~xarray.CFTimeIndex` with regularly-spaced dates. For +instance, we can create the same dates and DataArray we created above using: + +.. ipython:: python + + dates = xr.cftime_range(start='0001', periods=24, freq='MS', calendar='noleap') + da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') + +For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: + +- `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial + datetime strings: + +.. ipython:: python + + da.sel(time='0001') + da.sel(time=slice('0001-05', '0002-02')) + +- Access of basic datetime components via the ``dt`` accessor (in this case + just "year", "month", "day", "hour", "minute", "second", "microsecond", and + "season"): + +.. ipython:: python + + da.time.dt.year + da.time.dt.month + da.time.dt.season + +- Group-by operations based on datetime accessor attributes (e.g. by month of + the year): + +.. ipython:: python + + da.groupby('time.month').sum() + +- Interpolation using :py:class:`cftime.datetime` objects: + +.. ipython:: python + + da.interp(time=[DatetimeNoLeap(1, 1, 15), DatetimeNoLeap(1, 2, 15)]) + +- Interpolation using datetime strings: + +.. ipython:: python + + da.interp(time=['0001-01-15', '0001-02-15']) + +- Differentiation: + +.. ipython:: python + + da.differentiate('time') + +- And serialization: + +.. ipython:: python + + da.to_netcdf('example-no-leap.nc') + xr.open_dataset('example-no-leap.nc') + +.. note:: + + While much of the time series functionality that is possible for standard + dates has been implemented for dates from non-standard calendars, there are + still some remaining important features that have yet to be implemented, + for example: + + - Resampling along the time dimension for data indexed by a + :py:class:`~xarray.CFTimeIndex` (:issue:`2191`, :issue:`2458`) + - Built-in plotting of data with :py:class:`cftime.datetime` coordinate axes + (:issue:`2164`). + + For some use-cases it may still be useful to convert from + a :py:class:`~xarray.CFTimeIndex` to a :py:class:`pandas.DatetimeIndex`, + despite the difference in calendar types (e.g. to allow the use of some + forms of resample with non-standard calendars). The recommended way of + doing this is to use the built-in + :py:meth:`~xarray.CFTimeIndex.to_datetimeindex` method: + + .. ipython:: python + + modern_times = xr.cftime_range('2000', periods=24, freq='MS', calendar='noleap') + da = xr.DataArray(range(24), [('time', modern_times)]) + da + datetimeindex = da.indexes['time'].to_datetimeindex() + da['time'] = datetimeindex + da.resample(time='Y').mean('time') + + However in this case one should use caution to only perform operations which + do not depend on differences between dates (e.g. differentiation, + interpolation, or upsampling with resample), as these could introduce subtle + and silent errors due to the difference in calendar types between the dates + encoded in your data and the dates stored in memory. + +.. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#timestamp-limitations +.. _ISO 8601-format: https://en.wikipedia.org/wiki/ISO_8601 +.. _partial datetime string indexing: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#partial-string-indexing diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a1fee8d5961..1da1da700e7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,63 +13,782 @@ What's New import xarray as xr np.random.seed(123456) -.. _whats-new.0.10.1: +.. warning:: + + Xarray plans to drop support for python 2.7 at the end of 2018. This + means that new releases of xarray published after this date will only be + installable on python 3+ environments, but older versions of xarray will + always be available to python 2.7 users. For more information see the + following references + + - `Xarray Github issue discussing dropping Python 2 `__ + - `Python 3 Statement `__ + - `Tips on porting to Python 3 `__ + +.. _whats-new.0.11.1: -v0.10.1 (unreleased) +v0.11.1 (unreleased) -------------------- +Breaking changes +~~~~~~~~~~~~~~~~ + +Enhancements +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +.. _whats-new.0.11.0: + +v0.11.0 (7 November 2018) +------------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Finished deprecations (changed behavior with this release): + + - ``Dataset.T`` has been removed as a shortcut for :py:meth:`Dataset.transpose`. + Call :py:meth:`Dataset.transpose` directly instead. + - Iterating over a ``Dataset`` now includes only data variables, not coordinates. + Similarily, calling ``len`` and ``bool`` on a ``Dataset`` now + includes only data variables. + - ``DataArray.__contains__`` (used by Python's ``in`` operator) now checks + array data, not coordinates. + - The old resample syntax from before xarray 0.10, e.g., + ``data.resample('1D', dim='time', how='mean')``, is no longer supported will + raise an error in most cases. You need to use the new resample syntax + instead, e.g., ``data.resample(time='1D').mean()`` or + ``data.resample({'time': '1D'}).mean()``. + + +- New deprecations (behavior will be changed in xarray 0.12): + + - Reduction of :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample` + without dimension argument will change in the next release. + Now we warn a FutureWarning. + By `Keisuke Fujii `_. + - The ``inplace`` kwarg of a number of `DataArray` and `Dataset` methods is being + deprecated and will be removed in the next release. + By `Deepak Cherian `_. + + +- Refactored storage backends: + + - Xarray's storage backends now automatically open and close files when + necessary, rather than requiring opening a file with ``autoclose=True``. A + global least-recently-used cache is used to store open files; the default + limit of 128 open files should suffice in most cases, but can be adjusted if + necessary with + ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument + to ``open_dataset`` and related functions has been deprecated and is now a + no-op. + + This change, along with an internal refactor of xarray's storage backends, + should significantly improve performance when reading and writing + netCDF files with Dask, especially when working with many files or using + Dask Distributed. By `Stephan Hoyer `_ + + +- Support for non-standard calendars used in climate science: + + - Xarray will now always use :py:class:`cftime.datetime` objects, rather + than by default trying to coerce them into ``np.datetime64[ns]`` objects. + A :py:class:`~xarray.CFTimeIndex` will be used for indexing along time + coordinates in these cases. + - A new method :py:meth:`~xarray.CFTimeIndex.to_datetimeindex` has been added + to aid in converting from a :py:class:`~xarray.CFTimeIndex` to a + :py:class:`pandas.DatetimeIndex` for the remaining use-cases where + using a :py:class:`~xarray.CFTimeIndex` is still a limitation (e.g. for + resample or plotting). + - Setting the ``enable_cftimeindex`` option is now a no-op and emits a + ``FutureWarning``. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`xarray.DataArray.plot.line` can now accept multidimensional + coordinate variables as input. `hue` must be a dimension name in this case. + (:issue:`2407`) + By `Deepak Cherian `_. +- Added support for Python 3.7. (:issue:`2271`). + By `Joe Hamman `_. +- Added support for plotting data with `pandas.Interval` coordinates, such as those + created by :py:meth:`~xarray.DataArray.groupby_bins` + By `Maximilian Maahn `_. +- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a + CFTimeIndex by a specified frequency. (:issue:`2244`). + By `Spencer Clark `_. +- Added support for using ``cftime.datetime`` coordinates with + :py:meth:`~xarray.DataArray.differentiate`, + :py:meth:`~xarray.Dataset.differentiate`, + :py:meth:`~xarray.DataArray.interp`, and + :py:meth:`~xarray.Dataset.interp`. + By `Spencer Clark `_ +- There is now a global option to either always keep or always discard + dataset and dataarray attrs upon operations. The option is set with + ``xarray.set_options(keep_attrs=True)``, and the default is to use the old + behaviour. + By `Tom Nicholas `_. +- Added a new backend for the GRIB file format based on ECMWF *cfgrib* + python driver and *ecCodes* C-library. (:issue:`2475`) + By `Alessandro Amici `_, + sponsored by `ECMWF `_. +- Resample now supports a dictionary mapping from dimension to frequency as + its first argument, e.g., ``data.resample({'time': '1D'}).mean()``. This is + consistent with other xarray functions that accept either dictionaries or + keyword arguments. By `Stephan Hoyer `_. + +- The preferred way to access tutorial data is now to load it lazily with + :py:meth:`xarray.tutorial.open_dataset`. + :py:meth:`xarray.tutorial.load_dataset` calls `Dataset.load()` prior + to returning (and is now deprecated). This was changed in order to facilitate + using tutorial datasets with dask. + By `Joe Hamman `_. + +Bug fixes +~~~~~~~~~ + +- ``FacetGrid`` now properly uses the ``cbar_kwargs`` keyword argument. + (:issue:`1504`, :issue:`1717`) + By `Deepak Cherian `_. +- Addition and subtraction operators used with a CFTimeIndex now preserve the + index's type. (:issue:`2244`). + By `Spencer Clark `_. +- We now properly handle arrays of ``datetime.datetime`` and ``datetime.timedelta`` + provided as coordinates. (:issue:`2512`) + By `Deepak Cherian `_. +- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override + the norm's ``vmin`` and ``vmax``. (:issue:`2381`) + By `Deepak Cherian `_. +- ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. + (:issue:`2240`) + By `Keisuke Fujii `_. +- Restore matplotlib's default of plotting dashed negative contours when + a single color is passed to ``DataArray.contour()`` e.g. ``colors='k'``. + By `Deepak Cherian `_. + + +- Fix a bug that caused some indexing operations on arrays opened with + ``open_rasterio`` to error (:issue:`2454`). + By `Stephan Hoyer `_. + +- Subtracting one CFTimeIndex from another now returns a + ``pandas.TimedeltaIndex``, analogous to the behavior for DatetimeIndexes + (:issue:`2484`). By `Spencer Clark `_. +- Adding a TimedeltaIndex to, or subtracting a TimedeltaIndex from a + CFTimeIndex is now allowed (:issue:`2484`). + By `Spencer Clark `_. +- Avoid use of Dask's deprecated ``get=`` parameter in tests + by `Matthew Rocklin `_. +- An ``OverflowError`` is now accurately raised and caught during the + encoding process if a reference date is used that is so distant that + the dates must be encoded using cftime rather than NumPy (:issue:`2272`). + By `Spencer Clark `_. + +- Chunked datasets can now roundtrip to Zarr storage continually + with `to_zarr` and ``open_zarr`` (:issue:`2300`). + By `Lily Wang `_. + +.. _whats-new.0.10.9: + +v0.10.9 (21 September 2018) +--------------------------- + +This minor release contains a number of backwards compatible enhancements. + +Announcements of note: + +- Xarray is now a NumFOCUS fiscally sponsored project! Read + `the anouncement `_ + for more details. +- We have a new :doc:`roadmap` that outlines our future development plans. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.differentiate` and + :py:meth:`~xarray.Dataset.differentiate` are newly added. + (:issue:`1332`) + By `Keisuke Fujii `_. +- Default colormap for sequential and divergent data can now be set via + :py:func:`~xarray.set_options()` + (:issue:`2394`) + By `Julius Busecke `_. + +- min_count option is newly supported in :py:meth:`~xarray.DataArray.sum`, + :py:meth:`~xarray.DataArray.prod` and :py:meth:`~xarray.Dataset.sum`, and + :py:meth:`~xarray.Dataset.prod`. + (:issue:`2230`) + By `Keisuke Fujii `_. + +- :py:meth:`plot()` now accepts the kwargs + ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits. + By `Deepak Cherian `_. (:issue:`2224`) + +- DataArray coordinates and Dataset coordinates and data variables are + now displayed as `a b ... y z` rather than `a b c d ...`. + (:issue:`1186`) + By `Seth P `_. +- A new CFTimeIndex-enabled :py:func:`cftime_range` function for use in + generating dates from standard or non-standard calendars. By `Spencer Clark + `_. + +- When interpolating over a ``datetime64`` axis, you can now provide a datetime string instead of a ``datetime64`` object. E.g. ``da.interp(time='1991-02-01')`` + (:issue:`2284`) + By `Deepak Cherian `_. + +- A clear error message is now displayed if a ``set`` or ``dict`` is passed in place of an array + (:issue:`2331`) + By `Maximilian Roos `_. + +- Applying ``unstack`` to a large DataArray or Dataset is now much faster if the MultiIndex has not been modified after stacking the indices. + (:issue:`1560`) + By `Maximilian Maahn `_. + +- You can now control whether or not to offset the coordinates when using + the ``roll`` method and the current behavior, coordinates rolled by default, + raises a deprecation warning unless explicitly setting the keyword argument. + (:issue:`1875`) + By `Andrew Huang `_. + +- You can now call ``unstack`` without arguments to unstack every MultiIndex in a DataArray or Dataset. + By `Julia Signell `_. + +- Added the ability to pass a data kwarg to ``copy`` to create a new object with the + same metadata as the original object but using new values. + By `Julia Signell `_. + +Bug fixes +~~~~~~~~~ + +- ``xarray.plot.imshow()`` correctly uses the ``origin`` argument. + (:issue:`2379`) + By `Deepak Cherian `_. + +- Fixed ``DataArray.to_iris()`` failure while creating ``DimCoord`` by + falling back to creating ``AuxCoord``. Fixed dependency on ``var_name`` + attribute being set. + (:issue:`2201`) + By `Thomas Voigt `_. +- Fixed a bug in ``zarr`` backend which prevented use with datasets with + invalid chunk size encoding after reading from an existing store + (:issue:`2278`). + By `Joe Hamman `_. + +- Tests can be run in parallel with pytest-xdist + By `Tony Tung `_. + +- Follow up the renamings in dask; from dask.ghost to dask.overlap + By `Keisuke Fujii `_. + +- Now raises a ValueError when there is a conflict between dimension names and + level names of MultiIndex. (:issue:`2299`) + By `Keisuke Fujii `_. + +- Follow up the renamings in dask; from dask.ghost to dask.overlap + By `Keisuke Fujii `_. + +- Now :py:func:`xr.apply_ufunc` raises a ValueError when the size of + ``input_core_dims`` is inconsistent with the number of arguments. + (:issue:`2341`) + By `Keisuke Fujii `_. + +- Fixed ``Dataset.filter_by_attrs()`` behavior not matching ``netCDF4.Dataset.get_variables_by_attributes()``. + When more than one ``key=value`` is passed into ``Dataset.filter_by_attrs()`` it will now return a Dataset with variables which pass + all the filters. + (:issue:`2315`) + By `Andrew Barna `_. + +.. _whats-new.0.10.8: + +v0.10.8 (18 July 2018) +---------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Xarray no longer supports python 3.4. Additionally, the minimum supported + versions of the following dependencies has been updated and/or clarified: + + - Pandas: 0.18 -> 0.19 + - NumPy: 1.11 -> 1.12 + - Dask: 0.9 -> 0.16 + - Matplotlib: unspecified -> 1.5 + + (:issue:`2204`). By `Joe Hamman `_. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.interp_like` and + :py:meth:`~xarray.Dataset.interp_like` methods are newly added. + (:issue:`2218`) + By `Keisuke Fujii `_. + +- Added support for curvilinear and unstructured generic grids + to :py:meth:`~xarray.DataArray.to_cdms2` and + :py:meth:`~xarray.DataArray.from_cdms2` (:issue:`2262`). + By `Stephane Raynaud `_. + +Bug fixes +~~~~~~~~~ + +- Fixed a bug in ``zarr`` backend which prevented use with datasets with + incomplete chunks in multiple dimensions (:issue:`2225`). + By `Joe Hamman `_. + +- Fixed a bug in :py:meth:`~Dataset.to_netcdf` which prevented writing + datasets when the arrays had different chunk sizes (:issue:`2254`). + By `Mike Neish `_. + +- Fixed masking during the conversion to cdms2 objects by + :py:meth:`~xarray.DataArray.to_cdms2` (:issue:`2262`). + By `Stephane Raynaud `_. + +- Fixed a bug in 2D plots which incorrectly raised an error when 2D coordinates + weren't monotonic (:issue:`2250`). + By `Fabien Maussion `_. + +- Fixed warning raised in :py:meth:`~Dataset.to_netcdf` due to deprecation of + `effective_get` in dask (:issue:`2238`). + By `Joe Hamman `_. + +.. _whats-new.0.10.7: + +v0.10.7 (7 June 2018) +--------------------- + +Enhancements +~~~~~~~~~~~~ + +- Plot labels now make use of metadata that follow CF conventions + (:issue:`2135`). + By `Deepak Cherian `_ and `Ryan Abernathey `_. + +- Line plots now support facetting with ``row`` and ``col`` arguments + (:issue:`2107`). + By `Yohai Bar Sinai `_. + +- :py:meth:`~xarray.DataArray.interp` and :py:meth:`~xarray.Dataset.interp` + methods are newly added. + See :ref:`interpolating values with interp` for the detail. + (:issue:`2079`) + By `Keisuke Fujii `_. + +Bug fixes +~~~~~~~~~ + +- Fixed a bug in ``rasterio`` backend which prevented use with ``distributed``. + The ``rasterio`` backend now returns pickleable objects (:issue:`2021`). + By `Joe Hamman `_. + +.. _whats-new.0.10.6: + +v0.10.6 (31 May 2018) +--------------------- + +The minor release includes a number of bug-fixes and backwards compatible +enhancements. + +Enhancements +~~~~~~~~~~~~ + +- New PseudoNetCDF backend for many Atmospheric data formats including + GEOS-Chem, CAMx, NOAA arlpacked bit and many others. See + :ref:`io.PseudoNetCDF` for more details. + By `Barron Henderson `_. + +- The :py:class:`Dataset` constructor now aligns :py:class:`DataArray` + arguments in ``data_vars`` to indexes set explicitly in ``coords``, + where previously an error would be raised. + (:issue:`674`) + By `Maximilian Roos `_. + +- :py:meth:`~DataArray.sel`, :py:meth:`~DataArray.isel` & :py:meth:`~DataArray.reindex`, + (and their :py:class:`Dataset` counterparts) now support supplying a ``dict`` + as a first argument, as an alternative to the existing approach + of supplying `kwargs`. This allows for more robust behavior + of dimension names which conflict with other keyword names, or are + not strings. + By `Maximilian Roos `_. + +- :py:meth:`~DataArray.rename` now supports supplying ``**kwargs``, as an + alternative to the existing approach of supplying a ``dict`` as the + first argument. + By `Maximilian Roos `_. + +- :py:meth:`~DataArray.cumsum` and :py:meth:`~DataArray.cumprod` now support + aggregation over multiple dimensions at the same time. This is the default + behavior when dimensions are not specified (previously this raised an error). + By `Stephan Hoyer `_ + +- :py:meth:`DataArray.dot` and :py:func:`dot` are partly supported with older + dask<0.17.4. (related to :issue:`2203`) + By `Keisuke Fujii `_. + +- Xarray now uses `Versioneer `__ + to manage its version strings. (:issue:`1300`). + By `Joe Hamman `_. + +Bug fixes +~~~~~~~~~ + +- Fixed a regression in 0.10.4, where explicitly specifying ``dtype='S1'`` or + ``dtype=str`` in ``encoding`` with ``to_netcdf()`` raised an error + (:issue:`2149`). + `Stephan Hoyer `_ + +- :py:func:`apply_ufunc` now directly validates output variables + (:issue:`1931`). + By `Stephan Hoyer `_. + +- Fixed a bug where ``to_netcdf(..., unlimited_dims='bar')`` yielded NetCDF + files with spurious 0-length dimensions (i.e. ``b``, ``a``, and ``r``) + (:issue:`2134`). + By `Joe Hamman `_. + +- Removed spurious warnings with ``Dataset.update(Dataset)`` (:issue:`2161`) + and ``array.equals(array)`` when ``array`` contains ``NaT`` (:issue:`2162`). + By `Stephan Hoyer `_. + +- Aggregations with :py:meth:`Dataset.reduce` (including ``mean``, ``sum``, + etc) no longer drop unrelated coordinates (:issue:`1470`). Also fixed a + bug where non-scalar data-variables that did not include the aggregation + dimension were improperly skipped. + By `Stephan Hoyer `_ + +- Fix :meth:`~DataArray.stack` with non-unique coordinates on pandas 0.23 + (:issue:`2160`). + By `Stephan Hoyer `_ + +- Selecting data indexed by a length-1 ``CFTimeIndex`` with a slice of strings + now behaves as it does when using a length-1 ``DatetimeIndex`` (i.e. it no + longer falsely returns an empty array when the slice includes the value in + the index) (:issue:`2165`). + By `Spencer Clark `_. + +- Fix ``DataArray.groupby().reduce()`` mutating coordinates on the input array + when grouping over dimension coordinates with duplicated entries + (:issue:`2153`). + By `Stephan Hoyer `_ + +- Fix ``Dataset.to_netcdf()`` cannot create group with ``engine="h5netcdf"`` + (:issue:`2177`). + By `Stephan Hoyer `_ + +.. _whats-new.0.10.4: + +v0.10.4 (16 May 2018) +---------------------- + +The minor release includes a number of bug-fixes and backwards compatible +enhancements. A highlight is ``CFTimeIndex``, which offers support for +non-standard calendars used in climate modeling. + +Documentation +~~~~~~~~~~~~~ + +- New FAQ entry, :ref:`faq.other_projects`. + By `Deepak Cherian `_. +- :ref:`assigning_values` now includes examples on how to select and assign + values to a :py:class:`~xarray.DataArray` with ``.loc``. + By `Chiara Lepore `_. + +Enhancements +~~~~~~~~~~~~ + +- Add an option for using a ``CFTimeIndex`` for indexing times with + non-standard calendars and/or outside the Timestamp-valid range; this index + enables a subset of the functionality of a standard + ``pandas.DatetimeIndex``. + See :ref:`CFTimeIndex` for full details. + (:issue:`789`, :issue:`1084`, :issue:`1252`) + By `Spencer Clark `_ with help from + `Stephan Hoyer `_. +- Allow for serialization of ``cftime.datetime`` objects (:issue:`789`, + :issue:`1084`, :issue:`2008`, :issue:`1252`) using the standalone ``cftime`` + library. + By `Spencer Clark `_. +- Support writing lists of strings as netCDF attributes (:issue:`2044`). + By `Dan Nowacki `_. +- :py:meth:`~xarray.Dataset.to_netcdf` with ``engine='h5netcdf'`` now accepts h5py + encoding settings ``compression`` and ``compression_opts``, along with the + NetCDF4-Python style settings ``gzip=True`` and ``complevel``. + This allows using any compression plugin installed in hdf5, e.g. LZF + (:issue:`1536`). By `Guido Imperiale `_. +- :py:meth:`~xarray.dot` on dask-backed data will now call :func:`dask.array.einsum`. + This greatly boosts speed and allows chunking on the core dims. + The function now requires dask >= 0.17.3 to work on dask-backed data + (:issue:`2074`). By `Guido Imperiale `_. +- ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change + the direction of the respective axes. + By `Deepak Cherian `_. + +- Added the ``parallel`` option to :py:func:`open_mfdataset`. This option uses + ``dask.delayed`` to parallelize the open and preprocessing steps within + ``open_mfdataset``. This is expected to provide performance improvements when + opening many files, particularly when used in conjunction with dask's + multiprocessing or distributed schedulers (:issue:`1981`). + By `Joe Hamman `_. + +- New ``compute`` option in :py:meth:`~xarray.Dataset.to_netcdf`, + :py:meth:`~xarray.Dataset.to_zarr`, and :py:func:`~xarray.save_mfdataset` to + allow for the lazy computation of netCDF and zarr stores. This feature is + currently only supported by the netCDF4 and zarr backends. (:issue:`1784`). + By `Joe Hamman `_. + + +Bug fixes +~~~~~~~~~ + +- ``ValueError`` is raised when coordinates with the wrong size are assigned to + a :py:class:`DataArray`. (:issue:`2112`) + By `Keisuke Fujii `_. +- Fixed a bug in :py:meth:`~xarary.DatasArray.rolling` with bottleneck. Also, + fixed a bug in rolling an integer dask array. (:issue:`2113`) + By `Keisuke Fujii `_. +- Fixed a bug where `keep_attrs=True` flag was neglected if + :py:func:`apply_ufunc` was used with :py:class:`Variable`. (:issue:`2114`) + By `Keisuke Fujii `_. +- When assigning a :py:class:`DataArray` to :py:class:`Dataset`, any conflicted + non-dimensional coordinates of the DataArray are now dropped. + (:issue:`2068`) + By `Keisuke Fujii `_. +- Better error handling in ``open_mfdataset`` (:issue:`2077`). + By `Stephan Hoyer `_. +- ``plot.line()`` does not call ``autofmt_xdate()`` anymore. Instead it changes + the rotation and horizontal alignment of labels without removing the x-axes of + any other subplots in the figure (if any). + By `Deepak Cherian `_. +- Colorbar limits are now determined by excluding ±Infs too. + By `Deepak Cherian `_. + By `Joe Hamman `_. +- Fixed ``to_iris`` to maintain lazy dask array after conversion (:issue:`2046`). + By `Alex Hilson `_ and `Stephan Hoyer `_. + +.. _whats-new.0.10.3: + +v0.10.3 (13 April 2018) +------------------------ + +The minor release includes a number of bug-fixes and backwards compatible enhancements. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.isin` and :py:meth:`~xarray.Dataset.isin` methods, + which test each value in the array for whether it is contained in the + supplied list, returning a bool array. See :ref:`selecting values with isin` + for full details. Similar to the ``np.isin`` function. + By `Maximilian Roos `_. +- Some speed improvement to construct :py:class:`~xarray.DataArrayRolling` + object (:issue:`1993`) + By `Keisuke Fujii `_. +- Handle variables with different values for ``missing_value`` and + ``_FillValue`` by masking values for both attributes; previously this + resulted in a ``ValueError``. (:issue:`2016`) + By `Ryan May `_. + +Bug fixes +~~~~~~~~~ + +- Fixed ``decode_cf`` function to operate lazily on dask arrays + (:issue:`1372`). By `Ryan Abernathey `_. +- Fixed labeled indexing with slice bounds given by xarray objects with + datetime64 or timedelta64 dtypes (:issue:`1240`). + By `Stephan Hoyer `_. +- Attempting to convert an xarray.Dataset into a numpy array now raises an + informative error message. + By `Stephan Hoyer `_. +- Fixed a bug in decode_cf_datetime where ``int32`` arrays weren't parsed + correctly (:issue:`2002`). + By `Fabien Maussion `_. +- When calling `xr.auto_combine()` or `xr.open_mfdataset()` with a `concat_dim`, + the resulting dataset will have that one-element dimension (it was + silently dropped, previously) (:issue:`1988`). + By `Ben Root `_. + +.. _whats-new.0.10.2: + +v0.10.2 (13 March 2018) +----------------------- + +The minor release includes a number of bug-fixes and enhancements, along with +one possibly **backwards incompatible change**. + +Backwards incompatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The addition of ``__array_ufunc__`` for xarray objects (see below) means that + NumPy `ufunc methods`_ (e.g., ``np.add.reduce``) that previously worked on + ``xarray.DataArray`` objects by converting them into NumPy arrays will now + raise ``NotImplementedError`` instead. In all cases, the work-around is + simple: convert your objects explicitly into NumPy arrays before calling the + ufunc (e.g., with ``.values``). + +.. _ufunc methods: https://docs.scipy.org/doc/numpy/reference/ufuncs.html#methods + +Enhancements +~~~~~~~~~~~~ + +- Added :py:func:`~xarray.dot`, equivalent to :py:func:`np.einsum`. + Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option, + which specifies the dimensions to sum over. + (:issue:`1951`) + By `Keisuke Fujii `_. + +- Support for writing xarray datasets to netCDF files (netcdf4 backend only) + when using the `dask.distributed `_ + scheduler (:issue:`1464`). + By `Joe Hamman `_. + +- Support lazy vectorized-indexing. After this change, flexible indexing such + as orthogonal/vectorized indexing, becomes possible for all the backend + arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`) + By `Keisuke Fujii `_. + +- Implemented NumPy's ``__array_ufunc__`` protocol for all xarray objects + (:issue:`1617`). This enables using NumPy ufuncs directly on + ``xarray.Dataset`` objects with recent versions of NumPy (v1.13 and newer): + + .. ipython:: python + + ds = xr.Dataset({'a': 1}) + np.sin(ds) + + This obliviates the need for the ``xarray.ufuncs`` module, which will be + deprecated in the future when xarray drops support for older versions of + NumPy. By `Stephan Hoyer `_. + +- Improve :py:func:`~xarray.DataArray.rolling` logic. + :py:func:`~xarray.DataArrayRolling` object now supports + :py:func:`~xarray.DataArrayRolling.construct` method that returns a view + of the DataArray / Dataset object with the rolling-window dimension added + to the last axis. This enables more flexible operation, such as strided + rolling, windowed rolling, ND-rolling, short-time FFT and convolution. + (:issue:`1831`, :issue:`1142`, :issue:`819`) + By `Keisuke Fujii `_. +- :py:func:`~plot.line()` learned to make plots with data on x-axis if so specified. (:issue:`575`) + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ + +- Raise an informative error message when using ``apply_ufunc`` with numpy + v1.11 (:issue:`1956`). + By `Stephan Hoyer `_. +- Fix the precision drop after indexing datetime64 arrays (:issue:`1932`). + By `Keisuke Fujii `_. +- Silenced irrelevant warnings issued by ``open_rasterio`` (:issue:`1964`). + By `Stephan Hoyer `_. +- Fix kwarg `colors` clashing with auto-inferred `cmap` (:issue:`1461`) + By `Deepak Cherian `_. +- Fix :py:func:`~xarray.plot.imshow` error when passed an RGB array with + size one in a spatial dimension. + By `Zac Hatfield-Dodds `_. + +.. _whats-new.0.10.1: + +v0.10.1 (25 February 2018) +-------------------------- + +The minor release includes a number of bug-fixes and backwards compatible enhancements. + Documentation ~~~~~~~~~~~~~ -- Added apply_ufunc example to toy weather data page (:issue:`1844`). - By `Liam Brannigan ` _. +- Added a new guide on :ref:`contributing` (:issue:`640`) + By `Joe Hamman `_. +- Added apply_ufunc example to :ref:`toy weather data` (:issue:`1844`). + By `Liam Brannigan `_. - New entry `Why don’t aggregations return Python scalars?` in the :doc:`faq` (:issue:`1726`). By `0x0L `_. Enhancements ~~~~~~~~~~~~ -- reduce methods such as :py:func:`DataArray.sum()` now accepts ``dtype`` +**New functions and methods**: + +- Added :py:meth:`DataArray.to_iris` and + :py:meth:`DataArray.from_iris` for + converting data arrays to and from Iris_ Cubes with the same data and coordinates + (:issue:`621` and :issue:`37`). + By `Neil Parley `_ and `Duncan Watson-Parris `_. +- Experimental support for using `Zarr`_ as storage layer for xarray + (:issue:`1223`). + By `Ryan Abernathey `_ and + `Joe Hamman `_. +- New :py:meth:`~xarray.DataArray.rank` on arrays and datasets. Requires + bottleneck (:issue:`1731`). + By `0x0L `_. +- ``.dt`` accessor can now ceil, floor and round timestamps to specified frequency. + By `Deepak Cherian `_. + +**Plotting enhancements**: + +- :func:`xarray.plot.imshow` now handles RGB and RGBA images. + Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``. + By `Zac Hatfield-Dodds `_. +- :py:func:`~plot.contourf()` learned to contour 2D variables that have both a + 1D coordinate (e.g. time) and a 2D coordinate (e.g. depth as a function of + time) (:issue:`1737`). + By `Deepak Cherian `_. +- :py:func:`~plot()` rotates x-axis ticks if x-axis is time. + By `Deepak Cherian `_. +- :py:func:`~plot.line()` can draw multiple lines if provided with a + 2D variable. + By `Deepak Cherian `_. + +**Other enhancements**: + +- Reduce methods such as :py:func:`DataArray.sum()` now handles object-type array. + + .. ipython:: python + + da = xr.DataArray(np.array([True, False, np.nan], dtype=object), dims='x') + da.sum() + + (:issue:`1866`) + By `Keisuke Fujii `_. +- Reduce methods such as :py:func:`DataArray.sum()` now accepts ``dtype`` arguments. (:issue:`1838`) By `Keisuke Fujii `_. - Added nodatavals attribute to DataArray when using :py:func:`~xarray.open_rasterio`. (:issue:`1736`). By `Alan Snow `_. -- :py:func:`~plot.contourf()` learned to contour 2D variables that have both a - 1D co-ordinate (e.g. time) and a 2D co-ordinate (e.g. depth as a function of - time) (:issue:`1737`). - By `Deepak Cherian `_. -- Added :py:meth:`DataArray.to_iris ` and :py:meth:`DataArray.from_iris ` for - converting data arrays to and from Iris_ Cubes with the same data and coordinates (:issue:`621` and :issue:`37`). - By `Neil Parley `_ and `Duncan Watson-Parris `_. - Use ``pandas.Grouper`` class in xarray resample methods rather than the deprecated ``pandas.TimeGrouper`` class (:issue:`1766`). By `Joe Hamman `_. -- Support for using `Zarr`_ as storage layer for xarray. (:issue:`1223`). - By `Ryan Abernathey `_ and - `Joe Hamman `_. -- Support for using `Zarr`_ as storage layer for xarray. - By `Ryan Abernathey `_. -- :func:`xarray.plot.imshow` now handles RGB and RGBA images. - Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``. - By `Zac Hatfield-Dodds `_. - Experimental support for parsing ENVI metadata to coordinates and attributes in :py:func:`xarray.open_rasterio`. By `Matti Eskelinen `_. -- :py:func:`~plot()` learned to rotate x-axis ticks if x-axis is time. - By `Deepak Cherian `_. -- :py:func:`~plot.line()` learned to draw multiple lines if provided with a - 2D variable. - By `Deepak Cherian `_. - Reduce memory usage when decoding a variable with a scale_factor, by converting 8-bit and 16-bit integers to float32 instead of float64 (:pull:`1840`), and keeping float16 and float32 as float32 (:issue:`1842`). Correspondingly, encoded variables may also be saved with a smaller dtype. By `Zac Hatfield-Dodds `_. +- Speed of reindexing/alignment with dask array is orders of magnitude faster + when inserting missing values (:issue:`1847`). + By `Stephan Hoyer `_. +- Fix ``axis`` keyword ignored when applying ``np.squeeze`` to ``DataArray`` (:issue:`1487`). + By `Florian Pinault `_. +- ``netcdf4-python`` has moved the its time handling in the ``netcdftime`` module to + a standalone package (`netcdftime`_). As such, xarray now considers `netcdftime`_ + an optional dependency. One benefit of this change is that it allows for + encoding/decoding of datetimes with non-standard calendars without the + ``netcdf4-python`` dependency (:issue:`1084`). + By `Joe Hamman `_. .. _Zarr: http://zarr.readthedocs.io/ .. _Iris: http://scitools.org.uk/iris +.. _netcdftime: https://unidata.github.io/netcdftime + **New functions/methods** - New :py:meth:`~xarray.DataArray.rank` on arrays and datasets. Requires @@ -78,9 +797,15 @@ Enhancements Bug fixes ~~~~~~~~~ +- Rolling aggregation with ``center=True`` option now gives the same result + with pandas including the last element (:issue:`1046`). + By `Keisuke Fujii `_. + +- Support indexing with a 0d-np.ndarray (:issue:`1921`). + By `Keisuke Fujii `_. - Added warning in api.py of a netCDF4 bug that occurs when the filepath has 88 characters (:issue:`1745`). - By `Liam Brannigan ` _. + By `Liam Brannigan `_. - Fixed encoding of multi-dimensional coordinates in :py:meth:`~Dataset.to_netcdf` (:issue:`1763`). By `Mike Neish `_. @@ -102,7 +827,7 @@ Bug fixes with size one in some dimension can now be plotted, which is good for exploring satellite imagery (:issue:`1780`). By `Zac Hatfield-Dodds `_. -- Fixed ``UnboundLocalError`` when opening netCDF file `` (:issue:`1781`). +- Fixed ``UnboundLocalError`` when opening netCDF file (:issue:`1781`). By `Stephan Hoyer `_. - The ``variables``, ``attrs``, and ``dimensions`` properties have been deprecated as part of a bug fix addressing an issue where backends were @@ -112,9 +837,28 @@ Bug fixes - Compatibility fixes to plotting module for Numpy 1.14 and Pandas 0.22 (:issue:`1813`). By `Joe Hamman `_. +- Bug fix in encoding coordinates with ``{'_FillValue': None}`` in netCDF + metadata (:issue:`1865`). + By `Chris Roth `_. - Fix indexing with lists for arrays loaded from netCDF files with ``engine='h5netcdf`` (:issue:`1864`). By `Stephan Hoyer `_. +- Corrected a bug with incorrect coordinates for non-georeferenced geotiff + files (:issue:`1686`). Internally, we now use the rasterio coordinate + transform tool instead of doing the computations ourselves. A + ``parse_coordinates`` kwarg has beed added to :py:func:`~open_rasterio` + (set to ``True`` per default). + By `Fabien Maussion `_. +- The colors of discrete colormaps are now the same regardless if `seaborn` + is installed or not (:issue:`1896`). + By `Fabien Maussion `_. +- Fixed dtype promotion rules in :py:func:`where` and :py:func:`concat` to + match pandas (:issue:`1847`). A combination of strings/numbers or + unicode/bytes now promote to object dtype, instead of strings or unicode. + By `Stephan Hoyer `_. +- Fixed bug where :py:meth:`~xarray.DataArray.isnull` was loading data + stored as dask arrays (:issue:`1937`). + By `Joe Hamman `_. .. _whats-new.0.10.0: @@ -451,6 +1195,9 @@ Bug fixes ``apionly`` module was deprecated. (:issue:`1633`). By `Joe Hamman `_. +- Fix COMPAT: MultiIndex checking is fragile + (:issue:`1833`). By `Florian Pinault `_. + - Fix ``rasterio`` backend for Rasterio versions 1.0alpha10 and newer. (:issue:`1641`). By `Chris Holden `_. @@ -547,7 +1294,7 @@ Enhancements By `Stephan Hoyer `_. - New function :py:func:`~xarray.open_rasterio` for opening raster files with - the `rasterio `_ library. + the `rasterio `_ library. See :ref:`the docs ` for details. By `Joe Hamman `_, `Nic Wayand `_ and diff --git a/examples/xarray_multidimensional_coords.ipynb b/examples/xarray_multidimensional_coords.ipynb index bed7e8b962f..6bd942c5ba2 100644 --- a/examples/xarray_multidimensional_coords.ipynb +++ b/examples/xarray_multidimensional_coords.ipynb @@ -6,7 +6,7 @@ "source": [ "# Working with Multidimensional Coordinates\n", "\n", - "Author: [Ryan Abernathey](http://github.org/rabernat)\n", + "Author: [Ryan Abernathey](https://github.com/rabernat)\n", "\n", "Many datasets have _physical coordinates_ which differ from their _logical coordinates_. Xarray provides several ways to plot and analyze such datasets." ] diff --git a/licenses/DASK_LICENSE b/licenses/DASK_LICENSE new file mode 100644 index 00000000000..893bddfb933 --- /dev/null +++ b/licenses/DASK_LICENSE @@ -0,0 +1,28 @@ +:py:meth:`~xarray.DataArray.isin`Copyright (c) 2014-2018, Anaconda, Inc. and contributors +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +Neither the name of Anaconda nor the names of any contributors may be used to +endorse or promote products derived from this software without specific prior +written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGE. diff --git a/properties/README.md b/properties/README.md new file mode 100644 index 00000000000..711062a2473 --- /dev/null +++ b/properties/README.md @@ -0,0 +1,22 @@ +# Property-based tests using Hypothesis + +This directory contains property-based tests using a library +called [Hypothesis](https://github.com/HypothesisWorks/hypothesis-python). + +The property tests for Xarray are a work in progress - more are always welcome. +They are stored in a separate directory because they tend to run more examples +and thus take longer, and so that local development can run a test suite +without needing to `pip install hypothesis`. + +## Hang on, "property-based" tests? + +Instead of making assertions about operations on a particular piece of +data, you use Hypothesis to describe a *kind* of data, then make assertions +that should hold for *any* example of this kind. + +For example: "given a 2d ndarray of dtype uint8 `arr`, +`xr.DataArray(arr).plot.imshow()` never raises an exception". + +Hypothesis will then try many random examples, and report a minimised +failing input for each error it finds. +[See the docs for more info.](https://hypothesis.readthedocs.io/en/master/) diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py new file mode 100644 index 00000000000..13f63f259cf --- /dev/null +++ b/properties/test_encode_decode.py @@ -0,0 +1,47 @@ +""" +Property-based tests for encoding/decoding methods. + +These ones pass, just as you'd hope! + +""" +from __future__ import absolute_import, division, print_function + +import hypothesis.extra.numpy as npst +import hypothesis.strategies as st +from hypothesis import given, settings + +import xarray as xr + +# Run for a while - arrays are a bigger search space than usual +settings.register_profile("ci", deadline=None) +settings.load_profile("ci") + + +an_array = npst.arrays( + dtype=st.one_of( + npst.unsigned_integer_dtypes(), + npst.integer_dtypes(), + npst.floating_dtypes(), + ), + shape=npst.array_shapes(max_side=3), # max_side specified for performance +) + + +@given(st.data(), an_array) +def test_CFMask_coder_roundtrip(data, arr): + names = data.draw(st.lists(st.text(), min_size=arr.ndim, + max_size=arr.ndim, unique=True).map(tuple)) + original = xr.Variable(names, arr) + coder = xr.coding.variables.CFMaskCoder() + roundtripped = coder.decode(coder.encode(original)) + xr.testing.assert_identical(original, roundtripped) + + +@given(st.data(), an_array) +def test_CFScaleOffset_coder_roundtrip(data, arr): + names = data.draw(st.lists(st.text(), min_size=arr.ndim, + max_size=arr.ndim, unique=True).map(tuple)) + original = xr.Variable(names, arr) + coder = xr.coding.variables.CFScaleOffsetCoder() + roundtripped = coder.decode(coder.encode(original)) + xr.testing.assert_identical(original, roundtripped) diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 0132dbd4752..00000000000 --- a/pytest.ini +++ /dev/null @@ -1,2 +0,0 @@ -[pytest] -addopts = -p no:hypothesis diff --git a/readthedocs.yml b/readthedocs.yml index 0129abe15aa..8e9c09c9414 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,5 +1,8 @@ +build: + image: latest conda: file: doc/environment.yml python: - version: 3 - setup_py_install: true + version: 3.6 + setup_py_install: true +formats: [] diff --git a/setup.cfg b/setup.cfg index d2f336aa1d0..17f24b3f1ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,27 @@ universal = 1 [tool:pytest] python_files=test_*.py +testpaths=xarray/tests [flake8] max-line-length=79 +ignore= + W503 +exclude= + doc/ + +[isort] +default_section=THIRDPARTY +known_first_party=xarray +multi_line_output=4 + +[versioneer] +VCS = git +style = pep440 +versionfile_source = xarray/_version.py +versionfile_build = xarray/_version.py +tag_prefix = v +parentdir_prefix = xarray- + +[aliases] +test = pytest diff --git a/setup.py b/setup.py index ccffc6369e8..3b56d9265af 100644 --- a/setup.py +++ b/setup.py @@ -1,19 +1,8 @@ #!/usr/bin/env python -import os -import re import sys -import warnings - -from setuptools import setup, find_packages -from setuptools import Command - -MAJOR = 0 -MINOR = 10 -MICRO = 0 -ISRELEASED = False -VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO) -QUALIFIER = '' +import versioneer +from setuptools import find_packages, setup DISTNAME = 'xarray' LICENSE = 'Apache' @@ -29,13 +18,13 @@ 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Scientific/Engineering', ] -INSTALL_REQUIRES = ['numpy >= 1.11', 'pandas >= 0.18.0'] +INSTALL_REQUIRES = ['numpy >= 1.12', 'pandas >= 0.19.2'] TESTS_REQUIRE = ['pytest >= 2.7.1'] if sys.version_info[0] < 3: TESTS_REQUIRE.append('mock') @@ -64,79 +53,12 @@ - Issue tracker: http://github.com/pydata/xarray/issues - Source code: http://github.com/pydata/xarray - SciPy2015 talk: https://www.youtube.com/watch?v=X0pAhJgySxk -""" - -# Code to extract and write the version copied from pandas. -# Used under the terms of pandas's license, see licenses/PANDAS_LICENSE. -FULLVERSION = VERSION -write_version = True - -if not ISRELEASED: - import subprocess - FULLVERSION += '.dev' - - pipe = None - for cmd in ['git', 'git.cmd']: - try: - pipe = subprocess.Popen( - [cmd, "describe", "--always", "--match", "v[0-9]*"], - stdout=subprocess.PIPE) - (so, serr) = pipe.communicate() - if pipe.returncode == 0: - break - except: - pass - - if pipe is None or pipe.returncode != 0: - # no git, or not in git dir - if os.path.exists('xarray/version.py'): - warnings.warn("WARNING: Couldn't get git revision, using existing xarray/version.py") - write_version = False - else: - warnings.warn("WARNING: Couldn't get git revision, using generic version string") - else: - # have git, in git dir, but may have used a shallow clone (travis does this) - rev = so.strip() - # makes distutils blow up on Python 2.7 - if sys.version_info[0] >= 3: - rev = rev.decode('ascii') - - if not rev.startswith('v') and re.match("[a-zA-Z0-9]{7,9}", rev): - # partial clone, manually construct version string - # this is the format before we started using git-describe - # to get an ordering on dev version strings. - rev = "v%s+dev.%s" % (VERSION, rev) - - # Strip leading v from tags format "vx.y.z" to get th version string - FULLVERSION = rev.lstrip('v') - - # make sure we respect PEP 440 - FULLVERSION = FULLVERSION.replace("-", "+dev", 1).replace("-", ".") - -else: - FULLVERSION += QUALIFIER - - -def write_version_py(filename=None): - cnt = """\ -version = '%s' -short_version = '%s' -""" - if not filename: - filename = os.path.join( - os.path.dirname(__file__), 'xarray', 'version.py') - - a = open(filename, 'w') - try: - a.write(cnt % (FULLVERSION, VERSION)) - finally: - a.close() +""" # noqa -if write_version: - write_version_py() setup(name=DISTNAME, - version=FULLVERSION, + version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), license=LICENSE, author=AUTHOR, author_email=AUTHOR_EMAIL, @@ -146,5 +68,6 @@ def write_version_py(filename=None): install_requires=INSTALL_REQUIRES, tests_require=TESTS_REQUIRE, url=URL, + python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', packages=find_packages(), - package_data={'xarray': ['tests/data/*', 'plot/default_colormap.csv']}) + package_data={'xarray': ['tests/data/*']}) diff --git a/versioneer.py b/versioneer.py new file mode 100644 index 00000000000..dffd66b69a6 --- /dev/null +++ b/versioneer.py @@ -0,0 +1,1824 @@ + +# Version: 0.18 + +"""The Versioneer - like a rocketeer, but for versions. + +The Versioneer +============== + +* like a rocketeer, but for versions! +* https://github.com/warner/python-versioneer +* Brian Warner +* License: Public Domain +* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy +* [![Latest Version] +(https://pypip.in/version/versioneer/badge.svg?style=flat) +](https://pypi.python.org/pypi/versioneer/) +* [![Build Status] +(https://travis-ci.org/warner/python-versioneer.png?branch=master) +](https://travis-ci.org/warner/python-versioneer) + +This is a tool for managing a recorded version number in distutils-based +python projects. The goal is to remove the tedious and error-prone "update +the embedded version string" step from your release process. Making a new +release should be as easy as recording a new tag in your version-control +system, and maybe making new tarballs. + + +## Quick Install + +* `pip install versioneer` to somewhere to your $PATH +* add a `[versioneer]` section to your setup.cfg (see below) +* run `versioneer install` in your source tree, commit the results + +## Version Identifiers + +Source trees come from a variety of places: + +* a version-control system checkout (mostly used by developers) +* a nightly tarball, produced by build automation +* a snapshot tarball, produced by a web-based VCS browser, like github's + "tarball from tag" feature +* a release tarball, produced by "setup.py sdist", distributed through PyPI + +Within each source tree, the version identifier (either a string or a number, +this tool is format-agnostic) can come from a variety of places: + +* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows + about recent "tags" and an absolute revision-id +* the name of the directory into which the tarball was unpacked +* an expanded VCS keyword ($Id$, etc) +* a `_version.py` created by some earlier build step + +For released software, the version identifier is closely related to a VCS +tag. Some projects use tag names that include more than just the version +string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool +needs to strip the tag prefix to extract the version identifier. For +unreleased software (between tags), the version identifier should provide +enough information to help developers recreate the same tree, while also +giving them an idea of roughly how old the tree is (after version 1.2, before +version 1.3). Many VCS systems can report a description that captures this, +for example `git describe --tags --dirty --always` reports things like +"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the +0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has +uncommitted changes. + +The version identifier is used for multiple purposes: + +* to allow the module to self-identify its version: `myproject.__version__` +* to choose a name and prefix for a 'setup.py sdist' tarball + +## Theory of Operation + +Versioneer works by adding a special `_version.py` file into your source +tree, where your `__init__.py` can import it. This `_version.py` knows how to +dynamically ask the VCS tool for version information at import time. + +`_version.py` also contains `$Revision$` markers, and the installation +process marks `_version.py` to have this marker rewritten with a tag name +during the `git archive` command. As a result, generated tarballs will +contain enough information to get the proper version. + +To allow `setup.py` to compute a version too, a `versioneer.py` is added to +the top level of your source tree, next to `setup.py` and the `setup.cfg` +that configures it. This overrides several distutils/setuptools commands to +compute the version when invoked, and changes `setup.py build` and `setup.py +sdist` to replace `_version.py` with a small static file that contains just +the generated version data. + +## Installation + +See [INSTALL.md](./INSTALL.md) for detailed installation instructions. + +## Version-String Flavors + +Code which uses Versioneer can learn about its version string at runtime by +importing `_version` from your main `__init__.py` file and running the +`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can +import the top-level `versioneer.py` and run `get_versions()`. + +Both functions return a dictionary with different flavors of version +information: + +* `['version']`: A condensed version string, rendered using the selected + style. This is the most commonly used value for the project's version + string. The default "pep440" style yields strings like `0.11`, + `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section + below for alternative styles. + +* `['full-revisionid']`: detailed revision identifier. For Git, this is the + full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". + +* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the + commit date in ISO 8601 format. This will be None if the date is not + available. + +* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that + this is only accurate if run in a VCS checkout, otherwise it is likely to + be False or None + +* `['error']`: if the version string could not be computed, this will be set + to a string describing the problem, otherwise it will be None. It may be + useful to throw an exception in setup.py if this is set, to avoid e.g. + creating tarballs with a version string of "unknown". + +Some variants are more useful than others. Including `full-revisionid` in a +bug report should allow developers to reconstruct the exact code being tested +(or indicate the presence of local changes that should be shared with the +developers). `version` is suitable for display in an "about" box or a CLI +`--version` output: it can be easily compared against release notes and lists +of bugs fixed in various releases. + +The installer adds the following text to your `__init__.py` to place a basic +version in `YOURPROJECT.__version__`: + + from ._version import get_versions + __version__ = get_versions()['version'] + del get_versions + +## Styles + +The setup.cfg `style=` configuration controls how the VCS information is +rendered into a version string. + +The default style, "pep440", produces a PEP440-compliant string, equal to the +un-prefixed tag name for actual releases, and containing an additional "local +version" section with more detail for in-between builds. For Git, this is +TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags +--dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the +tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and +that this commit is two revisions ("+2") beyond the "0.11" tag. For released +software (exactly equal to a known tag), the identifier will only contain the +stripped tag, e.g. "0.11". + +Other styles are available. See [details.md](details.md) in the Versioneer +source tree for descriptions. + +## Debugging + +Versioneer tries to avoid fatal errors: if something goes wrong, it will tend +to return a version of "0+unknown". To investigate the problem, run `setup.py +version`, which will run the version-lookup code in a verbose mode, and will +display the full contents of `get_versions()` (including the `error` string, +which may help identify what went wrong). + +## Known Limitations + +Some situations are known to cause problems for Versioneer. This details the +most significant ones. More can be found on Github +[issues page](https://github.com/warner/python-versioneer/issues). + +### Subprojects + +Versioneer has limited support for source trees in which `setup.py` is not in +the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are +two common reasons why `setup.py` might not be in the root: + +* Source trees which contain multiple subprojects, such as + [Buildbot](https://github.com/buildbot/buildbot), which contains both + "master" and "slave" subprojects, each with their own `setup.py`, + `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI + distributions (and upload multiple independently-installable tarballs). +* Source trees whose main purpose is to contain a C library, but which also + provide bindings to Python (and perhaps other langauges) in subdirectories. + +Versioneer will look for `.git` in parent directories, and most operations +should get the right version string. However `pip` and `setuptools` have bugs +and implementation details which frequently cause `pip install .` from a +subproject directory to fail to find a correct version string (so it usually +defaults to `0+unknown`). + +`pip install --editable .` should work correctly. `setup.py install` might +work too. + +Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in +some later version. + +[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +this issue. The discussion in +[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +issue from the Versioneer side in more detail. +[pip PR#3176](https://github.com/pypa/pip/pull/3176) and +[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve +pip to let Versioneer work correctly. + +Versioneer-0.16 and earlier only looked for a `.git` directory next to the +`setup.cfg`, so subprojects were completely unsupported with those releases. + +### Editable installs with setuptools <= 18.5 + +`setup.py develop` and `pip install --editable .` allow you to install a +project into a virtualenv once, then continue editing the source code (and +test) without re-installing after every change. + +"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a +convenient way to specify executable scripts that should be installed along +with the python package. + +These both work as expected when using modern setuptools. When using +setuptools-18.5 or earlier, however, certain operations will cause +`pkg_resources.DistributionNotFound` errors when running the entrypoint +script, which must be resolved by re-installing the package. This happens +when the install happens with one version, then the egg_info data is +regenerated while a different version is checked out. Many setup.py commands +cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into +a different virtualenv), so this can be surprising. + +[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +this one, but upgrading to a newer version of setuptools should probably +resolve it. + +### Unicode version strings + +While Versioneer works (and is continually tested) with both Python 2 and +Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. +Newer releases probably generate unicode version strings on py2. It's not +clear that this is wrong, but it may be surprising for applications when then +write these strings to a network connection or include them in bytes-oriented +APIs like cryptographic checksums. + +[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates +this question. + + +## Updating Versioneer + +To upgrade your project to a new release of Versioneer, do the following: + +* install the new Versioneer (`pip install -U versioneer` or equivalent) +* edit `setup.cfg`, if necessary, to include any new configuration settings + indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install` in your source tree, to replace + `SRC/_version.py` +* commit any changed files + +## Future Directions + +This tool is designed to make it easily extended to other version-control +systems: all VCS-specific components are in separate directories like +src/git/ . The top-level `versioneer.py` script is assembled from these +components by running make-versioneer.py . In the future, make-versioneer.py +will take a VCS name as an argument, and will construct a version of +`versioneer.py` that is specific to the given VCS. It might also take the +configuration arguments that are currently provided manually during +installation by editing setup.py . Alternatively, it might go the other +direction and include code from all supported VCS systems, reducing the +number of intermediate scripts. + + +## License + +To make Versioneer easier to embed, all its code is dedicated to the public +domain. The `_version.py` that it creates is also in the public domain. +Specifically, both are released under the Creative Commons "Public Domain +Dedication" license (CC0-1.0), as described in +https://creativecommons.org/publicdomain/zero/1.0/ . + +""" + +from __future__ import print_function + +import errno +import json +import os +import re +import subprocess +import sys + +try: + import configparser +except ImportError: + import ConfigParser as configparser + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_root(): + """Get the project root directory. + + We require that all commands are run from the project root, i.e. the + directory that contains setup.py, setup.cfg, and versioneer.py . + """ + root = os.path.realpath(os.path.abspath(os.getcwd())) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + # allow 'python path/to/setup.py COMMAND' + root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + err = ("Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND').") + raise VersioneerBadRootError(err) + try: + # Certain runtime workflows (setup.py install/develop in a setuptools + # tree) execute all dependencies in a single python process, so + # "versioneer" may be imported multiple times, and python's shared + # module-import table will cache the first one. So we can't use + # os.path.dirname(__file__), as that will find whichever + # versioneer.py was first imported, even in later projects. + me = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(me)[0]) + vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) + if me_dir != vsr_dir: + print("Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(me), versioneer_py)) + except NameError: + pass + return root + + +def get_config_from_root(root): + """Read the project setup.cfg file to determine Versioneer config.""" + # This might raise EnvironmentError (if setup.cfg is missing), or + # configparser.NoSectionError (if it lacks a [versioneer] section), or + # configparser.NoOptionError (if it lacks "VCS="). See the docstring at + # the top of versioneer.py for instructions on writing your setup.cfg . + setup_cfg = os.path.join(root, "setup.cfg") + parser = configparser.SafeConfigParser() + with open(setup_cfg, "r") as f: + parser.readfp(f) + VCS = parser.get("versioneer", "VCS") # mandatory + + def get(parser, name): + if parser.has_option("versioneer", name): + return parser.get("versioneer", name) + return None + cfg = VersioneerConfig() + cfg.VCS = VCS + cfg.style = get(parser, "style") or "" + cfg.versionfile_source = get(parser, "versionfile_source") + cfg.versionfile_build = get(parser, "versionfile_build") + cfg.tag_prefix = get(parser, "tag_prefix") + if cfg.tag_prefix in ("''", '""'): + cfg.tag_prefix = "" + cfg.parentdir_prefix = get(parser, "parentdir_prefix") + cfg.verbose = get(parser, "verbose") + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +# these dictionaries contain VCS-specific tools +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, p.returncode + return stdout, p.returncode + + +LONG_VERSION_PY['git'] = ''' +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.18 (https://github.com/warner/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" + git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" + git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "%(STYLE)s" + cfg.tag_prefix = "%(TAG_PREFIX)s" + cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" + cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %%s" %% dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %%s" %% (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %%s (error)" %% dispcmd) + print("stdout was %%s" %% stdout) + return None, p.returncode + return stdout, p.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %%s but none started with prefix %%s" %% + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %%d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%%s', no digits" %% ",".join(refs - tags)) + if verbose: + print("likely tags: %%s" %% ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %%s" %% r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %%s not under git control" %% root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%%s*" %% tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%%s'" + %% describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%%s' doesn't start with prefix '%%s'" + print(fmt %% (full_tag, tag_prefix)) + pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" + %% (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%%d" %% pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%%d" %% pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%%s'" %% style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for i in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} +''' + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def do_vcs_install(manifest_in, versionfile_source, ipy): + """Git-specific installation logic for Versioneer. + + For Git, this means creating/changing .gitattributes to mark _version.py + for export-subst keyword substitution. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + files = [manifest_in, versionfile_source] + if ipy: + files.append(ipy) + try: + me = __file__ + if me.endswith(".pyc") or me.endswith(".pyo"): + me = os.path.splitext(me)[0] + ".py" + versioneer_file = os.path.relpath(me) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) + present = False + try: + f = open(".gitattributes", "r") + for line in f.readlines(): + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + f.close() + except EnvironmentError: + pass + if not present: + f = open(".gitattributes", "a+") + f.write("%s export-subst\n" % versionfile_source) + f.close() + files.append(".gitattributes") + run_command(GITS, ["add", "--"] + files) + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +SHORT_VERSION_PY = """ +# This file was generated by 'versioneer.py' (0.18) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. + +import json + +version_json = ''' +%s +''' # END VERSION_JSON + + +def get_versions(): + return json.loads(version_json) +""" + + +def versions_from_file(filename): + """Try to determine the version from _version.py if present.""" + try: + with open(filename) as f: + contents = f.read() + except EnvironmentError: + raise NotThisMethod("unable to read _version.py") + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) + if not mo: + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) + if not mo: + raise NotThisMethod("no version_json in _version.py") + return json.loads(mo.group(1)) + + +def write_to_version_file(filename, versions): + """Write the given version number to the given _version.py file.""" + os.unlink(filename) + contents = json.dumps(versions, sort_keys=True, + indent=1, separators=(",", ": ")) + with open(filename, "w") as f: + f.write(SHORT_VERSION_PY % contents) + + print("set %s to '%s'" % (filename, versions["version"])) + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +class VersioneerBadRootError(Exception): + """The project root directory is unknown or missing key files.""" + + +def get_versions(verbose=False): + """Get the project version from whatever source is available. + + Returns dict with two keys: 'version' and 'full'. + """ + if "versioneer" in sys.modules: + # see the discussion in cmdclass.py:get_cmdclass() + del sys.modules["versioneer"] + + root = get_root() + cfg = get_config_from_root(root) + + assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" + handlers = HANDLERS.get(cfg.VCS) + assert handlers, "unrecognized VCS '%s'" % cfg.VCS + verbose = verbose or cfg.verbose + assert cfg.versionfile_source is not None, \ + "please set versioneer.versionfile_source" + assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" + + versionfile_abs = os.path.join(root, cfg.versionfile_source) + + # extract version from first of: _version.py, VCS command (e.g. 'git + # describe'), parentdir. This is meant to work for developers using a + # source checkout, for users of a tarball created by 'setup.py sdist', + # and for users of a tarball/zipball created by 'git archive' or github's + # download-from-tag feature or the equivalent in other VCSes. + + get_keywords_f = handlers.get("get_keywords") + from_keywords_f = handlers.get("keywords") + if get_keywords_f and from_keywords_f: + try: + keywords = get_keywords_f(versionfile_abs) + ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) + if verbose: + print("got version from expanded keyword %s" % ver) + return ver + except NotThisMethod: + pass + + try: + ver = versions_from_file(versionfile_abs) + if verbose: + print("got version from file %s %s" % (versionfile_abs, ver)) + return ver + except NotThisMethod: + pass + + from_vcs_f = handlers.get("pieces_from_vcs") + if from_vcs_f: + try: + pieces = from_vcs_f(cfg.tag_prefix, root, verbose) + ver = render(pieces, cfg.style) + if verbose: + print("got version from VCS %s" % ver) + return ver + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + if verbose: + print("got version from parentdir %s" % ver) + return ver + except NotThisMethod: + pass + + if verbose: + print("unable to compute version") + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, "error": "unable to compute version", + "date": None} + + +def get_version(): + """Get the short version string for this project.""" + return get_versions()["version"] + + +def get_cmdclass(): + """Get the custom setuptools/distutils subclasses used by Versioneer.""" + if "versioneer" in sys.modules: + del sys.modules["versioneer"] + # this fixes the "python setup.py develop" case (also 'install' and + # 'easy_install .'), in which subdependencies of the main project are + # built (using setup.py bdist_egg) in the same python process. Assume + # a main project A and a dependency B, which use different versions + # of Versioneer. A's setup.py imports A's Versioneer, leaving it in + # sys.modules by the time B's setup.py is executed, causing B to run + # with the wrong versioneer. Setuptools wraps the sub-dep builds in a + # sandbox that restores sys.modules to it's pre-build state, so the + # parent is protected against the child's "import versioneer". By + # removing ourselves from sys.modules here, before the child build + # happens, we protect the child from the parent's versioneer too. + # Also see https://github.com/warner/python-versioneer/issues/52 + + cmds = {} + + # we add "version" to both distutils and setuptools + from distutils.core import Command + + class cmd_version(Command): + description = "report generated version string" + user_options = [] + boolean_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + vers = get_versions(verbose=True) + print("Version: %s" % vers["version"]) + print(" full-revisionid: %s" % vers.get("full-revisionid")) + print(" dirty: %s" % vers.get("dirty")) + print(" date: %s" % vers.get("date")) + if vers["error"]: + print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version + + # we override "build_py" in both distutils and setuptools + # + # most invocation pathways end up running build_py: + # distutils/build -> build_py + # distutils/install -> distutils/build ->.. + # setuptools/bdist_wheel -> distutils/install ->.. + # setuptools/bdist_egg -> distutils/install_lib -> build_py + # setuptools/install -> bdist_egg ->.. + # setuptools/develop -> ? + # pip install: + # copies source tree to a tempdir before running egg_info/etc + # if .git isn't copied too, 'git describe' will fail + # then does setup.py bdist_wheel, or sometimes setup.py install + # setup.py egg_info -> ? + + # we override different "build_py" commands for both environments + if "setuptools" in sys.modules: + from setuptools.command.build_py import build_py as _build_py + else: + from distutils.command.build_py import build_py as _build_py + + class cmd_build_py(_build_py): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_py.run(self) + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if cfg.versionfile_build: + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py + + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string + # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. + # setup(console=[{ + # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION + # "product_version": versioneer.get_version(), + # ... + + class cmd_build_exe(_build_exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _build_exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + cmds["build_exe"] = cmd_build_exe + del cmds["build_py"] + + if 'py2exe' in sys.modules: # py2exe enabled? + try: + from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + except ImportError: + from py2exe.build_exe import py2exe as _py2exe # py2 + + class cmd_py2exe(_py2exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _py2exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + cmds["py2exe"] = cmd_py2exe + + # we override different "sdist" commands for both environments + if "setuptools" in sys.modules: + from setuptools.command.sdist import sdist as _sdist + else: + from distutils.command.sdist import sdist as _sdist + + class cmd_sdist(_sdist): + def run(self): + versions = get_versions() + self._versioneer_generated_versions = versions + # unless we update this, the command will keep using the old + # version + self.distribution.metadata.version = versions["version"] + return _sdist.run(self) + + def make_release_tree(self, base_dir, files): + root = get_root() + cfg = get_config_from_root(root) + _sdist.make_release_tree(self, base_dir, files) + # now locate _version.py in the new base_dir directory + # (remembering that it may be a hardlink) and replace it with an + # updated value + target_versionfile = os.path.join(base_dir, cfg.versionfile_source) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, + self._versioneer_generated_versions) + cmds["sdist"] = cmd_sdist + + return cmds + + +CONFIG_ERROR = """ +setup.cfg is missing the necessary Versioneer configuration. You need +a section like: + + [versioneer] + VCS = git + style = pep440 + versionfile_source = src/myproject/_version.py + versionfile_build = myproject/_version.py + tag_prefix = + parentdir_prefix = myproject- + +You will also need to edit your setup.py to use the results: + + import versioneer + setup(version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), ...) + +Please read the docstring in ./versioneer.py for configuration instructions, +edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. +""" + +SAMPLE_CONFIG = """ +# See the docstring in versioneer.py for instructions. Note that you must +# re-run 'versioneer.py setup' after changing this section, and commit the +# resulting files. + +[versioneer] +#VCS = git +#style = pep440 +#versionfile_source = +#versionfile_build = +#tag_prefix = +#parentdir_prefix = + +""" + +INIT_PY_SNIPPET = """ +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions +""" + + +def do_setup(): + """Main VCS-independent setup function for installing Versioneer.""" + root = get_root() + try: + cfg = get_config_from_root(root) + except (EnvironmentError, configparser.NoSectionError, + configparser.NoOptionError) as e: + if isinstance(e, (EnvironmentError, configparser.NoSectionError)): + print("Adding sample versioneer config to setup.cfg", + file=sys.stderr) + with open(os.path.join(root, "setup.cfg"), "a") as f: + f.write(SAMPLE_CONFIG) + print(CONFIG_ERROR, file=sys.stderr) + return 1 + + print(" creating %s" % cfg.versionfile_source) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), + "__init__.py") + if os.path.exists(ipy): + try: + with open(ipy, "r") as f: + old = f.read() + except EnvironmentError: + old = "" + if INIT_PY_SNIPPET not in old: + print(" appending to %s" % ipy) + with open(ipy, "a") as f: + f.write(INIT_PY_SNIPPET) + else: + print(" %s unmodified" % ipy) + else: + print(" %s doesn't exist, ok" % ipy) + ipy = None + + # Make sure both the top-level "versioneer.py" and versionfile_source + # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so + # they'll be copied into source distributions. Pip won't be able to + # install the package without this. + manifest_in = os.path.join(root, "MANIFEST.in") + simple_includes = set() + try: + with open(manifest_in, "r") as f: + for line in f: + if line.startswith("include "): + for include in line.split()[1:]: + simple_includes.add(include) + except EnvironmentError: + pass + # That doesn't cover everything MANIFEST.in can do + # (http://docs.python.org/2/distutils/sourcedist.html#commands), so + # it might give some false negatives. Appending redundant 'include' + # lines is safe, though. + if "versioneer.py" not in simple_includes: + print(" appending 'versioneer.py' to MANIFEST.in") + with open(manifest_in, "a") as f: + f.write("include versioneer.py\n") + else: + print(" 'versioneer.py' already in MANIFEST.in") + if cfg.versionfile_source not in simple_includes: + print(" appending versionfile_source ('%s') to MANIFEST.in" % + cfg.versionfile_source) + with open(manifest_in, "a") as f: + f.write("include %s\n" % cfg.versionfile_source) + else: + print(" versionfile_source already in MANIFEST.in") + + # Make VCS-specific changes. For git, this means creating/changing + # .gitattributes to mark _version.py for export-subst keyword + # substitution. + do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + return 0 + + +def scan_setup_py(): + """Validate the contents of setup.py against Versioneer's expectations.""" + found = set() + setters = False + errors = 0 + with open("setup.py", "r") as f: + for line in f.readlines(): + if "import versioneer" in line: + found.add("import") + if "versioneer.get_cmdclass()" in line: + found.add("cmdclass") + if "versioneer.get_version()" in line: + found.add("get_version") + if "versioneer.VCS" in line: + setters = True + if "versioneer.versionfile_source" in line: + setters = True + if len(found) != 3: + print("") + print("Your setup.py appears to be missing some important items") + print("(but I might be wrong). Please make sure it has something") + print("roughly like the following:") + print("") + print(" import versioneer") + print(" setup( version=versioneer.get_version(),") + print(" cmdclass=versioneer.get_cmdclass(), ...)") + print("") + errors += 1 + if setters: + print("You should remove lines like 'versioneer.VCS = ' and") + print("'versioneer.versionfile_source = ' . This configuration") + print("now lives in setup.cfg, and should be removed from setup.py") + print("") + errors += 1 + return errors + + +if __name__ == "__main__": + cmd = sys.argv[1] + if cmd == "setup": + errors = do_setup() + errors += scan_setup_py() + if errors: + sys.exit(1) diff --git a/xarray/__init__.py b/xarray/__init__.py index 3e80acd1572..59a961c6b56 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -3,10 +3,14 @@ from __future__ import division from __future__ import print_function +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions + from .core.alignment import align, broadcast, broadcast_arrays from .core.common import full_like, zeros_like, ones_like from .core.combine import concat, auto_combine -from .core.computation import apply_ufunc, where +from .core.computation import apply_ufunc, dot, where from .core.extensions import (register_dataarray_accessor, register_dataset_accessor) from .core.variable import as_variable, Variable, IndexVariable, Coordinate @@ -22,15 +26,13 @@ from .conventions import decode_cf, SerializationWarning -try: - from .version import version as __version__ -except ImportError: # pragma: no cover - raise ImportError('xarray not properly installed. If you are running from ' - 'the source directory, please instead create a new ' - 'virtual environment (using conda or virtualenv) and ' - 'then install it in-place by running: pip install -e .') +from .coding.cftime_offsets import cftime_range +from .coding.cftimeindex import CFTimeIndex + from .util.print_versions import show_versions from . import tutorial from . import ufuncs from . import testing + +from .core.common import ALL_DIMS diff --git a/xarray/_version.py b/xarray/_version.py new file mode 100644 index 00000000000..df4ee95ade4 --- /dev/null +++ b/xarray/_version.py @@ -0,0 +1,520 @@ + +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.18 (https://github.com/warner/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + git_date = "$Format:%ci$" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "v" + cfg.parentdir_prefix = "xarray-" + cfg.versionfile_source = "xarray/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, p.returncode + return stdout, p.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for i in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index d85893afb0b..9b9e04d9346 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -4,16 +4,23 @@ formats. They should not be used directly, but rather through Dataset objects. """ from .common import AbstractDataStore +from .file_manager import FileManager, CachingFileManager, DummyFileManager +from .cfgrib_ import CfGribDataStore from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore from .pydap_ import PydapDataStore from .pynio_ import NioDataStore from .scipy_ import ScipyDataStore from .h5netcdf_ import H5NetCDFStore +from .pseudonetcdf_ import PseudoNetCDFDataStore from .zarr import ZarrStore __all__ = [ 'AbstractDataStore', + 'FileManager', + 'CachingFileManager', + 'CfGribDataStore', + 'DummyFileManager', 'InMemoryDataStore', 'NetCDF4DataStore', 'PydapDataStore', @@ -21,4 +28,5 @@ 'ScipyDataStore', 'H5NetCDFStore', 'ZarrStore', + 'PseudoNetCDFDataStore', ] diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4359868feae..ca440872d73 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1,48 +1,91 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import os.path from glob import glob from io import BytesIO from numbers import Number - +import warnings import numpy as np -from .. import backends, conventions, Dataset -from .common import ArrayWriter, GLOBAL_LOCK +from .. import Dataset, backends, conventions from ..core import indexing from ..core.combine import auto_combine -from ..core.utils import close_on_error, is_remote_uri from ..core.pycompat import basestring, path_type +from ..core.utils import close_on_error, is_remote_uri, is_grib_path +from .common import ArrayWriter +from .locks import _get_scheduler + DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' -def _get_default_engine(path, allow_remote=False): - if allow_remote and is_remote_uri(path): # pragma: no cover +def _get_default_engine_remote_uri(): + try: + import netCDF4 + engine = 'netcdf4' + except ImportError: # pragma: no cover try: - import netCDF4 - engine = 'netcdf4' + import pydap # flake8: noqa + engine = 'pydap' except ImportError: - try: - import pydap # flake8: noqa - engine = 'pydap' - except ImportError: - raise ValueError('netCDF4 or pydap is required for accessing ' - 'remote datasets via OPeNDAP') + raise ValueError('netCDF4 or pydap is required for accessing ' + 'remote datasets via OPeNDAP') + return engine + + +def _get_default_engine_grib(): + msgs = [] + try: + import Nio # flake8: noqa + msgs += ["set engine='pynio' to access GRIB files with PyNIO"] + except ImportError: # pragma: no cover + pass + try: + import cfgrib # flake8: noqa + msgs += ["set engine='cfgrib' to access GRIB files with cfgrib"] + except ImportError: # pragma: no cover + pass + if msgs: + raise ValueError(' or\n'.join(msgs)) else: + raise ValueError('PyNIO or cfgrib is required for accessing ' + 'GRIB files') + + +def _get_default_engine_gz(): + try: + import scipy # flake8: noqa + engine = 'scipy' + except ImportError: # pragma: no cover + raise ValueError('scipy is required for accessing .gz files') + return engine + + +def _get_default_engine_netcdf(): + try: + import netCDF4 # flake8: noqa + engine = 'netcdf4' + except ImportError: # pragma: no cover try: - import netCDF4 # flake8: noqa - engine = 'netcdf4' - except ImportError: # pragma: no cover - try: - import scipy.io.netcdf # flake8: noqa - engine = 'scipy' - except ImportError: - raise ValueError('cannot read or write netCDF files without ' - 'netCDF4-python or scipy installed') + import scipy.io.netcdf # flake8: noqa + engine = 'scipy' + except ImportError: + raise ValueError('cannot read or write netCDF files without ' + 'netCDF4-python or scipy installed') + return engine + + +def _get_default_engine(path, allow_remote=False): + if allow_remote and is_remote_uri(path): + engine = _get_default_engine_remote_uri() + elif is_grib_path(path): + engine = _get_default_engine_grib() + elif path.endswith('.gz'): + engine = _get_default_engine_gz() + else: + engine = _get_default_engine_netcdf() return engine @@ -53,27 +96,6 @@ def _normalize_path(path): return os.path.abspath(os.path.expanduser(path)) -def _default_lock(filename, engine): - if filename.endswith('.gz'): - lock = False - else: - if engine is None: - engine = _get_default_engine(filename, allow_remote=True) - - if engine == 'netcdf4': - if is_remote_uri(filename): - lock = False - else: - # TODO: identify netcdf3 files and don't use the global lock - # for them - lock = GLOBAL_LOCK - elif engine in {'h5netcdf', 'pynio'}: - lock = GLOBAL_LOCK - else: - lock = False - return lock - - def _validate_dataset_names(dataset): """DataArray.name and Dataset keys must be a string or None""" def check_name(name): @@ -131,10 +153,17 @@ def _protect_dataset_variables_inplace(dataset, cache): variable.data = data +def _finalize_store(write, store): + """ Finalize this store by explicitly syncing and closing""" + del write # ensure writing is done first + store.close() + + def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=True, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, cache=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None, + backend_kwargs=None): """Load and decode a dataset from a file or file-like object. Parameters @@ -158,7 +187,8 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. + be replaced by NA. mask_and_scale defaults to True except for the + pseudonetcdf backend. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. @@ -174,7 +204,8 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio'}, optional + engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib', + 'pseudonetcdf'}, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. @@ -182,12 +213,11 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, If chunks is provided, it used to load the new dataset into dask arrays. ``chunks={}`` loads the dataset with dask using a single chunk for all arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -199,6 +229,10 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. + backend_kwargs: dictionary, optional + A dictionary of keyword arguments to pass on to the backend. This + may be useful when backend options would improve performance or + allow user control of dataset processing. Returns ------- @@ -209,6 +243,18 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, -------- open_mfdataset """ + if autoclose is not None: + warnings.warn( + 'The autoclose argument is no longer used by ' + 'xarray.open_dataset() and is now ignored; it will be removed in ' + 'xarray v0.12. If necessary, you can control the maximum number ' + 'of simultaneous open files with ' + 'xarray.set_options(file_cache_maxsize=...).', + FutureWarning, stacklevel=2) + + if mask_and_scale is None: + mask_and_scale = not engine == 'pseudonetcdf' + if not decode_cf: mask_and_scale = False decode_times = False @@ -218,6 +264,9 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, if cache is None: cache = chunks is None + if backend_kwargs is None: + backend_kwargs = {} + def maybe_decode_store(store, lock=False): ds = conventions.decode_cf( store, mask_and_scale=mask_and_scale, decode_times=decode_times, @@ -239,18 +288,11 @@ def maybe_decode_store(store, lock=False): mask_and_scale, decode_times, concat_characters, decode_coords, engine, chunks, drop_variables) name_prefix = 'open_dataset-%s' % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token, - lock=lock) + ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) ds2._file_obj = ds._file_obj else: ds2 = ds - # protect so that dataset store isn't necessarily closed, e.g., - # streams like BytesIO can't be reopened - # datastore backend is responsible for determining this capability - if store._autoclose: - store.close() - return ds2 if isinstance(filename_or_obj, path_type): @@ -270,39 +312,35 @@ def maybe_decode_store(store, lock=False): elif isinstance(filename_or_obj, basestring): filename_or_obj = _normalize_path(filename_or_obj) - if filename_or_obj.endswith('.gz'): - if engine is not None and engine != 'scipy': - raise ValueError('can only read gzipped netCDF files with ' - "default engine or engine='scipy'") - else: - engine = 'scipy' - if engine is None: engine = _get_default_engine(filename_or_obj, allow_remote=True) if engine == 'netcdf4': - store = backends.NetCDF4DataStore.open(filename_or_obj, - group=group, - autoclose=autoclose) + store = backends.NetCDF4DataStore.open( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'scipy': - store = backends.ScipyDataStore(filename_or_obj, - autoclose=autoclose) + store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pydap': - store = backends.PydapDataStore.open(filename_or_obj) + store = backends.PydapDataStore.open( + filename_or_obj, **backend_kwargs) elif engine == 'h5netcdf': - store = backends.H5NetCDFStore(filename_or_obj, group=group, - autoclose=autoclose) + store = backends.H5NetCDFStore( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'pynio': - store = backends.NioDataStore(filename_or_obj, - autoclose=autoclose) + store = backends.NioDataStore( + filename_or_obj, lock=lock, **backend_kwargs) + elif engine == 'pseudonetcdf': + store = backends.PseudoNetCDFDataStore.open( + filename_or_obj, lock=lock, **backend_kwargs) + elif engine == 'cfgrib': + store = backends.CfGribDataStore( + filename_or_obj, lock=lock, **backend_kwargs) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) - if lock is None: - lock = _default_lock(filename_or_obj, engine) with close_on_error(store): - return maybe_decode_store(store, lock) + return maybe_decode_store(store) else: if engine is not None and engine != 'scipy': raise ValueError('can only read file-like objects with ' @@ -314,9 +352,10 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=True, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, cache=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None, + backend_kwargs=None): """Open an DataArray from a netCDF file containing a single data variable. This is designed to read netCDF files with only one data variable. If @@ -343,14 +382,11 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. + be replaced by NA. mask_and_scale defaults to True except for the + pseudonetcdf backend. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. concat_characters : bool, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and @@ -359,19 +395,19 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio'}, optional + engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib'}, + optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -383,6 +419,10 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. + backend_kwargs: dictionary, optional + A dictionary of keyword arguments to pass on to the backend. This + may be useful when backend options would improve performance or + allow user control of dataset processing. Notes ----- @@ -397,13 +437,15 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, -------- open_dataset """ + dataset = open_dataset(filename_or_obj, group=group, decode_cf=decode_cf, mask_and_scale=mask_and_scale, decode_times=decode_times, autoclose=autoclose, concat_characters=concat_characters, decode_coords=decode_coords, engine=engine, chunks=chunks, lock=lock, cache=cache, - drop_variables=drop_variables) + drop_variables=drop_variables, + backend_kwargs=backend_kwargs) if len(dataset.data_vars) != 1: raise ValueError('Given file dataset contains more than one data ' @@ -440,11 +482,12 @@ def close(self): def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', preprocess=None, engine=None, - lock=None, data_vars='all', coords='different', **kwargs): + lock=None, data_vars='all', coords='different', + autoclose=None, parallel=False, **kwargs): """Open multiple files as a single dataset. - Requires dask to be installed. Attributes from the first dataset file - are used for the combined dataset. + Requires dask to be installed. See documentation for details on dask [1]. + Attributes from the first dataset file are used for the combined dataset. Parameters ---------- @@ -458,7 +501,7 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, If int, chunk each dimension by ``chunks``. By default, chunks will be chosen to load entire input files into memory at once. This has a major impact on performance: please see the - full documentation for more details. + full documentation for more details [2]. concat_dim : None, str, DataArray or Index, optional Dimension to concatenate files along. This argument is passed on to :py:func:`xarray.auto_combine` along with the dataset objects. You only @@ -483,19 +526,16 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, of all non-null values. preprocess : callable, optional If provided, call this function on each dataset prior to concatenation. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio'}, optional + engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib'}, + optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. - lock : False, True or threading.Lock, optional - This argument is passed on to :py:func:`dask.array.from_array`. By - default, a per-variable lock is used when reading data from netCDF - files with the netcdf4 and h5netcdf engines to avoid issues with - concurrent access when using dask's multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. data_vars : {'minimal', 'different', 'all' or list of str}, optional These data variables will be concatenated together: * 'minimal': Only data variables in which the dimension already @@ -521,7 +561,9 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, those corresponding to other dimensions. * list of str: The listed coordinate variables will be concatenated, in addition the 'minimal' coordinates. - + parallel : bool, optional + If True, the open and preprocess steps of this function will be + performed in parallel using ``dask.delayed``. Default is False. **kwargs : optional Additional arguments passed on to :py:func:`xarray.open_dataset`. @@ -533,8 +575,18 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, -------- auto_combine open_dataset + + References + ---------- + .. [1] http://xarray.pydata.org/en/stable/dask.html + .. [2] http://xarray.pydata.org/en/stable/dask.html#chunking-and-performance """ if isinstance(paths, basestring): + if is_remote_uri(paths): + raise ValueError( + 'cannot do wild-card matching for paths that are remote URLs: ' + '{!r}. Instead, supply paths as an explicit list of strings.' + .format(paths)) paths = sorted(glob(paths)) else: paths = [str(p) if isinstance(p, path_type) else p for p in paths] @@ -542,15 +594,30 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, if not paths: raise IOError('no files to open') - if lock is None: - lock = _default_lock(paths[0], engine) - datasets = [open_dataset(p, engine=engine, chunks=chunks or {}, lock=lock, - **kwargs) for p in paths] - file_objs = [ds._file_obj for ds in datasets] + open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, + autoclose=autoclose, **kwargs) + if parallel: + import dask + # wrap the open_dataset, getattr, and preprocess with delayed + open_ = dask.delayed(open_dataset) + getattr_ = dask.delayed(getattr) + if preprocess is not None: + preprocess = dask.delayed(preprocess) + else: + open_ = open_dataset + getattr_ = getattr + + datasets = [open_(p, **open_kwargs) for p in paths] + file_objs = [getattr_(ds, '_file_obj') for ds in datasets] if preprocess is not None: datasets = [preprocess(ds) for ds in datasets] + if parallel: + # calling compute here will return the datasets/file_objs lists, + # the underlying datasets will still be stored as dask arrays + datasets, file_objs = dask.compute(datasets, file_objs) + # close datasets in case of a ValueError try: if concat_dim is _CONCAT_DIM_DEFAULT: @@ -576,18 +643,21 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, - engine=None, writer=None, encoding=None, unlimited_dims=None): + engine=None, encoding=None, unlimited_dims=None, compute=True, + multifile=False): """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file See `Dataset.to_netcdf` for full API docs. - The ``writer`` argument is only for the private use of save_mfdataset. + The ``multifile`` argument is only for the private use of save_mfdataset. """ if isinstance(path_or_file, path_type): path_or_file = str(path_or_file) + if encoding is None: encoding = {} + if path_or_file is None: if engine is None: engine = 'scipy' @@ -595,6 +665,10 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, raise ValueError('invalid engine for creating bytes with ' 'to_netcdf: %r. Only the default engine ' "or engine='scipy' is supported" % engine) + if not compute: + raise NotImplementedError( + 'to_netcdf() with compute=False is not yet implemented when ' + 'returning bytes') elif isinstance(path_or_file, basestring): if engine is None: engine = _get_default_engine(path_or_file) @@ -614,29 +688,81 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, if format is not None: format = format.upper() - # if a writer is provided, store asynchronously - sync = writer is None + # handle scheduler specific logic + scheduler = _get_scheduler() + have_chunks = any(v.chunks for v in dataset.variables.values()) + + autoclose = have_chunks and scheduler in ['distributed', 'multiprocessing'] + if autoclose and engine == 'scipy': + raise NotImplementedError("Writing netCDF files with the %s backend " + "is not currently supported with dask's %s " + "scheduler" % (engine, scheduler)) target = path_or_file if path_or_file is not None else BytesIO() - store = store_open(target, mode, format, group, writer) + kwargs = dict(autoclose=True) if autoclose else {} + store = store_open(target, mode, format, group, **kwargs) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) + if isinstance(unlimited_dims, basestring): + unlimited_dims = [unlimited_dims] + + writer = ArrayWriter() + + # TODO: figure out how to refactor this logic (here and in save_mfdataset) + # to avoid this mess of conditionals try: - dataset.dump_to_store(store, sync=sync, encoding=encoding, - unlimited_dims=unlimited_dims) + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + dump_to_store(dataset, store, writer, encoding=encoding, + unlimited_dims=unlimited_dims) + if autoclose: + store.close() + + if multifile: + return writer, store + + writes = writer.sync(compute=compute) + if path_or_file is None: + store.sync() return target.getvalue() finally: - if sync and isinstance(path_or_file, basestring): + if not multifile and compute: store.close() - if not sync: - return store + if not compute: + import dask + return dask.delayed(_finalize_store)(writes, store) + + +def dump_to_store(dataset, store, writer=None, encoder=None, + encoding=None, unlimited_dims=None): + """Store dataset contents to a backends.*DataStore object.""" + if writer is None: + writer = ArrayWriter() + + if encoding is None: + encoding = {} + + variables, attrs = conventions.encode_dataset_coordinates(dataset) + + check_encoding = set() + for k, enc in encoding.items(): + # no need to shallow copy the variable again; that already happened + # in encode_dataset_coordinates + variables[k].encoding = enc + check_encoding.add(k) + + if encoder: + variables, attrs = encoder(variables, attrs) + + store.store(variables, attrs, check_encoding, writer, + unlimited_dims=unlimited_dims) def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, - engine=None): + engine=None, compute=True): """Write multiple datasets to disk as netCDF files simultaneously. This function is intended for use with datasets consisting of dask.array @@ -685,6 +811,10 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, Engine to use when writing netCDF files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4' if writing to a file on disk. + See `Dataset.to_netcdf` for additional information. + compute: boolean + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. Examples -------- @@ -702,7 +832,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, for obj in datasets: if not isinstance(obj, Dataset): raise TypeError('save_mfdataset only supports writing Dataset ' - 'objects, recieved type %s' % type(obj)) + 'objects, received type %s' % type(obj)) if groups is None: groups = [None] * len(datasets) @@ -712,20 +842,26 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, 'datasets, paths and groups arguments to ' 'save_mfdataset') - writer = ArrayWriter() - stores = [to_netcdf(ds, path, mode, format, group, engine, writer) - for ds, path, group in zip(datasets, paths, groups)] + writers, stores = zip(*[ + to_netcdf(ds, path, mode, format, group, engine, compute=compute, + multifile=True) + for ds, path, group in zip(datasets, paths, groups)]) + try: - writer.sync() - for store in stores: - store.sync() + writes = [w.sync(compute=compute) for w in writers] finally: - for store in stores: - store.close() + if compute: + for store in stores: + store.close() + + if not compute: + import dask + return dask.delayed([dask.delayed(_finalize_store)(w, s) + for w, s in zip(writes, stores)]) def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, - encoding=None): + encoding=None, compute=True): """This function creates an appropriate datastore for writing a dataset to a zarr ztore @@ -742,9 +878,14 @@ def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, store = backends.ZarrStore.open_group(store=store, mode=mode, synchronizer=synchronizer, - group=group, writer=None) + group=group) - # I think zarr stores should always be sync'd immediately + writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims - dataset.dump_to_store(store, sync=True, encoding=encoding) + dump_to_store(dataset, store, writer, encoding=encoding) + writes = writer.sync(compute=compute) + + if not compute: + import dask + return dask.delayed(_finalize_store)(writes, store) return store diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py new file mode 100644 index 00000000000..0807900054a --- /dev/null +++ b/xarray/backends/cfgrib_.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +from .. import Variable +from ..core import indexing +from ..core.utils import Frozen, FrozenOrderedDict +from .common import AbstractDataStore, BackendArray +from .locks import ensure_lock, SerializableLock + +# FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe +# in most circumstances. See: +# https://confluence.ecmwf.int/display/ECC/Frequently+Asked+Questions +ECCODES_LOCK = SerializableLock() + + +class CfGribArrayWrapper(BackendArray): + def __init__(self, datastore, array): + self.datastore = datastore + self.shape = array.shape + self.dtype = array.dtype + self.array = array + + def __getitem__(self, key): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem) + + def _getitem(self, key): + with self.datastore.lock: + return self.array[key] + + +class CfGribDataStore(AbstractDataStore): + """ + Implements the ``xr.AbstractDataStore`` read-only API for a GRIB file. + """ + def __init__(self, filename, lock=None, **backend_kwargs): + import cfgrib + if lock is None: + lock = ECCODES_LOCK + self.lock = ensure_lock(lock) + self.ds = cfgrib.open_file(filename, **backend_kwargs) + + def open_store_variable(self, name, var): + if isinstance(var.data, np.ndarray): + data = var.data + else: + wrapped_array = CfGribArrayWrapper(self, var.data) + data = indexing.LazilyOuterIndexedArray(wrapped_array) + + encoding = self.ds.encoding.copy() + encoding['original_shape'] = var.data.shape + + return Variable(var.dimensions, data, var.attributes, encoding) + + def get_variables(self): + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) + + def get_attrs(self): + return Frozen(self.ds.attributes) + + def get_dimensions(self): + return Frozen(self.ds.dimensions) + + def get_encoding(self): + dims = self.get_dimensions() + encoding = { + 'unlimited_dims': {k for k, v in dims.items() if v is None}, + } + return encoding diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 157ee494067..405d989f4af 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,24 +1,17 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import numpy as np +from __future__ import absolute_import, division, print_function + import logging import time import traceback -import contextlib -from collections import Mapping, OrderedDict import warnings +from collections import Mapping, OrderedDict + +import numpy as np from ..conventions import cf_encoder from ..core import indexing +from ..core.pycompat import dask_array_type, iteritems from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin -from ..core.pycompat import iteritems, dask_array_type - -try: - from dask.utils import SerializableLock as Lock -except ImportError: - from threading import Lock - # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -27,10 +20,6 @@ NONE_VAR_NAME = '__values__' -# dask.utils.SerializableLock if available, otherwise just a threading.Lock -GLOBAL_LOCK = Lock() - - def _encode_variable_name(name): if name is None: name = NONE_VAR_NAME @@ -85,7 +74,6 @@ def __array__(self, dtype=None): class AbstractDataStore(Mapping): - _autoclose = False def __iter__(self): return iter(self.variables) @@ -114,7 +102,7 @@ def load(self): A centralized loading function makes it easier to create data stores that do automatic encoding/decoding. - For example: + For example:: class SuffixAppendingDataStore(AbstractDataStore): @@ -168,7 +156,7 @@ def __exit__(self, exception_type, exception_value, traceback): class ArrayWriter(object): - def __init__(self, lock=GLOBAL_LOCK): + def __init__(self, lock=None): self.sources = [] self.targets = [] self.lock = lock @@ -178,25 +166,23 @@ def add(self, source, target): self.sources.append(source) self.targets.append(target) else: - try: - target[...] = source - except TypeError: - # workaround for GH: scipy/scipy#6880 - target[:] = source + target[...] = source - def sync(self): + def sync(self, compute=True): if self.sources: import dask.array as da - da.store(self.sources, self.targets, lock=self.lock) + # TODO: consider wrapping targets with dask.delayed, if this makes + # for any discernable difference in perforance, e.g., + # targets = [dask.delayed(t) for t in self.targets] + delayed_store = da.store(self.sources, self.targets, + lock=self.lock, compute=compute, + flush=True) self.sources = [] self.targets = [] + return delayed_store class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None): - if writer is None: - writer = ArrayWriter() - self.writer = writer def encode(self, variables, attributes): """ @@ -238,9 +224,6 @@ def set_attribute(self, k, v): # pragma: no cover def set_variable(self, k, v): # pragma: no cover raise NotImplementedError - def sync(self): - self.writer.sync() - def store_dataset(self, dataset): """ in stores, variables are all variables AND coordinates @@ -251,7 +234,7 @@ def store_dataset(self, dataset): self.store(dataset, dataset.attrs) def store(self, variables, attributes, check_encoding_set=frozenset(), - unlimited_dims=None): + writer=None, unlimited_dims=None): """ Top level method for putting data on this store, this method: - encodes variables/attributes @@ -267,16 +250,19 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. """ + if writer is None: + writer = ArrayWriter() variables, attributes = self.encode(variables, attributes) self.set_attributes(attributes) self.set_dimensions(variables, unlimited_dims=unlimited_dims) - self.set_variables(variables, check_encoding_set, + self.set_variables(variables, check_encoding_set, writer, unlimited_dims=unlimited_dims) def set_attributes(self, attributes): @@ -292,7 +278,7 @@ def set_attributes(self, attributes): for k, v in iteritems(attributes): self.set_attribute(k, v) - def set_variables(self, variables, check_encoding_set, + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): """ This provides a centralized method to set the variables on the data @@ -305,6 +291,7 @@ def set_variables(self, variables, check_encoding_set, check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. @@ -316,7 +303,7 @@ def set_variables(self, variables, check_encoding_set, target, source = self.prepare_variable( name, v, check, unlimited_dims=unlimited_dims) - self.writer.add(source, target) + writer.add(source, target) def set_dimensions(self, variables, unlimited_dims=None): """ @@ -363,46 +350,3 @@ def encode(self, variables, attributes): attributes = OrderedDict([(k, self.encode_attribute(v)) for k, v in attributes.items()]) return variables, attributes - - -class DataStorePickleMixin(object): - """Subclasses must define `ds`, `_opener` and `_mode` attributes. - - Do not subclass this class: it is not part of xarray's external API. - """ - - def __getstate__(self): - state = self.__dict__.copy() - del state['ds'] - if self._mode == 'w': - # file has already been created, don't override when restoring - state['_mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self.ds = self._opener(mode=self._mode) - - @contextlib.contextmanager - def ensure_open(self, autoclose): - """ - Helper function to make sure datasets are closed and opened - at appropriate times to avoid too many open file errors. - - Use requires `autoclose=True` argument to `open_mfdataset`. - """ - if self._autoclose and not self._isopen: - try: - self.ds = self._opener() - self._isopen = True - yield - finally: - if autoclose: - self.close() - else: - yield - - def assert_open(self): - if not self._isopen: - raise AssertionError('internal failure: file must be open ' - 'if `autoclose=True` is used.') diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py new file mode 100644 index 00000000000..a93285370b2 --- /dev/null +++ b/xarray/backends/file_manager.py @@ -0,0 +1,206 @@ +import threading + +from ..core import utils +from ..core.options import OPTIONS +from .lru_cache import LRUCache + + +# Global cache for storing open files. +FILE_CACHE = LRUCache( + OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close()) +assert FILE_CACHE.maxsize, 'file cache must be at least size one' + + +_DEFAULT_MODE = utils.ReprObject('') + + +class FileManager(object): + """Manager for acquiring and closing a file object. + + Use FileManager subclasses (CachingFileManager in particular) on backend + storage classes to automatically handle issues related to keeping track of + many open files and transferring them between multiple processes. + """ + + def acquire(self): + """Acquire the file object from this manager.""" + raise NotImplementedError + + def close(self, needs_lock=True): + """Close the file object associated with this manager, if needed.""" + raise NotImplementedError + + +class CachingFileManager(FileManager): + """Wrapper for automatically opening and closing file objects. + + Unlike files, CachingFileManager objects can be safely pickled and passed + between processes. They should be explicitly closed to release resources, + but a per-process least-recently-used cache for open files ensures that you + can safely create arbitrarily large numbers of FileManager objects. + + Don't directly close files acquired from a FileManager. Instead, call + FileManager.close(), which ensures that closed files are removed from the + cache as well. + + Example usage: + + manager = FileManager(open, 'example.txt', mode='w') + f = manager.acquire() + f.write(...) + manager.close() # ensures file is closed + + Note that as long as previous files are still cached, acquiring a file + multiple times from the same FileManager is essentially free: + + f1 = manager.acquire() + f2 = manager.acquire() + assert f1 is f2 + + """ + + def __init__(self, opener, *args, **keywords): + """Initialize a FileManager. + + Parameters + ---------- + opener : callable + Function that when called like ``opener(*args, **kwargs)`` returns + an open file object. The file object must implement a ``close()`` + method. + *args + Positional arguments for opener. A ``mode`` argument should be + provided as a keyword argument (see below). All arguments must be + hashable. + mode : optional + If provided, passed as a keyword argument to ``opener`` along with + ``**kwargs``. ``mode='w' `` has special treatment: after the first + call it is replaced by ``mode='a'`` in all subsequent function to + avoid overriding the newly created file. + kwargs : dict, optional + Keyword arguments for opener, excluding ``mode``. All values must + be hashable. + lock : duck-compatible threading.Lock, optional + Lock to use when modifying the cache inside acquire() and close(). + By default, uses a new threading.Lock() object. If set, this object + should be pickleable. + cache : MutableMapping, optional + Mapping to use as a cache for open files. By default, uses xarray's + global LRU file cache. Because ``cache`` typically points to a + global variable and contains non-picklable file objects, an + unpickled FileManager objects will be restored with the default + cache. + """ + # TODO: replace with real keyword arguments when we drop Python 2 + # support + mode = keywords.pop('mode', _DEFAULT_MODE) + kwargs = keywords.pop('kwargs', None) + lock = keywords.pop('lock', None) + cache = keywords.pop('cache', FILE_CACHE) + if keywords: + raise TypeError('FileManager() got unexpected keyword arguments: ' + '%s' % list(keywords)) + + self._opener = opener + self._args = args + self._mode = mode + self._kwargs = {} if kwargs is None else dict(kwargs) + self._default_lock = lock is None or lock is False + self._lock = threading.Lock() if self._default_lock else lock + self._cache = cache + self._key = self._make_key() + + def _make_key(self): + """Make a key for caching files in the LRU cache.""" + value = (self._opener, + self._args, + self._mode, + tuple(sorted(self._kwargs.items()))) + return _HashedSequence(value) + + def acquire(self): + """Acquiring a file object from the manager. + + A new file is only opened if it has expired from the + least-recently-used cache. + + This method uses a reentrant lock, which ensures that it is + thread-safe. You can safely acquire a file in multiple threads at the + same time, as long as the underlying file object is thread-safe. + + Returns + ------- + An open file object, as returned by ``opener(*args, **kwargs)``. + """ + with self._lock: + try: + file = self._cache[self._key] + except KeyError: + kwargs = self._kwargs + if self._mode is not _DEFAULT_MODE: + kwargs = kwargs.copy() + kwargs['mode'] = self._mode + file = self._opener(*self._args, **kwargs) + if self._mode == 'w': + # ensure file doesn't get overriden when opened again + self._mode = 'a' + self._key = self._make_key() + self._cache[self._key] = file + return file + + def _close(self): + default = None + file = self._cache.pop(self._key, default) + if file is not None: + file.close() + + def close(self, needs_lock=True): + """Explicitly close any associated file object (if necessary).""" + # TODO: remove needs_lock if/when we have a reentrant lock in + # dask.distributed: https://github.com/dask/dask/issues/3832 + if needs_lock: + with self._lock: + self._close() + else: + self._close() + + def __getstate__(self): + """State for pickling.""" + lock = None if self._default_lock else self._lock + return (self._opener, self._args, self._mode, self._kwargs, lock) + + def __setstate__(self, state): + """Restore from a pickle.""" + opener, args, mode, kwargs, lock = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock) + + +class _HashedSequence(list): + """Speedup repeated look-ups by caching hash values. + + Based on what Python uses internally in functools.lru_cache. + + Python doesn't perform this optimization automatically: + https://bugs.python.org/issue1462796 + """ + + def __init__(self, tuple_value): + self[:] = tuple_value + self.hashvalue = hash(tuple_value) + + def __hash__(self): + return self.hashvalue + + +class DummyFileManager(FileManager): + """FileManager that simply wraps an open file in the FileManager interface. + """ + def __init__(self, value): + self._value = value + + def acquire(self): + return self._value + + def close(self, needs_lock=True): + del needs_lock # ignored + self._value.close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index cba1d33115f..59cd4e84793 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -1,31 +1,34 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import functools import numpy as np from .. import Variable from ..core import indexing +from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error -from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict - -from .common import WritableCFDataStore, DataStorePickleMixin, find_root -from .netCDF4_ import (_nc4_group, _encode_nc4_variable, _get_datatype, - _extract_nc4_variable_encoding, BaseNetCDF4Array) +from .common import WritableCFDataStore +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock +from .netCDF4_ import ( + BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, + _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group) class H5NetCDFArrayWrapper(BaseNetCDF4Array): def __getitem__(self, key): - key = indexing.unwrap_explicit_indexer( - key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer)) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, + self._getitem) + + def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 - # OuterIndexer only holds 1D integer ndarrays, so it's safe to convert - # them to lists. key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + array = self.get_array() + with self.datastore.lock: + return array[key] def maybe_decode_bytes(txt): @@ -40,99 +43,122 @@ def _read_attributes(h5netcdf_var): # to ensure conventions decoding works properly on Python 3, decode all # bytes attributes to strings attrs = OrderedDict() - for k in h5netcdf_var.ncattrs(): - v = h5netcdf_var.getncattr(k) + for k, v in h5netcdf_var.attrs.items(): if k not in ['_FillValue', 'missing_value']: v = maybe_decode_bytes(v) attrs[k] = v return attrs -_extract_h5nc_encoding = functools.partial(_extract_nc4_variable_encoding, - lsd_okay=False, backend='h5netcdf') +_extract_h5nc_encoding = functools.partial( + _extract_nc4_variable_encoding, + lsd_okay=False, h5py_okay=True, backend='h5netcdf') + + +def _h5netcdf_create_group(dataset, name): + return dataset.create_group(name) def _open_h5netcdf_group(filename, mode, group): - import h5netcdf.legacyapi - ds = h5netcdf.legacyapi.Dataset(filename, mode=mode) + import h5netcdf + ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): - return _nc4_group(ds, group, mode) + ds = _nc4_require_group( + ds, group, mode, create_group=_h5netcdf_create_group) + return GroupWrapper(ds) -class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): +class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, autoclose=False): + lock=None, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, - group=group) - self.ds = opener() - if autoclose: - raise NotImplementedError('autoclose=True is not implemented ' - 'for the h5netcdf backend pending ' - 'further exploration, e.g., bug fixes ' - '(in h5netcdf?)') - self._autoclose = False - self._isopen = True + self._manager = CachingFileManager( + _open_h5netcdf_group, filename, mode=mode, + kwargs=dict(group=group)) + + if lock is None: + if mode == 'r': + lock = HDF5_LOCK + else: + lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) + self.format = format - self._opener = opener self._filename = filename self._mode = mode - super(H5NetCDFStore, self).__init__(writer) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @property + def ds(self): + return self._manager.acquire().value def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyIndexedArray( - H5NetCDFArrayWrapper(name, self)) - attrs = _read_attributes(var) - - # netCDF4 specific encoding - encoding = dict(var.filters()) - chunking = var.chunking() - encoding['chunksizes'] = chunking \ - if chunking != 'contiguous' else None - - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape + import h5py + + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + H5NetCDFArrayWrapper(name, self)) + attrs = _read_attributes(var) + + # netCDF4 specific encoding + encoding = { + 'chunksizes': var.chunks, + 'fletcher32': var.fletcher32, + 'shuffle': var.shuffle, + } + # Convert h5py-style compression options to NetCDF4-Python + # style, if possible + if var.compression == 'gzip': + encoding['zlib'] = True + encoding['complevel'] = var.compression_opts + elif var.compression is not None: + encoding['compression'] = var.compression + encoding['compression_opts'] = var.compression_opts + + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + + vlen_dtype = h5py.check_dtype(vlen=var.dtype) + if vlen_dtype is unicode_type: + encoding['dtype'] = str + elif vlen_dtype is not None: # pragma: no cover + # xarray doesn't support writing arbitrary vlen dtypes yet. + pass + else: + encoding['dtype'] = var.dtype return Variable(dimensions, data, attrs, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return FrozenOrderedDict(_read_attributes(self.ds)) + return FrozenOrderedDict(_read_attributes(self.ds)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return self.ds.dimensions + return self.ds.dimensions def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v is None} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v is None} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if is_unlimited: - self.ds.createDimension(name, size=None) - self.ds.resize_dimension(name, length) - else: - self.ds.createDimension(name, size=length) + if is_unlimited: + self.ds.dimensions[name] = None + self.ds.resize_dimension(name, length) + else: + self.ds.dimensions[name] = length def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self.ds.setncattr(key, value) + self.ds.attrs[key] = value def encode_variable(self, variable): return _encode_nc4_variable(variable) @@ -142,10 +168,11 @@ def prepare_variable(self, name, variable, check_encoding=False, import h5py attrs = variable.attrs.copy() - dtype = _get_datatype(variable) + dtype = _get_datatype( + variable, raise_on_invalid_encoding=check_encoding) - fill_value = attrs.pop('_FillValue', None) - if dtype is str and fill_value is not None: + fillvalue = attrs.pop('_FillValue', None) + if dtype is str and fillvalue is not None: raise NotImplementedError( 'h5netcdf does not yet support setting a fill value for ' 'variable-length strings ' @@ -161,29 +188,49 @@ def prepare_variable(self, name, variable, check_encoding=False, raise_on_invalid=check_encoding) kwargs = {} - for key in ['zlib', 'complevel', 'shuffle', - 'chunksizes', 'fletcher32']: + # Convert from NetCDF4-Python style compression settings to h5py style + # If both styles are used together, h5py takes precedence + # If set_encoding=True, raise ValueError in case of mismatch + if encoding.pop('zlib', False): + if (check_encoding and encoding.get('compression') + not in (None, 'gzip')): + raise ValueError("'zlib' and 'compression' encodings mismatch") + encoding.setdefault('compression', 'gzip') + + if (check_encoding and + 'complevel' in encoding and 'compression_opts' in encoding and + encoding['complevel'] != encoding['compression_opts']): + raise ValueError("'complevel' and 'compression_opts' encodings " + "mismatch") + complevel = encoding.pop('complevel', 0) + if complevel != 0: + encoding.setdefault('compression_opts', complevel) + + encoding['chunks'] = encoding.pop('chunksizes', None) + + for key in ['compression', 'compression_opts', 'shuffle', + 'chunks', 'fletcher32']: if key in encoding: kwargs[key] = encoding[key] - if name not in self.ds.variables: - nc4_var = self.ds.createVariable(name, dtype, variable.dims, - fill_value=fill_value, **kwargs) + if name not in self.ds: + nc4_var = self.ds.create_variable( + name, dtype=dtype, dimensions=variable.dims, + fillvalue=fillvalue, **kwargs) else: - nc4_var = self.ds.variables[name] + nc4_var = self.ds[name] for k, v in iteritems(attrs): - nc4_var.setncattr(k, v) - return nc4_var, variable.data + nc4_var.attrs[k] = v + + target = H5NetCDFArrayWrapper(name, self) + + return target, variable.data def sync(self): - with self.ensure_open(autoclose=True): - super(H5NetCDFStore, self).sync() - self.ds.sync() - - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if not ds._closed: - ds.close() - self._isopen = False + self.ds.sync() + # if self.autoclose: + # self.close() + # super(H5NetCDFStore, self).sync(compute=compute) + + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py new file mode 100644 index 00000000000..f633280ef1d --- /dev/null +++ b/xarray/backends/locks.py @@ -0,0 +1,191 @@ +import multiprocessing +import threading +import weakref + +try: + from dask.utils import SerializableLock +except ImportError: + # no need to worry about serializing the lock + SerializableLock = threading.Lock + + +# Locks used by multiple backends. +# Neither HDF5 nor the netCDF-C library are thread-safe. +HDF5_LOCK = SerializableLock() +NETCDFC_LOCK = SerializableLock() + + +_FILE_LOCKS = weakref.WeakValueDictionary() + + +def _get_threaded_lock(key): + try: + lock = _FILE_LOCKS[key] + except KeyError: + lock = _FILE_LOCKS[key] = threading.Lock() + return lock + + +def _get_multiprocessing_lock(key): + # TODO: make use of the key -- maybe use locket.py? + # https://github.com/mwilliamson/locket.py + del key # unused + return multiprocessing.Lock() + + +def _get_distributed_lock(key): + from dask.distributed import Lock + return Lock(key) + + +_LOCK_MAKERS = { + None: _get_threaded_lock, + 'threaded': _get_threaded_lock, + 'multiprocessing': _get_multiprocessing_lock, + 'distributed': _get_distributed_lock, +} + + +def _get_lock_maker(scheduler=None): + """Returns an appropriate function for creating resource locks. + + Parameters + ---------- + scheduler : str or None + Dask scheduler being used. + + See Also + -------- + dask.utils.get_scheduler_lock + """ + return _LOCK_MAKERS[scheduler] + + +def _get_scheduler(get=None, collection=None): + """Determine the dask scheduler that is being used. + + None is returned if no dask scheduler is active. + + See also + -------- + dask.base.get_scheduler + """ + try: + # dask 0.18.1 and later + from dask.base import get_scheduler + actual_get = get_scheduler(get, collection) + except ImportError: + try: + from dask.utils import effective_get + actual_get = effective_get(get, collection) + except ImportError: + return None + + try: + from dask.distributed import Client + if isinstance(actual_get.__self__, Client): + return 'distributed' + except (ImportError, AttributeError): + try: + import dask.multiprocessing + if actual_get == dask.multiprocessing.get: + return 'multiprocessing' + else: + return 'threaded' + except ImportError: + return 'threaded' + + +def get_write_lock(key): + """Get a scheduler appropriate lock for writing to the given resource. + + Parameters + ---------- + key : str + Name of the resource for which to acquire a lock. Typically a filename. + + Returns + ------- + Lock object that can be used like a threading.Lock object. + """ + scheduler = _get_scheduler() + lock_maker = _get_lock_maker(scheduler) + return lock_maker(key) + + +class CombinedLock(object): + """A combination of multiple locks. + + Like a locked door, a CombinedLock is locked if any of its constituent + locks are locked. + """ + + def __init__(self, locks): + self.locks = tuple(set(locks)) # remove duplicates + + def acquire(self, *args): + return all(lock.acquire(*args) for lock in self.locks) + + def release(self, *args): + for lock in self.locks: + lock.release(*args) + + def __enter__(self): + for lock in self.locks: + lock.__enter__() + + def __exit__(self, *args): + for lock in self.locks: + lock.__exit__(*args) + + @property + def locked(self): + return any(lock.locked for lock in self.locks) + + def __repr__(self): + return "CombinedLock(%r)" % list(self.locks) + + +class DummyLock(object): + """DummyLock provides the lock API without any actual locking.""" + + def acquire(self, *args): + pass + + def release(self, *args): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + @property + def locked(self): + return False + + +def combine_locks(locks): + """Combine a sequence of locks into a single lock.""" + all_locks = [] + for lock in locks: + if isinstance(lock, CombinedLock): + all_locks.extend(lock.locks) + elif lock is not None: + all_locks.append(lock) + + num_locks = len(all_locks) + if num_locks > 1: + return CombinedLock(all_locks) + elif num_locks == 1: + return all_locks[0] + else: + return DummyLock() + + +def ensure_lock(lock): + """Ensure that the given object is a lock.""" + if lock is None or lock is False: + return DummyLock() + return lock diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py new file mode 100644 index 00000000000..321a1ca4da4 --- /dev/null +++ b/xarray/backends/lru_cache.py @@ -0,0 +1,91 @@ +import collections +import threading + +from ..core.pycompat import move_to_end + + +class LRUCache(collections.MutableMapping): + """Thread-safe LRUCache based on an OrderedDict. + + All dict operations (__getitem__, __setitem__, __contains__) update the + priority of the relevant key and take O(1) time. The dict is iterated over + in order from the oldest to newest key, which means that a complete pass + over the dict should not affect the order of any entries. + + When a new item is set and the maximum size of the cache is exceeded, the + oldest item is dropped and called with ``on_evict(key, value)``. + + The ``maxsize`` property can be used to view or adjust the capacity of + the cache, e.g., ``cache.maxsize = new_size``. + """ + def __init__(self, maxsize, on_evict=None): + """ + Parameters + ---------- + maxsize : int + Integer maximum number of items to hold in the cache. + on_evict: callable, optional + Function to call like ``on_evict(key, value)`` when items are + evicted. + """ + if not isinstance(maxsize, int): + raise TypeError('maxsize must be an integer') + if maxsize < 0: + raise ValueError('maxsize must be non-negative') + self._maxsize = maxsize + self._on_evict = on_evict + self._cache = collections.OrderedDict() + self._lock = threading.RLock() + + def __getitem__(self, key): + # record recent use of the key by moving it to the front of the list + with self._lock: + value = self._cache[key] + move_to_end(self._cache, key) + return value + + def _enforce_size_limit(self, capacity): + """Shrink the cache if necessary, evicting the oldest items.""" + while len(self._cache) > capacity: + key, value = self._cache.popitem(last=False) + if self._on_evict is not None: + self._on_evict(key, value) + + def __setitem__(self, key, value): + with self._lock: + if key in self._cache: + # insert the new value at the end + del self._cache[key] + self._cache[key] = value + elif self._maxsize: + # make room if necessary + self._enforce_size_limit(self._maxsize - 1) + self._cache[key] = value + elif self._on_evict is not None: + # not saving, immediately evict + self._on_evict(key, value) + + def __delitem__(self, key): + del self._cache[key] + + def __iter__(self): + # create a list, so accessing the cache during iteration cannot change + # the iteration order + return iter(list(self._cache)) + + def __len__(self): + return len(self._cache) + + @property + def maxsize(self): + """Maximum number of items can be held in the cache.""" + return self._maxsize + + @maxsize.setter + def maxsize(self, size): + """Resize the cache, evicting the oldest items if necessary.""" + if size < 0: + raise ValueError('maxsize must be non-negative') + with self._lock: + self._enforce_size_limit(size) + self._maxsize = size diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py index 8c09277b2d0..195d4647534 100644 --- a/xarray/backends/memory.py +++ b/xarray/backends/memory.py @@ -1,13 +1,11 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import copy import numpy as np -from ..core.variable import Variable from ..core.pycompat import OrderedDict - +from ..core.variable import Variable from .common import AbstractWritableDataStore @@ -19,10 +17,9 @@ class InMemoryDataStore(AbstractWritableDataStore): This store exists purely for internal testing purposes. """ - def __init__(self, variables=None, attributes=None, writer=None): + def __init__(self, variables=None, attributes=None): self._variables = OrderedDict() if variables is None else variables self._attributes = OrderedDict() if attributes is None else attributes - super(InMemoryDataStore, self).__init__(writer) def get_attrs(self): return self._attributes @@ -39,9 +36,6 @@ def get_dimensions(self): def prepare_variable(self, k, v, *args, **kwargs): new_var = Variable(v.dims, np.empty_like(v), v.attrs) - # we copy the variable and stuff all encodings in the - # attributes to imitate what happens when writing to disk. - new_var.attrs.update(v.encoding) self._variables[k] = new_var return new_var, v.data diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b3cb1d8e49f..08ba085b77e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -1,6 +1,5 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import functools import operator import warnings @@ -8,16 +7,17 @@ import numpy as np -from .. import conventions -from .. import Variable -from ..conventions import pop_to +from .. import Variable, coding +from ..coding.variables import pop_to from ..core import indexing -from ..core.utils import (FrozenOrderedDict, close_on_error, is_remote_uri) -from ..core.pycompat import iteritems, basestring, OrderedDict, PY3, suppress - -from .common import (WritableCFDataStore, robust_getitem, BackendArray, - DataStorePickleMixin, find_root) -from .netcdf3 import (encode_nc3_attr_value, encode_nc3_variable) +from ..core.pycompat import PY3, OrderedDict, basestring, iteritems, suppress +from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri +from .common import ( + BackendArray, WritableCFDataStore, find_root, robust_getitem) +from .locks import (NETCDFC_LOCK, HDF5_LOCK, + combine_locks, ensure_lock, get_write_lock) +from .file_manager import CachingFileManager, DummyFileManager +from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. @@ -27,6 +27,9 @@ '|': 'native'} +NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) + + class BaseNetCDF4Array(BackendArray): def __init__(self, variable_name, datastore): self.datastore = datastore @@ -43,64 +46,96 @@ def __init__(self, variable_name, datastore): dtype = np.dtype('O') self.dtype = dtype + def __setitem__(self, key, value): + with self.datastore.lock: + data = self.get_array() + data[key] = value + if self.datastore.autoclose: + self.datastore.close(needs_lock=False) + def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] class NetCDF4ArrayWrapper(BaseNetCDF4Array): def __getitem__(self, key): - key = indexing.unwrap_explicit_indexer( - key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer)) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, + self._getitem) + def _getitem(self, key): if self.datastore.is_remote: # pragma: no cover getitem = functools.partial(robust_getitem, catch=RuntimeError) else: getitem = operator.getitem - with self.datastore.ensure_open(autoclose=True): - try: - data = getitem(self.get_array(), key) - except IndexError: - # Catch IndexError in netCDF4 and return a more informative - # error message. This is most often called when an unsorted - # indexer is used before the data is loaded from disk. - msg = ('The indexing operation you are attempting to perform ' - 'is not valid on netCDF4.Variable object. Try loading ' - 'your data into memory first by calling .load().') - if not PY3: - import traceback - msg += '\n\nOriginal traceback:\n' + traceback.format_exc() - raise IndexError(msg) - - return data + original_array = self.get_array() + + try: + with self.datastore.lock: + array = getitem(original_array, key) + except IndexError: + # Catch IndexError in netCDF4 and return a more informative + # error message. This is most often called when an unsorted + # indexer is used before the data is loaded from disk. + msg = ('The indexing operation you are attempting to perform ' + 'is not valid on netCDF4.Variable object. Try loading ' + 'your data into memory first by calling .load().') + if not PY3: + import traceback + msg += '\n\nOriginal traceback:\n' + traceback.format_exc() + raise IndexError(msg) + return array def _encode_nc4_variable(var): - if var.dtype.kind == 'S': - var = conventions.maybe_encode_as_char_array(var) + for coder in [coding.strings.EncodedStringCoder(allows_unicode=True), + coding.strings.CharacterArrayCoder()]: + var = coder.encode(var) return var -def _get_datatype(var, nc_format='NETCDF4'): +def _check_encoding_dtype_is_vlen_string(dtype): + if dtype is not str: + raise AssertionError( # pragma: no cover + "unexpected dtype encoding %r. This shouldn't happen: please " + "file a bug report at github.com/pydata/xarray" % dtype) + + +def _get_datatype(var, nc_format='NETCDF4', raise_on_invalid_encoding=False): if nc_format == 'NETCDF4': datatype = _nc4_dtype(var) else: + if 'dtype' in var.encoding: + encoded_dtype = var.encoding['dtype'] + _check_encoding_dtype_is_vlen_string(encoded_dtype) + if raise_on_invalid_encoding: + raise ValueError( + 'encoding dtype=str for vlen strings is only supported ' + 'with format=\'NETCDF4\'.') datatype = var.dtype return datatype def _nc4_dtype(var): - if var.dtype.kind == 'U': + if 'dtype' in var.encoding: + dtype = var.encoding.pop('dtype') + _check_encoding_dtype_is_vlen_string(dtype) + elif coding.strings.is_unicode_dtype(var.dtype): dtype = str elif var.dtype.kind in ['i', 'u', 'f', 'c', 'S']: dtype = var.dtype else: - raise ValueError('cannot infer dtype for netCDF4 variable') + raise ValueError('unsupported dtype for netCDF4 variable: {}' + .format(var.dtype)) return dtype -def _nc4_group(ds, group, mode): +def _netcdf4_create_group(dataset, name): + return dataset.createGroup(name) + + +def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): if group in set([None, '', '/']): # use the root group return ds @@ -115,7 +150,7 @@ def _nc4_group(ds, group, mode): ds = ds.groups[key] except KeyError as e: if mode != 'r': - ds = ds.createGroup(key) + ds = create_group(ds, key) else: # wrap error to provide slightly more helpful message raise IOError('group not found: %s' % key, e) @@ -151,8 +186,8 @@ def _force_native_endianness(var): def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, - lsd_okay=True, backend='netCDF4', - unlimited_dims=None): + lsd_okay=True, h5py_okay=False, + backend='netCDF4', unlimited_dims=None): if unlimited_dims is None: unlimited_dims = () @@ -160,9 +195,12 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, safe_to_drop = set(['source', 'original_shape']) valid_encodings = set(['zlib', 'complevel', 'fletcher32', 'contiguous', - 'chunksizes', 'shuffle']) + 'chunksizes', 'shuffle', '_FillValue', 'dtype']) if lsd_okay: valid_encodings.add('least_significant_digit') + if h5py_okay: + valid_encodings.add('compression') + valid_encodings.add('compression_opts') if not raise_on_invalid and encoding.get('chunksizes') is not None: # It's possible to get encoded chunksizes larger than a dimension size @@ -193,17 +231,27 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, return encoding -def _open_netcdf4_group(filename, mode, group=None, **kwargs): +class GroupWrapper(object): + """Wrap netCDF4.Group objects so closing them closes the root group.""" + def __init__(self, value): + self.value = value + + def close(self): + # netCDF4 only allows closing the root group + find_root(self.value).close() + + +def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs): import netCDF4 as nc4 ds = nc4.Dataset(filename, mode=mode, **kwargs) with close_on_error(ds): - ds = _nc4_group(ds, group, mode) + ds = _nc4_require_group(ds, group, mode) _disable_auto_decode_group(ds) - return ds + return GroupWrapper(ds) def _disable_auto_decode_variable(var): @@ -224,128 +272,153 @@ def _disable_auto_decode_group(ds): _disable_auto_decode_variable(var) -class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): +def _is_list_of_strings(value): + if (np.asarray(value).dtype.kind in ['U', 'S'] and + np.asarray(value).size > 1): + return True + else: + return False + + +def _set_nc_attribute(obj, key, value): + if _is_list_of_strings(value): + # encode as NC_STRING if attr is list of strings + try: + obj.setncattr_string(key, value) + except AttributeError: + # Inform users with old netCDF that does not support + # NC_STRING that we can't serialize lists of strings + # as attrs + msg = ('Attributes which are lists of strings are not ' + 'supported with this version of netCDF. Please ' + 'upgrade to netCDF4-python 1.2.4 or greater.') + raise AttributeError(msg) + else: + obj.setncattr(key, value) + + +class NetCDF4DataStore(WritableCFDataStore): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None, - autoclose=False): - - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=NETCDF4_PYTHON_LOCK, autoclose=False): + import netCDF4 - _disable_auto_decode_group(netcdf4_dataset) + if isinstance(manager, netCDF4.Dataset): + _disable_auto_decode_group(manager) + manager = DummyFileManager(GroupWrapper(manager)) - self.ds = netcdf4_dataset - self._autoclose = autoclose - self._isopen = True + self._manager = manager self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) - self._mode = mode = 'a' if mode == 'w' else mode - if opener: - self._opener = functools.partial(opener, mode=self._mode) - else: - self._opener = opener - super(NetCDF4DataStore, self).__init__(writer) + self.lock = ensure_lock(lock) + self.autoclose = autoclose @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, - writer=None, clobber=True, diskless=False, persist=False, - autoclose=False): - import netCDF4 as nc4 + clobber=True, diskless=False, persist=False, + lock=None, lock_maker=None, autoclose=False): + import netCDF4 if (len(filename) == 88 and - LooseVersion(nc4.__version__) < "1.3.1"): + LooseVersion(netCDF4.__version__) < "1.3.1"): warnings.warn( - '\nA segmentation fault may occur when the\n' - 'file path has exactly 88 characters as it does.\n' - 'in this case. The issue is known to occur with\n' - 'version 1.2.4 of netCDF4 and can be addressed by\n' - 'upgrading netCDF4 to at least version 1.3.1.\n' - 'More details can be found here:\n' - 'https://github.com/pydata/xarray/issues/1745 \n') + 'A segmentation fault may occur when the ' + 'file path has exactly 88 characters as it does ' + 'in this case. The issue is known to occur with ' + 'version 1.2.4 of netCDF4 and can be addressed by ' + 'upgrading netCDF4 to at least version 1.3.1. ' + 'More details can be found here: ' + 'https://github.com/pydata/xarray/issues/1745') if format is None: format = 'NETCDF4' - opener = functools.partial(_open_netcdf4_group, filename, mode=mode, - group=group, clobber=clobber, - diskless=diskless, persist=persist, - format=format) - ds = opener() - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose) - def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) - attributes = OrderedDict((k, var.getncattr(k)) - for k in var.ncattrs()) - _ensure_fill_value_valid(data, attributes) - # netCDF4 specific encoding; save _FillValue for later - encoding = {} - filters = var.filters() - if filters is not None: - encoding.update(filters) - chunking = var.chunking() - if chunking is not None: - if chunking == 'contiguous': - encoding['contiguous'] = True - encoding['chunksizes'] = None + if lock is None: + if mode == 'r': + if is_remote_uri(filename): + lock = NETCDFC_LOCK + else: + lock = NETCDF4_PYTHON_LOCK + else: + if format is None or format.startswith('NETCDF4'): + base_lock = NETCDF4_PYTHON_LOCK else: - encoding['contiguous'] = False - encoding['chunksizes'] = tuple(chunking) - # TODO: figure out how to round-trip "endian-ness" without raising - # warnings from netCDF4 - # encoding['endian'] = var.endian() - pop_to(attributes, encoding, 'least_significant_digit') - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape + base_lock = NETCDFC_LOCK + lock = combine_locks([base_lock, get_write_lock(filename)]) + + manager = CachingFileManager( + _open_netcdf4_group, filename, lock, mode=mode, + kwargs=dict(group=group, clobber=clobber, diskless=diskless, + persist=persist, format=format)) + return cls(manager, lock=lock, autoclose=autoclose) + + @property + def ds(self): + return self._manager.acquire().value + + def open_store_variable(self, name, var): + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + NetCDF4ArrayWrapper(name, self)) + attributes = OrderedDict((k, var.getncattr(k)) + for k in var.ncattrs()) + _ensure_fill_value_valid(data, attributes) + # netCDF4 specific encoding; save _FillValue for later + encoding = {} + filters = var.filters() + if filters is not None: + encoding.update(filters) + chunking = var.chunking() + if chunking is not None: + if chunking == 'contiguous': + encoding['contiguous'] = True + encoding['chunksizes'] = None + else: + encoding['contiguous'] = False + encoding['chunksizes'] = tuple(chunking) + # TODO: figure out how to round-trip "endian-ness" without raising + # warnings from netCDF4 + # encoding['endian'] = var.endian() + pop_to(attributes, encoding, 'least_significant_digit') + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + encoding['dtype'] = var.dtype return Variable(dimensions, data, attributes, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in - iteritems(self.ds.variables)) + dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in + iteritems(self.ds.variables)) return dsvars def get_attrs(self): - with self.ensure_open(autoclose=True): - attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) - for k in self.ds.ncattrs()) + attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) + for k in self.ds.ncattrs()) return attrs def get_dimensions(self): - with self.ensure_open(autoclose=True): - dims = FrozenOrderedDict((k, len(v)) - for k, v in iteritems(self.ds.dimensions)) + dims = FrozenOrderedDict((k, len(v)) + for k, v in iteritems(self.ds.dimensions)) return dims def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited()} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v.isunlimited()} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, size=dim_length) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, size=dim_length) def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - if self.format != 'NETCDF4': - value = encode_nc3_attr_value(value) - self.ds.setncattr(key, value) - - def set_variables(self, *args, **kwargs): - with self.ensure_open(autoclose=False): - super(NetCDF4DataStore, self).set_variables(*args, **kwargs) + if self.format != 'NETCDF4': + value = encode_nc3_attr_value(value) + _set_nc_attribute(self.ds, key, value) def encode_variable(self, variable): variable = _force_native_endianness(variable) @@ -357,7 +430,8 @@ def encode_variable(self, variable): def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): - datatype = _get_datatype(variable, self.format) + datatype = _get_datatype(variable, self.format, + raise_on_invalid_encoding=check_encoding) attrs = variable.attrs.copy() fill_value = attrs.pop('_FillValue', None) @@ -396,19 +470,14 @@ def prepare_variable(self, name, variable, check_encoding=False, for k, v in iteritems(attrs): # set attributes one-by-one since netCDF4<1.0.10 can't handle # OrderedDict as the input to setncatts - nc4_var.setncattr(k, v) + _set_nc_attribute(nc4_var, k, v) + + target = NetCDF4ArrayWrapper(name, self) - return nc4_var, variable.data + return target, variable.data def sync(self): - with self.ensure_open(autoclose=True): - super(NetCDF4DataStore, self).sync() - self.ds.sync() + self.ds.sync() - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if ds._isopen: - ds.close() - self._isopen = False + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index 7aa054bc119..c7bfa0ea20b 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -1,13 +1,11 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import unicodedata import numpy as np -from .. import conventions, Variable -from ..core.pycompat import basestring, unicode_type, OrderedDict - +from .. import Variable, coding +from ..core.pycompat import OrderedDict, basestring, unicode_type # Special characters that are permitted in netCDF names except in the # 0th position of the string @@ -67,7 +65,9 @@ def encode_nc3_attrs(attrs): def encode_nc3_variable(var): - var = conventions.maybe_encode_as_char_array(var) + for coder in [coding.strings.EncodedStringCoder(allows_unicode=False), + coding.strings.CharacterArrayCoder()]: + var = coder.encode(var) data = coerce_nc3_dtype(var.data) attrs = encode_nc3_attrs(var.attrs) return Variable(var.dims, data, attrs, var.encoding) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py new file mode 100644 index 00000000000..606ed5251ac --- /dev/null +++ b/xarray/backends/pseudonetcdf_.py @@ -0,0 +1,94 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +from .. import Variable +from ..core import indexing +from ..core.pycompat import OrderedDict +from ..core.utils import Frozen, FrozenOrderedDict +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock + + +# psuedonetcdf can invoke netCDF libraries internally +PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) + + +class PncArrayWrapper(BackendArray): + + def __init__(self, variable_name, datastore): + self.datastore = datastore + self.variable_name = variable_name + array = self.get_array() + self.shape = array.shape + self.dtype = np.dtype(array.dtype) + + def get_array(self): + return self.datastore.ds.variables[self.variable_name] + + def __getitem__(self, key): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, + self._getitem) + + def _getitem(self, key): + array = self.get_array() + with self.datastore.lock: + return array[key] + + +class PseudoNetCDFDataStore(AbstractDataStore): + """Store for accessing datasets via PseudoNetCDF + """ + @classmethod + def open(cls, filename, lock=None, **format_kwds): + from PseudoNetCDF import pncopen + + keywords = dict(kwargs=format_kwds) + # only include mode if explicitly passed + mode = format_kwds.pop('mode', None) + if mode is not None: + keywords['mode'] = mode + + if lock is None: + lock = PNETCDF_LOCK + + manager = CachingFileManager(pncopen, filename, lock=lock, **keywords) + return cls(manager, lock) + + def __init__(self, manager, lock=None): + self._manager = manager + self.lock = ensure_lock(lock) + + @property + def ds(self): + return self._manager.acquire() + + def open_store_variable(self, name, var): + data = indexing.LazilyOuterIndexedArray( + PncArrayWrapper(name, self) + ) + attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs()) + return Variable(var.dimensions, data, attrs) + + def get_variables(self): + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) + + def get_attrs(self): + return Frozen(dict([(k, getattr(self.ds, k)) + for k in self.ds.ncattrs()])) + + def get_dimensions(self): + return Frozen(self.ds.dimensions) + + def get_encoding(self): + encoding = {} + encoding['unlimited_dims'] = set( + [k for k in self.ds.dimensions + if self.ds.dimensions[k].isunlimited()]) + return encoding + + def close(self): + self._manager.close() diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 297d96e47f4..71ea4841b71 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -1,13 +1,11 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import numpy as np from .. import Variable -from ..core.utils import FrozenOrderedDict, Frozen, is_dict_like from ..core import indexing from ..core.pycompat import integer_types - +from ..core.utils import Frozen, FrozenOrderedDict, is_dict_like from .common import AbstractDataStore, BackendArray, robust_getitem @@ -24,9 +22,10 @@ def dtype(self): return self.array.dtype def __getitem__(self, key): - key = indexing.unwrap_explicit_indexer( - key, target=self, allow=indexing.BasicIndexer) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) + def _getitem(self, key): # pull the data from the array attribute if possible, to avoid # downloading coordinate data twice array = getattr(self.array, 'array', self.array) @@ -36,6 +35,7 @@ def __getitem__(self, key): if isinstance(k, integer_types)) if len(axis) > 0: result = np.squeeze(result, axis) + return result @@ -76,7 +76,7 @@ def open(cls, url, session=None): return cls(ds) def open_store_variable(self, var): - data = indexing.LazilyIndexedArray(PydapArrayWrapper(var)) + data = indexing.LazilyOuterIndexedArray(PydapArrayWrapper(var)) return Variable(var.dimensions, data, _fix_attributes(var.attributes)) diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 37f1db1f6a7..574fff744e3 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,16 +1,20 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools +from __future__ import absolute_import, division, print_function import numpy as np from .. import Variable -from ..core.utils import (FrozenOrderedDict, Frozen) from ..core import indexing +from ..core.utils import Frozen, FrozenOrderedDict +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import ( + HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, SerializableLock) + -from .common import AbstractDataStore, DataStorePickleMixin, BackendArray +# PyNIO can invoke netCDF libraries internally +# Add a dedicated lock just in case NCL as well isn't thread-safe. +NCL_LOCK = SerializableLock() +PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK]) class NioArrayWrapper(BackendArray): @@ -23,52 +27,52 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.typecode()) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): - key = indexing.unwrap_explicit_indexer( - key, target=self, allow=indexing.BasicIndexer) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) - with self.datastore.ensure_open(autoclose=True): - array = self.get_array() + def _getitem(self, key): + array = self.get_array() + with self.datastore.lock: if key == () and self.ndim == 0: return array.get_value() return array[key] -class NioDataStore(AbstractDataStore, DataStorePickleMixin): +class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r', autoclose=False): + def __init__(self, filename, mode='r', lock=None): import Nio - opener = functools.partial(Nio.open_file, filename, mode=mode) - self.ds = opener() + if lock is None: + lock = PYNIO_LOCK + self.lock = ensure_lock(lock) + self._manager = CachingFileManager( + Nio.open_file, filename, lock=lock, mode=mode) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. self.ds.set_option('MaskedArrayMode', 'MaskedNever') - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - data = indexing.LazilyIndexedArray(NioArrayWrapper(name, self)) + data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.attributes) + return Frozen(self.ds.attributes) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -77,6 +81,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index c624c1f5ff8..7a343a6529e 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -2,18 +2,19 @@ import warnings from collections import OrderedDict from distutils.version import LooseVersion + import numpy as np from .. import DataArray -from ..core.utils import is_scalar from ..core import indexing +from ..core.utils import is_scalar from .common import BackendArray -try: - from dask.utils import SerializableLock as Lock -except ImportError: - from threading import Lock +from .file_manager import CachingFileManager +from .locks import SerializableLock -RASTERIO_LOCK = Lock() + +# TODO: should this be GDAL_LOCK instead? +RASTERIO_LOCK = SerializableLock() _ERROR_MSG = ('The kind of indexing operation you are trying to do is not ' 'valid on rasterio files. Try to load your data with ds.load()' @@ -23,65 +24,103 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, rasterio_ds): - self.rasterio_ds = rasterio_ds - self._shape = (rasterio_ds.count, rasterio_ds.height, - rasterio_ds.width) - self._ndims = len(self.shape) + def __init__(self, manager): + self.manager = manager - @property - def dtype(self): - dtypes = self.rasterio_ds.dtypes + # cannot save riods as an attribute: this would break pickleability + riods = manager.acquire() + + self._shape = (riods.count, riods.height, riods.width) + + dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') - return np.dtype(dtypes[0]) + self._dtype = np.dtype(dtypes[0]) + + @property + def dtype(self): + return self._dtype @property def shape(self): return self._shape - def __getitem__(self, key): - key = indexing.unwrap_explicit_indexer( - key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer)) + def _get_indexer(self, key): + """ Get indexer for rasterio array. + + Parameter + --------- + key: tuple of int + + Returns + ------- + band_key: an indexer for the 1st dimension + window: two tuples. Each consists of (start, stop). + squeeze_axis: axes to be squeezed + np_ind: indexer for loaded numpy array + + See also + -------- + indexing.decompose_indexer + """ + assert len(key) == 3, 'rasterio datasets should always be 3D' # bands cannot be windowed but they can be listed band_key = key[0] - n_bands = self.shape[0] + np_inds = [] + # bands (axis=0) cannot be windowed but they can be listed if isinstance(band_key, slice): - start, stop, step = band_key.indices(n_bands) - if step is not None and step != 1: - raise IndexError(_ERROR_MSG) - band_key = np.arange(start, stop) + start, stop, step = band_key.indices(self.shape[0]) + band_key = np.arange(start, stop, step) # be sure we give out a list band_key = (np.asarray(band_key) + 1).tolist() + if isinstance(band_key, list): # if band_key is not a scalar + np_inds.append(slice(None)) # but other dims can only be windowed window = [] squeeze_axis = [] for i, (k, n) in enumerate(zip(key[1:], self.shape[1:])): if isinstance(k, slice): + # step is always positive. see indexing.decompose_indexer start, stop, step = k.indices(n) - if step is not None and step != 1: - raise IndexError(_ERROR_MSG) + np_inds.append(slice(None, None, step)) elif is_scalar(k): # windowed operations will always return an array # we will have to squeeze it later - squeeze_axis.append(i + 1) + squeeze_axis.append(- (2 - i)) start = k stop = k + 1 else: - k = np.asarray(k) - start = k[0] - stop = k[-1] + 1 - ids = np.arange(start, stop) - if not ((k.shape == ids.shape) and np.all(k == ids)): - raise IndexError(_ERROR_MSG) + start, stop = np.min(k), np.max(k) + 1 + np_inds.append(k - start) window.append((start, stop)) - out = self.rasterio_ds.read(band_key, window=tuple(window)) + if isinstance(key[1], np.ndarray) and isinstance(key[2], np.ndarray): + # do outer-style indexing + np_inds[-2:] = np.ix_(*np_inds[-2:]) + + return band_key, tuple(window), tuple(squeeze_axis), tuple(np_inds) + + def _getitem(self, key): + band_key, window, squeeze_axis, np_inds = self._get_indexer(key) + + if not band_key or any(start == stop for (start, stop) in window): + # no need to do IO + shape = (len(band_key),) + tuple( + stop - start for (start, stop) in window) + out = np.zeros(shape, dtype=self.dtype) + else: + riods = self.manager.acquire() + out = riods.read(band_key, window=window) + if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) - return out + return out[np_inds] + + def __getitem__(self, key): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem) def _parse_envi(meta): @@ -130,7 +169,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, from affine import Affine da = xr.open_rasterio('path_to_file.tif') - transform = Affine(*da.attrs['transform']) + transform = Affine.from_gdal(*da.attrs['transform']) nx, ny = da.sizes['x'], da.sizes['y'] x, y = np.meshgrid(np.arange(nx)+0.5, np.arange(ny)+0.5) * transform @@ -167,7 +206,9 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, """ import rasterio - riods = rasterio.open(filename, mode='r') + + manager = CachingFileManager(rasterio.open, filename, mode='r') + riods = manager.acquire() if cache is None: cache = chunks is None @@ -190,26 +231,28 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, if parse: nx, ny = riods.width, riods.height # xarray coordinates are pixel centered - x, _ = (np.arange(nx)+0.5, np.zeros(nx)+0.5) * transform - _, y = (np.zeros(ny)+0.5, np.arange(ny)+0.5) * transform + x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform + _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform coords['y'] = y coords['x'] = x else: # 2d coordinates parse = False if (parse_coordinates is None) else parse_coordinates if parse: - warnings.warn("The file coordinates' transformation isn't " - "rectilinear: xarray won't parse the coordinates " - "in this case. Set `parse_coordinates=False` to " - "suppress this warning.", - RuntimeWarning, stacklevel=3) + warnings.warn( + "The file coordinates' transformation isn't " + "rectilinear: xarray won't parse the coordinates " + "in this case. Set `parse_coordinates=False` to " + "suppress this warning.", + RuntimeWarning, stacklevel=3) # Attributes attrs = dict() # Affine transformation matrix (always available) # This describes coefficients mapping pixel coordinates to CRS # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) + # always (0, 0, 1) per definition (see + # https://github.com/sgillies/affine) attrs['transform'] = tuple(transform)[:6] if hasattr(riods, 'crs') and riods.crs: # CRS is a dict-like object specific to rasterio @@ -223,14 +266,11 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility attrs['is_tiled'] = np.uint8(riods.is_tiled) - if hasattr(riods, 'transform'): - # Affine transformation matrix (tuple of floats) - # Describes coefficients mapping pixel coordinates to CRS - attrs['transform'] = tuple(riods.transform) if hasattr(riods, 'nodatavals'): # The nodata values for the raster bands - attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.nodatavals]) + attrs['nodatavals'] = tuple( + np.nan if nodataval is None else nodataval + for nodataval in riods.nodatavals) # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} @@ -242,16 +282,17 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise - if isinstance(v, (list, np.ndarray)) and len(v) == riods.count: + if (isinstance(v, (list, np.ndarray)) and + len(v) == riods.count): coords[k] = ('band', np.asarray(v)) else: attrs[k] = v - data = indexing.LazilyIndexedArray(RasterioArrayWrapper(riods)) + data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) - if cache and (chunks is None): + if cache and chunks is None: data = indexing.MemoryCachedArray(data) result = DataArray(data=data, dims=('band', 'y', 'x'), @@ -273,6 +314,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=lock) # Make the file closeable - result._file_obj = riods + result._file_obj = manager return result diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index dba2e5672a2..b009342efb6 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,20 +1,20 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import functools +from __future__ import absolute_import, division, print_function + +import warnings +from distutils.version import LooseVersion from io import BytesIO import numpy as np -import warnings from .. import Variable -from ..core.pycompat import iteritems, OrderedDict, basestring -from ..core.utils import (Frozen, FrozenOrderedDict) from ..core.indexing import NumpyIndexingAdapter - -from .common import WritableCFDataStore, DataStorePickleMixin, BackendArray -from .netcdf3 import (is_valid_nc3_name, encode_nc3_attr_value, - encode_nc3_variable) +from ..core.pycompat import OrderedDict, basestring, iteritems +from ..core.utils import Frozen, FrozenOrderedDict +from .common import BackendArray, WritableCFDataStore +from .locks import get_write_lock +from .file_manager import CachingFileManager, DummyFileManager +from .netcdf3 import ( + encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) def _decode_string(s): @@ -41,19 +41,26 @@ def __init__(self, variable_name, datastore): str(array.dtype.itemsize)) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name].data def __getitem__(self, key): - with self.datastore.ensure_open(autoclose=True): - data = NumpyIndexingAdapter(self.get_array())[key] - # Copy data if the source file is mmapped. - # This makes things consistent - # with the netCDF4 library by ensuring - # we can safely read arrays even - # after closing associated files. - copy = self.datastore.ds.use_mmap - return np.array(data, dtype=self.dtype, copy=copy) + data = NumpyIndexingAdapter(self.get_array())[key] + # Copy data if the source file is mmapped. This makes things consistent + # with the netCDF4 library by ensuring we can safely read arrays even + # after closing associated files. + copy = self.datastore.ds.use_mmap + return np.array(data, dtype=self.dtype, copy=copy) + + def __setitem__(self, key, value): + data = self.datastore.ds.variables[self.variable_name] + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise def _open_scipy_netcdf(filename, mode, mmap, version): @@ -95,7 +102,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): raise -class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): +class ScipyDataStore(WritableCFDataStore): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a @@ -105,11 +112,12 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=False): + mmap=None, lock=None): import scipy import scipy.io - if mode != 'r' and scipy.__version__ < '0.13': # pragma: no cover + if (mode != 'r' and + scipy.__version__ < LooseVersion('0.13')): # pragma: no cover warnings.warn('scipy %s detected; ' 'the minimal recommended version is 0.13. ' 'Older version of this library do not reliably ' @@ -128,34 +136,38 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - opener = functools.partial(_open_scipy_netcdf, - filename=filename_or_obj, - mode=mode, mmap=mmap, version=version) - self.ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if (lock is None and mode != 'r' and + isinstance(filename_or_obj, basestring)): + lock = get_write_lock(filename_or_obj) + + if isinstance(filename_or_obj, basestring): + manager = CachingFileManager( + _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock, + kwargs=dict(mmap=mmap, version=version)) + else: + scipy_dataset = _open_scipy_netcdf( + filename_or_obj, mode=mode, mmap=mmap, version=version) + manager = DummyFileManager(scipy_dataset) + + self._manager = manager - super(ScipyDataStore, self).__init__(writer) + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - return Variable(var.dimensions, ScipyArrayWrapper(name, self), - _decode_attrs(var._attributes)) + return Variable(var.dimensions, ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes)) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(_decode_attrs(self.ds._attributes)) + return Frozen(_decode_attrs(self.ds._attributes)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -164,22 +176,20 @@ def get_encoding(self): return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if name in self.ds.dimensions: - raise ValueError('%s does not support modifying dimensions' - % type(self).__name__) - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, dim_length) + if name in self.ds.dimensions: + raise ValueError('%s does not support modifying dimensions' + % type(self).__name__) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, dim_length) def _validate_attr_key(self, key): if not is_valid_nc3_name(key): raise ValueError("Not a valid attribute name") def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self._validate_attr_key(key) - value = encode_nc3_attr_value(value) - setattr(self.ds, key, value) + self._validate_attr_key(key) + value = encode_nc3_attr_value(value) + setattr(self.ds, key, value) def encode_variable(self, variable): variable = encode_nc3_variable(variable) @@ -188,8 +198,9 @@ def encode_variable(self, variable): def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): if check_encoding and variable.encoding: - raise ValueError('unexpected encoding for scipy backend: %r' - % list(variable.encoding)) + if variable.encoding != {'_FillValue': None}: + raise ValueError('unexpected encoding for scipy backend: %r' + % list(variable.encoding)) data = variable.data # nb. this still creates a numpy array in all memory, even though we @@ -201,25 +212,13 @@ def prepare_variable(self, name, variable, check_encoding=False, for k, v in iteritems(variable.attrs): self._validate_attr_key(k) setattr(scipy_var, k, v) - return scipy_var, data + + target = ScipyArrayWrapper(name, self) + + return target, data def sync(self): - with self.ensure_open(autoclose=True): - super(ScipyDataStore, self).sync() - self.ds.flush() + self.ds.sync() def close(self): - self.ds.close() - self._isopen = False - - def __exit__(self, type, value, tb): - self.close() - - def __setstate__(self, state): - filename = state['_opener'].keywords['filename'] - if hasattr(filename, 'seek'): - # it's a file-like object - # seek to the start of the file so scipy can read it - filename.seek(0) - super(ScipyDataStore, self).__setstate__(state) - self._isopen = True + self._manager.close() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 02753f6cca9..06fe7f04e4f 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,18 +1,14 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from itertools import product -from base64 import b64encode +from __future__ import absolute_import, division, print_function + +from distutils.version import LooseVersion import numpy as np -from .. import coding -from .. import Variable +from .. import Variable, coding, conventions from ..core import indexing +from ..core.pycompat import OrderedDict, integer_types, iteritems from ..core.utils import FrozenOrderedDict, HiddenKeyDict -from ..core.pycompat import iteritems, OrderedDict, integer_types -from .common import AbstractWritableDataStore, BackendArray, ArrayWriter -from .. import conventions +from .common import AbstractWritableDataStore, ArrayWriter, BackendArray # need some special secret attributes to tell us the dimensions _DIMENSION_KEY = '_ARRAY_DIMENSIONS' @@ -27,47 +23,11 @@ def _encode_zarr_attr_value(value): # this checks if it's a scalar number elif isinstance(value, np.generic): encoded = value.item() - # np.string_('X').item() returns a type `bytes` - # zarr still doesn't like that - if type(encoded) is bytes: - encoded = b64encode(encoded) else: encoded = value return encoded -def _ensure_valid_fill_value(value, dtype): - if dtype.type == np.string_ and type(value) == bytes: - valid = b64encode(value) - else: - valid = value - return _encode_zarr_attr_value(valid) - - -def _replace_slices_with_arrays(key, shape): - """Replace slice objects in vindex with equivalent ndarray objects.""" - num_slices = sum(1 for k in key if isinstance(k, slice)) - ndims = [k.ndim for k in key if isinstance(k, np.ndarray)] - array_subspace_size = max(ndims) if ndims else 0 - assert len(key) == len(shape) - new_key = [] - slice_count = 0 - for k, size in zip(key, shape): - if isinstance(k, slice): - # the slice subspace always appears after the ndarray subspace - array = np.arange(*k.indices(size)) - sl = [np.newaxis] * len(shape) - sl[array_subspace_size + slice_count] = slice(None) - k = array[tuple(sl)] - slice_count += 1 - else: - assert isinstance(k, np.ndarray) - k = k[(slice(None),) * array_subspace_size + - (np.newaxis,) * num_slices] - new_key.append(k) - return tuple(new_key) - - class ZarrArrayWrapper(BackendArray): def __init__(self, variable_name, datastore): self.datastore = datastore @@ -87,8 +47,8 @@ def __getitem__(self, key): if isinstance(key, indexing.BasicIndexer): return array[key.tuple] elif isinstance(key, indexing.VectorizedIndexer): - return array.vindex[_replace_slices_with_arrays(key.tuple, - self.shape)] + return array.vindex[indexing._arrayize_vectorized_indexer( + key.tuple, self.shape).tuple] else: assert isinstance(key, indexing.OuterIndexer) return array.oindex[key.tuple] @@ -117,24 +77,18 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): # while dask chunks can be variable sized # http://dask.pydata.org/en/latest/array-design.html#chunks if var_chunks and enc_chunks is None: - all_var_chunks = list(product(*var_chunks)) - first_var_chunk = all_var_chunks[0] - # all but the last chunk have to match exactly - for this_chunk in all_var_chunks[:-1]: - if this_chunk != first_var_chunk: - raise ValueError( - "Zarr requires uniform chunk sizes excpet for final chunk." - " Variable %r has incompatible chunks. Consider " - "rechunking using `chunk()`." % (var_chunks,)) - # last chunk is allowed to be smaller - last_var_chunk = all_var_chunks[-1] - for len_first, len_last in zip(first_var_chunk, last_var_chunk): - if len_last > len_first: - raise ValueError( - "Final chunk of Zarr array must be smaller than first. " - "Variable %r has incompatible chunks. Consider rechunking " - "using `chunk()`." % var_chunks) - return first_var_chunk + if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks): + raise ValueError( + "Zarr requires uniform chunk sizes except for final chunk." + " Variable dask chunks %r are incompatible. Consider " + "rechunking using `chunk()`." % (var_chunks,)) + if any((chunks[0] < chunks[-1]) for chunks in var_chunks): + raise ValueError( + "Final chunk of Zarr array must be the same size or smaller " + "than the first. Variable Dask chunks %r are incompatible. " + "Consider rechunking using `chunk()`." % var_chunks) + # return the first chunk for each dimension + return tuple(chunk[0] for chunk in var_chunks) # from here on, we are dealing with user-specified chunks in encoding # zarr allows chunks to be an integer, in which case it uses the same chunk @@ -148,9 +102,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): enc_chunks_tuple = tuple(enc_chunks) if len(enc_chunks_tuple) != ndim: - raise ValueError("zarr chunks tuple %r must have same length as " - "variable.ndim %g" % - (enc_chunks_tuple, ndim)) + # throw away encoding chunks, start over + return _determine_zarr_chunks(None, var_chunks, ndim) for x in enc_chunks_tuple: if not isinstance(x, int): @@ -173,7 +126,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): # threads if var_chunks and enc_chunks_tuple: for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks): - for dchunk in dchunks: + for dchunk in dchunks[:-1]: if dchunk % zchunk: raise NotImplementedError( "Specified zarr chunks %r would overlap multiple dask " @@ -181,6 +134,13 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): " Consider rechunking the data using " "`chunk()` or specifying different chunks in encoding." % (enc_chunks_tuple, var_chunks)) + if dchunks[-1] > zchunk: + raise ValueError( + "Final chunk of Zarr array must be the same size or " + "smaller than the first. The specified Zarr chunk " + "encoding is %r, but %r in variable Dask chunks %r is " + "incompatible. Consider rechunking using `chunk()`." + % (enc_chunks_tuple, dchunks, var_chunks)) return enc_chunks_tuple raise AssertionError( @@ -247,22 +207,15 @@ def encode_zarr_variable(var, needs_copy=True, name=None): A variable which has been encoded as described above. """ - if var.dtype.kind == 'O': - raise NotImplementedError("Variable `%s` is an object. Zarr " - "store can't yet encode objects." % name) - - for coder in [coding.times.CFDatetimeCoder(), - coding.times.CFTimedeltaCoder(), - coding.variables.CFScaleOffsetCoder(), - coding.variables.CFMaskCoder(), - coding.variables.UnsignedIntegerCoder()]: - var = coder.encode(var, name=name) - - var = conventions.maybe_encode_nonstring_dtype(var, name=name) - var = conventions.maybe_default_fill_value(var) - var = conventions.maybe_encode_bools(var) - var = conventions.ensure_dtype_not_object(var, name=name) - var = conventions.maybe_encode_string_dtype(var, name=name) + var = conventions.encode_cf_variable(var, name=name) + + # zarr allows unicode, but not variable-length strings, so it's both + # simpler and more compact to always encode as UTF-8 explicitly. + # TODO: allow toggling this explicitly via dtype in encoding. + coder = coding.strings.EncodedStringCoder(allows_unicode=False) + var = coder.encode(var, name=name) + var = coding.strings.ensure_fixed_length_bytes(var) + return var @@ -271,31 +224,28 @@ class ZarrStore(AbstractWritableDataStore): """ @classmethod - def open_group(cls, store, mode='r', synchronizer=None, group=None, - writer=None): + def open_group(cls, store, mode='r', synchronizer=None, group=None): import zarr + min_zarr = '2.2' + + if LooseVersion(zarr.__version__) < min_zarr: # pragma: no cover + raise NotImplementedError("Zarr version %s or greater is " + "required by xarray. See zarr " + "installation " + "http://zarr.readthedocs.io/en/stable/" + "#installation" % min_zarr) zarr_group = zarr.open_group(store=store, mode=mode, synchronizer=synchronizer, path=group) - return cls(zarr_group, writer=writer) + return cls(zarr_group) - def __init__(self, zarr_group, writer=None): + def __init__(self, zarr_group): self.ds = zarr_group self._read_only = self.ds.read_only self._synchronizer = self.ds.synchronizer self._group = self.ds.path - if writer is None: - # by default, we should not need a lock for writing zarr because - # we do not (yet) allow overlapping chunks during write - zarr_writer = ArrayWriter(lock=False) - else: - zarr_writer = writer - - # do we need to define attributes for all of the opener keyword args? - super(ZarrStore, self).__init__(zarr_writer) - def open_store_variable(self, name, zarr_array): - data = indexing.LazilyIndexedArray(ZarrArrayWrapper(name, self)) + data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, _DIMENSION_KEY) attributes = OrderedDict(attributes) @@ -357,8 +307,9 @@ def prepare_variable(self, name, variable, check_encoding=False, dtype = variable.dtype shape = variable.shape - fill_value = _ensure_valid_fill_value(attrs.pop('_FillValue', None), - dtype) + fill_value = attrs.pop('_FillValue', None) + if variable.encoding == {'_FillValue': None} and fill_value is None: + variable.encoding = {} encoding = _extract_zarr_variable_encoding( variable, raise_on_invalid=check_encoding) @@ -379,6 +330,9 @@ def store(self, variables, attributes, *args, **kwargs): AbstractWritableDataStore.store(self, variables, attributes, *args, **kwargs) + def sync(self): + pass + def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, decode_cf=True, mask_and_scale=True, decode_times=True, @@ -482,7 +436,7 @@ def maybe_chunk(name, var): if (var.ndim > 0) and (chunks is not None): # does this cause any data to be read? token2 = tokenize(name, var._data) - name2 = 'zarr-%s-%s' % (name, token2) + name2 = 'zarr-%s' % token2 return var.chunk(chunks, name=name2, lock=None) else: return var diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py new file mode 100644 index 00000000000..83e8c7a7e4b --- /dev/null +++ b/xarray/coding/cftime_offsets.py @@ -0,0 +1,735 @@ +"""Time offset classes for use with cftime.datetime objects""" +# The offset classes and mechanisms for generating time ranges defined in +# this module were copied/adapted from those defined in pandas. See in +# particular the objects and methods defined in pandas.tseries.offsets +# and pandas.core.indexes.datetimes. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re +from datetime import timedelta +from functools import partial + +import numpy as np + +from ..core.pycompat import basestring +from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso +from .times import format_cftime_datetime + + +def get_date_type(calendar): + """Return the cftime date type for a given calendar name.""" + try: + import cftime + except ImportError: + raise ImportError( + 'cftime is required for dates with non-standard calendars') + else: + calendars = { + 'noleap': cftime.DatetimeNoLeap, + '360_day': cftime.Datetime360Day, + '365_day': cftime.DatetimeNoLeap, + '366_day': cftime.DatetimeAllLeap, + 'gregorian': cftime.DatetimeGregorian, + 'proleptic_gregorian': cftime.DatetimeProlepticGregorian, + 'julian': cftime.DatetimeJulian, + 'all_leap': cftime.DatetimeAllLeap, + 'standard': cftime.DatetimeProlepticGregorian + } + return calendars[calendar] + + +class BaseCFTimeOffset(object): + _freq = None + + def __init__(self, n=1): + if not isinstance(n, int): + raise TypeError( + "The provided multiple 'n' must be an integer. " + "Instead a value of type {!r} was provided.".format(type(n))) + self.n = n + + def rule_code(self): + return self._freq + + def __eq__(self, other): + return self.n == other.n and self.rule_code() == other.rule_code() + + def __ne__(self, other): + return not self == other + + def __add__(self, other): + return self.__apply__(other) + + def __sub__(self, other): + import cftime + + if isinstance(other, cftime.datetime): + raise TypeError('Cannot subtract a cftime.datetime ' + 'from a time offset.') + elif type(other) == type(self): + return type(self)(self.n - other.n) + else: + return NotImplemented + + def __mul__(self, other): + return type(self)(n=other * self.n) + + def __neg__(self): + return self * -1 + + def __rmul__(self, other): + return self.__mul__(other) + + def __radd__(self, other): + return self.__add__(other) + + def __rsub__(self, other): + if isinstance(other, BaseCFTimeOffset) and type(self) != type(other): + raise TypeError('Cannot subtract cftime offsets of differing ' + 'types') + return -self + other + + def __apply__(self): + return NotImplemented + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + test_date = (self + date) - self + return date == test_date + + def rollforward(self, date): + if self.onOffset(date): + return date + else: + return date + type(self)() + + def rollback(self, date): + if self.onOffset(date): + return date + else: + return date - type(self)() + + def __str__(self): + return '<{}: n={}>'.format(type(self).__name__, self.n) + + def __repr__(self): + return str(self) + + +def _days_in_month(date): + """The number of days in the month of the given date""" + if date.month == 12: + reference = type(date)(date.year + 1, 1, 1) + else: + reference = type(date)(date.year, date.month + 1, 1) + return (reference - timedelta(days=1)).day + + +def _adjust_n_months(other_day, n, reference_day): + """Adjust the number of times a monthly offset is applied based + on the day of a given date, and the reference day provided. + """ + if n > 0 and other_day < reference_day: + n = n - 1 + elif n <= 0 and other_day > reference_day: + n = n + 1 + return n + + +def _adjust_n_years(other, n, month, reference_day): + """Adjust the number of times an annual offset is applied based on + another date, and the reference day provided""" + if n > 0: + if other.month < month or (other.month == month and + other.day < reference_day): + n -= 1 + else: + if other.month > month or (other.month == month and + other.day > reference_day): + n += 1 + return n + + +def _shift_months(date, months, day_option='start'): + """Shift the date to a month start or end a given number of months away. + """ + delta_year = (date.month + months) // 12 + month = (date.month + months) % 12 + + if month == 0: + month = 12 + delta_year = delta_year - 1 + year = date.year + delta_year + + if day_option == 'start': + day = 1 + elif day_option == 'end': + reference = type(date)(year, month, 1) + day = _days_in_month(reference) + else: + raise ValueError(day_option) + return date.replace(year=year, month=month, day=day) + + +class MonthBegin(BaseCFTimeOffset): + _freq = 'MS' + + def __apply__(self, other): + n = _adjust_n_months(other.day, self.n, 1) + return _shift_months(other, n, 'start') + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == 1 + + +class MonthEnd(BaseCFTimeOffset): + _freq = 'M' + + def __apply__(self, other): + n = _adjust_n_months(other.day, self.n, _days_in_month(other)) + return _shift_months(other, n, 'end') + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == _days_in_month(date) + + +_MONTH_ABBREVIATIONS = { + 1: 'JAN', + 2: 'FEB', + 3: 'MAR', + 4: 'APR', + 5: 'MAY', + 6: 'JUN', + 7: 'JUL', + 8: 'AUG', + 9: 'SEP', + 10: 'OCT', + 11: 'NOV', + 12: 'DEC' +} + + +class YearOffset(BaseCFTimeOffset): + _freq = None + _day_option = None + _default_month = None + + def __init__(self, n=1, month=None): + BaseCFTimeOffset.__init__(self, n) + if month is None: + self.month = self._default_month + else: + self.month = month + if not isinstance(self.month, int): + raise TypeError("'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(self.month)) + elif not (1 <= self.month <= 12): + raise ValueError("'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(self.month)) + + def __apply__(self, other): + if self._day_option == 'start': + reference_day = 1 + elif self._day_option == 'end': + reference_day = _days_in_month(other) + else: + raise ValueError(self._day_option) + years = _adjust_n_years(other, self.n, self.month, reference_day) + months = years * 12 + (self.month - other.month) + return _shift_months(other, months, self._day_option) + + def __sub__(self, other): + import cftime + + if isinstance(other, cftime.datetime): + raise TypeError('Cannot subtract cftime.datetime from offset.') + elif type(other) == type(self) and other.month == self.month: + return type(self)(self.n - other.n, month=self.month) + else: + return NotImplemented + + def __mul__(self, other): + return type(self)(n=other * self.n, month=self.month) + + def rule_code(self): + return '{}-{}'.format(self._freq, _MONTH_ABBREVIATIONS[self.month]) + + def __str__(self): + return '<{}: n={}, month={}>'.format( + type(self).__name__, self.n, self.month) + + +class YearBegin(YearOffset): + _freq = 'AS' + _day_option = 'start' + _default_month = 1 + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == 1 and date.month == self.month + + def rollforward(self, date): + """Roll date forward to nearest start of year""" + if self.onOffset(date): + return date + else: + return date + YearBegin(month=self.month) + + def rollback(self, date): + """Roll date backward to nearest start of year""" + if self.onOffset(date): + return date + else: + return date - YearBegin(month=self.month) + + +class YearEnd(YearOffset): + _freq = 'A' + _day_option = 'end' + _default_month = 12 + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == _days_in_month(date) and date.month == self.month + + def rollforward(self, date): + """Roll date forward to nearest end of year""" + if self.onOffset(date): + return date + else: + return date + YearEnd(month=self.month) + + def rollback(self, date): + """Roll date backward to nearest end of year""" + if self.onOffset(date): + return date + else: + return date - YearEnd(month=self.month) + + +class Day(BaseCFTimeOffset): + _freq = 'D' + + def __apply__(self, other): + return other + timedelta(days=self.n) + + +class Hour(BaseCFTimeOffset): + _freq = 'H' + + def __apply__(self, other): + return other + timedelta(hours=self.n) + + +class Minute(BaseCFTimeOffset): + _freq = 'T' + + def __apply__(self, other): + return other + timedelta(minutes=self.n) + + +class Second(BaseCFTimeOffset): + _freq = 'S' + + def __apply__(self, other): + return other + timedelta(seconds=self.n) + + +_FREQUENCIES = { + 'A': YearEnd, + 'AS': YearBegin, + 'Y': YearEnd, + 'YS': YearBegin, + 'M': MonthEnd, + 'MS': MonthBegin, + 'D': Day, + 'H': Hour, + 'T': Minute, + 'min': Minute, + 'S': Second, + 'AS-JAN': partial(YearBegin, month=1), + 'AS-FEB': partial(YearBegin, month=2), + 'AS-MAR': partial(YearBegin, month=3), + 'AS-APR': partial(YearBegin, month=4), + 'AS-MAY': partial(YearBegin, month=5), + 'AS-JUN': partial(YearBegin, month=6), + 'AS-JUL': partial(YearBegin, month=7), + 'AS-AUG': partial(YearBegin, month=8), + 'AS-SEP': partial(YearBegin, month=9), + 'AS-OCT': partial(YearBegin, month=10), + 'AS-NOV': partial(YearBegin, month=11), + 'AS-DEC': partial(YearBegin, month=12), + 'A-JAN': partial(YearEnd, month=1), + 'A-FEB': partial(YearEnd, month=2), + 'A-MAR': partial(YearEnd, month=3), + 'A-APR': partial(YearEnd, month=4), + 'A-MAY': partial(YearEnd, month=5), + 'A-JUN': partial(YearEnd, month=6), + 'A-JUL': partial(YearEnd, month=7), + 'A-AUG': partial(YearEnd, month=8), + 'A-SEP': partial(YearEnd, month=9), + 'A-OCT': partial(YearEnd, month=10), + 'A-NOV': partial(YearEnd, month=11), + 'A-DEC': partial(YearEnd, month=12) +} + + +_FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys()) +_PATTERN = '^((?P\d+)|())(?P({0}))$'.format( + _FREQUENCY_CONDITION) + + +def to_offset(freq): + """Convert a frequency string to the appropriate subclass of + BaseCFTimeOffset.""" + if isinstance(freq, BaseCFTimeOffset): + return freq + else: + try: + freq_data = re.match(_PATTERN, freq).groupdict() + except AttributeError: + raise ValueError('Invalid frequency string provided') + + freq = freq_data['freq'] + multiples = freq_data['multiple'] + if multiples is None: + multiples = 1 + else: + multiples = int(multiples) + + return _FREQUENCIES[freq](n=multiples) + + +def to_cftime_datetime(date_str_or_date, calendar=None): + import cftime + + if isinstance(date_str_or_date, basestring): + if calendar is None: + raise ValueError( + 'If converting a string to a cftime.datetime object, ' + 'a calendar type must be provided') + date, _ = _parse_iso8601_with_reso(get_date_type(calendar), + date_str_or_date) + return date + elif isinstance(date_str_or_date, cftime.datetime): + return date_str_or_date + else: + raise TypeError("date_str_or_date must be a string or a " + 'subclass of cftime.datetime. Instead got ' + '{!r}.'.format(date_str_or_date)) + + +def normalize_date(date): + """Round datetime down to midnight.""" + return date.replace(hour=0, minute=0, second=0, microsecond=0) + + +def _maybe_normalize_date(date, normalize): + """Round datetime down to midnight if normalize is True.""" + if normalize: + return normalize_date(date) + else: + return date + + +def _generate_linear_range(start, end, periods): + """Generate an equally-spaced sequence of cftime.datetime objects between + and including two dates (whose length equals the number of periods).""" + import cftime + + total_seconds = (end - start).total_seconds() + values = np.linspace(0., total_seconds, periods, endpoint=True) + units = 'seconds since {}'.format(format_cftime_datetime(start)) + calendar = start.calendar + return cftime.num2date(values, units=units, calendar=calendar, + only_use_cftime_datetimes=True) + + +def _generate_range(start, end, periods, offset): + """Generate a regular range of cftime.datetime objects with a + given time offset. + + Adapted from pandas.tseries.offsets.generate_range. + + Parameters + ---------- + start : cftime.datetime, or None + Start of range + end : cftime.datetime, or None + End of range + periods : int, or None + Number of elements in the sequence + offset : BaseCFTimeOffset + An offset class designed for working with cftime.datetime objects + + Returns + ------- + A generator object + """ + if start: + start = offset.rollforward(start) + + if end: + end = offset.rollback(end) + + if periods is None and end < start: + end = None + periods = 0 + + if end is None: + end = start + (periods - 1) * offset + + if start is None: + start = end - (periods - 1) * offset + + current = start + if offset.n >= 0: + while current <= end: + yield current + + next_date = current + offset + if next_date <= current: + raise ValueError('Offset {offset} did not increment date' + .format(offset=offset)) + current = next_date + else: + while current >= end: + yield current + + next_date = current + offset + if next_date >= current: + raise ValueError('Offset {offset} did not decrement date' + .format(offset=offset)) + current = next_date + + +def _count_not_none(*args): + """Compute the number of non-None arguments.""" + return sum([arg is not None for arg in args]) + + +def cftime_range(start=None, end=None, periods=None, freq='D', + tz=None, normalize=False, name=None, closed=None, + calendar='standard'): + """Return a fixed frequency CFTimeIndex. + + Parameters + ---------- + start : str or cftime.datetime, optional + Left bound for generating dates. + end : str or cftime.datetime, optional + Right bound for generating dates. + periods : integer, optional + Number of periods to generate. + freq : str, default 'D', BaseCFTimeOffset, or None + Frequency strings can have multiples, e.g. '5H'. + normalize : bool, default False + Normalize start/end dates to midnight before generating date range. + name : str, default None + Name of the resulting index + closed : {None, 'left', 'right'}, optional + Make the interval closed with respect to the given frequency to the + 'left', 'right', or both sides (None, the default). + calendar : str + Calendar type for the datetimes (default 'standard'). + + Returns + ------- + CFTimeIndex + + Notes + ----- + + This function is an analog of ``pandas.date_range`` for use in generating + sequences of ``cftime.datetime`` objects. It supports most of the + features of ``pandas.date_range`` (e.g. specifying how the index is + ``closed`` on either side, or whether or not to ``normalize`` the start and + end bounds); however, there are some notable exceptions: + + - You cannot specify a ``tz`` (time zone) argument. + - Start or end dates specified as partial-datetime strings must use the + `ISO-8601 format `_. + - It supports many, but not all, frequencies supported by + ``pandas.date_range``. For example it does not currently support any of + the business-related, semi-monthly, or sub-second frequencies. + - Compound sub-monthly frequencies are not supported, e.g. '1H1min', as + these can easily be written in terms of the finest common resolution, + e.g. '61min'. + + Valid simple frequency strings for use with ``cftime``-calendars include + any multiples of the following. + + +--------+-----------------------+ + | Alias | Description | + +========+=======================+ + | A, Y | Year-end frequency | + +--------+-----------------------+ + | AS, YS | Year-start frequency | + +--------+-----------------------+ + | M | Month-end frequency | + +--------+-----------------------+ + | MS | Month-start frequency | + +--------+-----------------------+ + | D | Day frequency | + +--------+-----------------------+ + | H | Hour frequency | + +--------+-----------------------+ + | T, min | Minute frequency | + +--------+-----------------------+ + | S | Second frequency | + +--------+-----------------------+ + + Any multiples of the following anchored offsets are also supported. + + +----------+-------------------------------------------------------------------+ + | Alias | Description | + +==========+===================================================================+ + | A(S)-JAN | Annual frequency, anchored at the end (or beginning) of January | + +----------+-------------------------------------------------------------------+ + | A(S)-FEB | Annual frequency, anchored at the end (or beginning) of February | + +----------+-------------------------------------------------------------------+ + | A(S)-MAR | Annual frequency, anchored at the end (or beginning) of March | + +----------+-------------------------------------------------------------------+ + | A(S)-APR | Annual frequency, anchored at the end (or beginning) of April | + +----------+-------------------------------------------------------------------+ + | A(S)-MAY | Annual frequency, anchored at the end (or beginning) of May | + +----------+-------------------------------------------------------------------+ + | A(S)-JUN | Annual frequency, anchored at the end (or beginning) of June | + +----------+-------------------------------------------------------------------+ + | A(S)-JUL | Annual frequency, anchored at the end (or beginning) of July | + +----------+-------------------------------------------------------------------+ + | A(S)-AUG | Annual frequency, anchored at the end (or beginning) of August | + +----------+-------------------------------------------------------------------+ + | A(S)-SEP | Annual frequency, anchored at the end (or beginning) of September | + +----------+-------------------------------------------------------------------+ + | A(S)-OCT | Annual frequency, anchored at the end (or beginning) of October | + +----------+-------------------------------------------------------------------+ + | A(S)-NOV | Annual frequency, anchored at the end (or beginning) of November | + +----------+-------------------------------------------------------------------+ + | A(S)-DEC | Annual frequency, anchored at the end (or beginning) of December | + +----------+-------------------------------------------------------------------+ + + Finally, the following calendar aliases are supported. + + +--------------------------------+---------------------------------------+ + | Alias | Date type | + +================================+=======================================+ + | standard, proleptic_gregorian | ``cftime.DatetimeProlepticGregorian`` | + +--------------------------------+---------------------------------------+ + | gregorian | ``cftime.DatetimeGregorian`` | + +--------------------------------+---------------------------------------+ + | noleap, 365_day | ``cftime.DatetimeNoLeap`` | + +--------------------------------+---------------------------------------+ + | all_leap, 366_day | ``cftime.DatetimeAllLeap`` | + +--------------------------------+---------------------------------------+ + | 360_day | ``cftime.Datetime360Day`` | + +--------------------------------+---------------------------------------+ + | julian | ``cftime.DatetimeJulian`` | + +--------------------------------+---------------------------------------+ + + Examples + -------- + + This function returns a ``CFTimeIndex``, populated with ``cftime.datetime`` + objects associated with the specified calendar type, e.g. + + >>> xr.cftime_range(start='2000', periods=6, freq='2MS', calendar='noleap') + CFTimeIndex([2000-01-01 00:00:00, 2000-03-01 00:00:00, 2000-05-01 00:00:00, + 2000-07-01 00:00:00, 2000-09-01 00:00:00, 2000-11-01 00:00:00], + dtype='object') + + As in the standard pandas function, three of the ``start``, ``end``, + ``periods``, or ``freq`` arguments must be specified at a given time, with + the other set to ``None``. See the `pandas documentation + `_ + for more examples of the behavior of ``date_range`` with each of the + parameters. + + See Also + -------- + pandas.date_range + """ # noqa: E501 + # Adapted from pandas.core.indexes.datetimes._generate_range. + if _count_not_none(start, end, periods, freq) != 3: + raise ValueError( + "Of the arguments 'start', 'end', 'periods', and 'freq', three " + "must be specified at a time.") + + if start is not None: + start = to_cftime_datetime(start, calendar) + start = _maybe_normalize_date(start, normalize) + if end is not None: + end = to_cftime_datetime(end, calendar) + end = _maybe_normalize_date(end, normalize) + + if freq is None: + dates = _generate_linear_range(start, end, periods) + else: + offset = to_offset(freq) + dates = np.array(list(_generate_range(start, end, periods, offset))) + + left_closed = False + right_closed = False + + if closed is None: + left_closed = True + right_closed = True + elif closed == 'left': + left_closed = True + elif closed == 'right': + right_closed = True + else: + raise ValueError("Closed must be either 'left', 'right' or None") + + if (not left_closed and len(dates) and + start is not None and dates[0] == start): + dates = dates[1:] + if (not right_closed and len(dates) and + end is not None and dates[-1] == end): + dates = dates[:-1] + + return CFTimeIndex(dates, name=name) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py new file mode 100644 index 00000000000..2ce996b2bd2 --- /dev/null +++ b/xarray/coding/cftimeindex.py @@ -0,0 +1,461 @@ +"""DatetimeIndex analog for cftime.datetime objects""" +# The pandas.Index subclass defined here was copied and adapted for +# use with cftime.datetime objects based on the source code defining +# pandas.DatetimeIndex. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import absolute_import + +import re +import warnings +from datetime import timedelta + +import numpy as np +import pandas as pd + +from xarray.core import pycompat +from xarray.core.utils import is_scalar + +from .times import cftime_to_nptime, infer_calendar_name, _STANDARD_CALENDARS + + +def named(name, pattern): + return '(?P<' + name + '>' + pattern + ')' + + +def optional(x): + return '(?:' + x + ')?' + + +def trailing_optional(xs): + if not xs: + return '' + return xs[0] + optional(trailing_optional(xs[1:])) + + +def build_pattern(date_sep='\-', datetime_sep='T', time_sep='\:'): + pieces = [(None, 'year', '\d{4}'), + (date_sep, 'month', '\d{2}'), + (date_sep, 'day', '\d{2}'), + (datetime_sep, 'hour', '\d{2}'), + (time_sep, 'minute', '\d{2}'), + (time_sep, 'second', '\d{2}')] + pattern_list = [] + for sep, name, sub_pattern in pieces: + pattern_list.append((sep if sep else '') + named(name, sub_pattern)) + # TODO: allow timezone offsets? + return '^' + trailing_optional(pattern_list) + '$' + + +_BASIC_PATTERN = build_pattern(date_sep='', time_sep='') +_EXTENDED_PATTERN = build_pattern() +_PATTERNS = [_BASIC_PATTERN, _EXTENDED_PATTERN] + + +def parse_iso8601(datetime_string): + for pattern in _PATTERNS: + match = re.match(pattern, datetime_string) + if match: + return match.groupdict() + raise ValueError('no ISO-8601 match for string: %s' % datetime_string) + + +def _parse_iso8601_with_reso(date_type, timestr): + default = date_type(1, 1, 1) + result = parse_iso8601(timestr) + replace = {} + + for attr in ['year', 'month', 'day', 'hour', 'minute', 'second']: + value = result.get(attr, None) + if value is not None: + # Note ISO8601 conventions allow for fractional seconds. + # TODO: Consider adding support for sub-second resolution? + replace[attr] = int(value) + resolution = attr + + return default.replace(**replace), resolution + + +def _parsed_string_to_bounds(date_type, resolution, parsed): + """Generalization of + pandas.tseries.index.DatetimeIndex._parsed_string_to_bounds + for use with non-standard calendars and cftime.datetime + objects. + """ + if resolution == 'year': + return (date_type(parsed.year, 1, 1), + date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1)) + elif resolution == 'month': + if parsed.month == 12: + end = date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1) + else: + end = (date_type(parsed.year, parsed.month + 1, 1) - + timedelta(microseconds=1)) + return date_type(parsed.year, parsed.month, 1), end + elif resolution == 'day': + start = date_type(parsed.year, parsed.month, parsed.day) + return start, start + timedelta(days=1, microseconds=-1) + elif resolution == 'hour': + start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour) + return start, start + timedelta(hours=1, microseconds=-1) + elif resolution == 'minute': + start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour, + parsed.minute) + return start, start + timedelta(minutes=1, microseconds=-1) + elif resolution == 'second': + start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour, + parsed.minute, parsed.second) + return start, start + timedelta(seconds=1, microseconds=-1) + else: + raise KeyError + + +def get_date_field(datetimes, field): + """Adapted from pandas.tslib.get_date_field""" + return np.array([getattr(date, field) for date in datetimes]) + + +def _field_accessor(name, docstring=None): + """Adapted from pandas.tseries.index._field_accessor""" + def f(self): + return get_date_field(self._data, name) + + f.__name__ = name + f.__doc__ = docstring + return property(f) + + +def get_date_type(self): + if self._data.size: + return type(self._data[0]) + else: + return None + + +def assert_all_valid_date_type(data): + import cftime + + if data.size: + sample = data[0] + date_type = type(sample) + if not isinstance(sample, cftime.datetime): + raise TypeError( + 'CFTimeIndex requires cftime.datetime ' + 'objects. Got object of {}.'.format(date_type)) + if not all(isinstance(value, date_type) for value in data): + raise TypeError( + 'CFTimeIndex requires using datetime ' + 'objects of all the same type. Got\n{}.'.format(data)) + + +class CFTimeIndex(pd.Index): + """Custom Index for working with CF calendars and dates + + All elements of a CFTimeIndex must be cftime.datetime objects. + + Parameters + ---------- + data : array or CFTimeIndex + Sequence of cftime.datetime objects to use in index + name : str, default None + Name of the resulting index + + See Also + -------- + cftime_range + """ + year = _field_accessor('year', 'The year of the datetime') + month = _field_accessor('month', 'The month of the datetime') + day = _field_accessor('day', 'The days of the datetime') + hour = _field_accessor('hour', 'The hours of the datetime') + minute = _field_accessor('minute', 'The minutes of the datetime') + second = _field_accessor('second', 'The seconds of the datetime') + microsecond = _field_accessor('microsecond', + 'The microseconds of the datetime') + date_type = property(get_date_type) + + def __new__(cls, data, name=None): + if name is None and hasattr(data, 'name'): + name = data.name + + result = object.__new__(cls) + result._data = np.array(data, dtype='O') + assert_all_valid_date_type(result._data) + result.name = name + return result + + def _partial_date_slice(self, resolution, parsed): + """Adapted from + pandas.tseries.index.DatetimeIndex._partial_date_slice + + Note that when using a CFTimeIndex, if a partial-date selection + returns a single element, it will never be converted to a scalar + coordinate; this is in slight contrast to the behavior when using + a DatetimeIndex, which sometimes will return a DataArray with a scalar + coordinate depending on the resolution of the datetimes used in + defining the index. For example: + + >>> from cftime import DatetimeNoLeap + >>> import pandas as pd + >>> import xarray as xr + >>> da = xr.DataArray([1, 2], + coords=[[DatetimeNoLeap(2001, 1, 1), + DatetimeNoLeap(2001, 2, 1)]], + dims=['time']) + >>> da.sel(time='2001-01-01') + + array([1]) + Coordinates: + * time (time) object 2001-01-01 00:00:00 + >>> da = xr.DataArray([1, 2], + coords=[[pd.Timestamp(2001, 1, 1), + pd.Timestamp(2001, 2, 1)]], + dims=['time']) + >>> da.sel(time='2001-01-01') + + array(1) + Coordinates: + time datetime64[ns] 2001-01-01 + >>> da = xr.DataArray([1, 2], + coords=[[pd.Timestamp(2001, 1, 1, 1), + pd.Timestamp(2001, 2, 1)]], + dims=['time']) + >>> da.sel(time='2001-01-01') + + array([1]) + Coordinates: + * time (time) datetime64[ns] 2001-01-01T01:00:00 + """ + start, end = _parsed_string_to_bounds(self.date_type, resolution, + parsed) + lhs_mask = (self._data >= start) + rhs_mask = (self._data <= end) + return (lhs_mask & rhs_mask).nonzero()[0] + + def _get_string_slice(self, key): + """Adapted from pandas.tseries.index.DatetimeIndex._get_string_slice""" + parsed, resolution = _parse_iso8601_with_reso(self.date_type, key) + loc = self._partial_date_slice(resolution, parsed) + return loc + + def get_loc(self, key, method=None, tolerance=None): + """Adapted from pandas.tseries.index.DatetimeIndex.get_loc""" + if isinstance(key, pycompat.basestring): + return self._get_string_slice(key) + else: + return pd.Index.get_loc(self, key, method=method, + tolerance=tolerance) + + def _maybe_cast_slice_bound(self, label, side, kind): + """Adapted from + pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound""" + if isinstance(label, pycompat.basestring): + parsed, resolution = _parse_iso8601_with_reso(self.date_type, + label) + start, end = _parsed_string_to_bounds(self.date_type, resolution, + parsed) + if self.is_monotonic_decreasing and len(self) > 1: + return end if side == 'left' else start + return start if side == 'left' else end + else: + return label + + # TODO: Add ability to use integer range outside of iloc? + # e.g. series[1:5]. + def get_value(self, series, key): + """Adapted from pandas.tseries.index.DatetimeIndex.get_value""" + if not isinstance(key, slice): + return series.iloc[self.get_loc(key)] + else: + return series.iloc[self.slice_indexer( + key.start, key.stop, key.step)] + + def __contains__(self, key): + """Adapted from + pandas.tseries.base.DatetimeIndexOpsMixin.__contains__""" + try: + result = self.get_loc(key) + return (is_scalar(result) or type(result) == slice or + (isinstance(result, np.ndarray) and result.size)) + except (KeyError, TypeError, ValueError): + return False + + def contains(self, key): + """Needed for .loc based partial-string indexing""" + return self.__contains__(key) + + def shift(self, n, freq): + """Shift the CFTimeIndex a multiple of the given frequency. + + See the documentation for :py:func:`~xarray.cftime_range` for a + complete listing of valid frequency strings. + + Parameters + ---------- + n : int + Periods to shift by + freq : str or datetime.timedelta + A frequency string or datetime.timedelta object to shift by + + Returns + ------- + CFTimeIndex + + See also + -------- + pandas.DatetimeIndex.shift + + Examples + -------- + >>> index = xr.cftime_range('2000', periods=1, freq='M') + >>> index + CFTimeIndex([2000-01-31 00:00:00], dtype='object') + >>> index.shift(1, 'M') + CFTimeIndex([2000-02-29 00:00:00], dtype='object') + """ + from .cftime_offsets import to_offset + + if not isinstance(n, int): + raise TypeError("'n' must be an int, got {}.".format(n)) + if isinstance(freq, timedelta): + return self + n * freq + elif isinstance(freq, pycompat.basestring): + return self + n * to_offset(freq) + else: + raise TypeError( + "'freq' must be of type " + "str or datetime.timedelta, got {}.".format(freq)) + + def __add__(self, other): + if isinstance(other, pd.TimedeltaIndex): + other = other.to_pytimedelta() + return CFTimeIndex(np.array(self) + other) + + def __radd__(self, other): + if isinstance(other, pd.TimedeltaIndex): + other = other.to_pytimedelta() + return CFTimeIndex(other + np.array(self)) + + def __sub__(self, other): + if isinstance(other, CFTimeIndex): + return pd.TimedeltaIndex(np.array(self) - np.array(other)) + elif isinstance(other, pd.TimedeltaIndex): + return CFTimeIndex(np.array(self) - other.to_pytimedelta()) + else: + return CFTimeIndex(np.array(self) - other) + + def _add_delta(self, deltas): + # To support TimedeltaIndex + CFTimeIndex with older versions of + # pandas. No longer used as of pandas 0.23. + return self + deltas + + def to_datetimeindex(self, unsafe=False): + """If possible, convert this index to a pandas.DatetimeIndex. + + Parameters + ---------- + unsafe : bool + Flag to turn off warning when converting from a CFTimeIndex with + a non-standard calendar to a DatetimeIndex (default ``False``). + + Returns + ------- + pandas.DatetimeIndex + + Raises + ------ + ValueError + If the CFTimeIndex contains dates that are not possible in the + standard calendar or outside the pandas.Timestamp-valid range. + + Warns + ----- + RuntimeWarning + If converting from a non-standard calendar to a DatetimeIndex. + + Warnings + -------- + Note that for non-standard calendars, this will change the calendar + type of the index. In that case the result of this method should be + used with caution. + + Examples + -------- + >>> import xarray as xr + >>> times = xr.cftime_range('2000', periods=2, calendar='gregorian') + >>> times + CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00], dtype='object') + >>> times.to_datetimeindex() + DatetimeIndex(['2000-01-01', '2000-01-02'], dtype='datetime64[ns]', freq=None) + """ # noqa: E501 + nptimes = cftime_to_nptime(self) + calendar = infer_calendar_name(self) + if calendar not in _STANDARD_CALENDARS and not unsafe: + warnings.warn( + 'Converting a CFTimeIndex with dates from a non-standard ' + 'calendar, {!r}, to a pandas.DatetimeIndex, which uses dates ' + 'from the standard calendar. This may lead to subtle errors ' + 'in operations that depend on the length of time between ' + 'dates.'.format(calendar), RuntimeWarning) + return pd.DatetimeIndex(nptimes) + + +def _parse_iso8601_without_reso(date_type, datetime_str): + date, _ = _parse_iso8601_with_reso(date_type, datetime_str) + return date + + +def _parse_array_of_cftime_strings(strings, date_type): + """Create a numpy array from an array of strings. + + For use in generating dates from strings for use with interp. Assumes the + array is either 0-dimensional or 1-dimensional. + + Parameters + ---------- + strings : array of strings + Strings to convert to dates + date_type : cftime.datetime type + Calendar type to use for dates + + Returns + ------- + np.array + """ + return np.array([_parse_iso8601_without_reso(date_type, s) + for s in strings.ravel()]).reshape(strings.shape) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py new file mode 100644 index 00000000000..3502fd773d7 --- /dev/null +++ b/xarray/coding/strings.py @@ -0,0 +1,222 @@ +"""Coders for strings.""" +from __future__ import absolute_import, division, print_function + +from functools import partial + +import numpy as np + +from ..core import indexing +from ..core.pycompat import bytes_type, dask_array_type, unicode_type +from ..core.variable import Variable +from .variables import ( + VariableCoder, lazy_elemwise_func, pop_to, safe_setitem, + unpack_for_decoding, unpack_for_encoding) + + +def create_vlen_dtype(element_type): + # based on h5py.special_dtype + return np.dtype('O', metadata={'element_type': element_type}) + + +def check_vlen_dtype(dtype): + if dtype.kind != 'O' or dtype.metadata is None: + return None + else: + return dtype.metadata.get('element_type') + + +def is_unicode_dtype(dtype): + return dtype.kind == 'U' or check_vlen_dtype(dtype) == unicode_type + + +def is_bytes_dtype(dtype): + return dtype.kind == 'S' or check_vlen_dtype(dtype) == bytes_type + + +class EncodedStringCoder(VariableCoder): + """Transforms between unicode strings and fixed-width UTF-8 bytes.""" + + def __init__(self, allows_unicode=True): + self.allows_unicode = allows_unicode + + def encode(self, variable, name=None): + dims, data, attrs, encoding = unpack_for_encoding(variable) + + contains_unicode = is_unicode_dtype(data.dtype) + encode_as_char = encoding.get('dtype') == 'S1' + + if encode_as_char: + del encoding['dtype'] # no longer relevant + + if contains_unicode and (encode_as_char or not self.allows_unicode): + if '_FillValue' in attrs: + raise NotImplementedError( + 'variable {!r} has a _FillValue specified, but ' + '_FillValue is not yet supported on unicode strings: ' + 'https://github.com/pydata/xarray/issues/1647' + .format(name)) + + string_encoding = encoding.pop('_Encoding', 'utf-8') + safe_setitem(attrs, '_Encoding', string_encoding, name=name) + # TODO: figure out how to handle this in a lazy way with dask + data = encode_string_array(data, string_encoding) + + return Variable(dims, data, attrs, encoding) + + def decode(self, variable, name=None): + dims, data, attrs, encoding = unpack_for_decoding(variable) + + if '_Encoding' in attrs: + string_encoding = pop_to(attrs, encoding, '_Encoding') + func = partial(decode_bytes_array, encoding=string_encoding) + data = lazy_elemwise_func(data, func, np.dtype(object)) + + return Variable(dims, data, attrs, encoding) + + +def decode_bytes_array(bytes_array, encoding='utf-8'): + # This is faster than using np.char.decode() or np.vectorize() + bytes_array = np.asarray(bytes_array) + decoded = [x.decode(encoding) for x in bytes_array.ravel()] + return np.array(decoded, dtype=object).reshape(bytes_array.shape) + + +def encode_string_array(string_array, encoding='utf-8'): + string_array = np.asarray(string_array) + encoded = [x.encode(encoding) for x in string_array.ravel()] + return np.array(encoded, dtype=bytes).reshape(string_array.shape) + + +def ensure_fixed_length_bytes(var): + """Ensure that a variable with vlen bytes is converted to fixed width.""" + dims, data, attrs, encoding = unpack_for_encoding(var) + if check_vlen_dtype(data.dtype) == bytes_type: + # TODO: figure out how to handle this with dask + data = np.asarray(data, dtype=np.string_) + return Variable(dims, data, attrs, encoding) + + +class CharacterArrayCoder(VariableCoder): + """Transforms between arrays containing bytes and character arrays.""" + + def encode(self, variable, name=None): + variable = ensure_fixed_length_bytes(variable) + + dims, data, attrs, encoding = unpack_for_encoding(variable) + if data.dtype.kind == 'S' and encoding.get('dtype') is not str: + data = bytes_to_char(data) + dims = dims + ('string%s' % data.shape[-1],) + return Variable(dims, data, attrs, encoding) + + def decode(self, variable, name=None): + dims, data, attrs, encoding = unpack_for_decoding(variable) + + if data.dtype == 'S1' and dims: + dims = dims[:-1] + data = char_to_bytes(data) + + return Variable(dims, data, attrs, encoding) + + +def bytes_to_char(arr): + """Convert numpy/dask arrays from fixed width bytes to characters.""" + if arr.dtype.kind != 'S': + raise ValueError('argument must have a fixed-width bytes dtype') + + if isinstance(arr, dask_array_type): + import dask.array as da + return da.map_blocks(_numpy_bytes_to_char, arr, + dtype='S1', + chunks=arr.chunks + ((arr.dtype.itemsize,)), + new_axis=[arr.ndim]) + else: + return _numpy_bytes_to_char(arr) + + +def _numpy_bytes_to_char(arr): + """Like netCDF4.stringtochar, but faster and more flexible. + """ + # ensure the array is contiguous + arr = np.array(arr, copy=False, order='C', dtype=np.string_) + return arr.reshape(arr.shape + (1,)).view('S1') + + +def char_to_bytes(arr): + """Convert numpy/dask arrays from characters to fixed width bytes.""" + if arr.dtype != 'S1': + raise ValueError("argument must have dtype='S1'") + + if not arr.ndim: + # no dimension to concatenate along + return arr + + size = arr.shape[-1] + + if not size: + # can't make an S0 dtype + return np.zeros(arr.shape[:-1], dtype=np.string_) + + if isinstance(arr, dask_array_type): + import dask.array as da + + if len(arr.chunks[-1]) > 1: + raise ValueError('cannot stacked dask character array with ' + 'multiple chunks in the last dimension: {}' + .format(arr)) + + dtype = np.dtype('S' + str(arr.shape[-1])) + return da.map_blocks(_numpy_char_to_bytes, arr, + dtype=dtype, + chunks=arr.chunks[:-1], + drop_axis=[arr.ndim - 1]) + else: + return StackedBytesArray(arr) + + +def _numpy_char_to_bytes(arr): + """Like netCDF4.chartostring, but faster and more flexible. + """ + # based on: http://stackoverflow.com/a/10984878/809705 + arr = np.array(arr, copy=False, order='C') + dtype = 'S' + str(arr.shape[-1]) + return arr.view(dtype).reshape(arr.shape[:-1]) + + +class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin): + """Wrapper around array-like objects to create a new indexable object where + values, when accessed, are automatically stacked along the last dimension. + + >>> StackedBytesArray(np.array(['a', 'b', 'c']))[:] + array('abc', + dtype='|S3') + """ + + def __init__(self, array): + """ + Parameters + ---------- + array : array-like + Original array of values to wrap. + """ + if array.dtype != 'S1': + raise ValueError( + "can only use StackedBytesArray if argument has dtype='S1'") + self.array = indexing.as_indexable(array) + + @property + def dtype(self): + return np.dtype('S' + str(self.array.shape[-1])) + + @property + def shape(self): + return self.array.shape[:-1] + + def __repr__(self): + return ('%s(%r)' % (type(self).__name__, self.array)) + + def __getitem__(self, key): + # require slicing the last dimension completely + key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim)) + if key.tuple[-1] != slice(None): + raise IndexError('too many indices') + return _numpy_char_to_bytes(self.array[key]) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index e00769af884..dfc4b2fb023 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -1,6 +1,4 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import re import traceback @@ -9,24 +7,25 @@ from functools import partial import numpy as np - import pandas as pd -try: - from pandas.errors import OutOfBoundsDatetime -except ImportError: - # pandas < 0.20 - from pandas.tslib import OutOfBoundsDatetime -from .variables import (SerializationWarning, VariableCoder, - lazy_elemwise_func, pop_to, safe_setitem, - unpack_for_decoding, unpack_for_encoding) from ..core import indexing +from ..core.common import contains_cftime_datetimes from ..core.formatting import first_n_items, format_timestamp, last_item from ..core.pycompat import PY3 from ..core.variable import Variable +from .variables import ( + SerializationWarning, VariableCoder, lazy_elemwise_func, pop_to, + safe_setitem, unpack_for_decoding, unpack_for_encoding) +try: + from pandas.errors import OutOfBoundsDatetime +except ImportError: + # pandas < 0.20 + from pandas.tslib import OutOfBoundsDatetime -# standard calendars recognized by netcdftime + +# standard calendars recognized by cftime _STANDARD_CALENDARS = set(['standard', 'gregorian', 'proleptic_gregorian']) _NS_PER_TIME_DELTA = {'us': int(1e3), @@ -40,6 +39,32 @@ 'milliseconds', 'microseconds']) +def _import_cftime(): + ''' + helper function handle the transition to netcdftime/cftime + as a stand-alone package + ''' + try: + import cftime + except ImportError: + # in netCDF4 the num2date/date2num function are top-level api + try: + import netCDF4 as cftime + except ImportError: + raise ImportError("Failed to import cftime") + return cftime + + +def _require_standalone_cftime(): + """Raises an ImportError if the standalone cftime is not found""" + try: + import cftime # noqa: F401 + except ImportError: + raise ImportError('Decoding times with non-standard calendars ' + 'or outside the pandas.Timestamp-valid range ' + 'requires the standalone cftime package.') + + def _netcdf_to_numpy_timeunit(units): units = units.lower() if not units.endswith('s'): @@ -59,24 +84,28 @@ def _unpack_netcdf_time_units(units): return delta_units, ref_date -def _decode_datetime_with_netcdf4(num_dates, units, calendar): - import netCDF4 as nc4 +def _decode_datetime_with_cftime(num_dates, units, calendar): + cftime = _import_cftime() + + if cftime.__name__ == 'cftime': + dates = np.asarray(cftime.num2date(num_dates, units, calendar, + only_use_cftime_datetimes=True)) + else: + # Must be using num2date from an old version of netCDF4 which + # does not have the only_use_cftime_datetimes option. + dates = np.asarray(cftime.num2date(num_dates, units, calendar)) - dates = np.asarray(nc4.num2date(num_dates, units, calendar)) if (dates[np.nanargmin(num_dates)].year < 1678 or dates[np.nanargmax(num_dates)].year >= 2262): - warnings.warn('Unable to decode time axis into full ' - 'numpy.datetime64 objects, continuing using dummy ' - 'netCDF4.datetime objects instead, reason: dates out' - ' of range', SerializationWarning, stacklevel=3) + if calendar in _STANDARD_CALENDARS: + warnings.warn( + 'Unable to decode time axis into full ' + 'numpy.datetime64 objects, continuing using dummy ' + 'cftime.datetime objects instead, reason: dates out ' + 'of range', SerializationWarning, stacklevel=3) else: - try: - dates = nctime_to_nptime(dates) - except ValueError as e: - warnings.warn('Unable to decode time axis into full ' - 'numpy.datetime64 objects, continuing using ' - 'dummy netCDF4.datetime objects instead, reason:' - '{0}'.format(e), SerializationWarning, stacklevel=3) + if calendar in _STANDARD_CALENDARS: + dates = cftime_to_nptime(dates) return dates @@ -111,7 +140,7 @@ def decode_cf_datetime(num_dates, units, calendar=None): numpy array of date time objects. For standard (Gregorian) calendars, this function uses vectorized - operations, which makes it much faster than netCDF4.num2date. In such a + operations, which makes it much faster than cftime.num2date. In such a case, the returned array will be of type np.datetime64. Note that time unit in `units` must not be smaller than microseconds and @@ -119,7 +148,7 @@ def decode_cf_datetime(num_dates, units, calendar=None): See also -------- - netCDF4.num2date + cftime.num2date """ num_dates = np.asarray(num_dates) flat_num_dates = num_dates.ravel() @@ -137,27 +166,30 @@ def decode_cf_datetime(num_dates, units, calendar=None): ref_date = pd.Timestamp(ref_date) except ValueError: # ValueError is raised by pd.Timestamp for non-ISO timestamp - # strings, in which case we fall back to using netCDF4 + # strings, in which case we fall back to using cftime raise OutOfBoundsDatetime # fixes: https://github.com/pydata/pandas/issues/14068 # these lines check if the the lowest or the highest value in dates # cause an OutOfBoundsDatetime (Overflow) error - pd.to_timedelta(flat_num_dates.min(), delta) + ref_date - pd.to_timedelta(flat_num_dates.max(), delta) + ref_date + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'invalid value encountered', + RuntimeWarning) + pd.to_timedelta(flat_num_dates.min(), delta) + ref_date + pd.to_timedelta(flat_num_dates.max(), delta) + ref_date # Cast input dates to integers of nanoseconds because `pd.to_datetime` # works much faster when dealing with integers - flat_num_dates_ns_int = (flat_num_dates * + # make _NS_PER_TIME_DELTA an array to ensure type upcasting + flat_num_dates_ns_int = (flat_num_dates.astype(np.float64) * _NS_PER_TIME_DELTA[delta]).astype(np.int64) dates = (pd.to_timedelta(flat_num_dates_ns_int, 'ns') + ref_date).values except (OutOfBoundsDatetime, OverflowError): - dates = _decode_datetime_with_netcdf4(flat_num_dates.astype(np.float), - units, - calendar) + dates = _decode_datetime_with_cftime( + flat_num_dates.astype(np.float), units, calendar) return dates.reshape(num_dates.shape) @@ -189,18 +221,45 @@ def _infer_time_units_from_diff(unique_timedeltas): return 'seconds' +def infer_calendar_name(dates): + """Given an array of datetimes, infer the CF calendar name""" + if np.asarray(dates).dtype == 'datetime64[ns]': + return 'proleptic_gregorian' + else: + return np.asarray(dates).ravel()[0].calendar + + def infer_datetime_units(dates): """Given an array of datetimes, returns a CF compatible time-unit string of the form "{time_unit} since {date[0]}", where `time_unit` is 'days', 'hours', 'minutes' or 'seconds' (the first one that can evenly divide all unique time deltas in `dates`) """ - dates = pd.to_datetime(np.asarray(dates).ravel(), box=False) - dates = dates[pd.notnull(dates)] + dates = np.asarray(dates).ravel() + if np.asarray(dates).dtype == 'datetime64[ns]': + dates = pd.to_datetime(dates, box=False) + dates = dates[pd.notnull(dates)] + reference_date = dates[0] if len(dates) > 0 else '1970-01-01' + reference_date = pd.Timestamp(reference_date) + else: + reference_date = dates[0] if len(dates) > 0 else '1970-01-01' + reference_date = format_cftime_datetime(reference_date) unique_timedeltas = np.unique(np.diff(dates)) + if unique_timedeltas.dtype == np.dtype('O'): + # Convert to np.timedelta64 objects using pandas to work around a + # NumPy casting bug: https://github.com/numpy/numpy/issues/11096 + unique_timedeltas = pd.to_timedelta(unique_timedeltas, box=False) units = _infer_time_units_from_diff(unique_timedeltas) - reference_date = dates[0] if len(dates) > 0 else '1970-01-01' - return '%s since %s' % (units, pd.Timestamp(reference_date)) + return '%s since %s' % (units, reference_date) + + +def format_cftime_datetime(date): + """Converts a cftime.datetime object to a string with the format: + YYYY-MM-DD HH:MM:SS.UUUUUU + """ + return '{:04d}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}.{:06d}'.format( + date.year, date.month, date.day, date.hour, date.minute, date.second, + date.microsecond) def infer_timedelta_units(deltas): @@ -214,13 +273,22 @@ def infer_timedelta_units(deltas): return units -def nctime_to_nptime(times): - """Given an array of netCDF4.datetime objects, return an array of +def cftime_to_nptime(times): + """Given an array of cftime.datetime objects, return an array of numpy.datetime64 objects of the same size""" times = np.asarray(times) new = np.empty(times.shape, dtype='M8[ns]') for i, t in np.ndenumerate(times): - dt = datetime(t.year, t.month, t.day, t.hour, t.minute, t.second) + try: + # Use pandas.Timestamp in place of datetime.datetime, because + # NumPy casts it safely it np.datetime64[ns] for dates outside + # 1678 to 2262 (this is not currently the case for + # datetime.datetime). + dt = pd.Timestamp(t.year, t.month, t.day, t.hour, t.minute, + t.second, t.microsecond) + except ValueError as e: + raise ValueError('Cannot convert date {} to a date in the ' + 'standard calendar. Reason: {}.'.format(t, e)) new[i] = np.datetime64(dt) return new @@ -235,20 +303,20 @@ def _cleanup_netcdf_time_units(units): return units -def _encode_datetime_with_netcdf4(dates, units, calendar): - """Fallback method for encoding dates using netCDF4-python. +def _encode_datetime_with_cftime(dates, units, calendar): + """Fallback method for encoding dates using cftime. This method is more flexible than xarray's parsing using datetime64[ns] arrays but also slower because it loops over each element. """ - import netCDF4 as nc4 + cftime = _import_cftime() if np.issubdtype(dates.dtype, np.datetime64): # numpy's broken datetime conversion only works for us precision dates = dates.astype('M8[us]').astype(datetime) def encode_datetime(d): - return np.nan if d is None else nc4.date2num(d, units, calendar) + return np.nan if d is None else cftime.date2num(d, units, calendar) return np.vectorize(encode_datetime)(dates) @@ -268,7 +336,7 @@ def encode_cf_datetime(dates, units=None, calendar=None): See also -------- - netCDF4.date2num + cftime.date2num """ dates = np.asarray(dates) @@ -278,22 +346,27 @@ def encode_cf_datetime(dates, units=None, calendar=None): units = _cleanup_netcdf_time_units(units) if calendar is None: - calendar = 'proleptic_gregorian' + calendar = infer_calendar_name(dates) delta, ref_date = _unpack_netcdf_time_units(units) try: if calendar not in _STANDARD_CALENDARS or dates.dtype.kind == 'O': - # parse with netCDF4 instead + # parse with cftime instead raise OutOfBoundsDatetime assert dates.dtype == 'datetime64[ns]' delta_units = _netcdf_to_numpy_timeunit(delta) time_delta = np.timedelta64(1, delta_units).astype('timedelta64[ns]') ref_date = np.datetime64(pd.Timestamp(ref_date)) - num = (dates - ref_date) / time_delta + + # Wrap the dates in a DatetimeIndex to do the subtraction to ensure + # an OverflowError is raised if the ref_date is too far away from + # dates to be encoded (GH 2272). + num = (pd.DatetimeIndex(dates.ravel()) - ref_date) / time_delta + num = num.values.reshape(dates.shape) except (OutOfBoundsDatetime, OverflowError): - num = _encode_datetime_with_netcdf4(dates, units, calendar) + num = _encode_datetime_with_cftime(dates, units, calendar) num = cast_to_int_if_safe(num) return (num, units, calendar) @@ -314,8 +387,8 @@ class CFDatetimeCoder(VariableCoder): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) - - if np.issubdtype(data.dtype, np.datetime64): + if (np.issubdtype(data.dtype, np.datetime64) or + contains_cftime_datetimes(variable)): (data, units, calendar) = encode_cf_datetime( data, encoding.pop('units', None), diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 5d32970e2ed..b86b77a3707 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -1,18 +1,13 @@ """Coders for individual Variable objects.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function -from functools import partial import warnings +from functools import partial import numpy as np import pandas as pd -from ..core import dtypes -from ..core import duck_array_ops -from ..core import indexing -from ..core import utils +from ..core import dtypes, duck_array_ops, indexing from ..core.pycompat import dask_array_type from ..core.variable import Variable @@ -68,7 +63,10 @@ def dtype(self): return np.dtype(self._dtype) def __getitem__(self, key): - return self.func(self.array[key]) + return type(self)(self.array[key], self.func, self.dtype) + + def __array__(self, dtype=None): + return self.func(self.array) def __repr__(self): return ("%s(%r, func=%r, dtype=%r)" % @@ -134,10 +132,10 @@ def _apply_mask(data, # type: np.ndarray dtype, # type: Any ): # type: np.ndarray """Mask all matching values in a NumPy arrays.""" + data = np.asarray(data, dtype=dtype) condition = False for fv in encoded_fill_values: condition |= data == fv - data = np.asarray(data, dtype=dtype) return np.where(condition, decoded_fill_value, data) @@ -157,26 +155,12 @@ def encode(self, variable, name=None): def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - if 'missing_value' in attrs: - # missing_value is deprecated, but we still want to support it as - # an alias for _FillValue. - if ('_FillValue' in attrs and - not utils.equivalent(attrs['_FillValue'], - attrs['missing_value'])): - raise ValueError("Conflicting _FillValue and missing_value " - "attrs on a variable {!r}: {} vs. {}\n\n" - "Consider opening the offending dataset " - "using decode_cf=False, correcting the " - "attrs and decoding explicitly using " - "xarray.decode_cf()." - .format(name, attrs['_FillValue'], - attrs['missing_value'])) - attrs['_FillValue'] = attrs.pop('missing_value') - - if '_FillValue' in attrs: - raw_fill_value = pop_to(attrs, encoding, '_FillValue', name=name) - encoded_fill_values = [ - fv for fv in np.ravel(raw_fill_value) if not pd.isnull(fv)] + raw_fill_values = [pop_to(attrs, encoding, attr, name=name) + for attr in ('missing_value', '_FillValue')] + if raw_fill_values: + encoded_fill_values = {fv for option in raw_fill_values + for fv in np.ravel(option) + if not pd.isnull(fv)} if len(encoded_fill_values) > 1: warnings.warn("variable {!r} has multiple fill values {}, " diff --git a/xarray/conventions.py b/xarray/conventions.py index fe75d9e3e6a..f60ee6b2c15 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -1,109 +1,20 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import warnings from collections import defaultdict import numpy as np - import pandas as pd -from .coding import times -from .coding import variables +from .coding import strings, times, variables from .coding.variables import SerializationWarning from .core import duck_array_ops, indexing -from .core.pycompat import OrderedDict, basestring, iteritems +from .core.pycompat import ( + OrderedDict, basestring, bytes_type, dask_array_type, iteritems, + unicode_type) from .core.variable import IndexVariable, Variable, as_variable -class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin): - """Wrapper around array-like objects to create a new indexable object where - values, when accessed, are automatically stacked along the last dimension. - - >>> StackedBytesArray(np.array(['a', 'b', 'c']))[:] - array('abc', - dtype='|S3') - """ - - def __init__(self, array): - """ - Parameters - ---------- - array : array-like - Original array of values to wrap. - """ - if array.dtype != 'S1': - raise ValueError( - "can only use StackedBytesArray if argument has dtype='S1'") - self.array = indexing.as_indexable(array) - - @property - def dtype(self): - return np.dtype('S' + str(self.array.shape[-1])) - - @property - def shape(self): - return self.array.shape[:-1] - - def __str__(self): - # TODO(shoyer): figure out why we need this special case? - if self.ndim == 0: - return str(np.array(self).item()) - else: - return repr(self) - - def __repr__(self): - return ('%s(%r)' % (type(self).__name__, self.array)) - - def __getitem__(self, key): - # require slicing the last dimension completely - key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim)) - if key.tuple[-1] != slice(None): - raise IndexError('too many indices') - return char_to_bytes(self.array[key]) - - -class BytesToStringArray(indexing.ExplicitlyIndexedNDArrayMixin): - """Wrapper that decodes bytes to unicode when values are read. - - >>> BytesToStringArray(np.array([b'abc']))[:] - array(['abc'], - dtype=object) - """ - - def __init__(self, array, encoding='utf-8'): - """ - Parameters - ---------- - array : array-like - Original array of values to wrap. - encoding : str - String encoding to use. - """ - self.array = indexing.as_indexable(array) - self.encoding = encoding - - @property - def dtype(self): - # variable length string - return np.dtype(object) - - def __str__(self): - # TODO(shoyer): figure out why we need this special case? - if self.ndim == 0: - return str(np.array(self).item()) - else: - return repr(self) - - def __repr__(self): - return ('%s(%r, encoding=%r)' - % (type(self).__name__, self.array, self.encoding)) - - def __getitem__(self, key): - return decode_bytes_array(self.array[key], self.encoding) - - class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin): """Decode arrays on the fly from non-native to native endianness @@ -163,114 +74,13 @@ def __getitem__(self, key): return np.asarray(self.array[key], dtype=self.dtype) -def bytes_to_char(arr): - """Like netCDF4.stringtochar, but faster and more flexible. - """ - # ensure the array is contiguous - arr = np.array(arr, copy=False, order='C') - kind = arr.dtype.kind - if kind not in ['U', 'S']: - raise ValueError('argument must be a string array') - return arr.reshape(arr.shape + (1,)).view(kind + '1') - - -def char_to_bytes(arr): - """Like netCDF4.chartostring, but faster and more flexible. - """ - # based on: http://stackoverflow.com/a/10984878/809705 - arr = np.array(arr, copy=False, order='C') - - kind = arr.dtype.kind - if kind not in ['U', 'S']: - raise ValueError('argument must be a string array') - - if not arr.ndim: - # no dimension to concatenate along - return arr - - size = arr.shape[-1] - if not size: - # can't make an S0 dtype - return np.zeros(arr.shape[:-1], dtype=kind) - - dtype = kind + str(size) - return arr.view(dtype).reshape(arr.shape[:-1]) - - -def decode_bytes_array(bytes_array, encoding='utf-8'): - # This is faster than using np.char.decode() or np.vectorize() - bytes_array = np.asarray(bytes_array) - decoded = [x.decode(encoding) for x in bytes_array.ravel()] - return np.array(decoded, dtype=object).reshape(bytes_array.shape) - - -def encode_string_array(string_array, encoding='utf-8'): - string_array = np.asarray(string_array) - encoded = [x.encode(encoding) for x in string_array.ravel()] - return np.array(encoded, dtype=bytes).reshape(string_array.shape) - - -def safe_setitem(dest, key, value, name=None): - if key in dest: - var_str = ' on variable {!r}'.format(name) if name else '' - raise ValueError( - 'failed to prevent overwriting existing key {} in attrs{}. ' - 'This is probably an encoding field used by xarray to describe ' - 'how a variable is serialized. To proceed, remove this key from ' - "the variable's attributes manually.".format(key, var_str)) - dest[key] = value - - -def pop_to(source, dest, key, name=None): - """ - A convenience function which pops a key k from source to dest. - None values are not passed on. If k already exists in dest an - error is raised. - """ - value = source.pop(key, None) - if value is not None: - safe_setitem(dest, key, value, name=name) - return value - - def _var_as_tuple(var): return var.dims, var.data, var.attrs.copy(), var.encoding.copy() -def maybe_encode_as_char_array(var, name=None): - if var.dtype.kind in {'S', 'U'}: - dims, data, attrs, encoding = _var_as_tuple(var) - if data.dtype.kind == 'U': - if '_FillValue' in attrs: - raise NotImplementedError( - 'variable {!r} has a _FillValue specified, but ' - '_FillValue is yet supported on unicode strings: ' - 'https://github.com/pydata/xarray/issues/1647' - .format(name)) - - string_encoding = encoding.pop('_Encoding', 'utf-8') - safe_setitem(attrs, '_Encoding', string_encoding, name=name) - data = encode_string_array(data, string_encoding) - - if data.dtype.itemsize > 1: - data = bytes_to_char(data) - dims = dims + ('string%s' % data.shape[-1],) - - var = Variable(dims, data, attrs, encoding) - return var - - -def maybe_encode_string_dtype(var, name=None): - # need to apply after ensure_dtype_not_object() - if 'dtype' in var.encoding and var.encoding['dtype'] == 'S1': - assert var.dtype.kind in {'S', 'U'} - var = maybe_encode_as_char_array(var, name=name) - del var.encoding['dtype'] - return var - - def maybe_encode_nonstring_dtype(var, name=None): - if 'dtype' in var.encoding and var.encoding['dtype'] != 'S1': + if ('dtype' in var.encoding and + var.encoding['dtype'] not in ('S1', str)): dims, data, attrs, encoding = _var_as_tuple(var) dtype = np.dtype(encoding.pop('dtype')) if dtype != var.dtype: @@ -280,7 +90,7 @@ def maybe_encode_nonstring_dtype(var, name=None): warnings.warn('saving variable %s with floating ' 'point data as an integer dtype without ' 'any _FillValue to use for NaNs' % name, - SerializationWarning, stacklevel=3) + SerializationWarning, stacklevel=10) data = duck_array_ops.around(data)[...] data = data.astype(dtype=dtype) var = Variable(dims, data, attrs, encoding) @@ -310,19 +120,23 @@ def _infer_dtype(array, name=None): """Given an object array with no missing values, infer its dtype from its first element """ + if array.dtype.kind != 'O': + raise TypeError('infer_type must be called on a dtype=object array') + if array.size == 0: - dtype = np.dtype(float) - else: - dtype = np.array(array[(0,) * array.ndim]).dtype - if dtype.kind in ['S', 'U']: - # don't just use inferred dtype to avoid truncating arrays to - # the length of their first element - dtype = np.dtype(dtype.kind) - elif dtype.kind == 'O': - raise ValueError('unable to infer dtype on variable {!r}; xarray ' - 'cannot serialize arbitrary Python objects' - .format(name)) - return dtype + return np.dtype(float) + + element = array[(0,) * array.ndim] + if isinstance(element, (bytes_type, unicode_type)): + return strings.create_vlen_dtype(type(element)) + + dtype = np.array(element).dtype + if dtype.kind != 'O': + return dtype + + raise ValueError('unable to infer dtype on variable {!r}; xarray ' + 'cannot serialize arbitrary Python objects' + .format(name)) def ensure_not_multiindex(var, name=None): @@ -336,10 +150,32 @@ def ensure_not_multiindex(var, name=None): 'variables instead.'.format(name)) +def _copy_with_dtype(data, dtype): + """Create a copy of an array with the given dtype. + + We use this instead of np.array() to ensure that custom object dtypes end + up on the resulting array. + """ + result = np.empty(data.shape, dtype) + result[...] = data + return result + + def ensure_dtype_not_object(var, name=None): # TODO: move this from conventions to backends? (it's not CF related) if var.dtype.kind == 'O': dims, data, attrs, encoding = _var_as_tuple(var) + + if isinstance(data, dask_array_type): + warnings.warn( + 'variable {} has data in the form of a dask array with ' + 'dtype=object, which means it is being loaded into memory ' + 'to determine a data type that can be safely stored on disk. ' + 'To avoid this, coerce this variable to a fixed-size dtype ' + 'with astype() before saving it.'.format(name), + SerializationWarning) + data = data.compute() + missing = pd.isnull(data) if missing.any(): # nb. this will fail for dask.array data @@ -349,9 +185,9 @@ def ensure_dtype_not_object(var, name=None): # There is no safe bit-pattern for NA in typical binary string # formats, we so can't set a fill_value. Unfortunately, this means # we can't distinguish between missing values and empty strings. - if inferred_dtype.kind == 'S': + if strings.is_bytes_dtype(inferred_dtype): fill_value = b'' - elif inferred_dtype.kind == 'U': + elif strings.is_unicode_dtype(inferred_dtype): fill_value = u'' else: # insist on using float for numeric values @@ -359,10 +195,12 @@ def ensure_dtype_not_object(var, name=None): inferred_dtype = np.dtype(float) fill_value = inferred_dtype.type(np.nan) - data = np.array(data, dtype=inferred_dtype, copy=True) + data = _copy_with_dtype(data, dtype=inferred_dtype) data[missing] = fill_value else: - data = data.astype(dtype=_infer_dtype(data, name)) + data = _copy_with_dtype(data, dtype=_infer_dtype(data, name)) + + assert data.dtype.kind != 'O' or data.dtype.metadata var = Variable(dims, data, attrs, encoding) return var @@ -401,7 +239,6 @@ def encode_cf_variable(var, needs_copy=True, name=None): var = maybe_default_fill_value(var) var = maybe_encode_bools(var) var = ensure_dtype_not_object(var, name=name) - var = maybe_encode_string_dtype(var, name=name) return var @@ -443,32 +280,20 @@ def decode_cf_variable(name, var, concat_characters=True, mask_and_scale=True, out : Variable A variable holding the decoded equivalent of var. """ - # use _data instead of data so as not to trigger loading data var = as_variable(var) - data = var._data - dimensions = var.dims - attributes = var.attrs.copy() - encoding = var.encoding.copy() - - original_dtype = data.dtype + original_dtype = var.dtype - if concat_characters and data.dtype.kind == 'S': + if concat_characters: if stack_char_dim: - dimensions = dimensions[:-1] - data = StackedBytesArray(data) - - string_encoding = pop_to(attributes, encoding, '_Encoding') - if string_encoding is not None: - data = BytesToStringArray(data, string_encoding) - - # TODO(shoyer): convert everything above to use coders - var = Variable(dimensions, data, attributes, encoding) + var = strings.CharacterArrayCoder().decode(var, name=name) + var = strings.EncodedStringCoder().decode(var) if mask_and_scale: for coder in [variables.UnsignedIntegerCoder(), variables.CFMaskCoder(), variables.CFScaleOffsetCoder()]: var = coder.decode(var, name=name) + if decode_times: for coder in [times.CFTimedeltaCoder(), times.CFDatetimeCoder()]: @@ -483,19 +308,16 @@ def decode_cf_variable(name, var, concat_characters=True, mask_and_scale=True, data = NativeEndiannessArray(data) original_dtype = data.dtype - if 'dtype' in encoding: - if original_dtype != encoding['dtype']: - warnings.warn("CF decoding is overwriting dtype on variable {!r}" - .format(name)) - else: - encoding['dtype'] = original_dtype + encoding.setdefault('dtype', original_dtype) if 'dtype' in attributes and attributes['dtype'] == 'bool': del attributes['dtype'] data = BoolTypeArray(data) - return Variable(dimensions, indexing.LazilyIndexedArray(data), - attributes, encoding=encoding) + if not isinstance(data, dask_array_type): + data = indexing.LazilyOuterIndexedArray(data) + + return Variable(dimensions, data, attributes, encoding=encoding) def decode_cf_variables(variables, attributes, concat_characters=True, diff --git a/xarray/convert.py b/xarray/convert.py index caf665b421d..6cff72103ff 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -1,16 +1,18 @@ """Functions for converting to and from xarray objects """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +from collections import Counter import numpy as np +import pandas as pd from .coding.times import CFDatetimeCoder, CFTimedeltaCoder +from .conventions import decode_cf +from .core import duck_array_ops from .core.dataarray import DataArray -from .core.pycompat import OrderedDict, range from .core.dtypes import get_fill_value -from .conventions import decode_cf +from .core.pycompat import OrderedDict, range cdms2_ignored_attrs = {'name', 'tileIndex'} iris_forbidden_keys = {'standard_name', 'long_name', 'units', 'bounds', 'axis', @@ -39,15 +41,28 @@ def from_cdms2(variable): """ values = np.asarray(variable) name = variable.id - coords = [(v.id, np.asarray(v), - _filter_attrs(v.attributes, cdms2_ignored_attrs)) - for v in variable.getAxisList()] + dims = variable.getAxisIds() + coords = {} + for axis in variable.getAxisList(): + coords[axis.id] = DataArray( + np.asarray(axis), dims=[axis.id], + attrs=_filter_attrs(axis.attributes, cdms2_ignored_attrs)) + grid = variable.getGrid() + if grid is not None: + ids = [a.id for a in grid.getAxisList()] + for axis in grid.getLongitude(), grid.getLatitude(): + if axis.id not in variable.getAxisIds(): + coords[axis.id] = DataArray( + np.asarray(axis[:]), dims=ids, + attrs=_filter_attrs(axis.attributes, + cdms2_ignored_attrs)) attrs = _filter_attrs(variable.attributes, cdms2_ignored_attrs) - dataarray = DataArray(values, coords=coords, name=name, attrs=attrs) + dataarray = DataArray(values, dims=dims, coords=coords, name=name, + attrs=attrs) return decode_cf(dataarray.to_dataset())[dataarray.name] -def to_cdms2(dataarray): +def to_cdms2(dataarray, copy=True): """Convert a DataArray into a cdms2 variable """ # we don't want cdms2 to be a hard dependency @@ -57,6 +72,7 @@ def set_cdms2_attrs(var, attrs): for k, v in attrs.items(): setattr(var, k, v) + # 1D axes axes = [] for dim in dataarray.dims: coord = encode(dataarray.coords[dim]) @@ -64,9 +80,42 @@ def set_cdms2_attrs(var, attrs): set_cdms2_attrs(axis, coord.attrs) axes.append(axis) + # Data var = encode(dataarray) - cdms2_var = cdms2.createVariable(var.values, axes=axes, id=dataarray.name) + cdms2_var = cdms2.createVariable(var.values, axes=axes, id=dataarray.name, + mask=pd.isnull(var.values), copy=copy) + + # Attributes set_cdms2_attrs(cdms2_var, var.attrs) + + # Curvilinear and unstructured grids + if dataarray.name not in dataarray.coords: + + cdms2_axes = OrderedDict() + for coord_name in set(dataarray.coords.keys()) - set(dataarray.dims): + + coord_array = dataarray.coords[coord_name].to_cdms2() + + cdms2_axis_cls = (cdms2.coord.TransientAxis2D + if coord_array.ndim else + cdms2.auxcoord.TransientAuxAxis1D) + cdms2_axis = cdms2_axis_cls(coord_array) + if cdms2_axis.isLongitude(): + cdms2_axes['lon'] = cdms2_axis + elif cdms2_axis.isLatitude(): + cdms2_axes['lat'] = cdms2_axis + + if 'lon' in cdms2_axes and 'lat' in cdms2_axes: + if len(cdms2_axes['lon'].shape) == 2: + cdms2_grid = cdms2.hgrid.TransientCurveGrid( + cdms2_axes['lat'], cdms2_axes['lon']) + else: + cdms2_grid = cdms2.gengrid.AbstractGenericGrid( + cdms2_axes['lat'], cdms2_axes['lon']) + for axis in cdms2_grid.getAxisList(): + cdms2_var.setAxis(cdms2_var.getAxisIds().index(axis.id), axis) + cdms2_var.setGrid(cdms2_grid) + return cdms2_var @@ -96,7 +145,6 @@ def to_iris(dataarray): # Iris not a hard dependency import iris from iris.fileformats.netcdf import parse_cell_methods - from xarray.core.pycompat import dask_array_type dim_coords = [] aux_coords = [] @@ -109,8 +157,12 @@ def to_iris(dataarray): if coord.dims: axis = dataarray.get_axis_num(coord.dims) if coord_name in dataarray.dims: - iris_coord = iris.coords.DimCoord(coord.values, **coord_args) - dim_coords.append((iris_coord, axis)) + try: + iris_coord = iris.coords.DimCoord(coord.values, **coord_args) + dim_coords.append((iris_coord, axis)) + except ValueError: + iris_coord = iris.coords.AuxCoord(coord.values, **coord_args) + aux_coords.append((iris_coord, axis)) else: iris_coord = iris.coords.AuxCoord(coord.values, **coord_args) aux_coords.append((iris_coord, axis)) @@ -123,13 +175,7 @@ def to_iris(dataarray): args['cell_methods'] = \ parse_cell_methods(dataarray.attrs['cell_methods']) - # Create the right type of masked array (should be easier after #1769) - if isinstance(dataarray.data, dask_array_type): - from dask.array import ma as dask_ma - masked_data = dask_ma.masked_invalid(dataarray) - else: - masked_data = np.ma.masked_invalid(dataarray) - + masked_data = duck_array_ops.masked_invalid(dataarray.data) cube = iris.cube.Cube(masked_data, **args) return cube @@ -142,7 +188,7 @@ def _iris_obj_to_attrs(obj): 'long_name': obj.long_name} if obj.units.calendar: attrs['calendar'] = obj.units.calendar - if obj.units.origin != '1': + if obj.units.origin != '1' and not obj.units.is_unknown(): attrs['units'] = obj.units.origin attrs.update(obj.attributes) return dict((k, v) for k, v in attrs.items() if v is not None) @@ -165,34 +211,46 @@ def _iris_cell_methods_to_str(cell_methods_obj): return ' '.join(cell_methods) +def _name(iris_obj, default='unknown'): + """ Mimicks `iris_obj.name()` but with different name resolution order. + + Similar to iris_obj.name() method, but using iris_obj.var_name first to + enable roundtripping. + """ + return (iris_obj.var_name or iris_obj.standard_name or + iris_obj.long_name or default) + + def from_iris(cube): """ Convert a Iris cube into an DataArray """ import iris.exceptions from xarray.core.pycompat import dask_array_type - name = cube.var_name + name = _name(cube) + if name == 'unknown': + name = None dims = [] for i in range(cube.ndim): try: dim_coord = cube.coord(dim_coords=True, dimensions=(i,)) - dims.append(dim_coord.var_name) + dims.append(_name(dim_coord)) except iris.exceptions.CoordinateNotFoundError: dims.append("dim_{}".format(i)) + if len(set(dims)) != len(dims): + duplicates = [k for k, v in Counter(dims).items() if v > 1] + raise ValueError('Duplicate coordinate name {}.'.format(duplicates)) + coords = OrderedDict() for coord in cube.coords(): coord_attrs = _iris_obj_to_attrs(coord) coord_dims = [dims[i] for i in cube.coord_dims(coord)] - if not coord.var_name: - raise ValueError("Coordinate '{}' has no " - "var_name attribute".format(coord.name())) if coord_dims: - coords[coord.var_name] = (coord_dims, coord.points, coord_attrs) + coords[_name(coord)] = (coord_dims, coord.points, coord_attrs) else: - coords[coord.var_name] = ((), - np.asscalar(coord.points), coord_attrs) + coords[_name(coord)] = ((), np.asscalar(coord.points), coord_attrs) array_attrs = _iris_obj_to_attrs(cube) cell_methods = _iris_cell_methods_to_str(cube.cell_methods) diff --git a/xarray/core/accessors.py b/xarray/core/accessors.py index 5052b555c73..72791ed73ec 100644 --- a/xarray/core/accessors.py +++ b/xarray/core/accessors.py @@ -1,13 +1,11 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from .dtypes import is_datetime_like -from .pycompat import dask_array_type +from __future__ import absolute_import, division, print_function import numpy as np import pandas as pd +from .common import _contains_datetime_like_objects, is_np_datetime_like +from .pycompat import dask_array_type + def _season_from_months(months): """Compute season (DJF, MAM, JJA, SON) from month ordinal @@ -18,6 +16,20 @@ def _season_from_months(months): return seasons[(months // 3) % 4] +def _access_through_cftimeindex(values, name): + """Coerce an array of datetime-like values to a CFTimeIndex + and access requested datetime component + """ + from ..coding.cftimeindex import CFTimeIndex + values_as_cftimeindex = CFTimeIndex(values.ravel()) + if name == 'season': + months = values_as_cftimeindex.month + field_values = _season_from_months(months) + else: + field_values = getattr(values_as_cftimeindex, name) + return field_values.reshape(values.shape) + + def _access_through_series(values, name): """Coerce an array of datetime-like values to a pandas Series and access requested datetime component @@ -50,12 +62,54 @@ def _get_date_field(values, name, dtype): Array-like of datetime fields accessed for each element in values """ + if is_np_datetime_like(values.dtype): + access_method = _access_through_series + else: + access_method = _access_through_cftimeindex + if isinstance(values, dask_array_type): from dask.array import map_blocks - return map_blocks(_access_through_series, + return map_blocks(access_method, values, name, dtype=dtype) else: - return _access_through_series(values, name) + return access_method(values, name) + + +def _round_series(values, name, freq): + """Coerce an array of datetime-like values to a pandas Series and + apply requested rounding + """ + values_as_series = pd.Series(values.ravel()) + method = getattr(values_as_series.dt, name) + field_values = method(freq=freq).values + + return field_values.reshape(values.shape) + + +def _round_field(values, name, freq): + """Indirectly access pandas rounding functions by wrapping data + as a Series and calling through `.dt` attribute. + + Parameters + ---------- + values : np.ndarray or dask.array-like + Array-like container of datetime-like values + name : str (ceil, floor, round) + Name of rounding function + freq : a freq string indicating the rounding resolution + + Returns + ------- + rounded timestamps : same type as values + Array-like of datetime fields accessed for each element in values + + """ + if isinstance(values, dask_array_type): + from dask.array import map_blocks + return map_blocks(_round_series, + values, name, freq=freq, dtype=np.datetime64) + else: + return _round_series(values, name, freq) class DatetimeAccessor(object): @@ -76,15 +130,17 @@ class DatetimeAccessor(object): All of the pandas fields are accessible here. Note that these fields are not calendar-aware; if your datetimes are encoded with a non-Gregorian - calendar (e.g. a 360-day calendar) using netcdftime, then some fields like + calendar (e.g. a 360-day calendar) using cftime, then some fields like `dayofyear` may not be accurate. """ def __init__(self, xarray_obj): - if not is_datetime_like(xarray_obj.dtype): + if not _contains_datetime_like_objects(xarray_obj): raise TypeError("'dt' accessor only available for " - "DataArray with datetime64 or timedelta64 dtype") + "DataArray with datetime64 timedelta64 dtype or " + "for arrays containing cftime datetime " + "objects.") self._obj = xarray_obj def _tslib_field_accessor(name, docstring=None, dtype=None): @@ -147,3 +203,58 @@ def f(self, dtype=dtype): time = _tslib_field_accessor( "time", "Timestamps corresponding to datetimes", object ) + + def _tslib_round_accessor(self, name, freq): + obj_type = type(self._obj) + result = _round_field(self._obj.data, name, freq) + return obj_type(result, name=name, + coords=self._obj.coords, dims=self._obj.dims) + + def floor(self, freq): + ''' + Round timestamps downward to specified frequency resolution. + + Parameters + ---------- + freq : a freq string indicating the rounding resolution + e.g. 'D' for daily resolution + + Returns + ------- + floor-ed timestamps : same type as values + Array-like of datetime fields accessed for each element in values + ''' + + return self._tslib_round_accessor("floor", freq) + + def ceil(self, freq): + ''' + Round timestamps upward to specified frequency resolution. + + Parameters + ---------- + freq : a freq string indicating the rounding resolution + e.g. 'D' for daily resolution + + Returns + ------- + ceil-ed timestamps : same type as values + Array-like of datetime fields accessed for each element in values + ''' + return self._tslib_round_accessor("ceil", freq) + + def round(self, freq): + ''' + Round timestamps to specified frequency resolution. + + Parameters + ---------- + freq : a freq string indicating the rounding resolution + e.g. 'D' for daily resolution + + Returns + ------- + rounded timestamps : same type as values + Array-like of datetime fields accessed for each element in values + ''' + return self._tslib_round_accessor("round", freq) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 876245322fa..f82ddef25ba 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -1,20 +1,17 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import functools import operator -from collections import defaultdict import warnings +from collections import defaultdict import numpy as np -from . import duck_array_ops -from . import dtypes from . import utils from .indexing import get_indexer_nd -from .pycompat import iteritems, OrderedDict, suppress -from .utils import is_full_slice, is_dict_like -from .variable import Variable, IndexVariable +from .pycompat import OrderedDict, iteritems, suppress +from .utils import is_dict_like, is_full_slice +from .variable import IndexVariable def _get_joiner(join): @@ -177,11 +174,14 @@ def deep_align(objects, join='inner', copy=True, indexes=None, This function is not public API. """ + from .dataarray import DataArray + from .dataset import Dataset + if indexes is None: indexes = {} def is_alignable(obj): - return hasattr(obj, 'indexes') and hasattr(obj, 'reindex') + return isinstance(obj, (DataArray, Dataset)) positions = [] keys = [] @@ -306,59 +306,51 @@ def reindex_variables(variables, sizes, indexes, indexers, method=None, from .dataarray import DataArray # build up indexers for assignment along each dimension - to_indexers = {} - from_indexers = {} + int_indexers = {} + targets = {} + masked_dims = set() + unchanged_dims = set() + # size of reindexed dimensions new_sizes = {} for name, index in iteritems(indexes): if name in indexers: - target = utils.safe_cast_to_index(indexers[name]) if not index.is_unique: raise ValueError( 'cannot reindex or align along dimension %r because the ' 'index has duplicate values' % name) - indexer = get_indexer_nd(index, target, method, tolerance) + target = utils.safe_cast_to_index(indexers[name]) new_sizes[name] = len(target) - # Note pandas uses negative values from get_indexer_nd to signify - # values that are missing in the index - # The non-negative values thus indicate the non-missing values - to_indexers[name] = indexer >= 0 - if to_indexers[name].all(): - # If an indexer includes no negative values, then the - # assignment can be to a full-slice (which is much faster, - # and means we won't need to fill in any missing values) - to_indexers[name] = slice(None) - - from_indexers[name] = indexer[to_indexers[name]] - if np.array_equal(from_indexers[name], np.arange(len(index))): - # If the indexer is equal to the original index, use a full - # slice object to speed up selection and so we can avoid - # unnecessary copies - from_indexers[name] = slice(None) + + int_indexer = get_indexer_nd(index, target, method, tolerance) + + # We uses negative values from get_indexer_nd to signify + # values that are missing in the index. + if (int_indexer < 0).any(): + masked_dims.add(name) + elif np.array_equal(int_indexer, np.arange(len(index))): + unchanged_dims.add(name) + + int_indexers[name] = int_indexer + targets[name] = target for dim in sizes: if dim not in indexes and dim in indexers: existing_size = sizes[dim] - new_size = utils.safe_cast_to_index(indexers[dim]).size + new_size = indexers[dim].size if existing_size != new_size: raise ValueError( 'cannot reindex or align along dimension %r without an ' 'index because its size %r is different from the size of ' 'the new index %r' % (dim, existing_size, new_size)) - def any_not_full_slices(indexers): - return any(not is_full_slice(idx) for idx in indexers) - - def var_indexers(var, indexers): - return tuple(indexers.get(d, slice(None)) for d in var.dims) - # create variables for the new dataset reindexed = OrderedDict() for dim, indexer in indexers.items(): - if isinstance(indexer, DataArray) and indexer.dims != (dim, ): + if isinstance(indexer, DataArray) and indexer.dims != (dim,): warnings.warn( "Indexer has dimensions {0:s} that are different " "from that to be indexed along {1:s}. " @@ -375,47 +367,24 @@ def var_indexers(var, indexers): for name, var in iteritems(variables): if name not in indexers: - assign_to = var_indexers(var, to_indexers) - assign_from = var_indexers(var, from_indexers) - - if any_not_full_slices(assign_to): - # there are missing values to in-fill - data = var[assign_from].data - dtype, fill_value = dtypes.maybe_promote(var.dtype) - - if isinstance(data, np.ndarray): - shape = tuple(new_sizes.get(dim, size) - for dim, size in zip(var.dims, var.shape)) - new_data = np.empty(shape, dtype=dtype) - new_data[...] = fill_value - # create a new Variable so we can use orthogonal indexing - # use fastpath=True to avoid dtype inference - new_var = Variable(var.dims, new_data, var.attrs, - fastpath=True) - new_var[assign_to] = data - - else: # dask array - data = data.astype(dtype, copy=False) - for axis, indexer in enumerate(assign_to): - if not is_full_slice(indexer): - indices = np.cumsum(indexer)[~indexer] - data = duck_array_ops.insert( - data, indices, fill_value, axis=axis) - new_var = Variable(var.dims, data, var.attrs, - fastpath=True) - - elif any_not_full_slices(assign_from): - # type coercion is not necessary as there are no missing - # values - new_var = var[assign_from] - - else: - # no reindexing is necessary + key = tuple(slice(None) + if d in unchanged_dims + else int_indexers.get(d, slice(None)) + for d in var.dims) + needs_masking = any(d in masked_dims for d in var.dims) + + if needs_masking: + new_var = var._getitem_with_mask(key) + elif all(is_full_slice(k) for k in key): + # no reindexing necessary # here we need to manually deal with copying data, since # we neither created a new ndarray nor used fancy indexing new_var = var.copy(deep=copy) + else: + new_var = var[key] reindexed[name] = new_var + return reindexed diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py new file mode 100644 index 00000000000..a3bb135af24 --- /dev/null +++ b/xarray/core/arithmetic.py @@ -0,0 +1,77 @@ +"""Base classes implementing arithmetic for xarray objects.""" +from __future__ import absolute_import, division, print_function + +import numbers + +import numpy as np + +from .options import OPTIONS +from .pycompat import bytes_type, dask_array_type, unicode_type +from .utils import not_implemented + + +class SupportsArithmetic(object): + """Base class for xarray types that support arithmetic. + + Used by Dataset, DataArray, Variable and GroupBy. + """ + + # TODO: implement special methods for arithmetic here rather than injecting + # them in xarray/core/ops.py. Ideally, do so by inheriting from + # numpy.lib.mixins.NDArrayOperatorsMixin. + + # TODO: allow extending this with some sort of registration system + _HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes_type, + unicode_type) + dask_array_type + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + from .computation import apply_ufunc + + # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. + out = kwargs.get('out', ()) + for x in inputs + out: + if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): + return NotImplemented + + if ufunc.signature is not None: + raise NotImplementedError( + '{} not supported: xarray objects do not directly implement ' + 'generalized ufuncs. Instead, use xarray.apply_ufunc or ' + 'explicitly convert to xarray objects to NumPy arrays ' + '(e.g., with `.values`).' + .format(ufunc)) + + if method != '__call__': + # TODO: support other methods, e.g., reduce and accumulate. + raise NotImplementedError( + '{} method for ufunc {} is not implemented on xarray objects, ' + 'which currently only support the __call__ method. As an ' + 'alternative, consider explicitly converting xarray objects ' + 'to NumPy arrays (e.g., with `.values`).' + .format(method, ufunc)) + + if any(isinstance(o, SupportsArithmetic) for o in out): + # TODO: implement this with logic like _inplace_binary_op. This + # will be necessary to use NDArrayOperatorsMixin. + raise NotImplementedError( + 'xarray objects are not yet supported in the `out` argument ' + 'for ufuncs. As an alternative, consider explicitly ' + 'converting xarray objects to NumPy arrays (e.g., with ' + '`.values`).') + + join = dataset_join = OPTIONS['arithmetic_join'] + + return apply_ufunc(ufunc, *inputs, + input_core_dims=((),) * ufunc.nin, + output_core_dims=((),) * ufunc.nout, + join=join, + dataset_join=dataset_join, + dataset_fill_value=np.nan, + kwargs=kwargs, + dask='allowed') + + # this has no runtime function - these are listed so IDEs know these + # methods are defined and don't warn on these operations + __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ + __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ + __or__ = __div__ = __eq__ = __ne__ = not_implemented diff --git a/xarray/core/combine.py b/xarray/core/combine.py index b14d085f383..6853939c02d 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,6 +1,5 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import warnings import pandas as pd @@ -8,9 +7,9 @@ from . import utils from .alignment import align from .merge import merge -from .pycompat import iteritems, OrderedDict, basestring -from .variable import Variable, as_variable, IndexVariable, \ - concat as concat_vars +from .pycompat import OrderedDict, basestring, iteritems +from .variable import IndexVariable, Variable, as_variable +from .variable import concat as concat_vars def concat(objs, dim=None, data_vars='all', coords='different', @@ -126,16 +125,17 @@ def _calc_concat_dim_coord(dim): Infer the dimension name and 1d coordinate variable (if appropriate) for concatenating along the new dimension. """ + from .dataarray import DataArray + if isinstance(dim, basestring): coord = None - elif not hasattr(dim, 'dims'): - # dim is not a DataArray or IndexVariable + elif not isinstance(dim, (DataArray, Variable)): dim_name = getattr(dim, 'name', None) if dim_name is None: dim_name = 'concat_dim' coord = IndexVariable(dim_name, dim) dim = dim_name - elif not hasattr(dim, 'name'): + elif not isinstance(dim, DataArray): coord = as_variable(dim).to_index_variable() dim, = coord.dims else: @@ -341,7 +341,8 @@ def _dataarray_concat(arrays, dim, data_vars, coords, compat, def _auto_concat(datasets, dim=None, data_vars='all', coords='different'): - if len(datasets) == 1: + if len(datasets) == 1 and dim is None: + # There is nothing more to combine, so kick out early. return datasets[0] else: if dim is None: diff --git a/xarray/core/common.py b/xarray/core/common.py index 1366d0ff03d..34057e3715d 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,15 +1,20 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +import warnings +from distutils.version import LooseVersion +from textwrap import dedent + import numpy as np import pandas as pd -import warnings -from .pycompat import basestring, suppress, dask_array_type, OrderedDict -from . import dtypes -from . import formatting -from . import ops -from .utils import SortedKeysDict, not_implemented, Frozen +from . import dtypes, duck_array_ops, formatting, ops +from .arithmetic import SupportsArithmetic +from .pycompat import OrderedDict, basestring, dask_array_type, suppress +from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs +from .options import _get_keep_attrs + +# Used as a sentinel value to indicate a all dimensions +ALL_DIMS = ReprObject('') class ImplementsArrayReduce(object): @@ -17,44 +22,44 @@ class ImplementsArrayReduce(object): def _reduce_method(cls, func, include_skipna, numeric_only): if include_skipna: def wrapped_func(self, dim=None, axis=None, skipna=None, - keep_attrs=False, **kwargs): - return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + **kwargs): + return self.reduce(func, dim, axis, skipna=skipna, allow_lazy=True, **kwargs) else: - def wrapped_func(self, dim=None, axis=None, keep_attrs=False, + def wrapped_func(self, dim=None, axis=None, **kwargs): - return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + return self.reduce(func, dim, axis, allow_lazy=True, **kwargs) return wrapped_func - _reduce_extra_args_docstring = \ - """dim : str or sequence of str, optional + _reduce_extra_args_docstring = dedent("""\ + dim : str or sequence of str, optional Dimension(s) over which to apply `{name}`. axis : int or sequence of int, optional Axis(es) over which to apply `{name}`. Only one of the 'dim' and 'axis' arguments can be supplied. If neither are supplied, then - `{name}` is calculated over axes.""" + `{name}` is calculated over axes.""") - _cum_extra_args_docstring = \ - """dim : str or sequence of str, optional + _cum_extra_args_docstring = dedent("""\ + dim : str or sequence of str, optional Dimension over which to apply `{name}`. axis : int or sequence of int, optional Axis over which to apply `{name}`. Only one of the 'dim' - and 'axis' arguments can be supplied.""" + and 'axis' arguments can be supplied.""") class ImplementsDatasetReduce(object): @classmethod def _reduce_method(cls, func, include_skipna, numeric_only): if include_skipna: - def wrapped_func(self, dim=None, keep_attrs=False, skipna=None, + def wrapped_func(self, dim=None, skipna=None, **kwargs): - return self.reduce(func, dim, keep_attrs, skipna=skipna, + return self.reduce(func, dim, skipna=skipna, numeric_only=numeric_only, allow_lazy=True, **kwargs) else: - def wrapped_func(self, dim=None, keep_attrs=False, **kwargs): - return self.reduce(func, dim, keep_attrs, + def wrapped_func(self, dim=None, **kwargs): + return self.reduce(func, dim, numeric_only=numeric_only, allow_lazy=True, **kwargs) return wrapped_func @@ -211,24 +216,36 @@ def _ipython_key_completions_(self): return list(set(item_lists)) -def get_squeeze_dims(xarray_obj, dim): +def get_squeeze_dims(xarray_obj, dim, axis=None): """Get a list of dimensions to squeeze out. """ - if dim is None: + if dim is not None and axis is not None: + raise ValueError('cannot use both parameters `axis` and `dim`') + + if dim is None and axis is None: dim = [d for d, s in xarray_obj.sizes.items() if s == 1] else: if isinstance(dim, basestring): dim = [dim] + if isinstance(axis, int): + axis = (axis, ) + if isinstance(axis, tuple): + for a in axis: + if not isinstance(a, int): + raise ValueError( + 'parameter `axis` must be int or tuple of int.') + alldims = list(xarray_obj.sizes.keys()) + dim = [alldims[a] for a in axis] if any(xarray_obj.sizes[k] > 1 for k in dim): raise ValueError('cannot select a dimension to squeeze out ' 'which has length greater than one') return dim -class BaseDataObject(AttrAccessMixin): +class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" - def squeeze(self, dim=None, drop=False): + def squeeze(self, dim=None, drop=False, axis=None): """Return a new object with squeezed data. Parameters @@ -240,6 +257,8 @@ def squeeze(self, dim=None, drop=False): drop : bool, optional If ``drop=True``, drop squeezed coordinates instead of making them scalar. + axis : int, optional + Select the dimension to squeeze. Added for compatibility reasons. Returns ------- @@ -251,7 +270,7 @@ def squeeze(self, dim=None, drop=False): -------- numpy.squeeze """ - dims = get_squeeze_dims(self, dim) + dims = get_squeeze_dims(self, dim, axis) return self.isel(drop=drop, **{d: 0 for d in dims}) def get_index(self, key): @@ -295,6 +314,25 @@ def assign_coords(self, **kwargs): A new object with the new coordinates in addition to the existing data. + Examples + -------- + + Convert longitude coordinates from 0-359 to -180-179: + + >>> da = xr.DataArray(np.random.rand(4), + ... coords=[np.array([358, 359, 0, 1])], + ... dims='lon') + >>> da + + array([0.28298 , 0.667347, 0.657938, 0.177683]) + Coordinates: + * lon (lon) int64 358 359 0 1 + >>> da.assign_coords(lon=(((da.lon + 180) % 360) - 180)) + + array([0.28298 , 0.667347, 0.657938, 0.177683]) + Coordinates: + * lon (lon) int64 -2 -1 0 1 + Notes ----- Since ``kwargs`` is a dictionary, the order of your arguments may not @@ -306,6 +344,7 @@ def assign_coords(self, **kwargs): See also -------- Dataset.assign + Dataset.swap_dims """ data = self.copy(deep=False) results = self._calc_assign_results(kwargs) @@ -412,6 +451,31 @@ def groupby(self, group, squeeze=True): grouped : GroupBy A `GroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. + + Examples + -------- + Calculate daily anomalies for daily data: + + >>> da = xr.DataArray(np.linspace(0, 1826, num=1827), + ... coords=[pd.date_range('1/1/2000', '31/12/2004', + ... freq='D')], + ... dims='time') + >>> da + + array([0.000e+00, 1.000e+00, 2.000e+00, ..., 1.824e+03, 1.825e+03, 1.826e+03]) + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... + >>> da.groupby('time.dayofyear') - da.groupby('time.dayofyear').mean('time') + + array([-730.8, -730.8, -730.8, ..., 730.2, 730.2, 730.5]) + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... + dayofyear (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ... + + See Also + -------- + core.groupby.DataArrayGroupBy + core.groupby.DatasetGroupBy """ return self._groupby_cls(self, group, squeeze=squeeze) @@ -467,37 +531,35 @@ def groupby_bins(self, group, bins, right=True, labels=None, precision=3, 'precision': precision, 'include_lowest': include_lowest}) - def rolling(self, min_periods=None, center=False, **windows): + def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs): """ Rolling window object. - Rolling window aggregations are much faster when bottleneck is - installed. - Parameters ---------- + dim: dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. + **dim_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or dim_kwargs must be provided. Returns ------- - rolling : type of input argument + Rolling object (core.rolling.DataArrayRolling for DataArray, + core.rolling.DatasetRolling for Dataset.) Examples -------- Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON: - >>> da = xr.DataArray(np.linspace(0,11,num=12), + >>> da = xr.DataArray(np.linspace(0, 11, num=12), ... coords=[pd.date_range('15/12/1999', ... periods=12, freq=pd.DateOffset(months=1))], ... dims='time') @@ -506,34 +568,40 @@ def rolling(self, min_periods=None, center=False, **windows): array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) Coordinates: * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... - >>> da.rolling(time=3).mean() + >>> da.rolling(time=3, center=True).mean() - array([ nan, nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + array([nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., nan]) Coordinates: * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... Remove the NaNs using ``dropna()``: - >>> da.rolling(time=3).mean().dropna('time') + >>> da.rolling(time=3, center=True).mean().dropna('time') - array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) Coordinates: - * time (time) datetime64[ns] 2000-02-15 2000-03-15 2000-04-15 ... - """ + * time (time) datetime64[ns] 2000-01-15 2000-02-15 2000-03-15 ... - return self._rolling_cls(self, min_periods=min_periods, - center=center, **windows) + See Also + -------- + core.rolling.DataArrayRolling + core.rolling.DatasetRolling + """ + dim = either_dict_or_kwargs(dim, dim_kwargs, 'rolling') + return self._rolling_cls(self, dim, min_periods=min_periods, + center=center) - def resample(self, freq=None, dim=None, how=None, skipna=None, - closed=None, label=None, base=0, keep_attrs=False, **indexer): + def resample(self, indexer=None, skipna=None, closed=None, label=None, + base=0, keep_attrs=None, **indexer_kwargs): """Returns a Resample object for performing resampling operations. - Handles both downsampling and upsampling. Upsampling with filling is - not supported; if any intervals contain no values from the original - object, they will be given the value ``NaN``. + Handles both downsampling and upsampling. If any intervals contain no + values from the original object, they will be given the value ``NaN``. Parameters ---------- + indexer : {dim: freq}, optional + Mapping from the dimension name to resample frequency. skipna : bool, optional Whether to skip missing values when aggregating in downsampling. closed : 'left' or 'right', optional @@ -548,48 +616,90 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, If True, the object's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. - **indexer : {dim: freq} - Dictionary with a key indicating the dimension name to resample - over and a value corresponding to the resampling frequency. + **indexer_kwargs : {dim: freq} + The keyword arguments form of ``indexer``. + One of indexer or indexer_kwargs must be provided. Returns ------- resampled : same type as caller This object resampled. + Examples + -------- + Downsample monthly time-series data to seasonal data: + + >>> da = xr.DataArray(np.linspace(0, 11, num=12), + ... coords=[pd.date_range('15/12/1999', + ... periods=12, freq=pd.DateOffset(months=1))], + ... dims='time') + >>> da + + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... + >>> da.resample(time="QS-DEC").mean() + + array([ 1., 4., 7., 10.]) + Coordinates: + * time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01 + + Upsample monthly time-series data to daily data: + + >>> da.resample(time='1D').interpolate('linear') + + array([ 0. , 0.032258, 0.064516, ..., 10.935484, 10.967742, 11. ]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 1999-12-16 1999-12-17 ... + References ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases """ + # TODO support non-string indexer after removing the old API. + from .dataarray import DataArray from .resample import RESAMPLE_DIM + from ..coding.cftimeindex import CFTimeIndex - if dim is not None: - if how is None: - how = 'mean' - return self._resample_immediately(freq, dim, how, skipna, closed, - label, base, keep_attrs) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) - if (how is not None) and indexer: - raise TypeError("If passing an 'indexer' then 'dim' " - "and 'how' should not be used") + # note: the second argument (now 'skipna') use to be 'dim' + if ((skipna is not None and not isinstance(skipna, bool)) + or ('how' in indexer_kwargs and 'how' not in self.dims) + or ('dim' in indexer_kwargs and 'dim' not in self.dims)): + raise TypeError('resample() no longer supports the `how` or ' + '`dim` arguments. Instead call methods on resample ' + "objects, e.g., data.resample(time='1D').mean()") + + indexer = either_dict_or_kwargs(indexer, indexer_kwargs, 'resample') - # More than one indexer is ambiguous, but we do in fact need one if - # "dim" was not provided, until the old API is fully deprecated if len(indexer) != 1: raise ValueError( "Resampling only supported along single dimensions." ) dim, freq = indexer.popitem() - if isinstance(dim, basestring): - dim_name = dim - dim = self[dim] - else: - raise TypeError("Dimension name should be a string; " - "was passed %r" % dim) - group = DataArray(dim, [(dim.dims, dim)], name=RESAMPLE_DIM) + dim_name = dim + dim_coord = self[dim] + + if isinstance(self.indexes[dim_name], CFTimeIndex): + raise NotImplementedError( + 'Resample is currently not supported along a dimension ' + 'indexed by a CFTimeIndex. For certain kinds of downsampling ' + 'it may be possible to work around this by converting your ' + 'time index to a DatetimeIndex using ' + 'CFTimeIndex.to_datetimeindex. Use caution when doing this ' + 'however, because switching to a DatetimeIndex from a ' + 'CFTimeIndex with a non-standard calendar entails a change ' + 'in the calendar type, which could lead to subtle and silent ' + 'errors.' + ) + + group = DataArray(dim_coord, coords=dim_coord.coords, + dims=dim_coord.dims, name=RESAMPLE_DIM) grouper = pd.Grouper(freq=freq, closed=closed, label=label, base=base) resampler = self._resample_cls(self, group=group, dim=dim_name, grouper=grouper, @@ -597,39 +707,6 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, return resampler - def _resample_immediately(self, freq, dim, how, skipna, - closed, label, base, keep_attrs): - """Implement the original version of .resample() which immediately - executes the desired resampling operation. """ - from .dataarray import DataArray - RESAMPLE_DIM = '__resample_dim__' - - warnings.warn("\n.resample() has been modified to defer " - "calculations. Instead of passing 'dim' and " - "'how=\"{how}\", instead consider using " - ".resample({dim}=\"{freq}\").{how}() ".format( - dim=dim, freq=freq, how=how - ), DeprecationWarning, stacklevel=3) - - if isinstance(dim, basestring): - dim = self[dim] - group = DataArray(dim, [(dim.dims, dim)], name=RESAMPLE_DIM) - grouper = pd.Grouper(freq=freq, how=how, closed=closed, label=label, - base=base) - gb = self._groupby_cls(self, group, grouper=grouper) - if isinstance(how, basestring): - f = getattr(gb, how) - if how in ['first', 'last']: - result = f(skipna=skipna, keep_attrs=keep_attrs) - elif how == 'count': - result = f(dim=dim.name, keep_attrs=keep_attrs) - else: - result = f(dim=dim.name, skipna=skipna, keep_attrs=keep_attrs) - else: - result = gb.reduce(how, dim=dim.name, keep_attrs=keep_attrs) - result = result.rename({RESAMPLE_DIM: dim.name}) - return result - def where(self, cond, other=dtypes.NA, drop=False): """Filter elements from this object according to a condition. @@ -723,18 +800,61 @@ def close(self): self._file_obj.close() self._file_obj = None + def isin(self, test_elements): + """Tests each value in the array for whether it is in the supplied list. + + Parameters + ---------- + test_elements : array_like + The values against which to test each value of `element`. + This argument is flattened if an array or array_like. + See numpy notes for behavior with non-array-like parameters. + + Returns + ------- + isin : same as object, bool + Has the same shape as this object. + + Examples + -------- + + >>> array = xr.DataArray([1, 2, 3], dims='x') + >>> array.isin([1, 3]) + + array([ True, False, True]) + Dimensions without coordinates: x + + See also + -------- + numpy.isin + """ + from .computation import apply_ufunc + from .dataset import Dataset + from .dataarray import DataArray + from .variable import Variable + + if isinstance(test_elements, Dataset): + raise TypeError( + 'isin() argument must be convertible to an array: {}' + .format(test_elements)) + elif isinstance(test_elements, (Variable, DataArray)): + # need to explicitly pull out data to support dask arrays as the + # second argument + test_elements = test_elements.data + + return apply_ufunc( + duck_array_ops.isin, + self, + kwargs=dict(test_elements=test_elements), + dask='allowed', + ) + def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() - # this has no runtime function - these are listed so IDEs know these - # methods are defined and don't warn on these operations - __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ - __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ - __or__ = __div__ = __eq__ = __ne__ = not_implemented - def full_like(other, fill_value, dtype=None): """Return a new object with the same shape and type as a given object. @@ -803,3 +923,34 @@ def ones_like(other, dtype=None): """Shorthand for full_like(other, 1, dtype) """ return full_like(other, 1, dtype) + + +def is_np_datetime_like(dtype): + """Check if a dtype is a subclass of the numpy datetime types + """ + return (np.issubdtype(dtype, np.datetime64) or + np.issubdtype(dtype, np.timedelta64)) + + +def contains_cftime_datetimes(var): + """Check if a variable contains cftime datetime objects""" + try: + from cftime import datetime as cftime_datetime + except ImportError: + return False + else: + if var.dtype == np.dtype('O') and var.data.size > 0: + sample = var.data.ravel()[0] + if isinstance(sample, dask_array_type): + sample = sample.compute() + if isinstance(sample, np.ndarray): + sample = sample.item() + return isinstance(sample, cftime_datetime) + else: + return False + + +def _contains_datetime_like_objects(var): + """Check if a variable contains datetime like objects (either + np.datetime64, np.timedelta64, or cftime.datetime)""" + return is_np_datetime_like(var.dtype) or contains_cftime_datetimes(var) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f1519027398..7998cc4f72f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1,22 +1,22 @@ """ Functions for applying functions that act on arrays to xarray's labeled data. - -NOT PUBLIC API. """ +from __future__ import absolute_import, division, print_function + import functools import itertools import operator +from collections import Counter +from distutils.version import LooseVersion import numpy as np -from . import duck_array_ops -from . import utils +from . import duck_array_ops, utils from .alignment import deep_align from .merge import expand_and_merge_variables -from .pycompat import OrderedDict, dask_array_type +from .pycompat import OrderedDict, basestring, dask_array_type from .utils import is_dict_like - _DEFAULT_FROZEN_SET = frozenset() _NO_FILL_VALUE = utils.ReprObject('') _DEFAULT_NAME = utils.ReprObject('') @@ -196,7 +196,6 @@ def apply_dataarray_ufunc(func, *args, **kwargs): signature = kwargs.pop('signature') join = kwargs.pop('join', 'inner') exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) - keep_attrs = kwargs.pop('keep_attrs', False) if kwargs: raise TypeError('apply_dataarray_ufunc() got unexpected keyword ' 'arguments: %s' % list(kwargs)) @@ -218,11 +217,6 @@ def apply_dataarray_ufunc(func, *args, **kwargs): coords, = result_coords out = DataArray(result_var, coords, name=name, fastpath=True) - if keep_attrs and isinstance(args[0], DataArray): - if isinstance(out, tuple): - out = tuple(ds._copy_attrs_from(args[0]) for ds in out) - else: - out._copy_attrs_from(args[0]) return out @@ -520,13 +514,14 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims): def apply_variable_ufunc(func, *args, **kwargs): """apply_variable_ufunc(func, *args, signature, exclude_dims=frozenset()) """ - from .variable import Variable + from .variable import Variable, as_compatible_data signature = kwargs.pop('signature') exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) dask = kwargs.pop('dask', 'forbidden') output_dtypes = kwargs.pop('output_dtypes', None) output_sizes = kwargs.pop('output_sizes', None) + keep_attrs = kwargs.pop('keep_attrs', False) if kwargs: raise TypeError('apply_variable_ufunc() got unexpected keyword ' 'arguments: %s' % list(kwargs)) @@ -565,14 +560,42 @@ def func(*arrays): 'apply_ufunc: {}'.format(dask)) result_data = func(*input_data) - if signature.num_outputs > 1: - output = [] - for dims, data in zip(output_dims, result_data): - output.append(Variable(dims, data)) - return tuple(output) + if signature.num_outputs == 1: + result_data = (result_data,) + elif (not isinstance(result_data, tuple) or + len(result_data) != signature.num_outputs): + raise ValueError('applied function does not have the number of ' + 'outputs specified in the ufunc signature. ' + 'Result is not a tuple of {} elements: {!r}' + .format(signature.num_outputs, result_data)) + + output = [] + for dims, data in zip(output_dims, result_data): + data = as_compatible_data(data) + if data.ndim != len(dims): + raise ValueError( + 'applied function returned data with unexpected ' + 'number of dimensions: {} vs {}, for dimensions {}' + .format(data.ndim, len(dims), dims)) + + var = Variable(dims, data, fastpath=True) + for dim, new_size in var.sizes.items(): + if dim in dim_sizes and new_size != dim_sizes[dim]: + raise ValueError( + 'size of dimension {!r} on inputs was unexpectedly ' + 'changed by applied function from {} to {}. Only ' + 'dimensions specified in ``exclude_dims`` with ' + 'xarray.apply_ufunc are allowed to change size.' + .format(dim, dim_sizes[dim], new_size)) + + if keep_attrs and isinstance(args[0], Variable): + var.attrs.update(args[0].attrs) + output.append(var) + + if signature.num_outputs == 1: + return output[0] else: - dims, = output_dims - return Variable(dims, result_data) + return tuple(output) def _apply_with_dask_atop(func, args, input_dims, output_dims, signature, @@ -699,7 +722,7 @@ def apply_ufunc(func, *args, **kwargs): on each input argument that should not be broadcast. By default, we assume there are no core dimensions on any input arguments. - For example ,``input_core_dims=[[], ['time']]`` indicates that all + For example, ``input_core_dims=[[], ['time']]`` indicates that all dimensions on the first argument and all dimensions other than 'time' on the second argument should be broadcast. @@ -719,7 +742,8 @@ def apply_ufunc(func, *args, **kwargs): Core dimensions on the inputs to exclude from alignment and broadcasting entirely. Any input coordinates along these dimensions will be dropped. Each excluded dimension must also appear in - ``input_core_dims`` for at least one argument. + ``input_core_dims`` for at least one argument. Only dimensions listed + here are allowed to change size between input and output objects. vectorize : bool, optional If True, then assume ``func`` only takes arrays defined over core dimensions as input and vectorize it automatically with @@ -777,15 +801,38 @@ def apply_ufunc(func, *args, **kwargs): Examples -------- - For illustrative purposes only, here are examples of how you could use - ``apply_ufunc`` to write functions to (very nearly) replicate existing - xarray functionality: - Calculate the vector magnitude of two arguments:: + Calculate the vector magnitude of two arguments: + + >>> def magnitude(a, b): + ... func = lambda x, y: np.sqrt(x ** 2 + y ** 2) + ... return xr.apply_ufunc(func, a, b) + + You can now apply ``magnitude()`` to ``xr.DataArray`` and ``xr.Dataset`` + objects, with automatically preserved dimensions and coordinates, e.g., + + >>> array = xr.DataArray([1, 2, 3], coords=[('x', [0.1, 0.2, 0.3])]) + >>> magnitude(array, -array) + + array([1.414214, 2.828427, 4.242641]) + Coordinates: + * x (x) float64 0.1 0.2 0.3 + + Plain scalars, numpy arrays and a mix of these with xarray objects is also + supported: + + >>> magnitude(4, 5) + 5.0 + >>> magnitude(3, np.array([0, 4])) + array([3., 5.]) + >>> magnitude(array, 0) + + array([1., 2., 3.]) + Coordinates: + * x (x) float64 0.1 0.2 0.3 - def magnitude(a, b): - func = lambda x, y: np.sqrt(x ** 2 + y ** 2) - return xr.apply_func(func, a, b) + Other examples of how you could use ``apply_ufunc`` to write functions to + (very nearly) replicate existing xarray functionality: Compute the mean (``.mean``) over one dimension:: @@ -795,7 +842,7 @@ def mean(obj, dim): input_core_dims=[[dim]], kwargs={'axis': -1}) - Inner product over a specific dimension:: + Inner product over a specific dimension (like ``xr.dot``):: def _inner(x, y): result = np.matmul(x[..., np.newaxis, :], y[..., :, np.newaxis]) @@ -836,7 +883,8 @@ def earth_mover_distance(first_samples, Most of NumPy's builtin functions already broadcast their inputs appropriately for use in `apply`. You may find helper functions such as numpy.broadcast_arrays helpful in writing your function. `apply_ufunc` also - works well with numba's vectorize and guvectorize. + works well with numba's vectorize and guvectorize. Further explanation with + examples are provided in the xarray documentation [3]. See also -------- @@ -848,6 +896,7 @@ def earth_mover_distance(first_samples, ---------- .. [1] http://docs.scipy.org/doc/numpy/reference/ufuncs.html .. [2] http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html + .. [3] http://xarray.pydata.org/en/stable/computation.html#wrapping-custom-computation """ # noqa: E501 # don't error on that URL one line up from .groupby import GroupBy from .dataarray import DataArray @@ -871,6 +920,11 @@ def earth_mover_distance(first_samples, if input_core_dims is None: input_core_dims = ((),) * (len(args)) + elif len(input_core_dims) != len(args): + raise ValueError( + 'input_core_dims must be None or a tuple with the length same to ' + 'the number of arguments. Given input_core_dims: {}, ' + 'number of args: {}.'.format(input_core_dims, len(args))) signature = _UFuncSignature(input_core_dims, output_core_dims) @@ -882,14 +936,26 @@ def earth_mover_distance(first_samples, func = functools.partial(func, **kwargs_) if vectorize: - func = np.vectorize(func, - otypes=output_dtypes, - signature=signature.to_gufunc_string(), - excluded=set(kwargs)) + if signature.all_core_dims: + # we need the signature argument + if LooseVersion(np.__version__) < '1.12': # pragma: no cover + raise NotImplementedError( + 'numpy 1.12 or newer required when using vectorize=True ' + 'in xarray.apply_ufunc with non-scalar output core ' + 'dimensions.') + func = np.vectorize(func, + otypes=output_dtypes, + signature=signature.to_gufunc_string(), + excluded=set(kwargs)) + else: + func = np.vectorize(func, + otypes=output_dtypes, + excluded=set(kwargs)) variables_ufunc = functools.partial(apply_variable_ufunc, func, signature=signature, exclude_dims=exclude_dims, + keep_attrs=keep_attrs, dask=dask, output_dtypes=output_dtypes, output_sizes=output_sizes) @@ -918,14 +984,119 @@ def earth_mover_distance(first_samples, return apply_dataarray_ufunc(variables_ufunc, *args, signature=signature, join=join, - exclude_dims=exclude_dims, - keep_attrs=keep_attrs) + exclude_dims=exclude_dims) elif any(isinstance(a, Variable) for a in args): return variables_ufunc(*args) else: return apply_array_ufunc(func, *args, dask=dask) +def dot(*arrays, **kwargs): + """ dot(*arrays, dims=None) + + Generalized dot product for xarray objects. Like np.einsum, but + provides a simpler interface based on array dimensions. + + Parameters + ---------- + arrays: DataArray (or Variable) objects + Arrays to compute. + dims: str or tuple of strings, optional + Which dimensions to sum over. + If not speciified, then all the common dimensions are summed over. + **kwargs: dict + Additional keyword arguments passed to numpy.einsum or + dask.array.einsum + + Returns + ------- + dot: DataArray + + Examples + -------- + + >>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b']) + >>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5), + >>> dims=['a', 'b', 'c']) + >>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) + >>> + >>> xr.dot(da_a, da_b, dims=['a', 'b']).dims + ('c', ) + >>> xr.dot(da_a, da_b, dims=['a']).dims + ('b', 'c') + >>> xr.dot(da_a, da_b, da_c, dims=['b', 'c']).dims + ('a', 'd') + """ + from .dataarray import DataArray + from .variable import Variable + + dims = kwargs.pop('dims', None) + + if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): + raise TypeError('Only xr.DataArray and xr.Variable are supported.' + 'Given {}.'.format([type(arr) for arr in arrays])) + + if len(arrays) == 0: + raise TypeError('At least one array should be given.') + + if isinstance(dims, basestring): + dims = (dims, ) + + common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) + all_dims = [] + for arr in arrays: + all_dims += [d for d in arr.dims if d not in all_dims] + + einsum_axes = 'abcdefghijklmnopqrstuvwxyz' + dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} + + if dims is None: + # find dimensions that occur more than one times + dim_counts = Counter() + for arr in arrays: + dim_counts.update(arr.dims) + dims = tuple(d for d, c in dim_counts.items() if c > 1) + + dims = tuple(dims) # make dims a tuple + + # dimensions to be parallelized + broadcast_dims = tuple(d for d in all_dims + if d in common_dims and d not in dims) + input_core_dims = [[d for d in arr.dims if d not in broadcast_dims] + for arr in arrays] + output_core_dims = [tuple(d for d in all_dims if d not in + dims + broadcast_dims)] + + # older dask than 0.17.4, we use tensordot if possible. + if isinstance(arr.data, dask_array_type): + import dask + if LooseVersion(dask.__version__) < LooseVersion('0.17.4'): + if len(broadcast_dims) == 0 and len(arrays) == 2: + axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] + for arr in arrays] + return apply_ufunc(duck_array_ops.tensordot, *arrays, + dask='allowed', + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + kwargs={'axes': axes}) + + # construct einsum subscripts, such as '...abc,...ab->...c' + # Note: input_core_dims are always moved to the last position + subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds + in input_core_dims] + subscripts = ','.join(subscripts_list) + subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) + + # subscripts should be passed to np.einsum as arg, not as kwargs. We need + # to construct a partial function for apply_ufunc to work. + func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) + result = apply_ufunc(func, *arrays, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + dask='allowed') + return result.transpose(*[d for d in all_dims if d in result.dims]) + + def where(cond, x, y): """Return elements from `x` or `y` depending on `cond`. diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 60c01e8be72..efe8affb2a3 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1,17 +1,21 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + from collections import Mapping from contextlib import contextmanager + import pandas as pd from . import formatting, indexing -from .utils import Frozen from .merge import ( - merge_coords, expand_and_merge_variables, merge_coords_for_inplace_math) + expand_and_merge_variables, merge_coords, merge_coords_for_inplace_math) from .pycompat import OrderedDict +from .utils import Frozen, ReprObject, either_dict_or_kwargs from .variable import Variable +# Used as the key corresponding to a DataArray's variable when converting +# arbitrary DataArray objects to datasets +_THIS_ARRAY = ReprObject('') + class AbstractCoordinates(Mapping, formatting.ReprMixin): def __getitem__(self, key): @@ -225,7 +229,9 @@ def __getitem__(self, key): def _update_coords(self, coords): from .dataset import calculate_dimensions - dims = calculate_dimensions(coords) + coords_plus_data = coords.copy() + coords_plus_data[_THIS_ARRAY] = self._data.variable + dims = calculate_dimensions(coords_plus_data) if not set(dims) <= set(self.dims): raise ValueError('cannot add coordinates with new dimensions to ' 'a DataArray') @@ -277,8 +283,8 @@ class Indexes(Mapping, formatting.ReprMixin): def __init__(self, variables, sizes): """Not for public consumption. - Arguments - --------- + Parameters + ---------- variables : OrderedDict[Any, Variable] Reference to OrderedDict holding variable objects. Should be the same dictionary used by the source object. @@ -325,7 +331,8 @@ def assert_coordinate_consistent(obj, coords): .format(k, obj[k], coords[k])) -def remap_label_indexers(obj, method=None, tolerance=None, **indexers): +def remap_label_indexers(obj, indexers=None, method=None, tolerance=None, + **indexers_kwargs): """ Remap **indexers from obj.coords. If indexer is an instance of DataArray and it has coordinate, then this @@ -338,6 +345,8 @@ def remap_label_indexers(obj, method=None, tolerance=None, **indexers): new_indexes: mapping of new dimensional-coordinate. """ from .dataarray import DataArray + indexers = either_dict_or_kwargs( + indexers, indexers_kwargs, 'remap_label_indexers') v_indexers = {k: v.variable.data if isinstance(v, DataArray) else v for k, v in indexers.items()} diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py new file mode 100644 index 00000000000..6b53dcffe6e --- /dev/null +++ b/xarray/core/dask_array_compat.py @@ -0,0 +1,162 @@ +from __future__ import absolute_import, division, print_function + +from distutils.version import LooseVersion + +import dask.array as da +import numpy as np +from dask import __version__ as dask_version + +try: + from dask.array import isin +except ImportError: # pragma: no cover + # Copied from dask v0.17.3. + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + + def _isin_kernel(element, test_elements, assume_unique=False): + values = np.in1d(element.ravel(), test_elements, + assume_unique=assume_unique) + return values.reshape(element.shape + (1,) * test_elements.ndim) + + def isin(element, test_elements, assume_unique=False, invert=False): + element = da.asarray(element) + test_elements = da.asarray(test_elements) + element_axes = tuple(range(element.ndim)) + test_axes = tuple(i + element.ndim for i in range(test_elements.ndim)) + mapped = da.atop(_isin_kernel, element_axes + test_axes, + element, element_axes, + test_elements, test_axes, + adjust_chunks={axis: lambda _: 1 + for axis in test_axes}, + dtype=bool, + assume_unique=assume_unique) + result = mapped.any(axis=test_axes) + if invert: + result = ~result + return result + + +if LooseVersion(dask_version) > LooseVersion('1.19.2'): + gradient = da.gradient + +else: # pragma: no cover + # Copied from dask v0.19.2 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + import math + from numbers import Integral, Real + + try: + AxisError = np.AxisError + except AttributeError: + try: + np.array([0]).sum(axis=5) + except Exception as e: + AxisError = type(e) + + def validate_axis(axis, ndim): + """ Validate an input to axis= keywords """ + if isinstance(axis, (tuple, list)): + return tuple(validate_axis(ax, ndim) for ax in axis) + if not isinstance(axis, Integral): + raise TypeError("Axis value must be an integer, got %s" % axis) + if axis < -ndim or axis >= ndim: + raise AxisError("Axis %d is out of bounds for array of dimension " + "%d" % (axis, ndim)) + if axis < 0: + axis += ndim + return axis + + def _gradient_kernel(x, block_id, coord, axis, array_locs, grad_kwargs): + """ + x: nd-array + array of one block + coord: 1d-array or scalar + coordinate along which the gradient is computed. + axis: int + axis along which the gradient is computed + array_locs: + actual location along axis. None if coordinate is scalar + grad_kwargs: + keyword to be passed to np.gradient + """ + block_loc = block_id[axis] + if array_locs is not None: + coord = coord[array_locs[0][block_loc]:array_locs[1][block_loc]] + grad = np.gradient(x, coord, axis=axis, **grad_kwargs) + return grad + + def gradient(f, *varargs, **kwargs): + f = da.asarray(f) + + kwargs["edge_order"] = math.ceil(kwargs.get("edge_order", 1)) + if kwargs["edge_order"] > 2: + raise ValueError("edge_order must be less than or equal to 2.") + + drop_result_list = False + axis = kwargs.pop("axis", None) + if axis is None: + axis = tuple(range(f.ndim)) + elif isinstance(axis, Integral): + drop_result_list = True + axis = (axis,) + + axis = validate_axis(axis, f.ndim) + + if len(axis) != len(set(axis)): + raise ValueError("duplicate axes not allowed") + + axis = tuple(ax % f.ndim for ax in axis) + + if varargs == (): + varargs = (1,) + if len(varargs) == 1: + varargs = len(axis) * varargs + if len(varargs) != len(axis): + raise TypeError( + "Spacing must either be a single scalar, or a scalar / " + "1d-array per axis" + ) + + if issubclass(f.dtype.type, (np.bool8, Integral)): + f = f.astype(float) + elif issubclass(f.dtype.type, Real) and f.dtype.itemsize < 4: + f = f.astype(float) + + results = [] + for i, ax in enumerate(axis): + for c in f.chunks[ax]: + if np.min(c) < kwargs["edge_order"] + 1: + raise ValueError( + 'Chunk size must be larger than edge_order + 1. ' + 'Minimum chunk for aixs {} is {}. Rechunk to ' + 'proceed.'.format(np.min(c), ax)) + + if np.isscalar(varargs[i]): + array_locs = None + else: + if isinstance(varargs[i], da.Array): + raise NotImplementedError( + 'dask array coordinated is not supported.') + # coordinate position for each block taking overlap into + # account + chunk = np.array(f.chunks[ax]) + array_loc_stop = np.cumsum(chunk) + 1 + array_loc_start = array_loc_stop - chunk - 2 + array_loc_stop[-1] -= 1 + array_loc_start[0] = 0 + array_locs = (array_loc_start, array_loc_stop) + + results.append(f.map_overlap( + _gradient_kernel, + dtype=f.dtype, + depth={j: 1 if j == ax else 0 for j in range(f.ndim)}, + boundary="none", + coord=varargs[i], + axis=ax, + array_locs=array_locs, + grad_kwargs=kwargs, + )) + + if drop_result_list: + results = results[0] + + return results diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 3aefd114517..25c572edd54 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,26 +1,106 @@ -"""Define core operations for xarray objects. -""" +from __future__ import absolute_import, division, print_function + +from distutils.version import LooseVersion + import numpy as np +from . import dtypes, nputils + try: + import dask import dask.array as da + # Note: dask has used `ghost` before 0.18.2 + if LooseVersion(dask.__version__) <= LooseVersion('0.18.2'): + overlap = da.ghost.ghost + trim_internal = da.ghost.trim_internal + else: + overlap = da.overlap.overlap + trim_internal = da.overlap.trim_internal except ImportError: pass def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): '''wrapper to apply bottleneck moving window funcs on dask arrays''' - # inputs for ghost + dtype, fill_value = dtypes.maybe_promote(a.dtype) + a = a.astype(dtype) + # inputs for overlap if axis < 0: axis = a.ndim + axis depth = {d: 0 for d in range(a.ndim)} - depth[axis] = window - 1 - boundary = {d: np.nan for d in range(a.ndim)} - # create ghosted arrays - ag = da.ghost.ghost(a, depth=depth, boundary=boundary) + depth[axis] = (window + 1) // 2 + boundary = {d: fill_value for d in range(a.ndim)} + # Create overlap array. + ag = overlap(a, depth=depth, boundary=boundary) # apply rolling func out = ag.map_blocks(moving_func, window, min_count=min_count, axis=axis, dtype=a.dtype) # trim array - result = da.ghost.trim_internal(out, depth) + result = trim_internal(out, depth) return result + + +def rolling_window(a, axis, window, center, fill_value): + """ Dask's equivalence to np.utils.rolling_window """ + orig_shape = a.shape + if axis < 0: + axis = a.ndim + axis + depth = {d: 0 for d in range(a.ndim)} + depth[axis] = int(window / 2) + # For evenly sized window, we need to crop the first point of each block. + offset = 1 if window % 2 == 0 else 0 + + if depth[axis] > min(a.chunks[axis]): + raise ValueError( + "For window size %d, every chunk should be larger than %d, " + "but the smallest chunk size is %d. Rechunk your array\n" + "with a larger chunk size or a chunk size that\n" + "more evenly divides the shape of your array." % + (window, depth[axis], min(a.chunks[axis]))) + + # Although dask.overlap pads values to boundaries of the array, + # the size of the generated array is smaller than what we want + # if center == False. + if center: + start = int(window / 2) # 10 -> 5, 9 -> 4 + end = window - 1 - start + else: + start, end = window - 1, 0 + pad_size = max(start, end) + offset - depth[axis] + drop_size = 0 + # pad_size becomes more than 0 when the overlapped array is smaller than + # needed. In this case, we need to enlarge the original array by padding + # before overlapping. + if pad_size > 0: + if pad_size < depth[axis]: + # overlapping requires each chunk larger than depth. If pad_size is + # smaller than the depth, we enlarge this and truncate it later. + drop_size = depth[axis] - pad_size + pad_size = depth[axis] + shape = list(a.shape) + shape[axis] = pad_size + chunks = list(a.chunks) + chunks[axis] = (pad_size, ) + fill_array = da.full(shape, fill_value, dtype=a.dtype, chunks=chunks) + a = da.concatenate([fill_array, a], axis=axis) + + boundary = {d: fill_value for d in range(a.ndim)} + + # create overlap arrays + ag = overlap(a, depth=depth, boundary=boundary) + + # apply rolling func + def func(x, window, axis=-1): + x = np.asarray(x) + rolling = nputils._rolling_window(x, window, axis) + return rolling[(slice(None), ) * axis + (slice(offset, None), )] + + chunks = list(a.chunks) + chunks.append(window) + out = ag.map_blocks(func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, + window=window, axis=axis) + + # crop boundary. + index = (slice(None),) * axis + (slice(drop_size, + drop_size + orig_shape[axis]), ) + return out[index] diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8e1ec8ab7b8..17af3cf2cd1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,35 +1,29 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import functools import warnings import numpy as np import pandas as pd +from . import computation, groupby, indexing, ops, resample, rolling, utils from ..plot.plot import _PlotMethods - -from . import duck_array_ops -from . import indexing -from . import groupby -from . import resample -from . import rolling -from . import ops -from . import utils from .accessors import DatetimeAccessor from .alignment import align, reindex_like_indexers -from .common import AbstractArray, BaseDataObject -from .coordinates import (DataArrayCoordinates, LevelCoordinatesSource, - Indexes, assert_coordinate_consistent, - remap_label_indexers) +from .common import AbstractArray, DataWithCoords +from .coordinates import ( + DataArrayCoordinates, Indexes, LevelCoordinatesSource, + assert_coordinate_consistent, remap_label_indexers) from .dataset import Dataset, merge_indexes, split_indexes -from .pycompat import iteritems, basestring, OrderedDict, zip, range -from .variable import (as_variable, Variable, as_compatible_data, - IndexVariable, - assert_unique_multiindex_level_names) from .formatting import format_item -from .utils import decode_numpy_dict_values, ensure_us_time_resolution -from .options import OPTIONS +from .options import OPTIONS, _get_keep_attrs +from .pycompat import OrderedDict, basestring, iteritems, range, zip +from .utils import ( + _check_inplace, decode_numpy_dict_values, either_dict_or_kwargs, + ensure_us_time_resolution) +from .variable import ( + IndexVariable, Variable, as_compatible_data, as_variable, + assert_unique_multiindex_level_names) def _infer_coords_and_dims(shape, coords, dims): @@ -125,7 +119,7 @@ def __setitem__(self, key, value): _THIS_ARRAY = utils.ReprObject('') -class DataArray(AbstractArray, BaseDataObject): +class DataArray(AbstractArray, DataWithCoords): """N-dimensional array with labeled coordinates and dimensions. DataArray provides a wrapper around numpy ndarrays that uses labeled @@ -259,7 +253,7 @@ def _replace(self, variable=None, coords=None, name=__default): def _replace_maybe_drop_dims(self, variable, name=__default): if variable.dims == self.dims: - coords = None + coords = self._coords.copy() else: allowed_dims = set(variable.dims) coords = OrderedDict((k, v) for k, v in self._coords.items() @@ -371,6 +365,7 @@ def name(self, value): @property def variable(self): + """Low level interface to the Variable object for this DataArray.""" return self._variable @property @@ -477,7 +472,7 @@ def __getitem__(self, key): return self._getitem_coord(key) else: # xarray-style array indexing - return self.isel(**self._item_key_to_dict(key)) + return self.isel(indexers=self._item_key_to_dict(key)) def __setitem__(self, key, value): if isinstance(key, basestring): @@ -505,15 +500,11 @@ def _attr_sources(self): @property def _item_sources(self): """List of places to look-up items for key-completion""" - return [self.coords, {d: self[d] for d in self.dims}, + return [self.coords, {d: self.coords[d] for d in self.dims}, LevelCoordinatesSource(self)] def __contains__(self, key): - warnings.warn( - 'xarray.DataArray.__contains__ currently checks membership in ' - 'DataArray.coords, but in xarray v0.11 will change to check ' - 'membership in array values.', FutureWarning, stacklevel=2) - return key in self._coords + return key in self.data @property def loc(self): @@ -552,7 +543,7 @@ def coords(self): """ return DataArrayCoordinates(self) - def reset_coords(self, names=None, drop=False, inplace=False): + def reset_coords(self, names=None, drop=False, inplace=None): """Given names of coordinates, reset them to become variables. Parameters @@ -571,6 +562,7 @@ def reset_coords(self, names=None, drop=False, inplace=False): ------- Dataset, or DataArray if ``drop == True`` """ + inplace = _check_inplace(inplace) if inplace and not drop: raise ValueError('cannot reset coordinates in-place on a ' 'DataArray without ``drop == True``') @@ -683,14 +675,77 @@ def persist(self, **kwargs): ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this array. - If `deep=True`, a deep copy is made of all variables in the underlying - dataset. Otherwise, a shallow copy is made, so each variable in the new + If `deep=True`, a deep copy is made of the data array. + Otherwise, a shallow copy is made, so each variable in the new array's dataset is also a variable in this array's dataset. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether the data array and its coordinates are loaded into memory + and copied onto the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored for all data variables, + and only used for coords. + + Returns + ------- + object : DataArray + New object with dimensions, attributes, coordinates, name, + encoding, and optionally data copied from original. + + Examples + -------- + + Shallow versus deep copy + + >>> array = xr.DataArray([1, 2, 3], dims='x', + ... coords={'x': ['a', 'b', 'c']}) + >>> array.copy() + + array([1, 2, 3]) + Coordinates: + * x (x) >> array_0 = array.copy(deep=False) + >>> array_0[0] = 7 + >>> array_0 + + array([7, 2, 3]) + Coordinates: + * x (x) >> array + + array([7, 2, 3]) + Coordinates: + * x (x) >> array.copy(data=[0.1, 0.2, 0.3]) + + array([ 0.1, 0.2, 0.3]) + Coordinates: + * x (x) >> array + + array([1, 2, 3]) + Coordinates: + * x (x) >> da = xr.DataArray([1, 3], [('x', np.arange(2))]) + >>> da.interp(x=0.5) + + array(2.0) + Coordinates: + x float64 0.5 + """ + if self.dtype.kind not in 'uifc': + raise TypeError('interp only works for a numeric type array. ' + 'Given {}.'.format(self.dtype)) + ds = self._to_temp_dataset().interp( + coords, method=method, kwargs=kwargs, assume_sorted=assume_sorted, + **coords_kwargs) + return self._from_temp_dataset(ds) + + def interp_like(self, other, method='linear', assume_sorted=False, + kwargs={}): + """Interpolate this object onto the coordinates of another object, + filling out of range values with NaN. Parameters ---------- - new_name_or_name_dict : str or dict-like + other : Dataset or DataArray + Object with an 'indexes' attribute giving a mapping from dimension + names to an 1d array-like, which provides coordinates upon + which to index the variables in this dataset. + method: string, optional. + {'linear', 'nearest'} for multidimensional array, + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} + for 1-dimensional array. 'linear' is used by default. + assume_sorted: boolean, optional + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs: dictionary, optional + Additional keyword passed to scipy's interpolator. + + Returns + ------- + interpolated: xr.DataArray + Another dataarray by interpolating this dataarray's data along the + coordinates of the other object. + + Note + ---- + scipy is required. + If the dataarray has object-type coordinates, reindex is used for these + coordinates instead of the interpolation. + + See Also + -------- + DataArray.interp + DataArray.reindex_like + """ + if self.dtype.kind not in 'uifc': + raise TypeError('interp only works for a numeric type array. ' + 'Given {}.'.format(self.dtype)) + + ds = self._to_temp_dataset().interp_like( + other, method=method, kwargs=kwargs, assume_sorted=assume_sorted) + return self._from_temp_dataset(ds) + + def rename(self, new_name_or_name_dict=None, **names): + """Returns a new DataArray with renamed coordinates or a new name. + + Parameters + ---------- + new_name_or_name_dict : str or dict-like, optional If the argument is dict-like, it it used as a mapping from old names to new names for coordinates. Otherwise, use the argument as the new name for this array. + **names, optional + The keyword arguments form of a mapping from old names to + new names for coordinates. + One of new_name_or_name_dict or names must be provided. Returns @@ -911,8 +1094,10 @@ def rename(self, new_name_or_name_dict): Dataset.rename DataArray.swap_dims """ - if utils.is_dict_like(new_name_or_name_dict): - dataset = self._to_temp_dataset().rename(new_name_or_name_dict) + if names or utils.is_dict_like(new_name_or_name_dict): + name_dict = either_dict_or_kwargs( + new_name_or_name_dict, names, 'rename') + dataset = self._to_temp_dataset().rename(name_dict) return self._from_temp_dataset(dataset) else: return self._replace(name=new_name_or_name_dict) @@ -968,22 +1153,26 @@ def expand_dims(self, dim, axis=None): ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) - def set_index(self, append=False, inplace=False, **indexes): + def set_index(self, indexes=None, append=False, inplace=None, + **indexes_kwargs): """Set DataArray (multi-)indexes using one or more existing coordinates. Parameters ---------- + indexes : {dim: index, ...} + Mapping from names matching dimensions and values given + by (lists of) the names of existing coordinates or variables to set + as new (multi-)index. append : bool, optional If True, append the supplied index(es) to the existing index(es). Otherwise replace the existing index(es) (default). inplace : bool, optional If True, set new index(es) in-place. Otherwise, return a new DataArray object. - **indexes : {dim: index, ...} - Keyword arguments with names matching dimensions and values given - by (lists of) the names of existing coordinates or variables to set - as new (multi-)index. + **indexes_kwargs: optional + The keyword arguments form of ``indexes``. + One of indexes or indexes_kwargs must be provided. Returns ------- @@ -994,13 +1183,15 @@ def set_index(self, append=False, inplace=False, **indexes): -------- DataArray.reset_index """ + inplace = _check_inplace(inplace) + indexes = either_dict_or_kwargs(indexes, indexes_kwargs, 'set_index') coords, _ = merge_indexes(indexes, self._coords, set(), append=append) if inplace: self._coords = coords else: return self._replace(coords=coords) - def reset_index(self, dims_or_levels, drop=False, inplace=False): + def reset_index(self, dims_or_levels, drop=False, inplace=None): """Reset the specified index(es) or multi-index level(s). Parameters @@ -1025,6 +1216,7 @@ def reset_index(self, dims_or_levels, drop=False, inplace=False): -------- DataArray.set_index """ + inplace = _check_inplace(inplace) coords, _ = split_indexes(dims_or_levels, self._coords, set(), self._level_coords, drop=drop) if inplace: @@ -1032,18 +1224,22 @@ def reset_index(self, dims_or_levels, drop=False, inplace=False): else: return self._replace(coords=coords) - def reorder_levels(self, inplace=False, **dim_order): + def reorder_levels(self, dim_order=None, inplace=None, + **dim_order_kwargs): """Rearrange index levels using input order. Parameters ---------- + dim_order : optional + Mapping from names matching dimensions and values given + by lists representing new level orders. Every given dimension + must have a multi-index. inplace : bool, optional If True, modify the dataarray in-place. Otherwise, return a new DataArray object. - **dim_order : optional - Keyword arguments with names matching dimensions and values given - by lists representing new level orders. Every given dimension - must have a multi-index. + **dim_order_kwargs: optional + The keyword arguments form of ``dim_order``. + One of dim_order or dim_order_kwargs must be provided. Returns ------- @@ -1051,6 +1247,9 @@ def reorder_levels(self, inplace=False, **dim_order): Another dataarray, with this dataarray's data but replaced coordinates. """ + inplace = _check_inplace(inplace) + dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, + 'reorder_levels') replace_coords = {} for dim, order in dim_order.items(): coord = self._coords[dim] @@ -1066,7 +1265,7 @@ def reorder_levels(self, inplace=False, **dim_order): else: return self._replace(coords=coords) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -1075,9 +1274,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of the form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1106,26 +1308,48 @@ def stack(self, **dimensions): -------- DataArray.unstack """ - ds = self._to_temp_dataset().stack(**dimensions) + ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs) return self._from_temp_dataset(ds) - def unstack(self, dim): + def unstack(self, dim=None): """ - Unstack an existing dimension corresponding to a MultiIndex into + Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. New dimensions will be added at the end. Parameters ---------- - dim : str - Name of the existing dimension to unstack. + dim : str or sequence of str, optional + Dimension(s) over which to unstack. By default unstacks all + MultiIndexes. Returns ------- unstacked : DataArray Array with unstacked data. + Examples + -------- + + >>> arr = DataArray(np.arange(6).reshape(2, 3), + ... coords=[('x', ['a', 'b']), ('y', [0, 1, 2])]) + >>> arr + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) |S1 'a' 'b' + * y (y) int64 0 1 2 + >>> stacked = arr.stack(z=('x', 'y')) + >>> stacked.indexes['z'] + MultiIndex(levels=[[u'a', u'b'], [0, 1, 2]], + labels=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + names=[u'x', u'y']) + >>> roundtripped = stacked.unstack() + >>> arr.identical(roundtripped) + True + See also -------- DataArray.stack @@ -1336,7 +1560,7 @@ def combine_first(self, other): """ return ops.fillna(self, other, join="outer") - def reduce(self, func, dim=None, axis=None, keep_attrs=False, **kwargs): + def reduce(self, func, dim=None, axis=None, keep_attrs=None, **kwargs): """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1365,6 +1589,7 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, **kwargs): DataArray with this object's array replaced with an array with summarized data and the indicated dimension(s) removed. """ + var = self.variable.reduce(func, dim, axis, keep_attrs, **kwargs) return self._replace_maybe_drop_dims(var) @@ -1450,8 +1675,7 @@ def to_masked_array(self, copy=True): return np.ma.MaskedArray(data=self.values, mask=isnull, copy=copy) def to_netcdf(self, *args, **kwargs): - """ - Write DataArray contents to a netCDF file. + """Write DataArray contents to a netCDF file. Parameters ---------- @@ -1733,7 +1957,7 @@ def _binary_op(f, reflexive=False, join=None, **ignored_kwargs): def func(self, other): if isinstance(other, (Dataset, groupby.GroupBy)): return NotImplemented - if hasattr(other, 'indexes'): + if isinstance(other, DataArray): align_type = (OPTIONS['arithmetic_join'] if join is None else join) self, other = align(self, other, join=align_type, copy=False) @@ -1851,11 +2075,14 @@ def diff(self, dim, n=1, label='upper'): Coordinates: * x (x) int64 3 4 + See Also + -------- + DataArray.differentiate """ ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label) return self._from_temp_dataset(ds) - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """Shift this array by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. Values shifted from @@ -1864,10 +2091,13 @@ def shift(self, **shifts): Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : Mapping with the form of {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- @@ -1889,17 +2119,23 @@ def shift(self, **shifts): Coordinates: * x (x) int64 0 1 2 """ - variable = self.variable.shift(**shifts) - return self._replace(variable) + ds = self._to_temp_dataset().shift(shifts=shifts, **shifts_kwargs) + return self._from_temp_dataset(ds) - def roll(self, **shifts): + def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this array by an offset along one or more dimensions. - Unlike shift, roll rotates all variables, including coordinates. The - direction of rotation is consistent with :py:func:`numpy.roll`. + Unlike shift, roll may rotate all variables, including coordinates + if specified. The direction of rotation is consistent with + :py:func:`numpy.roll`. Parameters ---------- + roll_coords : bool + Indicates whether to roll the coordinates by the offset + The current default of roll_coords (None, equivalent to True) is + deprecated and will change to False in a future version. + Explicitly pass roll_coords to silence the warning. **shifts : keyword arguments of the form {dim: offset} Integer offset to rotate each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. @@ -1923,7 +2159,8 @@ def roll(self, **shifts): Coordinates: * x (x) int64 2 0 1 """ - ds = self._to_temp_dataset().roll(**shifts) + ds = self._to_temp_dataset().roll( + shifts=shifts, roll_coords=roll_coords, **shifts_kwargs) return self._from_temp_dataset(ds) @property @@ -1934,7 +2171,7 @@ def real(self): def imag(self): return self._replace(self.variable.imag) - def dot(self, other): + def dot(self, other, dims=None): """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims. @@ -1943,6 +2180,9 @@ def dot(self, other): ---------- other : DataArray The other array with which the dot product is performed. + dims: list of strings, optional + Along which dimensions to be summed over. Default all the common + dimensions are summed over. Returns ------- @@ -1951,6 +2191,7 @@ def dot(self, other): See also -------- + dot numpy.tensordot Examples @@ -1976,23 +2217,7 @@ def dot(self, other): if not isinstance(other, DataArray): raise TypeError('dot only operates on DataArrays.') - # sum over the common dims - dims = set(self.dims) & set(other.dims) - if len(dims) == 0: - raise ValueError('DataArrays have no shared dimensions over which ' - 'to perform dot.') - - self, other = align(self, other, join='inner', copy=False) - - axes = (self.get_axis_num(dims), other.get_axis_num(dims)) - new_data = duck_array_ops.tensordot(self.data, other.data, axes=axes) - - new_coords = self.coords.merge(other.coords) - new_coords = new_coords.drop([d for d in dims if d in new_coords]) - new_dims = ([d for d in self.dims if d not in dims] + - [d for d in other.dims if d not in dims]) - - return type(self)(new_data, new_coords.variables, new_dims) + return computation.dot(self, other, dims=dims) def sortby(self, variables, ascending=True): """ @@ -2047,7 +2272,7 @@ def sortby(self, variables, ascending=True): ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) - def quantile(self, q, dim=None, interpolation='linear', keep_attrs=False): + def quantile(self, q, dim=None, interpolation='linear', keep_attrs=None): """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -2093,7 +2318,7 @@ def quantile(self, q, dim=None, interpolation='linear', keep_attrs=False): q, dim=dim, keep_attrs=keep_attrs, interpolation=interpolation) return self._from_temp_dataset(ds) - def rank(self, dim, pct=False, keep_attrs=False): + def rank(self, dim, pct=False, keep_attrs=None): """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -2129,9 +2354,65 @@ def rank(self, dim, pct=False, keep_attrs=False): array([ 1., 2., 3.]) Dimensions without coordinates: x """ + ds = self._to_temp_dataset().rank(dim, pct=pct, keep_attrs=keep_attrs) return self._from_temp_dataset(ds) + def differentiate(self, coord, edge_order=1, datetime_unit=None): + """ Differentiate the array with the second order accurate central + differences. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord: str + The coordinate to be used to compute the gradient. + edge_order: 1 or 2. Default 1 + N-th order accurate differences at the boundaries. + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + Unit to compute gradient. Only valid for datetime coordinate. + + Returns + ------- + differentiated: DataArray + + See also + -------- + numpy.gradient: corresponding numpy function + + Examples + -------- + + >>> da = xr.DataArray(np.arange(12).reshape(4, 3), dims=['x', 'y'], + ... coords={'x': [0, 0.1, 1.1, 1.2]}) + >>> da + + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + >>> + >>> da.differentiate('x') + + array([[30. , 30. , 30. ], + [27.545455, 27.545455, 27.545455], + [27.545455, 27.545455, 27.545455], + [30. , 30. , 30. ]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + """ + ds = self._to_temp_dataset().differentiate( + coord, edge_order, datetime_unit) + return self._from_temp_dataset(ds) + # priority most be higher than Variable to properly work with binary ufuncs ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 62ad2b9b653..4f9c61b3269 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,43 +1,41 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import functools +import sys +import warnings from collections import Mapping, defaultdict from distutils.version import LooseVersion from numbers import Number -import warnings - -import sys import numpy as np import pandas as pd -from . import ops -from . import utils -from . import groupby -from . import resample -from . import rolling -from . import indexing -from . import alignment -from . import formatting -from . import duck_array_ops +import xarray as xr + +from . import ( + alignment, computation, duck_array_ops, formatting, groupby, indexing, ops, + resample, rolling, utils) from .. import conventions +from ..coding.cftimeindex import _parse_array_of_cftime_strings from .alignment import align -from .coordinates import (DatasetCoordinates, LevelCoordinatesSource, Indexes, - assert_coordinate_consistent, remap_label_indexers) -from .common import ImplementsDatasetReduce, BaseDataObject +from .common import ( + ALL_DIMS, DataWithCoords, ImplementsDatasetReduce, + _contains_datetime_like_objects) +from .coordinates import ( + DatasetCoordinates, Indexes, LevelCoordinatesSource, + assert_coordinate_consistent, remap_label_indexers) from .dtypes import is_datetime_like -from .merge import (dataset_update_method, dataset_merge_method, - merge_data_and_coords, merge_variables) -from .utils import (Frozen, SortedKeysDict, maybe_wrap_array, hashable, - decode_numpy_dict_values, ensure_us_time_resolution) -from .variable import (Variable, as_variable, IndexVariable, - broadcast_variables) -from .pycompat import (iteritems, basestring, OrderedDict, - integer_types, dask_array_type, range) -from .options import OPTIONS - -import xarray as xr +from .merge import ( + dataset_merge_method, dataset_update_method, merge_data_and_coords, + merge_variables) +from .options import OPTIONS, _get_keep_attrs +from .pycompat import ( + OrderedDict, basestring, dask_array_type, integer_types, iteritems, range) +from .utils import ( + _check_inplace, Frozen, SortedKeysDict, datetime_to_numeric, + decode_numpy_dict_values, either_dict_or_kwargs, ensure_us_time_resolution, + hashable, maybe_wrap_array) +from .variable import IndexVariable, Variable, as_variable, broadcast_variables # list of attributes of pd.DatetimeIndex that are ndarrays of time info _DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute', @@ -81,7 +79,7 @@ def _get_virtual_variable(variables, key, level_vars=None, dim_sizes=None): virtual_var = ref_var var_name = key else: - if is_datetime_like(ref_var.dtype): + if _contains_datetime_like_objects(ref_var): ref_var = xr.DataArray(ref_var) data = getattr(ref_var.dt, var_name).data else: @@ -304,7 +302,7 @@ def __getitem__(self, key): return self.dataset.sel(**key) -class Dataset(Mapping, ImplementsDatasetReduce, BaseDataObject, +class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords, formatting.ReprMixin): """A multi-dimensional, in memory, array database. @@ -405,8 +403,12 @@ def load_store(cls, store, decoder=None): @property def variables(self): - """Frozen dictionary of xarray.Variable objects constituting this - dataset's data + """Low level interface to Dataset contents as dict of Variable objects. + + This ordered dictionary is frozen to prevent mutation that could + violate Dataset invariants. It contains all variable objects + constituting the Dataset, including both data variables and + coordinates. """ return Frozen(self._variables) @@ -710,16 +712,120 @@ def _replace_indexes(self, indexes): obj = obj.rename(dim_names) return obj - def copy(self, deep=False): + def copy(self, deep=False, data=None): """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. Otherwise, a shallow copy of each of the component variable is made, so that the underlying memory region of the new dataset is the same as in the original dataset. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether each component variable is loaded into memory and copied onto + the new object. Default is False. + data : dict-like, optional + Data to use in the new object. Each item in `data` must have same + shape as corresponding data variable in original. When `data` is + used, `deep` is ignored for the data variables and only used for + coords. + + Returns + ------- + object : Dataset + New object with dimensions, attributes, coordinates, name, encoding, + and optionally data copied from original. + + Examples + -------- + + Shallow copy versus deep copy + + >>> da = xr.DataArray(np.random.randn(2, 3)) + >>> ds = xr.Dataset({'foo': da, 'bar': ('x', [-1, 2])}, + coords={'x': ['one', 'two']}) + >>> ds.copy() + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds_0 = ds.copy(deep=False) + >>> ds_0['foo'][0, 0] = 7 + >>> ds_0 + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds.copy(data={'foo': np.arange(6).reshape(2, 3), 'bar': ['a', 'b']}) + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) =0.16): @@ -1439,13 +1565,9 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): drop : bool, optional If ``drop=True``, drop coordinates variables in `indexers` instead of making them scalar. - **indexers : {dim: indexer, ...} - Keyword arguments with names matching dimensions and values given - by scalars, slices or arrays of tick labels. For dimensions with - multi-index, the indexer may also be a dict-like object with keys - matching index level names. - If DataArrays are passed as indexers, xarray-style indexing will be - carried out. See :ref:`indexing` for the details. + **indexers_kwarg : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. Returns ------- @@ -1464,9 +1586,10 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): Dataset.isel DataArray.sel """ - pos_indexers, new_indexes = remap_label_indexers(self, method, - tolerance, **indexers) - result = self.isel(drop=drop, **pos_indexers) + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel') + pos_indexers, new_indexes = remap_label_indexers( + self, indexers=indexers, method=method, tolerance=tolerance) + result = self.isel(indexers=pos_indexers, drop=drop) return result._replace_indexes(new_indexes) def isel_points(self, dim='points', **indexers): @@ -1516,7 +1639,8 @@ def take(variable, slices): # Note: remove helper function when once when numpy # supports vindex https://github.com/numpy/numpy/pull/6075 if hasattr(variable.data, 'vindex'): - # Special case for dask backed arrays to use vectorised list indexing + # Special case for dask backed arrays to use vectorised list + # indexing sel = variable.data.vindex[slices] else: # Otherwise assume backend is numpy array with 'fancy' indexing @@ -1579,7 +1703,8 @@ def relevant_keys(mapping): variables = OrderedDict() for name, var in reordered.variables.items(): - if name in indexers_dict or any(d in indexer_dims for d in var.dims): + if name in indexers_dict or any( + d in indexer_dims for d in var.dims): # slice if var is an indexer or depends on an indexed dim slc = [indexers_dict[k] if k in indexers_dict @@ -1708,11 +1833,11 @@ def reindex_like(self, other, method=None, tolerance=None, copy=True): align """ indexers = alignment.reindex_like_indexers(self, other) - return self.reindex(method=method, copy=copy, tolerance=tolerance, - **indexers) + return self.reindex(indexers=indexers, method=method, copy=copy, + tolerance=tolerance) def reindex(self, indexers=None, method=None, tolerance=None, copy=True, - **kw_indexers): + **indexers_kwargs): """Conform this object onto a new set of indexes, filling in missing values with NaN. @@ -1723,6 +1848,7 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, arrays of coordinates tick labels. Any mis-matched coordinate values will be filled in with NaN, and any mis-matched dimension names will simply be ignored. + One of indexers or indexers_kwargs must be provided. method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional Method to use for filling index values in ``indexers`` not found in this dataset: @@ -1741,8 +1867,9 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, ``copy=False`` and reindexing is unnecessary, or can be performed with only slice operations, then the output may share memory with the input. In either case, a new xarray object is always returned. - **kw_indexers : optional + **indexers_kwarg : {dim: indexer, ...}, optional Keyword arguments in the same form as ``indexers``. + One of indexers or indexers_kwargs must be provided. Returns ------- @@ -1755,8 +1882,8 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, align pandas.Index.get_indexer """ - indexers = utils.combine_pos_and_kw_args(indexers, kw_indexers, - 'reindex') + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, + 'reindex') bad_dims = [d for d in indexers if d not in self.dims] if bad_dims: @@ -1769,17 +1896,171 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, coord_names.update(indexers) return self._replace_vars_and_dims(variables, coord_names) - def rename(self, name_dict, inplace=False): + def interp(self, coords=None, method='linear', assume_sorted=False, + kwargs={}, **coords_kwargs): + """ Multidimensional interpolation of Dataset. + + Parameters + ---------- + coords : dict, optional + Mapping from dimension names to the new coordinates. + New coordinate can be a scalar, array-like or DataArray. + If DataArrays are passed as new coordates, their dimensions are + used for the broadcasting. + method: string, optional. + {'linear', 'nearest'} for multidimensional array, + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} + for 1-dimensional array. 'linear' is used by default. + assume_sorted: boolean, optional + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs: dictionary, optional + Additional keyword passed to scipy's interpolator. + **coords_kwarg : {dim: coordinate, ...}, optional + The keyword arguments form of ``coords``. + One of coords or coords_kwargs must be provided. + + Returns + ------- + interpolated: xr.Dataset + New dataset on the new coordinates. + + Note + ---- + scipy is required. + + See Also + -------- + scipy.interpolate.interp1d + scipy.interpolate.interpn + """ + from . import missing + + coords = either_dict_or_kwargs(coords, coords_kwargs, 'rename') + indexers = OrderedDict(self._validate_indexers(coords)) + + obj = self if assume_sorted else self.sortby([k for k in coords]) + + def maybe_variable(obj, k): + # workaround to get variable for dimension without coordinate. + try: + return obj._variables[k] + except KeyError: + return as_variable((k, range(obj.dims[k]))) + + def _validate_interp_indexer(x, new_x): + # In the case of datetimes, the restrictions placed on indexers + # used with interp are stronger than those which are placed on + # isel, so we need an additional check after _validate_indexers. + if (_contains_datetime_like_objects(x) and + not _contains_datetime_like_objects(new_x)): + raise TypeError('When interpolating over a datetime-like ' + 'coordinate, the coordinates to ' + 'interpolate to must be either datetime ' + 'strings or datetimes. ' + 'Instead got\n{}'.format(new_x)) + else: + return (x, new_x) + + variables = OrderedDict() + for name, var in iteritems(obj._variables): + if name not in indexers: + if var.dtype.kind in 'uifc': + var_indexers = {k: _validate_interp_indexer( + maybe_variable(obj, k), v) for k, v + in indexers.items() if k in var.dims} + variables[name] = missing.interp( + var, var_indexers, method, **kwargs) + elif all(d not in indexers for d in var.dims): + # keep unrelated object array + variables[name] = var + + coord_names = set(variables).intersection(obj._coord_names) + selected = obj._replace_vars_and_dims(variables, + coord_names=coord_names) + # attach indexer as coordinate + variables.update(indexers) + # Extract coordinates from indexers + coord_vars = selected._get_indexers_coordinates(coords) + variables.update(coord_vars) + coord_names = (set(variables) + .intersection(obj._coord_names) + .union(coord_vars)) + return obj._replace_vars_and_dims(variables, coord_names=coord_names) + + def interp_like(self, other, method='linear', assume_sorted=False, + kwargs={}): + """Interpolate this object onto the coordinates of another object, + filling the out of range values with NaN. + + Parameters + ---------- + other : Dataset or DataArray + Object with an 'indexes' attribute giving a mapping from dimension + names to an 1d array-like, which provides coordinates upon + which to index the variables in this dataset. + method: string, optional. + {'linear', 'nearest'} for multidimensional array, + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} + for 1-dimensional array. 'linear' is used by default. + assume_sorted: boolean, optional + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs: dictionary, optional + Additional keyword passed to scipy's interpolator. + + Returns + ------- + interpolated: xr.Dataset + Another dataset by interpolating this dataset's data along the + coordinates of the other object. + + Note + ---- + scipy is required. + If the dataset has object-type coordinates, reindex is used for these + coordinates instead of the interpolation. + + See Also + -------- + Dataset.interp + Dataset.reindex_like + """ + coords = alignment.reindex_like_indexers(self, other) + + numeric_coords = OrderedDict() + object_coords = OrderedDict() + for k, v in coords.items(): + if v.dtype.kind in 'uifcMm': + numeric_coords[k] = v + else: + object_coords[k] = v + + ds = self + if object_coords: + # We do not support interpolation along object coordinate. + # reindex instead. + ds = self.reindex(object_coords) + return ds.interp(numeric_coords, method, assume_sorted, kwargs) + + def rename(self, name_dict=None, inplace=None, **names): """Returns a new object with renamed variables and dimensions. Parameters ---------- - name_dict : dict-like + name_dict : dict-like, optional Dictionary whose keys are current variable or dimension names and whose values are the desired names. inplace : bool, optional If True, rename variables and dimensions in-place. Otherwise, return a new dataset object. + **names, optional + Keyword form of ``name_dict``. + One of name_dict or names must be provided. Returns ------- @@ -1791,6 +2072,8 @@ def rename(self, name_dict, inplace=False): Dataset.swap_dims DataArray.rename """ + inplace = _check_inplace(inplace) + name_dict = either_dict_or_kwargs(name_dict, names, 'rename') for k, v in name_dict.items(): if k not in self and k not in self.dims: raise ValueError("cannot rename %r because it is not a " @@ -1815,7 +2098,7 @@ def rename(self, name_dict, inplace=False): return self._replace_vars_and_dims(variables, coord_names, dims=dims, inplace=inplace) - def swap_dims(self, dims_dict, inplace=False): + def swap_dims(self, dims_dict, inplace=None): """Returns a new object with swapped dimensions. Parameters @@ -1839,6 +2122,7 @@ def swap_dims(self, dims_dict, inplace=False): Dataset.rename DataArray.swap_dims """ + inplace = _check_inplace(inplace) for k, v in dims_dict.items(): if k not in self.dims: raise ValueError('cannot swap from dimension %r because it is ' @@ -1951,22 +2235,26 @@ def expand_dims(self, dim, axis=None): return self._replace_vars_and_dims(variables, self._coord_names) - def set_index(self, append=False, inplace=False, **indexes): + def set_index(self, indexes=None, append=False, inplace=None, + **indexes_kwargs): """Set Dataset (multi-)indexes using one or more existing coordinates or variables. Parameters ---------- + indexes : {dim: index, ...} + Mapping from names matching dimensions and values given + by (lists of) the names of existing coordinates or variables to set + as new (multi-)index. append : bool, optional If True, append the supplied index(es) to the existing index(es). Otherwise replace the existing index(es) (default). inplace : bool, optional If True, set new index(es) in-place. Otherwise, return a new Dataset object. - **indexes : {dim: index, ...} - Keyword arguments with names matching dimensions and values given - by (lists of) the names of existing coordinates or variables to set - as new (multi-)index. + **indexes_kwargs: optional + The keyword arguments form of ``indexes``. + One of indexes or indexes_kwargs must be provided. Returns ------- @@ -1976,14 +2264,17 @@ def set_index(self, append=False, inplace=False, **indexes): See Also -------- Dataset.reset_index + Dataset.swap_dims """ + inplace = _check_inplace(inplace) + indexes = either_dict_or_kwargs(indexes, indexes_kwargs, 'set_index') variables, coord_names = merge_indexes(indexes, self._variables, self._coord_names, append=append) return self._replace_vars_and_dims(variables, coord_names=coord_names, inplace=inplace) - def reset_index(self, dims_or_levels, drop=False, inplace=False): + def reset_index(self, dims_or_levels, drop=False, inplace=None): """Reset the specified index(es) or multi-index level(s). Parameters @@ -2007,24 +2298,29 @@ def reset_index(self, dims_or_levels, drop=False, inplace=False): -------- Dataset.set_index """ + inplace = _check_inplace(inplace) variables, coord_names = split_indexes(dims_or_levels, self._variables, self._coord_names, self._level_coords, drop=drop) return self._replace_vars_and_dims(variables, coord_names=coord_names, inplace=inplace) - def reorder_levels(self, inplace=False, **dim_order): + def reorder_levels(self, dim_order=None, inplace=None, + **dim_order_kwargs): """Rearrange index levels using input order. Parameters ---------- + dim_order : optional + Mapping from names matching dimensions and values given + by lists representing new level orders. Every given dimension + must have a multi-index. inplace : bool, optional If True, modify the dataset in-place. Otherwise, return a new DataArray object. - **dim_order : optional - Keyword arguments with names matching dimensions and values given - by lists representing new level orders. Every given dimension - must have a multi-index. + **dim_order_kwargs: optional + The keyword arguments form of ``dim_order``. + One of dim_order or dim_order_kwargs must be provided. Returns ------- @@ -2032,6 +2328,9 @@ def reorder_levels(self, inplace=False, **dim_order): Another dataset, with this dataset's data but replaced coordinates. """ + inplace = _check_inplace(inplace) + dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, + 'reorder_levels') replace_variables = {} for dim, order in dim_order.items(): coord = self._variables[dim] @@ -2060,7 +2359,7 @@ def _stack_once(self, dims, new_dim): # consider dropping levels that are unused? levels = [self.get_index(dim) for dim in dims] - if hasattr(pd, 'RangeIndex'): + if LooseVersion(pd.__version__) < LooseVersion('0.19.0'): # RangeIndex levels in a MultiIndex are broken for appending in # pandas before v0.19.0 levels = [pd.Int64Index(level) @@ -2074,7 +2373,7 @@ def _stack_once(self, dims, new_dim): return self._replace_vars_and_dims(variables, coord_names) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -2083,9 +2382,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of the form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -2096,42 +2398,22 @@ def stack(self, **dimensions): -------- Dataset.unstack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'stack') result = self for new_dim, dims in dimensions.items(): result = result._stack_once(dims, new_dim) return result - def unstack(self, dim): - """ - Unstack an existing dimension corresponding to a MultiIndex into - multiple new dimensions. - - New dimensions will be added at the end. - - Parameters - ---------- - dim : str - Name of the existing dimension to unstack. - - Returns - ------- - unstacked : Dataset - Dataset with unstacked data. - - See also - -------- - Dataset.stack - """ - if dim not in self.dims: - raise ValueError('invalid dimension: %s' % dim) - + def _unstack_once(self, dim): index = self.get_index(dim) - if not isinstance(index, pd.MultiIndex): - raise ValueError('cannot unstack a dimension that does not have ' - 'a MultiIndex') - full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) - obj = self.reindex(copy=False, **{dim: full_idx}) + + # take a shortcut in case the MultiIndex was not modified. + if index.equals(full_idx): + obj = self + else: + obj = self.reindex({dim: full_idx}, copy=False) new_dim_names = index.names new_dim_sizes = [lev.size for lev in index.levels] @@ -2141,7 +2423,7 @@ def unstack(self, dim): if name != dim: if dim in var.dims: new_dims = OrderedDict(zip(new_dim_names, new_dim_sizes)) - variables[name] = var.unstack(**{dim: new_dims}) + variables[name] = var.unstack({dim: new_dims}) else: variables[name] = var @@ -2152,7 +2434,52 @@ def unstack(self, dim): return self._replace_vars_and_dims(variables, coord_names) - def update(self, other, inplace=True): + def unstack(self, dim=None): + """ + Unstack existing dimensions corresponding to MultiIndexes into + multiple new dimensions. + + New dimensions will be added at the end. + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to unstack. By default unstacks all + MultiIndexes. + + Returns + ------- + unstacked : Dataset + Dataset with unstacked data. + + See also + -------- + Dataset.stack + """ + + if dim is None: + dims = [d for d in self.dims if isinstance(self.get_index(d), + pd.MultiIndex)] + else: + dims = [dim] if isinstance(dim, basestring) else dim + + missing_dims = [d for d in dims if d not in self.dims] + if missing_dims: + raise ValueError('Dataset does not contain the dimensions: %s' + % missing_dims) + + non_multi_dims = [d for d in dims if not + isinstance(self.get_index(d), pd.MultiIndex)] + if non_multi_dims: + raise ValueError('cannot unstack dimensions that do not ' + 'have a MultiIndex: %s' % non_multi_dims) + + result = self.copy(deep=False) + for dim in dims: + result = result._unstack_once(dim) + return result + + def update(self, other, inplace=None): """Update this dataset's variables with those from another dataset. Parameters @@ -2174,12 +2501,13 @@ def update(self, other, inplace=True): If any dimensions would have inconsistent sizes in the updated dataset. """ + inplace = _check_inplace(inplace, default=True) variables, coord_names, dims = dataset_update_method(self, other) return self._replace_vars_and_dims(variables, coord_names, dims, inplace=inplace) - def merge(self, other, inplace=False, overwrite_vars=frozenset(), + def merge(self, other, inplace=None, overwrite_vars=frozenset(), compat='no_conflicts', join='outer'): """Merge the arrays of two datasets into a single dataset. @@ -2230,6 +2558,7 @@ def merge(self, other, inplace=False, overwrite_vars=frozenset(), MergeError If any variables conflict (see ``compat``). """ + inplace = _check_inplace(inplace) variables, coord_names, dims = dataset_merge_method( self, other, overwrite_vars=overwrite_vars, compat=compat, join=join) @@ -2320,13 +2649,6 @@ def transpose(self, *dims): ds._variables[name] = var.transpose(*var_dims) return ds - @property - def T(self): - warnings.warn('xarray.Dataset.T has been deprecated as an alias for ' - '`.transpose()`. It will be removed in xarray v0.11.', - FutureWarning, stacklevel=2) - return self.transpose() - def dropna(self, dim, how='any', thresh=None, subset=None): """Returns a new dataset with dropped labels for missing values along the provided dimension. @@ -2366,7 +2688,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += array.count(dims) + count += np.asarray(array.count(dims)) size += np.prod([self.dims[d] for d in dims]) if thresh is not None: @@ -2380,7 +2702,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): else: raise TypeError('must specify how or thresh') - return self.isel(**{dim: mask}) + return self.isel({dim: mask}) def fillna(self, value): """Fill missing values in this object. @@ -2529,7 +2851,7 @@ def combine_first(self, other): out = ops.fillna(self, other, join="outer", dataset_join="outer") return out - def reduce(self, func, dim=None, keep_attrs=False, numeric_only=False, + def reduce(self, func, dim=None, keep_attrs=None, numeric_only=False, allow_lazy=False, **kwargs): """Reduce this dataset by applying `func` along some dimension(s). @@ -2557,6 +2879,8 @@ def reduce(self, func, dim=None, keep_attrs=False, numeric_only=False, Dataset with this object's DataArrays replaced with new DataArrays of summarized data and the indicated dimension(s) removed. """ + if dim is ALL_DIMS: + dim = None if isinstance(dim, basestring): dims = set([dim]) elif dim is None: @@ -2569,35 +2893,38 @@ def reduce(self, func, dim=None, keep_attrs=False, numeric_only=False, raise ValueError('Dataset does not contain the dimensions: %s' % missing_dimensions) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + variables = OrderedDict() for name, var in iteritems(self._variables): reduce_dims = [dim for dim in var.dims if dim in dims] - if reduce_dims or not var.dims: - if name not in self.coords: - if (not numeric_only or - np.issubdtype(var.dtype, np.number) or - (var.dtype == np.bool_)): - if len(reduce_dims) == 1: - # unpack dimensions for the benefit of functions - # like np.argmin which can't handle tuple arguments - reduce_dims, = reduce_dims - elif len(reduce_dims) == var.ndim: - # prefer to aggregate over axis=None rather than - # axis=(0, 1) if they will be equivalent, because - # the former is often more efficient - reduce_dims = None - variables[name] = var.reduce(func, dim=reduce_dims, - keep_attrs=keep_attrs, - allow_lazy=allow_lazy, - **kwargs) + if name in self.coords: + if not reduce_dims: + variables[name] = var else: - variables[name] = var + if (not numeric_only or + np.issubdtype(var.dtype, np.number) or + (var.dtype == np.bool_)): + if len(reduce_dims) == 1: + # unpack dimensions for the benefit of functions + # like np.argmin which can't handle tuple arguments + reduce_dims, = reduce_dims + elif len(reduce_dims) == var.ndim: + # prefer to aggregate over axis=None rather than + # axis=(0, 1) if they will be equivalent, because + # the former is often more efficient + reduce_dims = None + variables[name] = var.reduce(func, dim=reduce_dims, + keep_attrs=keep_attrs, + allow_lazy=allow_lazy, + **kwargs) coord_names = set(k for k in self.coords if k in variables) attrs = self.attrs if keep_attrs else None return self._replace_vars_and_dims(variables, coord_names, attrs=attrs) - def apply(self, func, keep_attrs=False, args=(), **kwargs): + def apply(self, func, keep_attrs=None, args=(), **kwargs): """Apply a function over the data variables in this dataset. Parameters @@ -2642,20 +2969,25 @@ def apply(self, func, keep_attrs=False, args=(), **kwargs): variables = OrderedDict( (k, maybe_wrap_array(v, func(v, *args, **kwargs))) for k, v in iteritems(self.data_vars)) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None return type(self)(variables, attrs=attrs) - def assign(self, **kwargs): + def assign(self, variables=None, **variables_kwargs): """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. Parameters ---------- - kwargs : keyword, value pairs - keywords are the variables names. If the values are callable, they - are computed on the Dataset and assigned to new data variables. If - the values are not callable, (e.g. a DataArray, scalar, or array), - they are simply assigned. + variables : mapping, value pairs + Mapping from variables names to the new values. If the new values + are callable, they are computed on the Dataset and assigned to new + data variables. If the values are not callable, (e.g. a DataArray, + scalar, or array), they are simply assigned. + **variables_kwargs: + The keyword arguments form of ``variables``. + One of variables or variables_kwarg must be provided. Returns ------- @@ -2675,9 +3007,10 @@ def assign(self, **kwargs): -------- pandas.DataFrame.assign """ + variables = either_dict_or_kwargs(variables, variables_kwargs, 'assign') data = self.copy() # do all calculations first... - results = data._calc_assign_results(kwargs) + results = data._calc_assign_results(variables) # ... and then assign data.update(results) return data @@ -2752,7 +3085,7 @@ def from_dataframe(cls, dataframe): idx = dataframe.index obj = cls() - if hasattr(idx, 'levels'): + if isinstance(idx, pd.MultiIndex): # it's a multi-index # expand the DataFrame to include the product of all levels full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names) @@ -2779,8 +3112,8 @@ def to_dask_dataframe(self, dim_order=None, set_index=False): The dimensions, coordinates and data variables in this dataset form the columns of the DataFrame. - Arguments - --------- + Parameters + ---------- dim_order : list, optional Hierarchical dimension order for the resulting dataframe. All arrays are transposed to this order and then written out as flat @@ -2957,10 +3290,12 @@ def func(self, *args, **kwargs): def _binary_op(f, reflexive=False, join=None): @functools.wraps(f) def func(self, other): + from .dataarray import DataArray + if isinstance(other, groupby.GroupBy): return NotImplemented align_type = OPTIONS['arithmetic_join'] if join is None else join - if hasattr(other, 'indexes'): + if isinstance(other, (DataArray, Dataset)): self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) ds = self._calculate_binary_op(g, other, join=align_type) @@ -2972,12 +3307,14 @@ def func(self, other): def _inplace_binary_op(f): @functools.wraps(f) def func(self, other): + from .dataarray import DataArray + if isinstance(other, groupby.GroupBy): raise TypeError('in-place operations between a Dataset and ' 'a grouped object are not permitted') # we don't actually modify arrays in-place with in-place Dataset # arithmetic -- this lets us automatically align things - if hasattr(other, 'indexes'): + if isinstance(other, (DataArray, Dataset)): other = other.reindex_like(self, copy=False) g = ops.inplace_to_noninplace_op(f) ds = self._calculate_binary_op(g, other, inplace=True) @@ -2989,7 +3326,6 @@ def func(self, other): def _calculate_binary_op(self, f, other, join='inner', inplace=False): - def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): if inplace and set(lhs_data_vars) != set(rhs_data_vars): raise ValueError('datasets must have the same data variables ' @@ -3073,6 +3409,9 @@ def diff(self, dim, n=1, label='upper'): Data variables: foo (x) int64 1 -1 + See Also + -------- + Dataset.differentiate """ if n == 0: return self @@ -3112,7 +3451,7 @@ def diff(self, dim, n=1, label='upper'): else: return difference - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """Shift this dataset by an offset along one or more dimensions. Only data variables are moved; coordinates stay in place. This is @@ -3120,10 +3459,13 @@ def shift(self, **shifts): Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : Mapping with the form of {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- @@ -3147,6 +3489,7 @@ def shift(self, **shifts): Data variables: foo (x) object nan nan 'a' 'b' 'c' """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') invalid = [k for k in shifts if k not in self.dims] if invalid: raise ValueError("dimensions %r do not exist" % invalid) @@ -3162,18 +3505,28 @@ def shift(self, **shifts): return self._replace_vars_and_dims(variables) - def roll(self, **shifts): + def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this dataset by an offset along one or more dimensions. - Unlike shift, roll rotates all variables, including coordinates. The - direction of rotation is consistent with :py:func:`numpy.roll`. + Unlike shift, roll may rotate all variables, including coordinates + if specified. The direction of rotation is consistent with + :py:func:`numpy.roll`. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} - Integer offset to rotate each of the given dimensions. Positive - offsets roll to the right; negative offsets roll to the left. + shifts : dict, optional + A dict with keys matching dimensions and values given + by integers to rotate each of the given dimensions. Positive + offsets roll to the right; negative offsets roll to the left. + roll_coords : bool + Indicates whether to roll the coordinates by the offset + The current default of roll_coords (None, equivalent to True) is + deprecated and will change to False in a future version. + Explicitly pass roll_coords to silence the warning. + **shifts_kwargs : {dim: offset, ...}, optional + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwargs must be provided. Returns ------- rolled : Dataset @@ -3196,15 +3549,26 @@ def roll(self, **shifts): Data variables: foo (x) object 'd' 'e' 'a' 'b' 'c' """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll') invalid = [k for k in shifts if k not in self.dims] if invalid: raise ValueError("dimensions %r do not exist" % invalid) + if roll_coords is None: + warnings.warn("roll_coords will be set to False in the future." + " Explicitly set roll_coords to silence warning.", + FutureWarning, stacklevel=2) + roll_coords = True + + unrolled_vars = () if roll_coords else self.coords + variables = OrderedDict() - for name, var in iteritems(self.variables): - var_shifts = dict((k, v) for k, v in shifts.items() - if k in var.dims) - variables[name] = var.roll(**var_shifts) + for k, v in iteritems(self.variables): + if k not in unrolled_vars: + variables[k] = v.roll(**{k: s for k, s in shifts.items() + if k in v.dims}) + else: + variables[k] = v return self._replace_vars_and_dims(variables) @@ -3270,7 +3634,7 @@ def sortby(self, variables, ascending=True): return aligned_self.isel(**indices) def quantile(self, q, dim=None, interpolation='linear', - numeric_only=False, keep_attrs=False): + numeric_only=False, keep_attrs=None): """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements for each variable @@ -3348,6 +3712,8 @@ def quantile(self, q, dim=None, interpolation='linear', # construct the new dataset coord_names = set(k for k in self.coords if k in variables) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None new = self._replace_vars_and_dims(variables, coord_names, attrs=attrs) if 'quantile' in new.dims: @@ -3356,7 +3722,7 @@ def quantile(self, q, dim=None, interpolation='linear', new.coords['quantile'] = q return new - def rank(self, dim, pct=False, keep_attrs=False): + def rank(self, dim, pct=False, keep_attrs=None): """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -3396,24 +3762,89 @@ def rank(self, dim, pct=False, keep_attrs=False): variables[name] = var coord_names = set(self.coords) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None return self._replace_vars_and_dims(variables, coord_names, attrs=attrs) + def differentiate(self, coord, edge_order=1, datetime_unit=None): + """ Differentiate with the second order accurate central + differences. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord: str + The coordinate to be used to compute the gradient. + edge_order: 1 or 2. Default 1 + N-th order accurate differences at the boundaries. + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + Unit to compute gradient. Only valid for datetime coordinate. + + Returns + ------- + differentiated: Dataset + + See also + -------- + numpy.gradient: corresponding numpy function + """ + from .variable import Variable + + if coord not in self.variables and coord not in self.dims: + raise ValueError('Coordinate {} does not exist.'.format(coord)) + + coord_var = self[coord].variable + if coord_var.ndim != 1: + raise ValueError('Coordinate {} must be 1 dimensional but is {}' + ' dimensional'.format(coord, coord_var.ndim)) + + dim = coord_var.dims[0] + if _contains_datetime_like_objects(coord_var): + if coord_var.dtype.kind in 'mM' and datetime_unit is None: + datetime_unit, _ = np.datetime_data(coord_var.dtype) + elif datetime_unit is None: + datetime_unit = 's' # Default to seconds for cftime objects + coord_var = datetime_to_numeric(coord_var, datetime_unit=datetime_unit) + + variables = OrderedDict() + for k, v in self.variables.items(): + if (k in self.data_vars and dim in v.dims and + k not in self.coords): + if _contains_datetime_like_objects(v): + v = datetime_to_numeric(v, datetime_unit=datetime_unit) + grad = duck_array_ops.gradient( + v.data, coord_var, edge_order=edge_order, + axis=v.get_axis_num(dim)) + variables[k] = Variable(v.dims, grad) + else: + variables[k] = v + return self._replace_vars_and_dims(variables) + @property def real(self): - return self._unary_op(lambda x: x.real, keep_attrs=True)(self) + return self._unary_op(lambda x: x.real, + keep_attrs=True)(self) @property def imag(self): - return self._unary_op(lambda x: x.imag, keep_attrs=True)(self) + return self._unary_op(lambda x: x.imag, + keep_attrs=True)(self) def filter_by_attrs(self, **kwargs): """Returns a ``Dataset`` with variables that match specific conditions. - Can pass in ``key=value`` or ``key=callable``. Variables are returned - that contain all of the matches or callable returns True. If using a - callable note that it should accept a single parameter only, - the attribute value. + Can pass in ``key=value`` or ``key=callable``. A Dataset is returned + containing only the variables for which all the filter tests pass. + These tests are either ``key=value`` for which the attribute ``key`` + has the exact value ``value`` or the callable passed into + ``key=callable`` returns True. The callable will be passed a single + value, either the value of the attribute ``key`` or ``None`` if the + DataArray does not have an attribute with the name ``key``. Parameters ---------- @@ -3484,11 +3915,17 @@ def filter_by_attrs(self, **kwargs): """ selection = [] for var_name, variable in self.data_vars.items(): + has_value_flag = False for attr_name, pattern in kwargs.items(): attr_value = variable.attrs.get(attr_name) if ((callable(pattern) and pattern(attr_value)) or attr_value == pattern): - selection.append(var_name) + has_value_flag = True + else: + has_value_flag = False + break + if has_value_flag is True: + selection.append(var_name) return self[selection] diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index ccbe48edc32..a2f11728b4d 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -1,12 +1,70 @@ +import functools + import numpy as np from . import utils - # Use as a sentinel value to indicate a dtype appropriate NA value. NA = utils.ReprObject('') +@functools.total_ordering +class AlwaysGreaterThan(object): + def __gt__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan(object): + def __lt__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html +PROMOTE_TO_OBJECT = [ + {np.number, np.character}, # numpy promotes to character + {np.bool_, np.character}, # numpy promotes to character + {np.bytes_, np.unicode_}, # numpy promotes to unicode +] + + +@functools.total_ordering +class AlwaysGreaterThan(object): + def __gt__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan(object): + def __lt__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + def maybe_promote(dtype): """Simpler equivalent of pandas.core.common._maybe_promote @@ -22,6 +80,11 @@ def maybe_promote(dtype): # N.B. these casting rules should match pandas if np.issubdtype(dtype, np.floating): fill_value = np.nan + elif np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64('NaT') elif np.issubdtype(dtype, np.integer): if dtype.itemsize <= 2: dtype = np.float32 @@ -32,14 +95,15 @@ def maybe_promote(dtype): fill_value = np.nan + np.nan * 1j elif np.issubdtype(dtype, np.datetime64): fill_value = np.datetime64('NaT') - elif np.issubdtype(dtype, np.timedelta64): - fill_value = np.timedelta64('NaT') else: dtype = object fill_value = np.nan return np.dtype(dtype), fill_value +NAT_TYPES = (np.datetime64('NaT'), np.timedelta64('NaT')) + + def get_fill_value(dtype): """Return an appropriate fill value for this dtype. @@ -55,8 +119,74 @@ def get_fill_value(dtype): return fill_value +def get_pos_infinity(dtype): + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, (np.floating, np.integer)): + return np.inf + + if issubclass(dtype.type, np.complexfloating): + return np.inf + 1j * np.inf + + return INF + + +def get_neg_infinity(dtype): + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, (np.floating, np.integer)): + return -np.inf + + if issubclass(dtype.type, np.complexfloating): + return -np.inf - 1j * np.inf + + return NINF + + def is_datetime_like(dtype): """Check if a dtype is a subclass of the numpy datetime types """ return (np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)) + + +def result_type(*arrays_and_dtypes): + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + types = {np.result_type(t).type for t in arrays_and_dtypes} + + for left, right in PROMOTE_TO_OBJECT: + if (any(issubclass(t, left) for t in types) and + any(issubclass(t, right) for t in types)): + return np.dtype(object) + + return np.result_type(*arrays_and_dtypes) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 2058ce86a99..ef89dba2ab8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -3,52 +3,47 @@ Currently, this means Dask or NumPy arrays. None of these functions should accept or return xarray objects. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function -from functools import partial import contextlib import inspect import warnings +from functools import partial import numpy as np import pandas as pd -from . import npcompat -from . import dtypes -from .pycompat import dask_array_type +from . import dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast +from .pycompat import dask_array_type try: - import bottleneck as bn - has_bottleneck = True -except ImportError: - # use numpy methods instead - bn = np - has_bottleneck = False - -try: - import dask.array as da - has_dask = True + import dask.array as dask_array + from . import dask_array_compat except ImportError: - has_dask = False + dask_array = None + dask_array_compat = None -def _dask_or_eager_func(name, eager_module=np, list_of_args=False, - n_array_args=1): +def _dask_or_eager_func(name, eager_module=np, dask_module=dask_array, + list_of_args=False, array_args=slice(1), + requires_dask=None): """Create a function that dispatches to dask for dask array inputs.""" - if has_dask: + if dask_module is not None: def f(*args, **kwargs): if list_of_args: dispatch_args = args[0] else: - dispatch_args = args[:n_array_args] - if any(isinstance(a, da.Array) for a in dispatch_args): - module = da + dispatch_args = args[array_args] + if any(isinstance(a, dask_array.Array) for a in dispatch_args): + try: + wrapped = getattr(dask_module, name) + except AttributeError as e: + raise AttributeError("%s: requires dask >=%s" % + (e, requires_dask)) else: - module = eager_module - return getattr(module, name)(*args, **kwargs) + wrapped = getattr(eager_module, name) + return wrapped(*args, ** kwargs) else: def f(data, *args, **kwargs): return getattr(eager_module, name)(data, *args, **kwargs) @@ -66,8 +61,8 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): around = _dask_or_eager_func('around') isclose = _dask_or_eager_func('isclose') -notnull = _dask_or_eager_func('notnull', pd) -_isnull = _dask_or_eager_func('isnull', pd) +notnull = _dask_or_eager_func('notnull', eager_module=pd) +_isnull = _dask_or_eager_func('isnull', eager_module=pd) def isnull(data): @@ -82,24 +77,50 @@ def isnull(data): transpose = _dask_or_eager_func('transpose') -where = _dask_or_eager_func('where', n_array_args=3) -insert = _dask_or_eager_func('insert') +_where = _dask_or_eager_func('where', array_args=slice(3)) +isin = _dask_or_eager_func('isin', eager_module=npcompat, + dask_module=dask_array_compat, array_args=slice(2)) take = _dask_or_eager_func('take') broadcast_to = _dask_or_eager_func('broadcast_to') -concatenate = _dask_or_eager_func('concatenate', list_of_args=True) -stack = _dask_or_eager_func('stack', list_of_args=True) +_concatenate = _dask_or_eager_func('concatenate', list_of_args=True) +_stack = _dask_or_eager_func('stack', list_of_args=True) array_all = _dask_or_eager_func('all') array_any = _dask_or_eager_func('any') -tensordot = _dask_or_eager_func('tensordot', n_array_args=2) +tensordot = _dask_or_eager_func('tensordot', array_args=slice(2)) +einsum = _dask_or_eager_func('einsum', array_args=slice(1, None), + requires_dask='0.17.3') + + +def gradient(x, coord, axis, edge_order): + if isinstance(x, dask_array_type): + return dask_array_compat.gradient( + x, coord, axis=axis, edge_order=edge_order) + return npcompat.gradient(x, coord, axis=axis, edge_order=edge_order) + + +masked_invalid = _dask_or_eager_func( + 'masked_invalid', eager_module=np.ma, + dask_module=getattr(dask_array, 'ma', None)) def asarray(data): return data if isinstance(data, dask_array_type) else np.asarray(data) +def as_shared_dtype(scalars_or_arrays): + """Cast a arrays to a shared dtype using xarray's type promotion rules.""" + arrays = [asarray(x) for x in scalars_or_arrays] + # Pass arrays directly instead of dtypes to result_type so scalars + # get handled properly. + # Note that result_type() safely gets the dtype from dask arrays without + # evaluating them. + out_type = dtypes.result_type(*arrays) + return [x.astype(out_type, copy=False) for x in arrays] + + def as_like_arrays(*data): if all(isinstance(d, dask_array_type) for d in data): return data @@ -124,10 +145,13 @@ def array_equiv(arr1, arr2): if arr1.shape != arr2.shape: return False - flag_array = (arr1 == arr2) - flag_array |= (isnull(arr1) & isnull(arr2)) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', "In the future, 'NAT == x'") + + flag_array = (arr1 == arr2) + flag_array |= (isnull(arr1) & isnull(arr2)) - return bool(flag_array.all()) + return bool(flag_array.all()) def array_notnull_equiv(arr1, arr2): @@ -138,17 +162,25 @@ def array_notnull_equiv(arr1, arr2): if arr1.shape != arr2.shape: return False - flag_array = (arr1 == arr2) - flag_array |= isnull(arr1) - flag_array |= isnull(arr2) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', "In the future, 'NAT == x'") - return bool(flag_array.all()) + flag_array = (arr1 == arr2) + flag_array |= isnull(arr1) + flag_array |= isnull(arr2) + + return bool(flag_array.all()) def count(data, axis=None): """Count the number of non-NA in this array along the given axis or axes """ - return sum(~isnull(data), axis=axis) + return np.sum(~isnull(data), axis=axis) + + +def where(condition, x, y): + """Three argument where() with better dtype promotion rules.""" + return _where(condition, *as_shared_dtype([x, y])) def where_method(data, cond, other=dtypes.NA): @@ -161,6 +193,16 @@ def fillna(data, other): return where(isnull(data), other, data) +def concatenate(arrays, axis=0): + """concatenate() with better dtype promotion rules.""" + return _concatenate(as_shared_dtype(arrays), axis=axis) + + +def stack(arrays, axis=0): + """stack() with better dtype promotion rules.""" + return _stack(as_shared_dtype(arrays), axis=axis) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -171,76 +213,93 @@ def _ignore_warnings_if(condition): yield -def _create_nan_agg_method(name, numeric_only=False, np_compat=False, - no_bottleneck=False, coerce_strings=False, - keep_dims=False): +def _create_nan_agg_method(name, coerce_strings=False): + from . import nanops + def f(values, axis=None, skipna=None, **kwargs): if kwargs.pop('out', None) is not None: raise TypeError('`out` is not valid for {}'.format(name)) - # If dtype is supplied, we use numpy's method. - dtype = kwargs.get('dtype', None) values = asarray(values) if coerce_strings and values.dtype.kind in 'SU': values = values.astype(object) - if skipna or (skipna is None and values.dtype.kind in 'cf'): - if values.dtype.kind not in ['u', 'i', 'f', 'c']: - raise NotImplementedError( - 'skipna=True not yet implemented for %s with dtype %s' - % (name, values.dtype)) + func = None + if skipna or (skipna is None and values.dtype.kind in 'cfO'): nanname = 'nan' + name - if (isinstance(axis, tuple) or not values.dtype.isnative or - no_bottleneck or - (dtype is not None and np.dtype(dtype) != values.dtype)): - # bottleneck can't handle multiple axis arguments or non-native - # endianness - if np_compat: - eager_module = npcompat - else: - eager_module = np - else: - kwargs.pop('dtype', None) - eager_module = bn - func = _dask_or_eager_func(nanname, eager_module) - using_numpy_nan_func = (eager_module is np or - eager_module is npcompat) + func = getattr(nanops, nanname) else: func = _dask_or_eager_func(name) - using_numpy_nan_func = False - with _ignore_warnings_if(using_numpy_nan_func): - try: - return func(values, axis=axis, **kwargs) - except AttributeError: - if isinstance(values, dask_array_type): + + try: + return func(values, axis=axis, **kwargs) + except AttributeError: + if isinstance(values, dask_array_type): + try: # dask/dask#3133 dask sometimes needs dtype argument + # if func does not accept dtype, then raises TypeError + return func(values, axis=axis, dtype=values.dtype, + **kwargs) + except (AttributeError, TypeError): msg = '%s is not yet implemented on dask arrays' % name - else: - assert using_numpy_nan_func - msg = ('%s is not available with skipna=False with the ' - 'installed version of numpy; upgrade to numpy 1.12 ' - 'or newer to use skipna=True or skipna=None' % name) - raise NotImplementedError(msg) - f.numeric_only = numeric_only - f.keep_dims = keep_dims + else: + msg = ('%s is not available with skipna=False with the ' + 'installed version of numpy; upgrade to numpy 1.12 ' + 'or newer to use skipna=True or skipna=None' % name) + raise NotImplementedError(msg) + f.__name__ = name return f +# Attributes `numeric_only`, `available_min_count` is used for docs. +# See ops.inject_reduce_methods argmax = _create_nan_agg_method('argmax', coerce_strings=True) argmin = _create_nan_agg_method('argmin', coerce_strings=True) max = _create_nan_agg_method('max', coerce_strings=True) min = _create_nan_agg_method('min', coerce_strings=True) -sum = _create_nan_agg_method('sum', numeric_only=True) -mean = _create_nan_agg_method('mean', numeric_only=True) -std = _create_nan_agg_method('std', numeric_only=True) -var = _create_nan_agg_method('var', numeric_only=True) -median = _create_nan_agg_method('median', numeric_only=True) -prod = _create_nan_agg_method('prod', numeric_only=True, no_bottleneck=True) -cumprod = _create_nan_agg_method('cumprod', numeric_only=True, np_compat=True, - no_bottleneck=True, keep_dims=True) -cumsum = _create_nan_agg_method('cumsum', numeric_only=True, np_compat=True, - no_bottleneck=True, keep_dims=True) +sum = _create_nan_agg_method('sum') +sum.numeric_only = True +sum.available_min_count = True +mean = _create_nan_agg_method('mean') +mean.numeric_only = True +std = _create_nan_agg_method('std') +std.numeric_only = True +var = _create_nan_agg_method('var') +var.numeric_only = True +median = _create_nan_agg_method('median') +median.numeric_only = True +prod = _create_nan_agg_method('prod') +prod.numeric_only = True +sum.available_min_count = True +cumprod_1d = _create_nan_agg_method('cumprod') +cumprod_1d.numeric_only = True +cumsum_1d = _create_nan_agg_method('cumsum') +cumsum_1d.numeric_only = True + + +def _nd_cum_func(cum_func, array, axis, **kwargs): + array = asarray(array) + if axis is None: + axis = tuple(range(array.ndim)) + if isinstance(axis, int): + axis = (axis,) + + out = array + for ax in axis: + out = cum_func(out, axis=ax, **kwargs) + return out + + +def cumprod(array, axis=None, **kwargs): + """N-dimensional version of cumprod.""" + return _nd_cum_func(cumprod_1d, array, axis, **kwargs) + + +def cumsum(array, axis=None, **kwargs): + """N-dimensional version of cumsum.""" + return _nd_cum_func(cumsum_1d, array, axis, **kwargs) + _fail_on_dask_array_input_skipna = partial( fail_on_dask_array_input, @@ -265,3 +324,16 @@ def last(values, axis, skipna=None): _fail_on_dask_array_input_skipna(values) return nanlast(values, axis) return take(values, -1, axis=axis) + + +def rolling_window(array, axis, window, center, fill_value): + """ + Make an ndarray with a rolling window of axis-th dimension. + The rolling dimension will be placed at the last dimension. + """ + if isinstance(array, dask_array_type): + return dask_array_ops.rolling_window( + array, axis, window, center, fill_value) + else: # np.ndarray + return nputils.rolling_window( + array, axis, window, center, fill_value) diff --git a/xarray/core/extensions.py b/xarray/core/extensions.py index 90639e47f43..8070e07a5ef 100644 --- a/xarray/core/extensions.py +++ b/xarray/core/extensions.py @@ -1,6 +1,5 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import traceback import warnings diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 83f8e2719d6..5dd3cf06025 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -4,24 +4,25 @@ be returned by the __unicode__ special method. We use ReprMixin to provide the __repr__ method so that things can work on Python 2. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import contextlib -from datetime import datetime, timedelta import functools +from datetime import datetime, timedelta import numpy as np import pandas as pd + +from .options import OPTIONS +from .pycompat import ( + PY2, bytes_type, dask_array_type, unicode_type, zip_longest) + try: from pandas.errors import OutOfBoundsDatetime except ImportError: # pandas < 0.20 from pandas.tslib import OutOfBoundsDatetime -from .options import OPTIONS -from .pycompat import PY2, unicode_type, bytes_type, dask_array_type - def pretty_print(x, numchars): """Given an object `x`, call `str(x)` and format the returned string so @@ -64,13 +65,13 @@ def __repr__(self): return ensure_valid_repr(self.__unicode__()) -def _get_indexer_at_least_n_items(shape, n_desired): +def _get_indexer_at_least_n_items(shape, n_desired, from_end): assert 0 < n_desired <= np.prod(shape) cum_items = np.cumprod(shape[::-1]) n_steps = np.argmax(cum_items >= n_desired) stop = int(np.ceil(float(n_desired) / np.r_[1, cum_items][n_steps])) - indexer = ((0,) * (len(shape) - 1 - n_steps) + - (slice(stop),) + + indexer = (((-1 if from_end else 0),) * (len(shape) - 1 - n_steps) + + ((slice(-stop, None) if from_end else slice(stop)),) + (slice(None),) * n_steps) return indexer @@ -89,11 +90,28 @@ def first_n_items(array, n_desired): return [] if n_desired < array.size: - indexer = _get_indexer_at_least_n_items(array.shape, n_desired) + indexer = _get_indexer_at_least_n_items(array.shape, n_desired, + from_end=False) array = array[indexer] return np.asarray(array).flat[:n_desired] +def last_n_items(array, n_desired): + """Returns the last n_desired items of an array""" + # Unfortunately, we can't just do array.flat[-n_desired:] here because it + # might not be a numpy.ndarray. Moreover, access to elements of the array + # could be very expensive (e.g. if it's only available over DAP), so go out + # of our way to get them in a single call to __getitem__ using only slices. + if (n_desired == 0) or (array.size == 0): + return [] + + if n_desired < array.size: + indexer = _get_indexer_at_least_n_items(array.shape, n_desired, + from_end=True) + array = array[indexer] + return np.asarray(array).flat[-n_desired:] + + def last_item(array): """Returns the last item of an array in a list or an empty list.""" if array.size == 0: @@ -164,7 +182,7 @@ def format_items(x): day_part = (x[~pd.isnull(x)] .astype('timedelta64[D]') .astype('timedelta64[ns]')) - time_needed = x != day_part + time_needed = x[~pd.isnull(x)] != day_part day_needed = day_part != np.timedelta64(0, 'ns') if np.logical_not(day_needed).all(): timedelta_format = 'time' @@ -180,20 +198,36 @@ def format_array_flat(array, max_width): array that will fit within max_width characters. """ # every item will take up at least two characters, but we always want to - # print at least one item - max_possibly_relevant = max(int(np.ceil(max_width / 2.0)), 1) - relevant_items = first_n_items(array, max_possibly_relevant) - pprint_items = format_items(relevant_items) - - cum_len = np.cumsum([len(s) + 1 for s in pprint_items]) - 1 - if (max_possibly_relevant < array.size or (cum_len > max_width).any()): - end_padding = u' ...' - count = max(np.argmax((cum_len + len(end_padding)) > max_width), 1) - pprint_items = pprint_items[:count] + # print at least first and last items + max_possibly_relevant = min(max(array.size, 1), + max(int(np.ceil(max_width / 2.)), 2)) + relevant_front_items = format_items( + first_n_items(array, (max_possibly_relevant + 1) // 2)) + relevant_back_items = format_items( + last_n_items(array, max_possibly_relevant // 2)) + # interleave relevant front and back items: + # [a, b, c] and [y, z] -> [a, z, b, y, c] + relevant_items = sum(zip_longest(relevant_front_items, + reversed(relevant_back_items)), + ())[:max_possibly_relevant] + + cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1 + if (array.size > 2) and ((max_possibly_relevant < array.size) or + (cum_len > max_width).any()): + padding = u' ... ' + count = min(array.size, + max(np.argmax(cum_len + len(padding) - 1 > max_width), 2)) else: - end_padding = u'' - - pprint_str = u' '.join(pprint_items) + end_padding + count = array.size + padding = u'' if (count <= 1) else u' ' + + num_front = (count + 1) // 2 + num_back = count - num_front + # note that num_back is 0 <--> array.size is 0 or 1 + # <--> relevant_back_items is [] + pprint_str = (u' '.join(relevant_front_items[:num_front]) + + padding + + u' '.join(relevant_back_items[-num_back:])) return pprint_str diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index c4b25741d5b..defe72ab3ee 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,21 +1,19 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import functools +import warnings + import numpy as np import pandas as pd -from . import dtypes -from . import duck_array_ops -from . import nputils -from . import ops +from . import dtypes, duck_array_ops, nputils, ops, utils +from .arithmetic import SupportsArithmetic from .combine import concat -from .common import ( - ImplementsArrayReduce, ImplementsDatasetReduce, -) -from .pycompat import range, zip, integer_types -from .utils import hashable, peek_at, maybe_wrap_array, safe_cast_to_index -from .variable import as_variable, Variable, IndexVariable +from .common import ALL_DIMS, ImplementsArrayReduce, ImplementsDatasetReduce +from .pycompat import integer_types, range, zip +from .utils import hashable, maybe_wrap_array, peek_at, safe_cast_to_index +from .variable import IndexVariable, Variable, as_variable +from .options import _get_keep_attrs def unique_value_groups(ar, sort=True): @@ -156,7 +154,7 @@ def _unique_and_monotonic(group): return index.is_unique and index.is_monotonic -class GroupBy(object): +class GroupBy(SupportsArithmetic): """A object that implements the split-apply-combine pattern. Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over @@ -407,15 +405,17 @@ def _first_or_last(self, op, skipna, keep_attrs): # NB. this is currently only used for reductions along an existing # dimension return self._obj + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) return self.reduce(op, self._group_dim, skipna=skipna, keep_attrs=keep_attrs, allow_lazy=True) - def first(self, skipna=None, keep_attrs=True): + def first(self, skipna=None, keep_attrs=None): """Return the first element of each group along the group dimension """ return self._first_or_last(duck_array_ops.first, skipna, keep_attrs) - def last(self, skipna=None, keep_attrs=True): + def last(self, skipna=None, keep_attrs=None): """Return the last element of each group along the group dimension """ return self._first_or_last(duck_array_ops.last, skipna, keep_attrs) @@ -426,6 +426,7 @@ def assign_coords(self, **kwargs): See also -------- Dataset.assign_coords + Dataset.swap_dims """ return self.apply(lambda ds: ds.assign_coords(**kwargs)) @@ -541,8 +542,8 @@ def _combine(self, applied, shortcut=False): combined = self._maybe_unstack(combined) return combined - def reduce(self, func, dim=None, axis=None, keep_attrs=False, - shortcut=True, **kwargs): + def reduce(self, func, dim=None, axis=None, + keep_attrs=None, shortcut=True, **kwargs): """Reduce the items in this group by applying `func` along some dimension(s). @@ -571,10 +572,42 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = ALL_DIMS + # TODO change this to dim = self._group_dim after + # the deprecation process + if self._obj.ndim > 1: + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension after xarray 0.12. To silence this " + "warning, pass dim=xarray.ALL_DIMS explicitly.", + FutureWarning, stacklevel=2) + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + def reduce_array(ar): return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs) return self.apply(reduce_array, shortcut=shortcut) + # TODO remove the following class method and DEFAULT_DIMS after the + # deprecation cycle + @classmethod + def _reduce_method(cls, func, include_skipna, numeric_only): + if include_skipna: + def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, skipna=None, + keep_attrs=None, **kwargs): + return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + skipna=skipna, allow_lazy=True, **kwargs) + else: + def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, + keep_attrs=None, **kwargs): + return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + allow_lazy=True, **kwargs) + return wrapped_func + + +DEFAULT_DIMS = utils.ReprObject('') ops.inject_reduce_methods(DataArrayGroupBy) ops.inject_binary_ops(DataArrayGroupBy) @@ -624,7 +657,7 @@ def _combine(self, applied): combined = self._maybe_unstack(combined) return combined - def reduce(self, func, dim=None, keep_attrs=False, **kwargs): + def reduce(self, func, dim=None, keep_attrs=None, **kwargs): """Reduce the items in this group by applying `func` along some dimension(s). @@ -653,10 +686,43 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = ALL_DIMS + # TODO change this to dim = self._group_dim after + # the deprecation process. Do not forget to remove _reduce_method + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension after xarray 0.12. To silence this " + "warning, pass dim=xarray.ALL_DIMS explicitly.", + FutureWarning, stacklevel=2) + elif dim is None: + dim = self._group_dim + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + def reduce_dataset(ds): return ds.reduce(func, dim, keep_attrs, **kwargs) return self.apply(reduce_dataset) + # TODO remove the following class method and DEFAULT_DIMS after the + # deprecation cycle + @classmethod + def _reduce_method(cls, func, include_skipna, numeric_only): + if include_skipna: + def wrapped_func(self, dim=DEFAULT_DIMS, + skipna=None, **kwargs): + return self.reduce(func, dim, + skipna=skipna, numeric_only=numeric_only, + allow_lazy=True, **kwargs) + else: + def wrapped_func(self, dim=DEFAULT_DIMS, + **kwargs): + return self.reduce(func, dim, + numeric_only=numeric_only, allow_lazy=True, + **kwargs) + return wrapped_func + def assign(self, **kwargs): """Assign data variables by group. diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index e06b045ad88..d51da471c8d 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1,18 +1,16 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from datetime import timedelta -from collections import defaultdict, Hashable +from __future__ import absolute_import, division, print_function + import functools import operator +from collections import Hashable, defaultdict +from datetime import timedelta + import numpy as np import pandas as pd -from . import nputils -from . import utils -from . import duck_array_ops -from .pycompat import (iteritems, range, integer_types, dask_array_type, - suppress) +from . import duck_array_ops, nputils, utils +from .pycompat import ( + dask_array_type, integer_types, iteritems, range, suppress) from .utils import is_dict_like @@ -50,11 +48,25 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) -def _try_get_item(x): - try: - return x.item() - except AttributeError: - return x +def _sanitize_slice_element(x): + from .variable import Variable + from .dataarray import DataArray + + if isinstance(x, (Variable, DataArray)): + x = x.values + + if isinstance(x, np.ndarray): + if x.ndim != 0: + raise ValueError('cannot use non-scalar arrays in a slice for ' + 'xarray indexing: {}'.format(x)) + x = x[()] + + if isinstance(x, np.timedelta64): + # pandas does not support indexing with np.timedelta64 yet: + # https://github.com/pandas-dev/pandas/issues/20393 + x = pd.Timedelta(x) + + return x def _asarray_tuplesafe(values): @@ -121,9 +133,9 @@ def convert_label_indexer(index, label, index_name='', method=None, raise NotImplementedError( 'cannot use ``method`` argument if any indexers are ' 'slice objects') - indexer = index.slice_indexer(_try_get_item(label.start), - _try_get_item(label.stop), - _try_get_item(label.step)) + indexer = index.slice_indexer(_sanitize_slice_element(label.start), + _sanitize_slice_element(label.stop), + _sanitize_slice_element(label.step)) if not isinstance(indexer, slice): # unlike pandas, in xarray we never want to silently convert a # slice indexer into an array indexer @@ -256,7 +268,7 @@ def slice_slice(old_slice, applied_slice, size): items = _expand_slice(old_slice, size)[applied_slice] if len(items) > 0: start = items[0] - stop = items[-1] + step + stop = items[-1] + int(np.sign(step)) if stop < 0: stop = None else: @@ -292,7 +304,7 @@ class ExplicitIndexer(object): """ def __init__(self, key): - if type(self) is ExplicitIndexer: + if type(self) is ExplicitIndexer: # noqa raise TypeError('cannot instantiate base ExplicitIndexer objects') self._key = tuple(key) @@ -427,23 +439,6 @@ def __array__(self, dtype=None): return np.asarray(self[key], dtype=dtype) -def unwrap_explicit_indexer(key, target, allow): - """Unwrap an explicit key into a tuple.""" - if not isinstance(key, ExplicitIndexer): - raise TypeError('unexpected key type: {}'.format(key)) - if not isinstance(key, allow): - key_type_name = { - BasicIndexer: 'Basic', - OuterIndexer: 'Outer', - VectorizedIndexer: 'Vectorized' - }[type(key)] - raise NotImplementedError( - '{} indexing for {} is not implemented. Load your data first with ' - '.load(), .compute() or .persist(), or disable caching by setting ' - 'cache=False in open_dataset.'.format(key_type_name, type(target))) - return key.tuple - - class ImplicitToExplicitIndexingAdapter(utils.NDArrayMixin): """Wrap an array, converting tuples into the indicated explicit indexer.""" @@ -459,8 +454,8 @@ def __getitem__(self, key): return self.array[self.indexer_cls(key)] -class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin): - """Wrap an array to make basic and orthogonal indexing lazy. +class LazilyOuterIndexedArray(ExplicitlyIndexedNDArrayMixin): + """Wrap an array to make basic and outer indexing lazy. """ def __init__(self, array, key=None): @@ -485,13 +480,6 @@ def __init__(self, array, key=None): self.key = key def _updated_key(self, new_key): - # TODO should suport VectorizedIndexer - if isinstance(new_key, VectorizedIndexer): - raise NotImplementedError( - 'Vectorized indexing for {} is not implemented. Load your ' - 'data first with .load() or .compute(), or disable caching by ' - 'setting cache=False in open_dataset.'.format(type(self))) - iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim)) full_key = [] for size, k in zip(self.array.shape, self.key.tuple): @@ -519,10 +507,21 @@ def __array__(self, dtype=None): array = as_indexable(self.array) return np.asarray(array[self.key], dtype=None) + def transpose(self, order): + return LazilyVectorizedIndexedArray( + self.array, self.key).transpose(order) + def __getitem__(self, indexer): + if isinstance(indexer, VectorizedIndexer): + array = LazilyVectorizedIndexedArray(self.array, self.key) + return array[indexer] return type(self)(self.array, self._updated_key(indexer)) def __setitem__(self, key, value): + if isinstance(key, VectorizedIndexer): + raise NotImplementedError( + 'Lazy item assignment with the vectorized indexer is not yet ' + 'implemented. Load your data first by .load() or compute().') full_key = self._updated_key(key) self.array[full_key] = value @@ -531,6 +530,56 @@ def __repr__(self): (type(self).__name__, self.array, self.key)) +class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin): + """Wrap an array to make vectorized indexing lazy. + """ + + def __init__(self, array, key): + """ + Parameters + ---------- + array : array_like + Array like object to index. + key : VectorizedIndexer + """ + if isinstance(key, (BasicIndexer, OuterIndexer)): + self.key = _outer_to_vectorized_indexer(key, array.shape) + else: + self.key = _arrayize_vectorized_indexer(key, array.shape) + self.array = as_indexable(array) + + @property + def shape(self): + return np.broadcast(*self.key.tuple).shape + + def __array__(self, dtype=None): + return np.asarray(self.array[self.key], dtype=None) + + def _updated_key(self, new_key): + return _combine_indexers(self.key, self.shape, new_key) + + def __getitem__(self, indexer): + # If the indexed array becomes a scalar, return LazilyOuterIndexedArray + if all(isinstance(ind, integer_types) for ind in indexer.tuple): + key = BasicIndexer(tuple(k[indexer.tuple] for k in self.key.tuple)) + return LazilyOuterIndexedArray(self.array, key) + return type(self)(self.array, self._updated_key(indexer)) + + def transpose(self, order): + key = VectorizedIndexer(tuple( + k.transpose(order) for k in self.key.tuple)) + return type(self)(self.array, key) + + def __setitem__(self, key, value): + raise NotImplementedError( + 'Lazy item assignment with the vectorized indexer is not yet ' + 'implemented. Load your data first by .load() or compute().') + + def __repr__(self): + return ('%s(array=%r, key=%r)' % + (type(self).__name__, self.array, self.key)) + + def _wrap_numpy_scalars(array): """Wrap NumPy scalars in 0d arrays.""" if np.isscalar(array): @@ -555,6 +604,9 @@ def __array__(self, dtype=None): def __getitem__(self, key): return type(self)(_wrap_numpy_scalars(self.array[key])) + def transpose(self, order): + return self.array.transpose(order) + def __setitem__(self, key, value): self._ensure_copied() self.array[key] = value @@ -575,6 +627,9 @@ def __array__(self, dtype=None): def __getitem__(self, key): return type(self)(_wrap_numpy_scalars(self.array[key])) + def transpose(self, order): + return self.array.transpose(order) + def __setitem__(self, key, value): self.array[key] = value @@ -601,24 +656,26 @@ def _outer_to_vectorized_indexer(key, shape): Parameters ---------- - key : tuple - Tuple from an OuterIndexer to convert. + key : Outer/Basic Indexer + An indexer to convert. shape : tuple Shape of the array subject to the indexing. Returns ------- - tuple + VectorizedIndexer Tuple suitable for use to index a NumPy array with vectorized indexing. - Each element is an integer or array: broadcasting them together gives - the shape of the result. + Each element is an array: broadcasting them together gives the shape + of the result. """ + key = key.tuple + n_dim = len([k for k in key if not isinstance(k, integer_types)]) i_dim = 0 new_key = [] for k, size in zip(key, shape): if isinstance(k, integer_types): - new_key.append(k) + new_key.append(np.array(k).reshape((1,) * n_dim)) else: # np.ndarray or slice if isinstance(k, slice): k = np.arange(*k.indices(size)) @@ -627,7 +684,7 @@ def _outer_to_vectorized_indexer(key, shape): (1,) * (n_dim - i_dim - 1)] new_key.append(k.reshape(*shape)) i_dim += 1 - return tuple(new_key) + return VectorizedIndexer(tuple(new_key)) def _outer_to_numpy_indexer(key, shape): @@ -635,8 +692,8 @@ def _outer_to_numpy_indexer(key, shape): Parameters ---------- - key : tuple - Tuple from an OuterIndexer to convert. + key : Basic/OuterIndexer + An indexer to convert. shape : tuple Shape of the array subject to the indexing. @@ -645,13 +702,315 @@ def _outer_to_numpy_indexer(key, shape): tuple Tuple suitable for use to index a NumPy array. """ - if len([k for k in key if not isinstance(k, slice)]) <= 1: + if len([k for k in key.tuple if not isinstance(k, slice)]) <= 1: # If there is only one vector and all others are slice, # it can be safely used in mixed basic/advanced indexing. # Boolean index should already be converted to integer array. - return tuple(key) + return key.tuple + else: + return _outer_to_vectorized_indexer(key, shape).tuple + + +def _combine_indexers(old_key, shape, new_key): + """ Combine two indexers. + + Parameters + ---------- + old_key: ExplicitIndexer + The first indexer for the original array + shape: tuple of ints + Shape of the original array to be indexed by old_key + new_key: + The second indexer for indexing original[old_key] + """ + if not isinstance(old_key, VectorizedIndexer): + old_key = _outer_to_vectorized_indexer(old_key, shape) + if len(old_key.tuple) == 0: + return new_key + + new_shape = np.broadcast(*old_key.tuple).shape + if isinstance(new_key, VectorizedIndexer): + new_key = _arrayize_vectorized_indexer(new_key, new_shape) + else: + new_key = _outer_to_vectorized_indexer(new_key, new_shape) + + return VectorizedIndexer(tuple(o[new_key.tuple] for o in + np.broadcast_arrays(*old_key.tuple))) + + +class IndexingSupport(object): # could inherit from enum.Enum on Python 3 + # for backends that support only basic indexer + BASIC = 'BASIC' + # for backends that support basic / outer indexer + OUTER = 'OUTER' + # for backends that support outer indexer including at most 1 vector. + OUTER_1VECTOR = 'OUTER_1VECTOR' + # for backends that support full vectorized indexer. + VECTORIZED = 'VECTORIZED' + + +def explicit_indexing_adapter( + key, shape, indexing_support, raw_indexing_method): + """Support explicit indexing by delegating to a raw indexing method. + + Outer and/or vectorized indexers are supported by indexing a second time + with a NumPy array. + + Parameters + ---------- + key : ExplicitIndexer + Explicit indexing object. + shape : Tuple[int, ...] + Shape of the indexed array. + indexing_support : IndexingSupport enum + Form of indexing supported by raw_indexing_method. + raw_indexing_method: callable + Function (like ndarray.__getitem__) that when called with indexing key + in the form of a tuple returns an indexed array. + + Returns + ------- + Indexing result, in the form of a duck numpy-array. + """ + raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support) + result = raw_indexing_method(raw_key.tuple) + if numpy_indices.tuple: + # index the loaded np.ndarray + result = NumpyIndexingAdapter(np.asarray(result))[numpy_indices] + return result + + +def decompose_indexer(indexer, shape, indexing_support): + if isinstance(indexer, VectorizedIndexer): + return _decompose_vectorized_indexer(indexer, shape, indexing_support) + if isinstance(indexer, (BasicIndexer, OuterIndexer)): + return _decompose_outer_indexer(indexer, shape, indexing_support) + raise TypeError('unexpected key type: {}'.format(indexer)) + + +def _decompose_slice(key, size): + """ convert a slice to successive two slices. The first slice always has + a positive step. + """ + start, stop, step = key.indices(size) + if step > 0: + # If key already has a positive step, use it as is in the backend + return key, slice(None) else: - return _outer_to_vectorized_indexer(key, shape) + # determine stop precisely for step > 1 case + # e.g. [98:2:-2] -> [98:3:-2] + stop = start + int((stop - start - 1) / step) * step + 1 + start, stop = stop + 1, start + 1 + return slice(start, stop, -step), slice(None, None, -1) + + +def _decompose_vectorized_indexer(indexer, shape, indexing_support): + """ + Decompose vectorized indexer to the successive two indexers, where the + first indexer will be used to index backend arrays, while the second one + is used to index loaded on-memory np.ndarray. + + Parameters + ---------- + indexer: VectorizedIndexer + indexing_support: one of IndexerSupport entries + + Returns + ------- + backend_indexer: OuterIndexer or BasicIndexer + np_indexers: an ExplicitIndexer (VectorizedIndexer / BasicIndexer) + + Notes + ----- + This function is used to realize the vectorized indexing for the backend + arrays that only support basic or outer indexing. + + As an example, let us consider to index a few elements from a backend array + with a vectorized indexer ([0, 3, 1], [2, 3, 2]). + Even if the backend array only supports outer indexing, it is more + efficient to load a subslice of the array than loading the entire array, + + >>> backend_indexer = OuterIndexer([0, 1, 3], [2, 3]) + >>> array = array[backend_indexer] # load subslice of the array + >>> np_indexer = VectorizedIndexer([0, 2, 1], [0, 1, 0]) + >>> array[np_indexer] # vectorized indexing for on-memory np.ndarray. + """ + assert isinstance(indexer, VectorizedIndexer) + + if indexing_support is IndexingSupport.VECTORIZED: + return indexer, BasicIndexer(()) + + backend_indexer = [] + np_indexer = [] + # convert negative indices + indexer = [np.where(k < 0, k + s, k) if isinstance(k, np.ndarray) else k + for k, s in zip(indexer.tuple, shape)] + + for k, s in zip(indexer, shape): + if isinstance(k, slice): + # If it is a slice, then we will slice it as-is + # (but make its step positive) in the backend, + # and then use all of it (slice(None)) for the in-memory portion. + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer.append(bk_slice) + np_indexer.append(np_slice) + else: + # If it is a (multidimensional) np.ndarray, just pickup the used + # keys without duplication and store them as a 1d-np.ndarray. + oind, vind = np.unique(k, return_inverse=True) + backend_indexer.append(oind) + np_indexer.append(vind.reshape(*k.shape)) + + backend_indexer = OuterIndexer(tuple(backend_indexer)) + np_indexer = VectorizedIndexer(tuple(np_indexer)) + + if indexing_support is IndexingSupport.OUTER: + return backend_indexer, np_indexer + + # If the backend does not support outer indexing, + # backend_indexer (OuterIndexer) is also decomposed. + backend_indexer, np_indexer1 = _decompose_outer_indexer( + backend_indexer, shape, indexing_support) + np_indexer = _combine_indexers(np_indexer1, shape, np_indexer) + return backend_indexer, np_indexer + + +def _decompose_outer_indexer(indexer, shape, indexing_support): + """ + Decompose outer indexer to the successive two indexers, where the + first indexer will be used to index backend arrays, while the second one + is used to index the loaded on-memory np.ndarray. + + Parameters + ---------- + indexer: VectorizedIndexer + indexing_support: One of the entries of IndexingSupport + + Returns + ------- + backend_indexer: OuterIndexer or BasicIndexer + np_indexers: an ExplicitIndexer (OuterIndexer / BasicIndexer) + + Notes + ----- + This function is used to realize the vectorized indexing for the backend + arrays that only support basic or outer indexing. + + As an example, let us consider to index a few elements from a backend array + with a orthogonal indexer ([0, 3, 1], [2, 3, 2]). + Even if the backend array only supports basic indexing, it is more + efficient to load a subslice of the array than loading the entire array, + + >>> backend_indexer = BasicIndexer(slice(0, 3), slice(2, 3)) + >>> array = array[backend_indexer] # load subslice of the array + >>> np_indexer = OuterIndexer([0, 2, 1], [0, 1, 0]) + >>> array[np_indexer] # outer indexing for on-memory np.ndarray. + """ + if indexing_support == IndexingSupport.VECTORIZED: + return indexer, BasicIndexer(()) + assert isinstance(indexer, (OuterIndexer, BasicIndexer)) + + backend_indexer = [] + np_indexer = [] + # make indexer positive + pos_indexer = [] + for k, s in zip(indexer.tuple, shape): + if isinstance(k, np.ndarray): + pos_indexer.append(np.where(k < 0, k + s, k)) + elif isinstance(k, integer_types) and k < 0: + pos_indexer.append(k + s) + else: + pos_indexer.append(k) + indexer = pos_indexer + + if indexing_support is IndexingSupport.OUTER_1VECTOR: + # some backends such as h5py supports only 1 vector in indexers + # We choose the most efficient axis + gains = [(np.max(k) - np.min(k) + 1.0) / len(np.unique(k)) + if isinstance(k, np.ndarray) else 0 for k in indexer] + array_index = np.argmax(np.array(gains)) if len(gains) > 0 else None + + for i, (k, s) in enumerate(zip(indexer, shape)): + if isinstance(k, np.ndarray) and i != array_index: + # np.ndarray key is converted to slice that covers the entire + # entries of this key. + backend_indexer.append(slice(np.min(k), np.max(k) + 1)) + np_indexer.append(k - np.min(k)) + elif isinstance(k, np.ndarray): + # Remove duplicates and sort them in the increasing order + pkey, ekey = np.unique(k, return_inverse=True) + backend_indexer.append(pkey) + np_indexer.append(ekey) + elif isinstance(k, integer_types): + backend_indexer.append(k) + else: # slice: convert positive step slice for backend + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer.append(bk_slice) + np_indexer.append(np_slice) + + return (OuterIndexer(tuple(backend_indexer)), + OuterIndexer(tuple(np_indexer))) + + if indexing_support == IndexingSupport.OUTER: + for k, s in zip(indexer, shape): + if isinstance(k, slice): + # slice: convert positive step slice for backend + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer.append(bk_slice) + np_indexer.append(np_slice) + elif isinstance(k, integer_types): + backend_indexer.append(k) + elif isinstance(k, np.ndarray) and (np.diff(k) >= 0).all(): + backend_indexer.append(k) + np_indexer.append(slice(None)) + else: + # Remove duplicates and sort them in the increasing order + oind, vind = np.unique(k, return_inverse=True) + backend_indexer.append(oind) + np_indexer.append(vind.reshape(*k.shape)) + + return (OuterIndexer(tuple(backend_indexer)), + OuterIndexer(tuple(np_indexer))) + + # basic indexer + assert indexing_support == IndexingSupport.BASIC + + for k, s in zip(indexer, shape): + if isinstance(k, np.ndarray): + # np.ndarray key is converted to slice that covers the entire + # entries of this key. + backend_indexer.append(slice(np.min(k), np.max(k) + 1)) + np_indexer.append(k - np.min(k)) + elif isinstance(k, integer_types): + backend_indexer.append(k) + else: # slice: convert positive step slice for backend + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer.append(bk_slice) + np_indexer.append(np_slice) + + return (BasicIndexer(tuple(backend_indexer)), + OuterIndexer(tuple(np_indexer))) + + +def _arrayize_vectorized_indexer(indexer, shape): + """ Return an identical vindex but slices are replaced by arrays """ + slices = [v for v in indexer.tuple if isinstance(v, slice)] + if len(slices) == 0: + return indexer + + arrays = [v for v in indexer.tuple if isinstance(v, np.ndarray)] + n_dim = arrays[0].ndim if len(arrays) > 0 else 0 + i_dim = 0 + new_key = [] + for v, size in zip(indexer.tuple, shape): + if isinstance(v, np.ndarray): + new_key.append(np.reshape(v, v.shape + (1, ) * len(slices))) + else: # slice + shape = ((1,) * (n_dim + i_dim) + (-1,) + + (1,) * (len(slices) - i_dim - 1)) + new_key.append(np.arange(*v.indices(size)).reshape(shape)) + i_dim += 1 + return VectorizedIndexer(tuple(new_key)) def _dask_array_with_chunks_hint(array, chunks): @@ -700,7 +1059,7 @@ def create_mask(indexer, shape, chunks_hint=None): same shape as the indexing result. """ if isinstance(indexer, OuterIndexer): - key = _outer_to_vectorized_indexer(indexer.tuple, shape) + key = _outer_to_vectorized_indexer(indexer, shape).tuple assert not any(isinstance(k, slice) for k in key) mask = _masked_result_drop_slice(key, chunks_hint) @@ -795,7 +1154,7 @@ def _ensure_ndarray(self, value): def _indexing_array_and_key(self, key): if isinstance(key, OuterIndexer): array = self.array - key = _outer_to_numpy_indexer(key.tuple, self.array.shape) + key = _outer_to_numpy_indexer(key, self.array.shape) elif isinstance(key, VectorizedIndexer): array = nputils.NumpyVIndexAdapter(self.array) key = key.tuple @@ -807,6 +1166,9 @@ def _indexing_array_and_key(self, key): return array, key + def transpose(self, order): + return self.array.transpose(order) + def __getitem__(self, key): array, key = self._indexing_array_and_key(key) return self._ensure_ndarray(array[key]) @@ -850,6 +1212,9 @@ def __setitem__(self, key, value): 'into memory explicitly using the .load() ' 'method or accessing its .values attribute.') + def transpose(self, order): + return self.array.transpose(order) + class PandasIndexAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" @@ -911,6 +1276,10 @@ def __getitem__(self, indexer): result = np.datetime64('NaT', 'ns') elif isinstance(result, timedelta): result = np.timedelta64(getattr(result, 'value', result), 'ns') + elif isinstance(result, pd.Timestamp): + # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 + # numpy fails to convert pd.Timestamp to np.datetime64[ns] + result = np.asarray(result.to_datetime64()) elif self.dtype != object: result = np.asarray(result, dtype=self.dtype) @@ -920,6 +1289,9 @@ def __getitem__(self, indexer): return result + def transpose(self, order): + return self.array # self.array should be always one-dimensional + def __repr__(self): return ('%s(array=%r, dtype=%r)' % (type(self).__name__, self.array, self.dtype)) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index c5e643adb0d..984dd2fa204 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -1,6 +1,5 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import pandas as pd from .alignment import deep_align @@ -8,7 +7,6 @@ from .utils import Frozen from .variable import as_variable, assert_unique_multiindex_level_names - PANDAS_TYPES = (pd.Series, pd.DataFrame, pd.Panel) _VALID_COMPAT = Frozen({'identical': 0, @@ -192,10 +190,13 @@ def expand_variable_dicts(list_of_variable_dicts): an input's values. The values of each ordered dictionary are all xarray.Variable objects. """ + from .dataarray import DataArray + from .dataset import Dataset + var_dicts = [] for variables in list_of_variable_dicts: - if hasattr(variables, 'variables'): # duck-type Dataset + if isinstance(variables, Dataset): sanitized_vars = variables.variables else: # append coords to var_dicts before appending sanitized_vars, @@ -203,7 +204,7 @@ def expand_variable_dicts(list_of_variable_dicts): sanitized_vars = OrderedDict() for name, var in variables.items(): - if hasattr(var, '_coords'): # duck-type DataArray + if isinstance(var, DataArray): # use private API for speed coords = var._coords.copy() # explicitly overwritten variables should take precedence @@ -234,17 +235,19 @@ def determine_coords(list_of_variable_dicts): All variable found in the input should appear in either the set of coordinate or non-coordinate names. """ + from .dataarray import DataArray + from .dataset import Dataset + coord_names = set() noncoord_names = set() for variables in list_of_variable_dicts: - if hasattr(variables, 'coords') and hasattr(variables, 'data_vars'): - # duck-type Dataset + if isinstance(variables, Dataset): coord_names.update(variables.coords) noncoord_names.update(variables.data_vars) else: for name, var in variables.items(): - if hasattr(var, '_coords'): # duck-type DataArray + if isinstance(var, DataArray): coords = set(var._coords) # use private API for speed # explicitly overwritten variables should take precedence coords.discard(name) @@ -362,7 +365,17 @@ def merge_data_and_coords(data, coords, compat='broadcast_equals', """Used in Dataset.__init__.""" objs = [data, coords] explicit_coords = coords.keys() - return merge_core(objs, compat, join, explicit_coords=explicit_coords) + indexes = dict(extract_indexes(coords)) + return merge_core(objs, compat, join, explicit_coords=explicit_coords, + indexes=indexes) + + +def extract_indexes(coords): + """Yields the name & index of valid indexes from a mapping of coords""" + for name, variable in coords.items(): + variable = as_variable(variable, name=name) + if variable.dims == (name,): + yield name, variable.to_index() def assert_valid_explicit_coords(variables, dims, explicit_coords): @@ -549,6 +562,24 @@ def dataset_merge_method(dataset, other, overwrite_vars, compat, join): def dataset_update_method(dataset, other): - """Guts of the Dataset.update method""" + """Guts of the Dataset.update method. + + This drops a duplicated coordinates from `other` if `other` is not an + `xarray.Dataset`, e.g., if it's a dict with DataArray values (GH2068, + GH2180). + """ + from .dataset import Dataset + from .dataarray import DataArray + + if not isinstance(other, Dataset): + other = OrderedDict(other) + for key, value in other.items(): + if isinstance(value, DataArray): + # drop conflicting coordinates + coord_names = [c for c in value.coords + if c not in value.dims and c in dataset.coords] + if coord_names: + other[key] = value.drop(coord_names) + return merge_core([dataset, other], priority_arg=1, indexes=dataset.indexes) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index e26e976a11b..3f4e0fc3ac9 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -1,18 +1,19 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function +import warnings from collections import Iterable from functools import partial import numpy as np import pandas as pd - -from .pycompat import iteritems +from . import rolling +from .common import _contains_datetime_like_objects from .computation import apply_ufunc -from .utils import is_scalar -from .npcompat import flip +from .duck_array_ops import dask_array_type +from .pycompat import iteritems +from .utils import OrderedSet, datetime_to_numeric, is_scalar +from .variable import Variable, broadcast_variables class BaseInterpolator(object): @@ -58,7 +59,7 @@ def __init__(self, xi, yi, method='linear', fill_value=None, **kwargs): if self.cons_kwargs: raise ValueError( - 'recieved invalid kwargs: %r' % self.cons_kwargs.keys()) + 'received invalid kwargs: %r' % self.cons_kwargs.keys()) if fill_value is None: self._left = np.nan @@ -205,15 +206,19 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, # method index = get_clean_interp_index(self, dim, use_coordinate=use_coordinate, **kwargs) - interpolator = _get_interpolator(method, **kwargs) - - arr = apply_ufunc(interpolator, index, self, - input_core_dims=[[dim], [dim]], - output_core_dims=[[dim]], - output_dtypes=[self.dtype], - dask='parallelized', - vectorize=True, - keep_attrs=True).transpose(*self.dims) + interp_class, kwargs = _get_interpolator(method, **kwargs) + interpolator = partial(func_interpolate_na, interp_class, **kwargs) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'overflow', RuntimeWarning) + warnings.filterwarnings('ignore', 'invalid value', RuntimeWarning) + arr = apply_ufunc(interpolator, index, self, + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim]], + output_dtypes=[self.dtype], + dask='parallelized', + vectorize=True, + keep_attrs=True).transpose(*self.dims) if limit is not None: arr = arr.where(valids) @@ -221,7 +226,7 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, return arr -def wrap_interpolator(interpolator, x, y, **kwargs): +def func_interpolate_na(interpolator, x, y, **kwargs): '''helper function to apply interpolation along 1 dimension''' # it would be nice if this wasn't necessary, works around: # "ValueError: assignment destination is read-only" in assignment below @@ -244,13 +249,13 @@ def _bfill(arr, n=None, axis=-1): '''inverse of ffill''' import bottleneck as bn - arr = flip(arr, axis=axis) + arr = np.flip(arr, axis=axis) # fill arr = bn.push(arr, axis=axis, n=n) # reverse back to original - return flip(arr, axis=axis) + return np.flip(arr, axis=axis) def ffill(arr, dim=None, limit=None): @@ -283,29 +288,41 @@ def bfill(arr, dim=None, limit=None): kwargs=dict(n=_limit, axis=axis)).transpose(*arr.dims) -def _get_interpolator(method, **kwargs): +def _get_interpolator(method, vectorizeable_only=False, **kwargs): '''helper function to select the appropriate interpolator class - returns a partial of wrap_interpolator + returns interpolator class and keyword arguments for the class ''' interp1d_methods = ['linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'polynomial'] valid_methods = interp1d_methods + ['barycentric', 'krog', 'pchip', 'spline', 'akima'] + has_scipy = True + try: + from scipy import interpolate + except ImportError: + has_scipy = False + + # prioritize scipy.interpolate if (method == 'linear' and not - kwargs.get('fill_value', None) == 'extrapolate'): + kwargs.get('fill_value', None) == 'extrapolate' and + not vectorizeable_only): kwargs.update(method=method) interp_class = NumpyInterpolator + elif method in valid_methods: - try: - from scipy import interpolate - except ImportError: + if not has_scipy: raise ImportError( 'Interpolation with method `%s` requires scipy' % method) + if method in interp1d_methods: kwargs.update(method=method) interp_class = ScipyInterpolator + elif vectorizeable_only: + raise ValueError('{} is not a vectorizeable interpolator. ' + 'Available methods are {}'.format( + method, interp1d_methods)) elif method == 'barycentric': interp_class = interpolate.BarycentricInterpolator elif method == 'krog': @@ -322,11 +339,225 @@ def _get_interpolator(method, **kwargs): else: raise ValueError('%s is not a valid interpolator' % method) - return partial(wrap_interpolator, interp_class, **kwargs) + return interp_class, kwargs + + +def _get_interpolator_nd(method, **kwargs): + '''helper function to select the appropriate interpolator class + + returns interpolator class and keyword arguments for the class + ''' + valid_methods = ['linear', 'nearest'] + + try: + from scipy import interpolate + except ImportError: + raise ImportError( + 'Interpolation with method `%s` requires scipy' % method) + + if method in valid_methods: + kwargs.update(method=method) + interp_class = interpolate.interpn + else: + raise ValueError('%s is not a valid interpolator for interpolating ' + 'over multiple dimensions.' % method) + + return interp_class, kwargs def _get_valid_fill_mask(arr, dim, limit): '''helper function to determine values that can be filled when limit is not None''' kw = {dim: limit + 1} - return arr.isnull().rolling(min_periods=1, **kw).sum() <= limit + # we explicitly use construct method to avoid copy. + new_dim = rolling._get_new_dimname(arr.dims, '_window') + return (arr.isnull().rolling(min_periods=1, **kw) + .construct(new_dim, fill_value=False) + .sum(new_dim, skipna=False)) <= limit + + +def _assert_single_chunk(var, axes): + for axis in axes: + if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]: + raise NotImplementedError( + 'Chunking along the dimension to be interpolated ' + '({}) is not yet supported.'.format(axis)) + + +def _localize(var, indexes_coords): + """ Speed up for linear and nearest neighbor method. + Only consider a subspace that is needed for the interpolation + """ + indexes = {} + for dim, [x, new_x] in indexes_coords.items(): + index = x.to_index() + imin = index.get_loc(np.min(new_x.values), method='nearest') + imax = index.get_loc(np.max(new_x.values), method='nearest') + + indexes[dim] = slice(max(imin - 2, 0), imax + 2) + indexes_coords[dim] = (x[indexes[dim]], new_x) + return var.isel(**indexes), indexes_coords + + +def _floatize_x(x, new_x): + """ Make x and new_x float. + This is particulary useful for datetime dtype. + x, new_x: tuple of np.ndarray + """ + x = list(x) + new_x = list(new_x) + for i in range(len(x)): + if _contains_datetime_like_objects(x[i]): + # Scipy casts coordinates to np.float64, which is not accurate + # enough for datetime64 (uses 64bit integer). + # We assume that the most of the bits are used to represent the + # offset (min(x)) and the variation (x - min(x)) can be + # represented by float. + xmin = x[i].min() + x[i] = datetime_to_numeric(x[i], offset=xmin, dtype=np.float64) + new_x[i] = datetime_to_numeric( + new_x[i], offset=xmin, dtype=np.float64) + return x, new_x + + +def interp(var, indexes_coords, method, **kwargs): + """ Make an interpolation of Variable + + Parameters + ---------- + var: Variable + index_coords: + Mapping from dimension name to a pair of original and new coordinates. + Original coordinates should be sorted in strictly ascending order. + Note that all the coordinates should be Variable objects. + method: string + One of {'linear', 'nearest', 'zero', 'slinear', 'quadratic', + 'cubic'}. For multidimensional interpolation, only + {'linear', 'nearest'} can be used. + **kwargs: + keyword arguments to be passed to scipy.interpolate + + Returns + ------- + Interpolated Variable + + See Also + -------- + DataArray.interp + Dataset.interp + """ + if not indexes_coords: + return var.copy() + + # simple speed up for the local interpolation + if method in ['linear', 'nearest']: + var, indexes_coords = _localize(var, indexes_coords) + + # default behavior + kwargs['bounds_error'] = kwargs.get('bounds_error', False) + + # target dimensions + dims = list(indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in dims]) + destination = broadcast_variables(*new_x) + + # transpose to make the interpolated axis to the last position + broadcast_dims = [d for d in var.dims if d not in dims] + original_dims = broadcast_dims + dims + new_dims = broadcast_dims + list(destination[0].dims) + interped = interp_func(var.transpose(*original_dims).data, + x, destination, method, kwargs) + + result = Variable(new_dims, interped, attrs=var.attrs) + + # dimension of the output array + out_dims = OrderedSet() + for d in var.dims: + if d in dims: + out_dims.update(indexes_coords[d][1].dims) + else: + out_dims.add(d) + return result.transpose(*tuple(out_dims)) + + +def interp_func(var, x, new_x, method, kwargs): + """ + multi-dimensional interpolation for array-like. Interpolated axes should be + located in the last position. + + Parameters + ---------- + var: np.ndarray or dask.array.Array + Array to be interpolated. The final dimension is interpolated. + x: a list of 1d array. + Original coordinates. Should not contain NaN. + new_x: a list of 1d array + New coordinates. Should not contain NaN. + method: string + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for + 1-dimensional itnterpolation. + {'linear', 'nearest'} for multidimensional interpolation + **kwargs: + Optional keyword arguments to be passed to scipy.interpolator + + Returns + ------- + interpolated: array + Interpolated array + + Note + ---- + This requiers scipy installed. + + See Also + -------- + scipy.interpolate.interp1d + """ + if not x: + return var.copy() + + if len(x) == 1: + func, kwargs = _get_interpolator(method, vectorizeable_only=True, + **kwargs) + else: + func, kwargs = _get_interpolator_nd(method, **kwargs) + + if isinstance(var, dask_array_type): + import dask.array as da + + _assert_single_chunk(var, range(var.ndim - len(x), var.ndim)) + chunks = var.chunks[:-len(x)] + new_x[0].shape + drop_axis = range(var.ndim - len(x), var.ndim) + new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim) + return da.map_blocks(_interpnd, var, x, new_x, func, kwargs, + dtype=var.dtype, chunks=chunks, + new_axis=new_axis, drop_axis=drop_axis) + + return _interpnd(var, x, new_x, func, kwargs) + + +def _interp1d(var, x, new_x, func, kwargs): + # x, new_x are tuples of size 1. + x, new_x = x[0], new_x[0] + rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x)) + if new_x.ndim > 1: + return rslt.reshape(var.shape[:-1] + new_x.shape) + if new_x.ndim == 0: + return rslt[..., -1] + return rslt + + +def _interpnd(var, x, new_x, func, kwargs): + x, new_x = _floatize_x(x, new_x) + + if len(x) == 1: + return _interp1d(var, x, new_x, func, kwargs) + + # move the interpolation axes to the start position + var = var.transpose(range(-len(x), var.ndim - len(x))) + # stack new_x to 1 vector, with reshape + xi = np.stack([x1.values.ravel() for x1 in new_x], axis=-1) + rslt = func(x, var, xi, **kwargs) + # move back the interpolation axes to the last position + rslt = rslt.transpose(range(-rslt.ndim + 1, 1)) + return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py new file mode 100644 index 00000000000..4d3f03c899e --- /dev/null +++ b/xarray/core/nanops.py @@ -0,0 +1,207 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +from . import dtypes, nputils +from .duck_array_ops import ( + _dask_or_eager_func, count, fillna, isnull, where_method) +from .pycompat import dask_array_type + +try: + import dask.array as dask_array +except ImportError: + dask_array = None + + +def _replace_nan(a, val): + """ + replace nan in a by val, and returns the replaced array and the nan + position + """ + mask = isnull(a) + return where_method(val, mask, a), mask + + +def _maybe_null_out(result, axis, mask, min_count=1): + """ + xarray version of pandas.core.nanops._maybe_null_out + """ + if hasattr(axis, '__len__'): # if tuple or list + raise ValueError('min_count is not available for reduction ' + 'with more than one dimensions.') + + if axis is not None and getattr(result, 'ndim', False): + null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 + if null_mask.any(): + dtype, fill_value = dtypes.maybe_promote(result.dtype) + result = result.astype(dtype) + result[null_mask] = fill_value + + elif getattr(result, 'dtype', None) not in dtypes.NAT_TYPES: + null_mask = mask.size - mask.sum() + if null_mask < min_count: + result = np.nan + + return result + + +def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): + """ In house nanargmin, nanargmax for object arrays. Always return integer + type + """ + valid_count = count(value, axis=axis) + value = fillna(value, fill_value) + data = _dask_or_eager_func(func)(value, axis=axis, **kwargs) + + # TODO This will evaluate dask arrays and might be costly. + if (valid_count == 0).any(): + raise ValueError('All-NaN slice encountered') + + return data + + +def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): + """ In house nanmin and nanmax for object array """ + valid_count = count(value, axis=axis) + filled_value = fillna(value, fill_value) + data = getattr(np, func)(filled_value, axis=axis, **kwargs) + if not hasattr(data, 'dtype'): # scalar case + data = dtypes.fill_value(value.dtype) if valid_count == 0 else data + return np.array(data, dtype=value.dtype) + return where_method(data, valid_count != 0) + + +def nanmin(a, axis=None, out=None): + if a.dtype.kind == 'O': + return _nan_minmax_object( + 'min', dtypes.get_pos_infinity(a.dtype), a, axis) + + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanmin(a, axis=axis) + + +def nanmax(a, axis=None, out=None): + if a.dtype.kind == 'O': + return _nan_minmax_object( + 'max', dtypes.get_neg_infinity(a.dtype), a, axis) + + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanmax(a, axis=axis) + + +def nanargmin(a, axis=None): + fill_value = dtypes.get_pos_infinity(a.dtype) + if a.dtype.kind == 'O': + return _nan_argminmax_object('argmin', fill_value, a, axis=axis) + a, mask = _replace_nan(a, fill_value) + if isinstance(a, dask_array_type): + res = dask_array.argmin(a, axis=axis) + else: + res = np.argmin(a, axis=axis) + + if mask is not None: + mask = mask.all(axis=axis) + if mask.any(): + raise ValueError("All-NaN slice encountered") + return res + + +def nanargmax(a, axis=None): + fill_value = dtypes.get_neg_infinity(a.dtype) + if a.dtype.kind == 'O': + return _nan_argminmax_object('argmax', fill_value, a, axis=axis) + + a, mask = _replace_nan(a, fill_value) + if isinstance(a, dask_array_type): + res = dask_array.argmax(a, axis=axis) + else: + res = np.argmax(a, axis=axis) + + if mask is not None: + mask = mask.all(axis=axis) + if mask.any(): + raise ValueError("All-NaN slice encountered") + return res + + +def nansum(a, axis=None, dtype=None, out=None, min_count=None): + a, mask = _replace_nan(a, 0) + result = _dask_or_eager_func('sum')(a, axis=axis, dtype=dtype) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def _nanmean_ddof_object(ddof, value, axis=None, **kwargs): + """ In house nanmean. ddof argument will be used in _nanvar method """ + from .duck_array_ops import (count, fillna, _dask_or_eager_func, + where_method) + + valid_count = count(value, axis=axis) + value = fillna(value, 0) + # As dtype inference is impossible for object dtype, we assume float + # https://github.com/dask/dask/issues/3162 + dtype = kwargs.pop('dtype', None) + if dtype is None and value.dtype.kind == 'O': + dtype = value.dtype if value.dtype.kind in ['cf'] else float + + data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs) + data = data / (valid_count - ddof) + return where_method(data, valid_count != 0) + + +def nanmean(a, axis=None, dtype=None, out=None): + if a.dtype.kind == 'O': + return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype) + + if isinstance(a, dask_array_type): + return dask_array.nanmean(a, axis=axis, dtype=dtype) + + return np.nanmean(a, axis=axis, dtype=dtype) + + +def nanmedian(a, axis=None, out=None): + return _dask_or_eager_func('nanmedian', eager_module=nputils)(a, axis=axis) + + +def _nanvar_object(value, axis=None, **kwargs): + ddof = kwargs.pop('ddof', 0) + kwargs_mean = kwargs.copy() + kwargs_mean.pop('keepdims', None) + value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis, + keepdims=True, **kwargs_mean) + squared = (value.astype(value_mean.dtype) - value_mean)**2 + return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs) + + +def nanvar(a, axis=None, dtype=None, out=None, ddof=0): + if a.dtype.kind == 'O': + return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof) + + return _dask_or_eager_func('nanvar', eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanstd(a, axis=None, dtype=None, out=None, ddof=0): + return _dask_or_eager_func('nanstd', eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanprod(a, axis=None, dtype=None, out=None, min_count=None): + a, mask = _replace_nan(a, 1) + result = _dask_or_eager_func('nanprod')(a, axis=axis, dtype=dtype, out=out) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def nancumsum(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nancumsum', eager_module=nputils)( + a, axis=axis, dtype=dtype) + + +def nancumprod(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nancumprod', eager_module=nputils)( + a, axis=axis, dtype=dtype) diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index bbe7b745621..efa68c8bad5 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -1,247 +1,285 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import numpy as np - -try: - from numpy import nancumsum, nancumprod, flip -except ImportError: # pragma: no cover - # Code copied from newer versions of NumPy (v1.12). - # Used under the terms of NumPy's license, see licenses/NUMPY_LICENSE. - - def _replace_nan(a, val): - """ - If `a` is of inexact type, make a copy of `a`, replace NaNs with - the `val` value, and return the copy together with a boolean mask - marking the locations where NaNs were present. If `a` is not of - inexact type, do nothing and return `a` together with a mask of None. - - Note that scalars will end up as array scalars, which is important - for using the result as the value of the out argument in some - operations. - - Parameters - ---------- - a : array-like - Input array. - val : float - NaN values are set to val before doing the operation. - - Returns - ------- - y : ndarray - If `a` is of inexact type, return a copy of `a` with the NaNs - replaced by the fill value, otherwise return `a`. - mask: {bool, None} - If `a` is of inexact type, return a boolean mask marking locations - of NaNs, otherwise return None. - - """ - is_new = not isinstance(a, np.ndarray) - if is_new: - a = np.array(a) - if not issubclass(a.dtype.type, np.inexact): - return a, None - if not is_new: - # need copy - a = np.array(a, subok=True) - - mask = np.isnan(a) - np.copyto(a, val, where=mask) - return a, mask - - def nancumsum(a, axis=None, dtype=None, out=None): - """ - Return the cumulative sum of array elements over a given axis treating - Not a Numbers (NaNs) as zero. The cumulative sum does not change when - NaNs are encountered and leading NaNs are replaced by zeros. +from __future__ import absolute_import, division, print_function - Zeros are returned for slices that are all-NaN or empty. +from distutils.version import LooseVersion - .. versionadded:: 1.12.0 - - Parameters - ---------- - a : array_like - Input array. - axis : int, optional - Axis along which the cumulative sum is computed. The default - (None) is to compute the cumsum over the flattened array. - dtype : dtype, optional - Type of the returned array and of the accumulator in which the - elements are summed. If `dtype` is not specified, it defaults - to the dtype of `a`, unless `a` has an integer dtype with a - precision less than that of the default platform integer. In - that case, the default platform integer is used. - out : ndarray, optional - Alternative output array in which to place the result. It must - have the same shape and buffer length as the expected output - but the type will be cast if necessary. See `doc.ufuncs` - (Section "Output arguments") for more details. - - Returns - ------- - nancumsum : ndarray. - A new array holding the result is returned unless `out` is - specified, in which it is returned. The result has the same - size as `a`, and the same shape as `a` if `axis` is not None - or `a` is a 1-d array. - - See Also - -------- - numpy.cumsum : Cumulative sum across array propagating NaNs. - isnan : Show which elements are NaN. - - Examples - -------- - >>> np.nancumsum(1) - array([1]) - >>> np.nancumsum([1]) - array([1]) - >>> np.nancumsum([1, np.nan]) - array([ 1., 1.]) - >>> a = np.array([[1, 2], [3, np.nan]]) - >>> np.nancumsum(a) - array([ 1., 3., 6., 6.]) - >>> np.nancumsum(a, axis=0) - array([[ 1., 2.], - [ 4., 2.]]) - >>> np.nancumsum(a, axis=1) - array([[ 1., 3.], - [ 3., 3.]]) - - """ - a, mask = _replace_nan(a, 0) - return np.cumsum(a, axis=axis, dtype=dtype, out=out) - - def nancumprod(a, axis=None, dtype=None, out=None): - """ - Return the cumulative product of array elements over a given axis - treating Not a Numbers (NaNs) as one. The cumulative product does not - change when NaNs are encountered and leading NaNs are replaced by ones. - - Ones are returned for slices that are all-NaN or empty. - - .. versionadded:: 1.12.0 - - Parameters - ---------- - a : array_like - Input array. - axis : int, optional - Axis along which the cumulative product is computed. By default - the input is flattened. - dtype : dtype, optional - Type of the returned array, as well as of the accumulator in which - the elements are multiplied. If *dtype* is not specified, it - defaults to the dtype of `a`, unless `a` has an integer dtype with - a precision less than that of the default platform integer. In - that case, the default platform integer is used instead. - out : ndarray, optional - Alternative output array in which to place the result. It must - have the same shape and buffer length as the expected output - but the type of the resulting values will be cast if necessary. - - Returns - ------- - nancumprod : ndarray - A new array holding the result is returned unless `out` is - specified, in which case it is returned. - - See Also - -------- - numpy.cumprod : Cumulative product across array propagating NaNs. - isnan : Show which elements are NaN. - - Examples - -------- - >>> np.nancumprod(1) - array([1]) - >>> np.nancumprod([1]) - array([1]) - >>> np.nancumprod([1, np.nan]) - array([ 1., 1.]) - >>> a = np.array([[1, 2], [3, np.nan]]) - >>> np.nancumprod(a) - array([ 1., 2., 6., 6.]) - >>> np.nancumprod(a, axis=0) - array([[ 1., 2.], - [ 3., 2.]]) - >>> np.nancumprod(a, axis=1) - array([[ 1., 2.], - [ 3., 3.]]) +import numpy as np - """ - a, mask = _replace_nan(a, 1) - return np.cumprod(a, axis=axis, dtype=dtype, out=out) +try: + from numpy import isin +except ImportError: - def flip(m, axis): + def isin(element, test_elements, assume_unique=False, invert=False): """ - Reverse the order of elements in an array along the given axis. - - The shape of the array is preserved, but the elements are reordered. - - .. versionadded:: 1.12.0 + Calculates `element in test_elements`, broadcasting over `element` + only. Returns a boolean array of the same shape as `element` that is + True where an element of `element` is in `test_elements` and False + otherwise. Parameters ---------- - m : array_like + element : array_like Input array. - axis : integer - Axis in array, which entries are reversed. - + test_elements : array_like + The values against which to test each value of `element`. + This argument is flattened if it is an array or array_like. + See notes for behavior with non-array-like parameters. + assume_unique : bool, optional + If True, the input arrays are both assumed to be unique, which + can speed up the calculation. Default is False. + invert : bool, optional + If True, the values in the returned array are inverted, as if + calculating `element not in test_elements`. Default is False. + ``np.isin(a, b, invert=True)`` is equivalent to (but faster + than) ``np.invert(np.isin(a, b))``. Returns ------- - out : array_like - A view of `m` with the entries of axis reversed. Since a view is - returned, this operation is done in constant time. + isin : ndarray, bool + Has the same shape as `element`. The values `element[isin]` + are in `test_elements`. See Also -------- - flipud : Flip an array vertically (axis=0). - fliplr : Flip an array horizontally (axis=1). + in1d : Flattened version of this function. + numpy.lib.arraysetops : Module with a number of other functions for + performing set operations on arrays. Notes ----- - flip(m, 0) is equivalent to flipud(m). - flip(m, 1) is equivalent to fliplr(m). - flip(m, n) corresponds to ``m[...,::-1,...]`` with ``::-1`` at index n. - Examples - -------- - >>> A = np.arange(8).reshape((2,2,2)) - >>> A - array([[[0, 1], - [2, 3]], - - [[4, 5], - [6, 7]]]) - - >>> flip(A, 0) - array([[[4, 5], - [6, 7]], + `isin` is an element-wise function version of the python keyword `in`. + ``isin(a, b)`` is roughly equivalent to + ``np.array([item in b for item in a])`` if `a` and `b` are 1-D + sequences. - [[0, 1], - [2, 3]]]) + `element` and `test_elements` are converted to arrays if they are not + already. If `test_elements` is a set (or other non-sequence collection) + it will be converted to an object array with one element, rather than + an array of the values contained in `test_elements`. This is a + consequence of the `array` constructor's way of handling non-sequence + collections. Converting the set to a list usually gives the desired + behavior. - >>> flip(A, 1) - array([[[2, 3], - [0, 1]], + .. versionadded:: 1.13.0 - [[6, 7], - [4, 5]]]) - - >>> A = np.random.randn(3,4,5) - >>> np.all(flip(A,2) == A[:,:,::-1,...]) - True + Examples + -------- + >>> element = 2*np.arange(4).reshape((2, 2)) + >>> element + array([[0, 2], + [4, 6]]) + >>> test_elements = [1, 2, 4, 8] + >>> mask = np.isin(element, test_elements) + >>> mask + array([[ False, True], + [ True, False]]) + >>> element[mask] + array([2, 4]) + >>> mask = np.isin(element, test_elements, invert=True) + >>> mask + array([[ True, False], + [ False, True]]) + >>> element[mask] + array([0, 6]) + + Because of how `array` handles sets, the following does not + work as expected: + + >>> test_set = {1, 2, 4, 8} + >>> np.isin(element, test_set) + array([[ False, False], + [ False, False]]) + + Casting the set to a list gives the expected result: + + >>> np.isin(element, list(test_set)) + array([[ False, True], + [ True, False]]) """ - if not hasattr(m, 'ndim'): - m = np.asarray(m) - indexer = [slice(None)] * m.ndim - try: - indexer[axis] = slice(None, None, -1) - except IndexError: - raise ValueError("axis=%i is invalid for the %i-dimensional " - "input array" % (axis, m.ndim)) - return m[tuple(indexer)] + element = np.asarray(element) + return np.in1d(element, test_elements, assume_unique=assume_unique, + invert=invert).reshape(element.shape) + + +if LooseVersion(np.__version__) >= LooseVersion('1.13'): + gradient = np.gradient +else: + def normalize_axis_tuple(axes, N): + if isinstance(axes, int): + axes = (axes, ) + return tuple([N + a if a < 0 else a for a in axes]) + + def gradient(f, *varargs, **kwargs): + f = np.asanyarray(f) + N = f.ndim # number of dimensions + + axes = kwargs.pop('axis', None) + if axes is None: + axes = tuple(range(N)) + else: + axes = normalize_axis_tuple(axes, N) + + len_axes = len(axes) + n = len(varargs) + if n == 0: + # no spacing argument - use 1 in all axes + dx = [1.0] * len_axes + elif n == 1 and np.ndim(varargs[0]) == 0: + # single scalar for all axes + dx = varargs * len_axes + elif n == len_axes: + # scalar or 1d array for each axis + dx = list(varargs) + for i, distances in enumerate(dx): + if np.ndim(distances) == 0: + continue + elif np.ndim(distances) != 1: + raise ValueError("distances must be either scalars or 1d") + if len(distances) != f.shape[axes[i]]: + raise ValueError("when 1d, distances must match the " + "length of the corresponding dimension") + diffx = np.diff(distances) + # if distances are constant reduce to the scalar case + # since it brings a consistent speedup + if (diffx == diffx[0]).all(): + diffx = diffx[0] + dx[i] = diffx + else: + raise TypeError("invalid number of arguments") + + edge_order = kwargs.pop('edge_order', 1) + if kwargs: + raise TypeError('"{}" are not valid keyword arguments.'.format( + '", "'.join(kwargs.keys()))) + if edge_order > 2: + raise ValueError("'edge_order' greater than 2 not supported") + + # use central differences on interior and one-sided differences on the + # endpoints. This preserves second order-accuracy over the full domain. + + outvals = [] + + # create slice objects --- initially all are [:, :, ..., :] + slice1 = [slice(None)] * N + slice2 = [slice(None)] * N + slice3 = [slice(None)] * N + slice4 = [slice(None)] * N + + otype = f.dtype.char + if otype not in ['f', 'd', 'F', 'D', 'm', 'M']: + otype = 'd' + + # Difference of datetime64 elements results in timedelta64 + if otype == 'M': + # Need to use the full dtype name because it contains unit + # information + otype = f.dtype.name.replace('datetime', 'timedelta') + elif otype == 'm': + # Needs to keep the specific units, can't be a general unit + otype = f.dtype + + # Convert datetime64 data into ints. Make dummy variable `y` + # that is a view of ints if the data is datetime64, otherwise + # just set y equal to the array `f`. + if f.dtype.char in ["M", "m"]: + y = f.view('int64') + else: + y = f + + for i, axis in enumerate(axes): + if y.shape[axis] < edge_order + 1: + raise ValueError( + "Shape of array too small to calculate a numerical " + "gradient, at least (edge_order + 1) elements are " + "required.") + # result allocation + out = np.empty_like(y, dtype=otype) + + uniform_spacing = np.ndim(dx[i]) == 0 + + # Numerical differentiation: 2nd order interior + slice1[axis] = slice(1, -1) + slice2[axis] = slice(None, -2) + slice3[axis] = slice(1, -1) + slice4[axis] = slice(2, None) + + if uniform_spacing: + out[slice1] = (f[slice4] - f[slice2]) / (2. * dx[i]) + else: + dx1 = dx[i][0:-1] + dx2 = dx[i][1:] + a = -(dx2) / (dx1 * (dx1 + dx2)) + b = (dx2 - dx1) / (dx1 * dx2) + c = dx1 / (dx2 * (dx1 + dx2)) + # fix the shape for broadcasting + shape = np.ones(N, dtype=int) + shape[axis] = -1 + a.shape = b.shape = c.shape = shape + # 1D equivalent -- + # out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] + out[slice1] = a * f[slice2] + b * f[slice3] + c * f[slice4] + + # Numerical differentiation: 1st order edges + if edge_order == 1: + slice1[axis] = 0 + slice2[axis] = 1 + slice3[axis] = 0 + dx_0 = dx[i] if uniform_spacing else dx[i][0] + # 1D equivalent -- out[0] = (y[1] - y[0]) / (x[1] - x[0]) + out[slice1] = (y[slice2] - y[slice3]) / dx_0 + + slice1[axis] = -1 + slice2[axis] = -1 + slice3[axis] = -2 + dx_n = dx[i] if uniform_spacing else dx[i][-1] + # 1D equivalent -- out[-1] = (y[-1] - y[-2]) / (x[-1] - x[-2]) + out[slice1] = (y[slice2] - y[slice3]) / dx_n + + # Numerical differentiation: 2nd order edges + else: + slice1[axis] = 0 + slice2[axis] = 0 + slice3[axis] = 1 + slice4[axis] = 2 + if uniform_spacing: + a = -1.5 / dx[i] + b = 2. / dx[i] + c = -0.5 / dx[i] + else: + dx1 = dx[i][0] + dx2 = dx[i][1] + a = -(2. * dx1 + dx2) / (dx1 * (dx1 + dx2)) + b = (dx1 + dx2) / (dx1 * dx2) + c = - dx1 / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[0] = a * y[0] + b * y[1] + c * y[2] + out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] + + slice1[axis] = -1 + slice2[axis] = -3 + slice3[axis] = -2 + slice4[axis] = -1 + if uniform_spacing: + a = 0.5 / dx[i] + b = -2. / dx[i] + c = 1.5 / dx[i] + else: + dx1 = dx[i][-2] + dx2 = dx[i][-1] + a = (dx2) / (dx1 * (dx1 + dx2)) + b = - (dx2 + dx1) / (dx1 * dx2) + c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] + out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] + + outvals.append(out) + + # reset the slice object in this dimension to ":" + slice1[axis] = slice(None) + slice2[axis] = slice(None) + slice3[axis] = slice(None) + slice4[axis] = slice(None) + + if len_axes == 1: + return outvals[0] + else: + return outvals diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 8ac04752e85..a8d596abd86 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -1,9 +1,17 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +import warnings + import numpy as np import pandas as pd -import warnings + +try: + import bottleneck as bn + _USE_BOTTLENECK = True +except ImportError: + # use numpy methods instead + bn = np + _USE_BOTTLENECK = False def _validate_axis(data, axis): @@ -133,3 +141,98 @@ def __setitem__(self, key, value): mixed_positions, vindex_positions = _advanced_indexer_subspaces(key) self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) + + +def rolling_window(a, axis, window, center, fill_value): + """ rolling window with padding. """ + pads = [(0, 0) for s in a.shape] + if center: + start = int(window / 2) # 10 -> 5, 9 -> 4 + end = window - 1 - start + pads[axis] = (start, end) + else: + pads[axis] = (window - 1, 0) + a = np.pad(a, pads, mode='constant', constant_values=fill_value) + return _rolling_window(a, window, axis) + + +def _rolling_window(a, window, axis=-1): + """ + Make an ndarray with a rolling window along axis. + + Parameters + ---------- + a : array_like + Array to add rolling window to + axis: int + axis position along which rolling window will be applied. + window : int + Size of rolling window + + Returns + ------- + Array that is a view of the original array with a added dimension + of size w. + + Examples + -------- + >>> x=np.arange(10).reshape((2,5)) + >>> np.rolling_window(x, 3, axis=-1) + array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]], + [[5, 6, 7], [6, 7, 8], [7, 8, 9]]]) + + Calculate rolling mean of last dimension: + >>> np.mean(np.rolling_window(x, 3, axis=-1), -1) + array([[ 1., 2., 3.], + [ 6., 7., 8.]]) + + This function is taken from https://github.com/numpy/numpy/pull/31 + but slightly modified to accept axis option. + """ + axis = _validate_axis(a, axis) + a = np.swapaxes(a, axis, -1) + + if window < 1: + raise ValueError( + "`window` must be at least 1. Given : {}".format(window)) + if window > a.shape[-1]: + raise ValueError("`window` is too long. Given : {}".format(window)) + + shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) + strides = a.strides + (a.strides[-1],) + rolling = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides, + writeable=False) + return np.swapaxes(rolling, -2, axis) + + +def _create_bottleneck_method(name, npmodule=np): + def f(values, axis=None, **kwds): + dtype = kwds.get('dtype', None) + bn_func = getattr(bn, name, None) + + if (_USE_BOTTLENECK and bn_func is not None and + not isinstance(axis, tuple) and + values.dtype.kind in 'uifc' and + values.dtype.isnative and + (dtype is None or np.dtype(dtype) == values.dtype)): + # bottleneck does not take care dtype, min_count + kwds.pop('dtype', None) + result = bn_func(values, axis=axis, **kwds) + else: + result = getattr(npmodule, name)(values, axis=axis, **kwds) + + return result + + f.__name__ = name + return f + + +nanmin = _create_bottleneck_method('nanmin') +nanmax = _create_bottleneck_method('nanmax') +nanmean = _create_bottleneck_method('nanmean') +nanmedian = _create_bottleneck_method('nanmedian') +nanvar = _create_bottleneck_method('nanvar') +nanstd = _create_bottleneck_method('nanstd') +nanprod = _create_bottleneck_method('nanprod') +nancumsum = _create_bottleneck_method('nancumsum') +nancumprod = _create_bottleneck_method('nancumprod') diff --git a/xarray/core/ops.py b/xarray/core/ops.py index d02b8fa3108..a0dd2212a8f 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -5,19 +5,15 @@ functions. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import operator import numpy as np -import pandas as pd -from . import dtypes -from . import duck_array_ops -from .pycompat import PY3 +from . import dtypes, duck_array_ops from .nputils import array_eq, array_ne +from .pycompat import PY3 try: import bottleneck as bn @@ -90,7 +86,7 @@ If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been - implemented (object, datetime64 or timedelta64). + implemented (object, datetime64 or timedelta64).{min_count_docs} keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -106,6 +102,12 @@ indicated dimension(s) removed. """ +_MINCOUNT_DOCSTRING = """ +min_count : int, default None + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will + be NA. New in version 0.10.8: Added with the default being None.""" + _ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\ Reduce this {da_or_ds}'s data windows by applying `{name}` along its dimension. @@ -227,20 +229,8 @@ def func(self, *args, **kwargs): def rolling_count(rolling): - not_null = rolling.obj.notnull() - instance_attr_dict = {'center': rolling.center, - 'min_periods': rolling.min_periods, - rolling.dim: rolling.window} - rolling_count = not_null.rolling(**instance_attr_dict).sum() - - if rolling.min_periods is None: - return rolling_count - - # otherwise we need to filter out points where there aren't enough periods - # but not_null is False, and so the NaNs don't flow through - # array with points where there are enough values given min_periods - enough_periods = rolling_count >= rolling.min_periods - + rolling_count = rolling._counts() + enough_periods = rolling_count >= rolling._min_periods return rolling_count.where(enough_periods) @@ -252,11 +242,15 @@ def inject_reduce_methods(cls): [('count', duck_array_ops.count, False)]) for name, f, include_skipna in methods: numeric_only = getattr(f, 'numeric_only', False) + available_min_count = getattr(f, 'available_min_count', False) + min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else '' + func = cls._reduce_method(f, include_skipna, numeric_only) func.__name__ = name func.__doc__ = _REDUCE_DOCSTRING_TEMPLATE.format( name=name, cls=cls.__name__, - extra_args=cls._reduce_extra_args_docstring.format(name=name)) + extra_args=cls._reduce_extra_args_docstring.format(name=name), + min_count_docs=min_count_docs) setattr(cls, name, func) @@ -320,7 +314,8 @@ def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True): setattr(cls, name, cls._unary_op(_method_wrapper(name))) for name in PANDAS_UNARY_FUNCTIONS: - f = _func_slash_method_wrapper(getattr(pd, name), name=name) + f = _func_slash_method_wrapper( + getattr(duck_array_ops, name), name=name) setattr(cls, name, cls._unary_op(f)) f = _func_slash_method_wrapper(duck_array_ops.around, name='round') diff --git a/xarray/core/options.py b/xarray/core/options.py index 9f06f8dbbae..ab461ca86bc 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,14 +1,71 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +import warnings + +DISPLAY_WIDTH = 'display_width' +ARITHMETIC_JOIN = 'arithmetic_join' +ENABLE_CFTIMEINDEX = 'enable_cftimeindex' +FILE_CACHE_MAXSIZE = 'file_cache_maxsize' +CMAP_SEQUENTIAL = 'cmap_sequential' +CMAP_DIVERGENT = 'cmap_divergent' +KEEP_ATTRS = 'keep_attrs' OPTIONS = { - 'display_width': 80, - 'arithmetic_join': 'inner', + DISPLAY_WIDTH: 80, + ARITHMETIC_JOIN: 'inner', + ENABLE_CFTIMEINDEX: True, + FILE_CACHE_MAXSIZE: 128, + CMAP_SEQUENTIAL: 'viridis', + CMAP_DIVERGENT: 'RdBu_r', + KEEP_ATTRS: 'default' +} + +_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact']) + + +def _positive_integer(value): + return isinstance(value, int) and value > 0 + + +_VALIDATORS = { + DISPLAY_WIDTH: _positive_integer, + ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, + ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), + FILE_CACHE_MAXSIZE: _positive_integer, + KEEP_ATTRS: lambda choice: choice in [True, False, 'default'] } +def _set_file_cache_maxsize(value): + from ..backends.file_manager import FILE_CACHE + FILE_CACHE.maxsize = value + + +def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): + warnings.warn( + 'The enable_cftimeindex option is now a no-op ' + 'and will be removed in a future version of xarray.', + FutureWarning) + + +_SETTERS = { + FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, + ENABLE_CFTIMEINDEX: _warn_on_setting_enable_cftimeindex +} + + +def _get_keep_attrs(default): + global_choice = OPTIONS['keep_attrs'] + + if global_choice is 'default': + return default + elif global_choice in [True, False]: + return global_choice + else: + raise ValueError("The global option keep_attrs must be one of True, False or 'default'.") + + class set_options(object): """Set options for xarray in a controlled context. @@ -18,6 +75,21 @@ class set_options(object): Default: ``80``. - ``arithmetic_join``: DataArray/Dataset alignment in binary operations. Default: ``'inner'``. + - ``file_cache_maxsize``: maximum number of open files to hold in xarray's + global least-recently-usage cached. This should be smaller than your + system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. + Default: 128. + - ``cmap_sequential``: colormap to use for nondivergent data plots. + Default: ``viridis``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) + - ``cmap_divergent``: colormap to use for divergent data plots. + Default: ``RdBu_r``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) + - ``keep_attrs``: rule for whether to keep attributes on xarray + Datasets/dataarrays after operations. Either ``True`` to always keep + attrs, ``False`` to always discard them, or ``'default'`` to use original + logic that attrs should only be kept in unambiguous circumstances. + Default: ``'default'``. You can use ``set_options`` either as a context manager: @@ -37,16 +109,26 @@ class set_options(object): """ def __init__(self, **kwargs): - invalid_options = {k for k in kwargs if k not in OPTIONS} - if invalid_options: - raise ValueError('argument names %r are not in the set of valid ' - 'options %r' % (invalid_options, set(OPTIONS))) - self.old = OPTIONS.copy() - OPTIONS.update(kwargs) + self.old = {} + for k, v in kwargs.items(): + if k not in OPTIONS: + raise ValueError( + 'argument name %r is not in the set of valid options %r' + % (k, set(OPTIONS))) + if k in _VALIDATORS and not _VALIDATORS[k](v): + raise ValueError( + 'option %r given an invalid value: %r' % (k, v)) + self.old[k] = OPTIONS[k] + self._apply_update(kwargs) + + def _apply_update(self, options_dict): + for k, v in options_dict.items(): + if k in _SETTERS: + _SETTERS[k](v) + OPTIONS.update(options_dict) def __enter__(self): return def __exit__(self, type, value, traceback): - OPTIONS.clear() - OPTIONS.update(self.old) + self._apply_update(self.old) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 4b83df9e14f..b980bc279b0 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,8 +1,7 @@ # flake8: noqa -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import sys import numpy as np @@ -24,10 +23,14 @@ def itervalues(d): range = range zip = zip + from itertools import zip_longest from functools import reduce import builtins from urllib.request import urlretrieve from inspect import getfullargspec as getargspec + + def move_to_end(ordered_dict, key): + ordered_dict.move_to_end(key) else: # pragma: no cover # Python 2 basestring = basestring # noqa @@ -42,12 +45,19 @@ def itervalues(d): return d.itervalues() range = xrange - from itertools import izip as zip, imap as map + from itertools import ( + izip as zip, imap as map, izip_longest as zip_longest, + ) reduce = reduce import __builtin__ as builtins from urllib import urlretrieve from inspect import getargspec + def move_to_end(ordered_dict, key): + value = ordered_dict[key] + del ordered_dict[key] + ordered_dict[key] = value + integer_types = native_int_types + (np.integer,) try: @@ -74,7 +84,6 @@ def itervalues(d): except ImportError as e: path_type = () - try: from contextlib import suppress except ImportError: @@ -108,7 +117,8 @@ def __exit__(self, exctype, excinst, exctb): # exactly reproduce the limitations of the CPython interpreter. # # See http://bugs.python.org/issue12029 for more details - return exctype is not None and issubclass(exctype, self._exceptions) + return exctype is not None and issubclass( + exctype, self._exceptions) try: from contextlib import ExitStack except ImportError: @@ -185,7 +195,8 @@ def enter_context(self, cm): If successful, also pushes its __exit__ method as a callback and returns the result of the __enter__ method. """ - # We look up the special methods on the type to match the with statement + # We look up the special methods on the type to match the with + # statement _cm_type = type(cm) _exit = _cm_type.__exit__ result = _cm_type.__enter__(cm) @@ -208,7 +219,7 @@ def __exit__(self, *exc_details): def _fix_exception_context(new_exc, old_exc): # Context may not be correct, so find the end of the chain - while 1: + while True: exc_context = new_exc.__context__ if exc_context is old_exc: # Context is already set correctly (see issue 20317) @@ -231,7 +242,7 @@ def _fix_exception_context(new_exc, old_exc): suppressed_exc = True pending_raise = False exc_details = (None, None, None) - except: + except BaseException: new_exc_details = sys.exc_info() # simulate the stack of exceptions by setting the context _fix_exception_context(new_exc_details[1], exc_details[1]) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 78fd39d3245..edf7dfc3d41 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,10 +1,8 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function from . import ops -from .groupby import DataArrayGroupBy, DatasetGroupBy -from .pycompat import dask_array_type, OrderedDict +from .groupby import DEFAULT_DIMS, DataArrayGroupBy, DatasetGroupBy +from .pycompat import OrderedDict, dask_array_type RESAMPLE_DIM = '__resample_dim__' @@ -275,19 +273,18 @@ def apply(self, func, **kwargs): return combined.rename({self._resample_dim: self._dim}) - def reduce(self, func, dim=None, keep_attrs=False, **kwargs): + def reduce(self, func, dim=None, keep_attrs=None, **kwargs): """Reduce the items in this group by applying `func` along the pre-defined resampling dimension. - Note that `dim` is by default here and ignored if passed by the user; - this ensures compatibility with the existing reduce interface. - Parameters ---------- func : function Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. + dim : str or sequence of str, optional + Dimension(s) over which to apply `func`. keep_attrs : bool, optional If True, the datasets's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -301,8 +298,11 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = None + return super(DatasetResample, self).reduce( - func, self._dim, keep_attrs, **kwargs) + func, dim, keep_attrs, **kwargs) def _interpolate(self, kind='linear'): """Apply scipy.interpolate.interp1d along resampling dimension.""" diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 8209e70e5a8..883dbb34dff 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -1,16 +1,34 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import numpy as np +from __future__ import absolute_import, division, print_function + import warnings from distutils.version import LooseVersion -from .pycompat import OrderedDict, zip, dask_array_type -from .common import full_like -from .combine import concat -from .ops import (inject_bottleneck_rolling_methods, - inject_datasetrolling_methods, has_bottleneck, bn) +import numpy as np + +from . import dtypes from .dask_array_ops import dask_rolling_wrapper +from .ops import ( + bn, has_bottleneck, inject_bottleneck_rolling_methods, + inject_datasetrolling_methods) +from .pycompat import OrderedDict, dask_array_type, zip + + +def _get_new_dimname(dims, new_dim): + """ Get an new dimension name based on new_dim, that is not used in dims. + If the same name exists, we add an underscore(s) in the head. + + Example1: + dims: ['a', 'b', 'c'] + new_dim: ['_rolling'] + -> ['_rolling'] + Example2: + dims: ['a', 'b', 'c', '_rolling'] + new_dim: ['_rolling'] + -> ['__rolling'] + """ + while new_dim in dims: + new_dim = '_' + new_dim + return new_dim class Rolling(object): @@ -26,7 +44,7 @@ class Rolling(object): _attributes = ['window', 'min_periods', 'center', 'dim'] - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object. @@ -34,18 +52,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : Dataset or DataArray Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -97,92 +115,103 @@ def __len__(self): class DataArrayRolling(Rolling): - """ - This class adds the following class methods; - + _reduce_method(cls, func) - + _bottleneck_reduce(cls, func) - - These class methods will be used to inject numpy or bottleneck function - by doing - - >>> func = cls._reduce_method(f) - >>> func.__name__ = name - >>> setattr(cls, name, func) - - in ops.inject_bottleneck_rolling_methods. - - After the injection, the Rolling object will have `name` (such as `mean` or - `median`) methods, - e.g. it enables the following call, - >>> data.rolling().mean() - - If bottleneck is installed, some bottleneck methods will be used instdad of - the numpy method. + def __init__(self, obj, windows, min_periods=None, center=False): + """ + Moving window object for DataArray. + You should use DataArray.rolling() method to construct this object + instead of the class constructor. - see also - + rolling.DataArrayRolling - + ops.inject_bottleneck_rolling_methods - """ + Parameters + ---------- + obj : DataArray + Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. + min_periods : int, default None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : boolean, default False + Set the labels at the center of the window. - def __init__(self, obj, min_periods=None, center=False, **windows): - super(DataArrayRolling, self).__init__(obj, min_periods=min_periods, - center=center, **windows) - self._windows = None - self._valid_windows = None - self.window_indices = None - self.window_labels = None + Returns + ------- + rolling : type of input argument - self._setup_windows() + See Also + -------- + DataArray.rolling + DataArray.groupby + Dataset.rolling + Dataset.groupby + """ + super(DataArrayRolling, self).__init__( + obj, windows, min_periods=min_periods, center=center) - @property - def windows(self): - if self._windows is None: - self._windows = OrderedDict(zip(self.window_labels, - self.window_indices)) - return self._windows + self.window_labels = self.obj[self.dim] def __iter__(self): - for (label, indices, valid) in zip(self.window_labels, - self.window_indices, - self._valid_windows): - - window = self.obj.isel(**{self.dim: indices}) + stops = np.arange(1, len(self.window_labels) + 1) + starts = stops - int(self.window) + starts[:int(self.window)] = 0 + for (label, start, stop) in zip(self.window_labels, starts, stops): + window = self.obj.isel(**{self.dim: slice(start, stop)}) - if not valid: - window = full_like(window, fill_value=True, dtype=bool) + counts = window.count(dim=self.dim) + window = window.where(counts >= self._min_periods) yield (label, window) - def _setup_windows(self): + def construct(self, window_dim, stride=1, fill_value=dtypes.NA): """ - Find the indices and labels for each window - """ - from .dataarray import DataArray - - self.window_labels = self.obj[self.dim] + Convert this rolling object to xr.DataArray, + where the window dimension is stacked as a new dimension - window = int(self.window) - - dim_size = self.obj[self.dim].size - - stops = np.arange(dim_size) + 1 - starts = np.maximum(stops - window, 0) + Parameters + ---------- + window_dim: str + New name of the window dimension. + stride: integer, optional + Size of stride for the rolling window. + fill_value: optional. Default dtypes.NA + Filling value to match the dimension size. - if self._min_periods > 1: - valid_windows = (stops - starts) >= self._min_periods - else: - # No invalid windows - valid_windows = np.ones(dim_size, dtype=bool) - self._valid_windows = DataArray(valid_windows, dims=(self.dim, ), - coords=self.obj[self.dim].coords) + Returns + ------- + DataArray that is a view of the original array. The returned array is + not writeable. + + Examples + -------- + >>> da = DataArray(np.arange(8).reshape(2, 4), dims=('a', 'b')) + >>> + >>> rolling = da.rolling(a=3) + >>> rolling.to_datarray('window_dim') + + array([[[np.nan, np.nan, 0], [np.nan, 0, 1], [0, 1, 2], [1, 2, 3]], + [[np.nan, np.nan, 4], [np.nan, 4, 5], [4, 5, 6], [5, 6, 7]]]) + Dimensions without coordinates: a, b, window_dim + >>> + >>> rolling = da.rolling(a=3, center=True) + >>> rolling.to_datarray('window_dim') + + array([[[np.nan, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, np.nan]], + [[np.nan, 4, 5], [4, 5, 6], [5, 6, 7], [6, 7, np.nan]]]) + Dimensions without coordinates: a, b, window_dim + """ - self.window_indices = [slice(start, stop) - for start, stop in zip(starts, stops)] + from .dataarray import DataArray - def _center_result(self, result): - """center result""" - shift = (-self.window // 2) + 1 - return result.shift(**{self.dim: shift}) + window = self.obj.variable.rolling_window(self.dim, self.window, + window_dim, self.center, + fill_value=fill_value) + result = DataArray(window, dims=self.obj.dims + (window_dim,), + coords=self.obj.coords) + return result.isel(**{self.dim: slice(None, None, stride)}) def reduce(self, func, **kwargs): """Reduce the items in this group by applying `func` along some @@ -202,27 +231,27 @@ def reduce(self, func, **kwargs): reduced : DataArray Array with summarized data. """ - - windows = [window.reduce(func, dim=self.dim, **kwargs) - for _, window in self] - - # Find valid windows based on count - if self.dim in self.obj.coords: - concat_dim = self.window_labels - else: - concat_dim = self.dim - counts = concat([window.count(dim=self.dim) for _, window in self], - dim=concat_dim) - result = concat(windows, dim=concat_dim) - # restore dim order - result = result.transpose(*self.obj.dims) - - result = result.where(counts >= self._min_periods) - - if self.center: - result = self._center_result(result) - - return result + rolling_dim = _get_new_dimname(self.obj.dims, '_rolling_dim') + windows = self.construct(rolling_dim) + result = windows.reduce(func, dim=rolling_dim, **kwargs) + + # Find valid windows based on count. + counts = self._counts() + return result.where(counts >= self._min_periods) + + def _counts(self): + """ Number of non-nan entries in each rolling window. """ + + rolling_dim = _get_new_dimname(self.obj.dims, '_rolling_dim') + # We use False as the fill_value instead of np.nan, since boolean + # array is faster to be reduced than object array. + # The use of skipna==False is also faster since it does not need to + # copy the strided array. + counts = (self.obj.notnull() + .rolling(center=self.center, **{self.dim: self.window}) + .construct(rolling_dim, fill_value=False) + .sum(dim=rolling_dim, skipna=False)) + return counts @classmethod def _reduce_method(cls, func): @@ -254,69 +283,79 @@ def wrapped_func(self, **kwargs): axis = self.obj.get_axis_num(self.dim) - if isinstance(self.obj.data, dask_array_type): - values = dask_rolling_wrapper(func, self.obj.data, + padded = self.obj.variable + if self.center: + if (LooseVersion(np.__version__) < LooseVersion('1.13') and + self.obj.dtype.kind == 'b'): + # with numpy < 1.13 bottleneck cannot handle np.nan-Boolean + # mixed array correctly. We cast boolean array to float. + padded = padded.astype(float) + + if isinstance(padded.data, dask_array_type): + # Workaround to make the padded chunk size is larger than + # self.window-1 + shift = - (self.window + 1) // 2 + offset = (self.window - 1) // 2 + valid = (slice(None), ) * axis + ( + slice(offset, offset + self.obj.shape[axis]), ) + else: + shift = (-self.window // 2) + 1 + valid = (slice(None), ) * axis + (slice(-shift, None), ) + padded = padded.pad_with_fill_value(**{self.dim: (0, -shift)}) + + if isinstance(padded.data, dask_array_type): + values = dask_rolling_wrapper(func, padded, window=self.window, min_count=min_count, axis=axis) else: - values = func(self.obj.data, window=self.window, + values = func(padded.data, window=self.window, min_count=min_count, axis=axis) - result = DataArray(values, self.obj.coords) - if self.center: - result = self._center_result(result) + values = values[valid] + result = DataArray(values, self.obj.coords) return result return wrapped_func class DatasetRolling(Rolling): - """An object that implements the moving window pattern for Dataset. - - This class has an OrderedDict named self.rollings, that is a collection of - DataArrayRollings for all the DataArrays in the Dataset, except for those - not depending on rolling dimension. - - reduce() method returns a new Dataset generated from a set of - self.rollings[key].reduce(). - - See Also - -------- - Dataset.groupby - DataArray.groupby - Dataset.rolling - DataArray.rolling - """ - - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for Dataset. + You should use Dataset.rolling() method to construct this object + instead of the class constructor. Parameters ---------- obj : Dataset Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- rolling : type of input argument + + See Also + -------- + Dataset.rolling + DataArray.rolling + Dataset.groupby + DataArray.groupby """ - super(DatasetRolling, self).__init__(obj, - min_periods, center, **windows) + super(DatasetRolling, self).__init__(obj, windows, min_periods, center) if self.dim not in self.obj.dims: raise KeyError(self.dim) # Keep each Rolling object as an OrderedDict @@ -324,8 +363,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows): for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim if self.dim in da.dims: - self.rollings[key] = DataArrayRolling(da, min_periods, - center, **windows) + self.rollings[key] = DataArrayRolling( + da, windows, min_periods, center) def reduce(self, func, **kwargs): """Reduce the items in this group by applying `func` along some @@ -354,6 +393,16 @@ def reduce(self, func, **kwargs): reduced[key] = self.obj[key] return Dataset(reduced, coords=self.obj.coords) + def _counts(self): + from .dataset import Dataset + reduced = OrderedDict() + for key, da in self.obj.data_vars.items(): + if self.dim in da.dims: + reduced[key] = self.rollings[key]._counts() + else: + reduced[key] = self.obj[key] + return Dataset(reduced, coords=self.obj.coords) + @classmethod def _reduce_method(cls, func): """ @@ -373,6 +422,37 @@ def wrapped_func(self, **kwargs): return Dataset(reduced, coords=self.obj.coords) return wrapped_func + def construct(self, window_dim, stride=1, fill_value=dtypes.NA): + """ + Convert this rolling object to xr.Dataset, + where the window dimension is stacked as a new dimension + + Parameters + ---------- + window_dim: str + New name of the window dimension. + stride: integer, optional + size of stride for the rolling window. + fill_value: optional. Default dtypes.NA + Filling value to match the dimension size. + + Returns + ------- + Dataset with variables converted from rolling object. + """ + + from .dataset import Dataset + + dataset = OrderedDict() + for key, da in self.obj.data_vars.items(): + if self.dim in da.dims: + dataset[key] = self.rollings[key].construct( + window_dim, fill_value=fill_value) + else: + dataset[key] = da + return Dataset(dataset, coords=self.obj.coords).isel( + **{self.dim: slice(None, None, stride)}) + inject_bottleneck_rolling_methods(DataArrayRolling) inject_datasetrolling_methods(DatasetRolling) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index de6b5825390..50d6ec7e05a 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -1,20 +1,30 @@ """Internal utilties; not for external use """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import contextlib import functools import itertools +import os.path import re import warnings -from collections import Mapping, MutableMapping, MutableSet, Iterable +from collections import Iterable, Mapping, MutableMapping, MutableSet import numpy as np import pandas as pd -from .pycompat import (iteritems, OrderedDict, basestring, bytes_type, - dask_array_type) +from .pycompat import ( + OrderedDict, basestring, bytes_type, dask_array_type, iteritems) + + +def _check_inplace(inplace, default=False): + if inplace is None: + inplace = default + else: + warnings.warn('The inplace argument has been deprecated and will be ' + 'removed in xarray 0.12.0.', FutureWarning, stacklevel=3) + + return inplace def alias_message(old_name, new_name): @@ -37,6 +47,18 @@ def wrapper(*args, **kwargs): return wrapper +def _maybe_cast_to_cftimeindex(index): + from ..coding.cftimeindex import CFTimeIndex + + if index.dtype == 'O': + try: + return CFTimeIndex(index) + except (ImportError, TypeError): + return index + else: + return index + + def safe_cast_to_index(array): """Given an array, safely cast it to a pandas.Index. @@ -55,19 +77,18 @@ def safe_cast_to_index(array): if hasattr(array, 'dtype') and array.dtype.kind == 'O': kwargs['dtype'] = object index = pd.Index(np.asarray(array), **kwargs) - return index + return _maybe_cast_to_cftimeindex(index) def multiindex_from_product_levels(levels, names=None): """Creating a MultiIndex from a product without refactorizing levels. - Keeping levels the same is faster, and also gives back the original labels - when we unstack. + Keeping levels the same gives back the original labels when we unstack. Parameters ---------- - levels : sequence of arrays - Unique labels for each level. + levels : sequence of pd.Index + Values for each MultiIndex level. names : optional sequence of objects Names for each level. @@ -75,8 +96,11 @@ def multiindex_from_product_levels(levels, names=None): ------- pandas.MultiIndex """ - labels_mesh = np.meshgrid(*[np.arange(len(lev)) for lev in levels], - indexing='ij') + if any(not isinstance(lev, pd.Index) for lev in levels): + raise TypeError('levels must be a list of pd.Index objects') + + split_labels, levels = zip(*[lev.factorize() for lev in levels]) + labels_mesh = np.meshgrid(*split_labels, indexing='ij') labels = [x.ravel() for x in labels_mesh] return pd.MultiIndex(levels, labels, sortorder=0, names=names) @@ -168,7 +192,7 @@ def is_full_slice(value): return isinstance(value, slice) and value == slice(None) -def combine_pos_and_kw_args(pos_kwargs, kw_kwargs, func_name): +def either_dict_or_kwargs(pos_kwargs, kw_kwargs, func_name): if pos_kwargs is not None: if not is_dict_like(pos_kwargs): raise ValueError('the first argument to .%s must be a dictionary' @@ -487,6 +511,11 @@ def is_remote_uri(path): return bool(re.search('^https?\://', path)) +def is_grib_path(path): + _, ext = os.path.splitext(path) + return ext in ['.grib', '.grb', '.grib2', '.grb2'] + + def is_uniform_spaced(arr, **kwargs): """Return True if values of an array are uniformly spaced and sorted. @@ -574,3 +603,29 @@ def __iter__(self): def __len__(self): num_hidden = sum([k in self._hidden_keys for k in self._data]) return len(self._data) - num_hidden + + +def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): + """Convert an array containing datetime-like data to an array of floats. + + Parameters + ---------- + da : array + Input data + offset: Scalar with the same type of array or None + If None, subtract minimum values to reduce round off error + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + dtype: target dtype + + Returns + ------- + array + """ + if offset is None: + offset = array.min() + array = array - offset + + if datetime_unit: + return (array / np.timedelta64(1, datetime_unit)).astype(dtype) + return array.astype(dtype) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d4863014f59..0bff06e7546 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1,29 +1,25 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from datetime import timedelta -from collections import defaultdict +from __future__ import absolute_import, division, print_function + import functools import itertools +from collections import defaultdict +from datetime import timedelta import numpy as np import pandas as pd -from . import common -from . import duck_array_ops -from . import dtypes -from . import indexing -from . import nputils -from . import ops -from . import utils -from .pycompat import (basestring, OrderedDict, zip, integer_types, - dask_array_type) -from .indexing import (PandasIndexAdapter, as_indexable, BasicIndexer, - OuterIndexer, VectorizedIndexer) -from .utils import OrderedSet - import xarray as xr # only for Dataset and DataArray +from . import ( + arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils) +from .indexing import ( + BasicIndexer, OuterIndexer, PandasIndexAdapter, VectorizedIndexer, + as_indexable) +from .pycompat import ( + OrderedDict, basestring, dask_array_type, integer_types, zip) +from .utils import OrderedSet, either_dict_or_kwargs +from .options import _get_keep_attrs + try: import dask.array as da except ImportError: @@ -69,34 +65,30 @@ def as_variable(obj, name=None): The newly created variable. """ + from .dataarray import DataArray + # TODO: consider extending this method to automatically handle Iris and - # pandas objects. - if hasattr(obj, 'variable'): + if isinstance(obj, DataArray): # extract the primary Variable from DataArrays obj = obj.variable if isinstance(obj, Variable): obj = obj.copy(deep=False) - elif hasattr(obj, 'dims') and (hasattr(obj, 'data') or - hasattr(obj, 'values')): - obj_data = getattr(obj, 'data', None) - if obj_data is None: - obj_data = getattr(obj, 'values') - obj = Variable(obj.dims, obj_data, - getattr(obj, 'attrs', None), - getattr(obj, 'encoding', None)) elif isinstance(obj, tuple): try: obj = Variable(*obj) - except TypeError: + except (TypeError, ValueError) as error: # use .format() instead of % because it handles tuples consistently - raise TypeError('tuples to convert into variables must be of the ' - 'form (dims, data[, attrs, encoding]): ' - '{}'.format(obj)) + raise error.__class__('Could not convert tuple of form ' + '(dims, data[, attrs, encoding]): ' + '{} to Variable.'.format(obj)) elif utils.is_scalar(obj): obj = Variable([], obj) elif isinstance(obj, (pd.Index, IndexVariable)) and obj.name is not None: obj = Variable(obj.name, obj) + elif isinstance(obj, (set, dict)): + raise TypeError( + "variable %r has invalid type %r" % (name, type(obj))) elif name is not None: data = as_compatible_data(obj) if data.ndim != 1: @@ -104,7 +96,7 @@ def as_variable(obj, name=None): 'cannot set variable %r with %r-dimensional data ' 'without explicit dimension names. Pass a tuple of ' '(dims, data) instead.' % (name, data.ndim)) - obj = Variable(name, obj, fastpath=True) + obj = Variable(name, data, fastpath=True) else: raise TypeError('unable to convert object into a variable without an ' 'explicit list of dimensions: %r' % obj) @@ -127,7 +119,7 @@ def _maybe_wrap_data(data): Put pandas.Index and numpy.ndarray arguments in adapter objects to ensure they can be indexed properly. - NumpyArrayAdapter, PandasIndexAdapter and LazilyIndexedArray should + NumpyArrayAdapter, PandasIndexAdapter and LazilyOuterIndexedArray should all pass through unmodified. """ if isinstance(data, pd.Index): @@ -222,8 +214,8 @@ def _as_array_or_item(data): return data -class Variable(common.AbstractArray, utils.NdimSizeLenMixin): - +class Variable(common.AbstractArray, arithmetic.SupportsArithmetic, + utils.NdimSizeLenMixin): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a @@ -463,10 +455,14 @@ def _broadcast_indexes(self, key): key = self._item_key_to_tuple(key) # key is a tuple # key is a tuple of full size key = indexing.expanded_indexer(key, self.ndim) - # Convert a scalar Variable as an integer + # Convert a scalar Variable to an integer key = tuple( k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k for k in key) + # Convert a 0d-array to an integer + key = tuple( + k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k + for k in key) if all(isinstance(k, BASIC_INDEXING_TYPES) for k in key): return self._broadcast_indexes_basic(key) @@ -680,7 +676,7 @@ def __setitem__(self, key, value): value = as_compatible_data(value) if value.ndim > len(dims): raise ValueError( - 'shape mismatch: value array of shape %s could not be' + 'shape mismatch: value array of shape %s could not be ' 'broadcast to indexing result with %s dimensions' % (value.shape, len(dims))) if value.ndim == 0: @@ -726,24 +722,81 @@ def encoding(self, value): except ValueError: raise ValueError('encoding must be castable to a dictionary') - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this object. If `deep=True`, the data array is loaded into memory and copied onto the new object. Dimensions, attributes and encodings are always copied. - """ - data = self._data - if isinstance(data, indexing.MemoryCachedArray): - # don't share caching between copies - data = indexing.MemoryCachedArray(data.array) + Use `data` to create a new object with the same structure as + original but entirely new data. - if deep: - if isinstance(data, dask_array_type): - data = data.copy() - elif not isinstance(data, PandasIndexAdapter): - # pandas.Index is immutable - data = np.array(data) + Parameters + ---------- + deep : bool, optional + Whether the data array is loaded into memory and copied onto + the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored. + + Returns + ------- + object : Variable + New object with dimensions, attributes, encodings, and optionally + data copied from original. + + Examples + -------- + + Shallow copy versus deep copy + + >>> var = xr.Variable(data=[1, 2, 3], dims='x') + >>> var.copy() + + array([1, 2, 3]) + >>> var_0 = var.copy(deep=False) + >>> var_0[0] = 7 + >>> var_0 + + array([7, 2, 3]) + >>> var + + array([7, 2, 3]) + + Changing the data using the ``data`` argument maintains the + structure of the original object, but with the new data. Original + object is unaffected. + + >>> var.copy(data=[0.1, 0.2, 0.3]) + + array([ 0.1, 0.2, 0.3]) + >>> var + + array([7, 2, 3]) + + See Also + -------- + pandas.DataFrame.copy + """ + if data is None: + data = self._data + + if isinstance(data, indexing.MemoryCachedArray): + # don't share caching between copies + data = indexing.MemoryCachedArray(data.array) + + if deep: + if isinstance(data, dask_array_type): + data = data.copy() + elif not isinstance(data, PandasIndexAdapter): + # pandas.Index is immutable + data = np.array(data) + else: + data = as_compatible_data(data) + if self.shape != data.shape: + raise ValueError("Data shape {} must match shape of object {}" + .format(data.shape, self.shape)) # note: # dims is already an immutable tuple @@ -825,7 +878,7 @@ def chunk(self, chunks=None, name=None, lock=False): return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) - def isel(self, **indexers): + def isel(self, indexers=None, drop=False, **indexers_kwargs): """Return a new array indexed along the specified dimension(s). Parameters @@ -842,6 +895,8 @@ def isel(self, **indexers): unless numpy fancy indexing was triggered by using an array indexer, in which case the data will be a copy. """ + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'isel') + invalid = [k for k in indexers if k not in self.dims] if invalid: raise ValueError("dimensions %r do not exist" % invalid) @@ -873,7 +928,7 @@ def squeeze(self, dim=None): numpy.squeeze """ dims = common.get_squeeze_dims(self, dim) - return self.isel(**{d: 0 for d in dims}) + return self.isel({d: 0 for d in dims}) def _shift_one_dim(self, dim, count): axis = self.get_axis_num(dim) @@ -915,27 +970,84 @@ def _shift_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """ Return a new Variable with shifted data. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : mapping of the form {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- shifted : Variable Variable with the same dimensions and attributes but shifted data. """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') result = self for dim, count in shifts.items(): result = result._shift_one_dim(dim, count) return result + def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA, + **pad_widths_kwargs): + """ + Return a new Variable with paddings. + + Parameters + ---------- + pad_width: Mapping of the form {dim: (before, after)} + Number of values padded to the edges of each dimension. + **pad_widths_kwargs: + Keyword argument for pad_widths + """ + pad_widths = either_dict_or_kwargs(pad_widths, pad_widths_kwargs, + 'pad') + + if fill_value is dtypes.NA: # np.nan is passed + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = self.dtype + + if isinstance(self.data, dask_array_type): + array = self.data + + # Dask does not yet support pad. We manually implement it. + # https://github.com/dask/dask/issues/1926 + for d, pad in pad_widths.items(): + axis = self.get_axis_num(d) + before_shape = list(array.shape) + before_shape[axis] = pad[0] + before_chunks = list(array.chunks) + before_chunks[axis] = (pad[0], ) + after_shape = list(array.shape) + after_shape[axis] = pad[1] + after_chunks = list(array.chunks) + after_chunks[axis] = (pad[1], ) + + arrays = [] + if pad[0] > 0: + arrays.append(da.full(before_shape, fill_value, + dtype=dtype, chunks=before_chunks)) + arrays.append(array) + if pad[1] > 0: + arrays.append(da.full(after_shape, fill_value, + dtype=dtype, chunks=after_chunks)) + if len(arrays) > 1: + array = da.concatenate(arrays, axis=axis) + else: + pads = [(0, 0) if d not in pad_widths else pad_widths[d] + for d in self.dims] + array = np.pad(self.data.astype(dtype, copy=False), pads, + mode='constant', constant_values=fill_value) + return type(self)(self.dims, array) + def _roll_one_dim(self, dim, count): axis = self.get_axis_num(dim) @@ -958,22 +1070,27 @@ def _roll_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def roll(self, **shifts): + def roll(self, shifts=None, **shifts_kwargs): """ Return a new Variable with rolld data. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : mapping of the form {dim: offset} Integer offset to roll along each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- shifted : Variable Variable with the same dimensions and attributes but rolled data. """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll') + result = self for dim, count in shifts.items(): result = result._roll_one_dim(dim, count) @@ -1008,7 +1125,8 @@ def transpose(self, *dims): axes = self.get_axis_num(dims) if len(dims) < 2: # no need to transpose if only one dimension return self.copy(deep=False) - data = duck_array_ops.transpose(self.data, axes) + + data = as_indexable(self._data).transpose(axes) return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) @@ -1090,7 +1208,7 @@ def _stack_once(self, dims, new_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -1099,9 +1217,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1112,6 +1233,8 @@ def stack(self, **dimensions): -------- Variable.unstack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'stack') result = self for new_dim, dims in dimensions.items(): result = result._stack_once(dims, new_dim) @@ -1143,7 +1266,7 @@ def _unstack_once(self, dims, old_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) - def unstack(self, **dimensions): + def unstack(self, dimensions=None, **dimensions_kwargs): """ Unstack an existing dimension into multiple new dimensions. @@ -1152,9 +1275,12 @@ def unstack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form old_dim={dim1: size1, ...} + dimensions : mapping of the form old_dim={dim1: size1, ...} Names of existing dimensions, and the new dimensions and sizes that they map to. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1165,6 +1291,8 @@ def unstack(self, **dimensions): -------- Variable.stack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'unstack') result = self for old_dim, dims in dimensions.items(): result = result._unstack_once(dims, old_dim) @@ -1176,8 +1304,8 @@ def fillna(self, value): def where(self, cond, other=dtypes.NA): return ops.where_method(self, cond, other) - def reduce(self, func, dim=None, axis=None, keep_attrs=False, - allow_lazy=False, **kwargs): + def reduce(self, func, dim=None, axis=None, + keep_attrs=None, allow_lazy=False, **kwargs): """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1206,14 +1334,11 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, Array with summarized data and the indicated dimension(s) removed. """ + if dim is common.ALL_DIMS: + dim = None if dim is not None and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim' arguments") - if getattr(func, 'keep_dims', False): - if dim is None and axis is None: - raise ValueError("must supply either single 'dim' or 'axis' " - "argument to %s" % (func.__name__)) - if dim is not None: axis = self.get_axis_num(dim) data = func(self.data if allow_lazy else self.values, @@ -1227,6 +1352,8 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, dims = [adim for n, adim in enumerate(self.dims) if n not in removed_axes] + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) attrs = self._attrs if keep_attrs else None return Variable(dims, data, attrs=attrs) @@ -1273,8 +1400,6 @@ def concat(cls, variables, dim='concat_dim', positions=None, arrays = [v.data for v in variables] - # TODO: use our own type promotion rules to ensure that - # [str, float] -> object, not str like numpy if dim in first_var.dims: axis = first_var.get_axis_num(dim) dims = first_var.dims @@ -1456,6 +1581,57 @@ def rank(self, dim, pct=False): ranked /= count return Variable(self.dims, ranked) + def rolling_window(self, dim, window, window_dim, center=False, + fill_value=dtypes.NA): + """ + Make a rolling_window along dim and add a new_dim to the last place. + + Parameters + ---------- + dim: str + Dimension over which to compute rolling_window + window: int + Window size of the rolling + window_dim: str + New name of the window dimension. + center: boolean. default False. + If True, pad fill_value for both ends. Otherwise, pad in the head + of the axis. + fill_value: + value to be filled. + + Returns + ------- + Variable that is a view of the original array with a added dimension of + size w. + The return dim: self.dims + (window_dim, ) + The return shape: self.shape + (window, ) + + Examples + -------- + >>> v=Variable(('a', 'b'), np.arange(8).reshape((2,4))) + >>> v.rolling_window(x, 'b', 3, 'window_dim') + + array([[[nan, nan, 0], [nan, 0, 1], [0, 1, 2], [1, 2, 3]], + [[nan, nan, 4], [nan, 4, 5], [4, 5, 6], [5, 6, 7]]]) + + >>> v.rolling_window(x, 'b', 3, 'window_dim', center=True) + + array([[[nan, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, nan]], + [[nan, 4, 5], [4, 5, 6], [5, 6, 7], [6, 7, nan]]]) + """ + if fill_value is dtypes.NA: # np.nan is passed + dtype, fill_value = dtypes.maybe_promote(self.dtype) + array = self.astype(dtype, copy=False).data + else: + dtype = self.dtype + array = self.data + + new_dims = self.dims + (window_dim, ) + return Variable(new_dims, duck_array_ops.rolling_window( + array, axis=self.get_axis_num(dim), window=window, + center=center, fill_value=fill_value)) + @property def real(self): return type(self)(self.dims, self.data.real, self._attrs) @@ -1595,14 +1771,37 @@ def concat(cls, variables, dim='concat_dim', positions=None, return cls(first_var.dims, data, attrs) - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this object. - `deep` is ignored since data is stored in the form of pandas.Index, - which is already immutable. Dimensions, attributes and encodings are - always copied. + `deep` is ignored since data is stored in the form of + pandas.Index, which is already immutable. Dimensions, attributes + and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Deep is always ignored. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + + Returns + ------- + object : Variable + New object with dimensions, attributes, encodings, and optionally + data copied from original. """ - return type(self)(self.dims, self._data, self._attrs, + if data is None: + data = self._data + else: + data = as_compatible_data(data) + if self.shape != data.shape: + raise ValueError("Data shape {} must match shape of object {}" + .format(data.shape, self.shape)) + return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) def equals(self, other, equiv=None): @@ -1780,12 +1979,15 @@ def assert_unique_multiindex_level_names(variables): objects. """ level_names = defaultdict(list) + all_level_names = set() for var_name, var in variables.items(): if isinstance(var._data, PandasIndexAdapter): idx_level_names = var.to_index_variable().level_names if idx_level_names is not None: for n in idx_level_names: level_names[n].append('%r (%s)' % (n, var_name)) + if idx_level_names: + all_level_names.update(idx_level_names) for k, v in level_names.items(): if k in variables: @@ -1796,3 +1998,9 @@ def assert_unique_multiindex_level_names(variables): conflict_str = '\n'.join([', '.join(v) for v in duplicate_names]) raise ValueError('conflicting MultiIndex level name(s):\n%s' % conflict_str) + # Check confliction between level names and dimensions GH:2299 + for k, v in variables.items(): + for d in v.dims: + if d in all_level_names: + raise ValueError('conflicting level / dimension names. {} ' + 'already exists as a level name.'.format(d)) diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index fe2c604a89e..4b53b22243c 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from .plot import (plot, line, contourf, contour, +from .plot import (plot, line, step, contourf, contour, hist, imshow, pcolormesh) from .facetgrid import FacetGrid @@ -9,6 +9,7 @@ __all__ = [ 'plot', 'line', + 'step', 'contour', 'contourf', 'hist', diff --git a/xarray/plot/default_colormap.csv b/xarray/plot/default_colormap.csv deleted file mode 100644 index de9632e3f26..00000000000 --- a/xarray/plot/default_colormap.csv +++ /dev/null @@ -1,256 +0,0 @@ -0.26700401,0.00487433,0.32941519 -0.26851048,0.00960483,0.33542652 -0.26994384,0.01462494,0.34137895 -0.27130489,0.01994186,0.34726862 -0.27259384,0.02556309,0.35309303 -0.27380934,0.03149748,0.35885256 -0.27495242,0.03775181,0.36454323 -0.27602238,0.04416723,0.37016418 -0.2770184,0.05034437,0.37571452 -0.27794143,0.05632444,0.38119074 -0.27879067,0.06214536,0.38659204 -0.2795655,0.06783587,0.39191723 -0.28026658,0.07341724,0.39716349 -0.28089358,0.07890703,0.40232944 -0.28144581,0.0843197,0.40741404 -0.28192358,0.08966622,0.41241521 -0.28232739,0.09495545,0.41733086 -0.28265633,0.10019576,0.42216032 -0.28291049,0.10539345,0.42690202 -0.28309095,0.11055307,0.43155375 -0.28319704,0.11567966,0.43611482 -0.28322882,0.12077701,0.44058404 -0.28318684,0.12584799,0.44496 -0.283072,0.13089477,0.44924127 -0.28288389,0.13592005,0.45342734 -0.28262297,0.14092556,0.45751726 -0.28229037,0.14591233,0.46150995 -0.28188676,0.15088147,0.46540474 -0.28141228,0.15583425,0.46920128 -0.28086773,0.16077132,0.47289909 -0.28025468,0.16569272,0.47649762 -0.27957399,0.17059884,0.47999675 -0.27882618,0.1754902,0.48339654 -0.27801236,0.18036684,0.48669702 -0.27713437,0.18522836,0.48989831 -0.27619376,0.19007447,0.49300074 -0.27519116,0.1949054,0.49600488 -0.27412802,0.19972086,0.49891131 -0.27300596,0.20452049,0.50172076 -0.27182812,0.20930306,0.50443413 -0.27059473,0.21406899,0.50705243 -0.26930756,0.21881782,0.50957678 -0.26796846,0.22354911,0.5120084 -0.26657984,0.2282621,0.5143487 -0.2651445,0.23295593,0.5165993 -0.2636632,0.23763078,0.51876163 -0.26213801,0.24228619,0.52083736 -0.26057103,0.2469217,0.52282822 -0.25896451,0.25153685,0.52473609 -0.25732244,0.2561304,0.52656332 -0.25564519,0.26070284,0.52831152 -0.25393498,0.26525384,0.52998273 -0.25219404,0.26978306,0.53157905 -0.25042462,0.27429024,0.53310261 -0.24862899,0.27877509,0.53455561 -0.2468114,0.28323662,0.53594093 -0.24497208,0.28767547,0.53726018 -0.24311324,0.29209154,0.53851561 -0.24123708,0.29648471,0.53970946 -0.23934575,0.30085494,0.54084398 -0.23744138,0.30520222,0.5419214 -0.23552606,0.30952657,0.54294396 -0.23360277,0.31382773,0.54391424 -0.2316735,0.3181058,0.54483444 -0.22973926,0.32236127,0.54570633 -0.22780192,0.32659432,0.546532 -0.2258633,0.33080515,0.54731353 -0.22392515,0.334994,0.54805291 -0.22198915,0.33916114,0.54875211 -0.22005691,0.34330688,0.54941304 -0.21812995,0.34743154,0.55003755 -0.21620971,0.35153548,0.55062743 -0.21429757,0.35561907,0.5511844 -0.21239477,0.35968273,0.55171011 -0.2105031,0.36372671,0.55220646 -0.20862342,0.36775151,0.55267486 -0.20675628,0.37175775,0.55311653 -0.20490257,0.37574589,0.55353282 -0.20306309,0.37971644,0.55392505 -0.20123854,0.38366989,0.55429441 -0.1994295,0.38760678,0.55464205 -0.1976365,0.39152762,0.55496905 -0.19585993,0.39543297,0.55527637 -0.19410009,0.39932336,0.55556494 -0.19235719,0.40319934,0.55583559 -0.19063135,0.40706148,0.55608907 -0.18892259,0.41091033,0.55632606 -0.18723083,0.41474645,0.55654717 -0.18555593,0.4185704,0.55675292 -0.18389763,0.42238275,0.55694377 -0.18225561,0.42618405,0.5571201 -0.18062949,0.42997486,0.55728221 -0.17901879,0.43375572,0.55743035 -0.17742298,0.4375272,0.55756466 -0.17584148,0.44128981,0.55768526 -0.17427363,0.4450441,0.55779216 -0.17271876,0.4487906,0.55788532 -0.17117615,0.4525298,0.55796464 -0.16964573,0.45626209,0.55803034 -0.16812641,0.45998802,0.55808199 -0.1666171,0.46370813,0.55811913 -0.16511703,0.4674229,0.55814141 -0.16362543,0.47113278,0.55814842 -0.16214155,0.47483821,0.55813967 -0.16066467,0.47853961,0.55811466 -0.15919413,0.4822374,0.5580728 -0.15772933,0.48593197,0.55801347 -0.15626973,0.4896237,0.557936 -0.15481488,0.49331293,0.55783967 -0.15336445,0.49700003,0.55772371 -0.1519182,0.50068529,0.55758733 -0.15047605,0.50436904,0.55742968 -0.14903918,0.50805136,0.5572505 -0.14760731,0.51173263,0.55704861 -0.14618026,0.51541316,0.55682271 -0.14475863,0.51909319,0.55657181 -0.14334327,0.52277292,0.55629491 -0.14193527,0.52645254,0.55599097 -0.14053599,0.53013219,0.55565893 -0.13914708,0.53381201,0.55529773 -0.13777048,0.53749213,0.55490625 -0.1364085,0.54117264,0.55448339 -0.13506561,0.54485335,0.55402906 -0.13374299,0.54853458,0.55354108 -0.13244401,0.55221637,0.55301828 -0.13117249,0.55589872,0.55245948 -0.1299327,0.55958162,0.55186354 -0.12872938,0.56326503,0.55122927 -0.12756771,0.56694891,0.55055551 -0.12645338,0.57063316,0.5498411 -0.12539383,0.57431754,0.54908564 -0.12439474,0.57800205,0.5482874 -0.12346281,0.58168661,0.54744498 -0.12260562,0.58537105,0.54655722 -0.12183122,0.58905521,0.54562298 -0.12114807,0.59273889,0.54464114 -0.12056501,0.59642187,0.54361058 -0.12009154,0.60010387,0.54253043 -0.11973756,0.60378459,0.54139999 -0.11951163,0.60746388,0.54021751 -0.11942341,0.61114146,0.53898192 -0.11948255,0.61481702,0.53769219 -0.11969858,0.61849025,0.53634733 -0.12008079,0.62216081,0.53494633 -0.12063824,0.62582833,0.53348834 -0.12137972,0.62949242,0.53197275 -0.12231244,0.63315277,0.53039808 -0.12344358,0.63680899,0.52876343 -0.12477953,0.64046069,0.52706792 -0.12632581,0.64410744,0.52531069 -0.12808703,0.64774881,0.52349092 -0.13006688,0.65138436,0.52160791 -0.13226797,0.65501363,0.51966086 -0.13469183,0.65863619,0.5176488 -0.13733921,0.66225157,0.51557101 -0.14020991,0.66585927,0.5134268 -0.14330291,0.66945881,0.51121549 -0.1466164,0.67304968,0.50893644 -0.15014782,0.67663139,0.5065889 -0.15389405,0.68020343,0.50417217 -0.15785146,0.68376525,0.50168574 -0.16201598,0.68731632,0.49912906 -0.1663832,0.69085611,0.49650163 -0.1709484,0.69438405,0.49380294 -0.17570671,0.6978996,0.49103252 -0.18065314,0.70140222,0.48818938 -0.18578266,0.70489133,0.48527326 -0.19109018,0.70836635,0.48228395 -0.19657063,0.71182668,0.47922108 -0.20221902,0.71527175,0.47608431 -0.20803045,0.71870095,0.4728733 -0.21400015,0.72211371,0.46958774 -0.22012381,0.72550945,0.46622638 -0.2263969,0.72888753,0.46278934 -0.23281498,0.73224735,0.45927675 -0.2393739,0.73558828,0.45568838 -0.24606968,0.73890972,0.45202405 -0.25289851,0.74221104,0.44828355 -0.25985676,0.74549162,0.44446673 -0.26694127,0.74875084,0.44057284 -0.27414922,0.75198807,0.4366009 -0.28147681,0.75520266,0.43255207 -0.28892102,0.75839399,0.42842626 -0.29647899,0.76156142,0.42422341 -0.30414796,0.76470433,0.41994346 -0.31192534,0.76782207,0.41558638 -0.3198086,0.77091403,0.41115215 -0.3277958,0.77397953,0.40664011 -0.33588539,0.7770179,0.40204917 -0.34407411,0.78002855,0.39738103 -0.35235985,0.78301086,0.39263579 -0.36074053,0.78596419,0.38781353 -0.3692142,0.78888793,0.38291438 -0.37777892,0.79178146,0.3779385 -0.38643282,0.79464415,0.37288606 -0.39517408,0.79747541,0.36775726 -0.40400101,0.80027461,0.36255223 -0.4129135,0.80304099,0.35726893 -0.42190813,0.80577412,0.35191009 -0.43098317,0.80847343,0.34647607 -0.44013691,0.81113836,0.3409673 -0.44936763,0.81376835,0.33538426 -0.45867362,0.81636288,0.32972749 -0.46805314,0.81892143,0.32399761 -0.47750446,0.82144351,0.31819529 -0.4870258,0.82392862,0.31232133 -0.49661536,0.82637633,0.30637661 -0.5062713,0.82878621,0.30036211 -0.51599182,0.83115784,0.29427888 -0.52577622,0.83349064,0.2881265 -0.5356211,0.83578452,0.28190832 -0.5455244,0.83803918,0.27562602 -0.55548397,0.84025437,0.26928147 -0.5654976,0.8424299,0.26287683 -0.57556297,0.84456561,0.25641457 -0.58567772,0.84666139,0.24989748 -0.59583934,0.84871722,0.24332878 -0.60604528,0.8507331,0.23671214 -0.61629283,0.85270912,0.23005179 -0.62657923,0.85464543,0.22335258 -0.63690157,0.85654226,0.21662012 -0.64725685,0.85839991,0.20986086 -0.65764197,0.86021878,0.20308229 -0.66805369,0.86199932,0.19629307 -0.67848868,0.86374211,0.18950326 -0.68894351,0.86544779,0.18272455 -0.69941463,0.86711711,0.17597055 -0.70989842,0.86875092,0.16925712 -0.72039115,0.87035015,0.16260273 -0.73088902,0.87191584,0.15602894 -0.74138803,0.87344918,0.14956101 -0.75188414,0.87495143,0.14322828 -0.76237342,0.87642392,0.13706449 -0.77285183,0.87786808,0.13110864 -0.78331535,0.87928545,0.12540538 -0.79375994,0.88067763,0.12000532 -0.80418159,0.88204632,0.11496505 -0.81457634,0.88339329,0.11034678 -0.82494028,0.88472036,0.10621724 -0.83526959,0.88602943,0.1026459 -0.84556056,0.88732243,0.09970219 -0.8558096,0.88860134,0.09745186 -0.86601325,0.88986815,0.09595277 -0.87616824,0.89112487,0.09525046 -0.88627146,0.89237353,0.09537439 -0.89632002,0.89361614,0.09633538 -0.90631121,0.89485467,0.09812496 -0.91624212,0.89609127,0.1007168 -0.92610579,0.89732977,0.10407067 -0.93590444,0.8985704,0.10813094 -0.94563626,0.899815,0.11283773 -0.95529972,0.90106534,0.11812832 -0.96489353,0.90232311,0.12394051 -0.97441665,0.90358991,0.13021494 -0.98386829,0.90486726,0.13689671 -0.99324789,0.90615657,0.1439362 diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index badd44b25db..f133e7806a3 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -1,18 +1,16 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function -import warnings -import itertools import functools +import itertools +import warnings import numpy as np -from ..core.pycompat import getargspec from ..core.formatting import format_item -from .utils import (_determine_cmap_params, _infer_xy_labels, - import_matplotlib_pyplot) - +from ..core.pycompat import getargspec +from .utils import ( + _determine_cmap_params, _infer_xy_labels, import_matplotlib_pyplot, + label_from_attrs) # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams @@ -191,6 +189,7 @@ def __init__(self, data, col=None, row=None, col_wrap=None, self._y_var = None self._cmap_extend = None self._mappables = [] + self._finalized = False @property def _left_axes(self): @@ -221,6 +220,19 @@ def map_dataarray(self, func, x, y, **kwargs): self : FacetGrid object """ + + cmapkw = kwargs.get('cmap') + colorskw = kwargs.get('colors') + cbar_kwargs = kwargs.pop('cbar_kwargs', {}) + cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) + + if kwargs.get('cbar_ax', None) is not None: + raise ValueError('cbar_ax not supported by FacetGrid.') + + # colors is mutually exclusive with cmap + if cmapkw and colorskw: + raise ValueError("Can't specify both cmap and colors.") + # These should be consistent with xarray.plot._plot2d cmap_kwargs = {'plot_data': self.data.values, # MPL default @@ -233,6 +245,9 @@ def map_dataarray(self, func, x, y, **kwargs): cmap_params = _determine_cmap_params(**cmap_kwargs) + if colorskw is not None: + cmap_params['cmap'] = None + # Order is important func_kwargs = kwargs.copy() func_kwargs.update(cmap_params) @@ -254,19 +269,90 @@ def map_dataarray(self, func, x, y, **kwargs): self._finalize_grid(x, y) if kwargs.get('add_colorbar', True): - self.add_colorbar() + self.add_colorbar(**cbar_kwargs) + + return self + + def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs): + """ + Apply a line plot to a 2d facet subset of the data. + + Parameters + ---------- + x, y, hue: string + dimension names for the axes and hues of each facet + + Returns + ------- + self : FacetGrid object + + """ + from .plot import line, _infer_line_data + + add_legend = kwargs.pop('add_legend', True) + kwargs['add_legend'] = False + + for d, ax in zip(self.name_dicts.flat, self.axes.flat): + # None is the sentinel value + if d is not None: + subset = self.data.loc[d] + mappable = line(subset, x=x, y=y, hue=hue, + ax=ax, _labels=False, + **kwargs) + self._mappables.append(mappable) + _, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data( + darray=self.data.loc[self.name_dicts.flat[0]], + x=x, y=y, hue=hue) + + self._hue_var = hueplt + self._hue_label = huelabel + self._finalize_grid(xlabel, ylabel) + + if add_legend and hueplt is not None and huelabel is not None: + self.add_legend() return self def _finalize_grid(self, *axlabels): """Finalize the annotations and layout.""" - self.set_axis_labels(*axlabels) - self.set_titles() - self.fig.tight_layout() + if not self._finalized: + self.set_axis_labels(*axlabels) + self.set_titles() + self.fig.tight_layout() - for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): - if namedict is None: - ax.set_visible(False) + for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): + if namedict is None: + ax.set_visible(False) + + self._finalized = True + + def add_legend(self, **kwargs): + figlegend = self.fig.legend( + handles=self._mappables[-1], + labels=list(self._hue_var.values), + title=self._hue_label, + loc="center right", **kwargs) + + # Draw the plot to set the bounding boxes correctly + self.fig.draw(self.fig.canvas.get_renderer()) + + # Calculate and set the new width of the figure so the legend fits + legend_width = figlegend.get_window_extent().width / self.fig.dpi + figure_width = self.fig.get_figwidth() + self.fig.set_figwidth(figure_width + legend_width) + + # Draw the plot again to get the new transformations + self.fig.draw(self.fig.canvas.get_renderer()) + + # Now calculate how much space we need on the right side + legend_width = figlegend.get_window_extent().width / self.fig.dpi + space_needed = legend_width / (figure_width + legend_width) + 0.02 + # margin = .01 + # _space_needed = margin + space_needed + right = 1 - space_needed + + # Place the subplot axes to give space for the legend + self.fig.subplots_adjust(right=right) def add_colorbar(self, **kwargs): """Draw a colorbar @@ -274,8 +360,8 @@ def add_colorbar(self, **kwargs): kwargs = kwargs.copy() if self._cmap_extend is not None: kwargs.setdefault('extend', self._cmap_extend) - if getattr(self.data, 'name', None) is not None: - kwargs.setdefault('label', self.data.name) + if 'label' not in kwargs: + kwargs.setdefault('label', label_from_attrs(self.data)) self.cbar = self.fig.colorbar(self._mappables[-1], ax=list(self.axes.flat), **kwargs) @@ -284,17 +370,25 @@ def add_colorbar(self, **kwargs): def set_axis_labels(self, x_var=None, y_var=None): """Set axis labels on the left column and bottom row of the grid.""" if x_var is not None: - self._x_var = x_var - self.set_xlabels(x_var) + if x_var in self.data.coords: + self._x_var = x_var + self.set_xlabels(label_from_attrs(self.data[x_var])) + else: + # x_var is a string + self.set_xlabels(x_var) + if y_var is not None: - self._y_var = y_var - self.set_ylabels(y_var) + if y_var in self.data.coords: + self._y_var = y_var + self.set_ylabels(label_from_attrs(self.data[y_var])) + else: + self.set_ylabels(y_var) return self def set_xlabels(self, label=None, **kwargs): """Label the x axis on the bottom row of the grid.""" if label is None: - label = self._x_var + label = label_from_attrs(self.data[self._x_var]) for ax in self._bottom_axes: ax.set_xlabel(label, **kwargs) return self @@ -302,7 +396,7 @@ def set_xlabels(self, label=None, **kwargs): def set_ylabels(self, label=None, **kwargs): """Label the y axis on the left column of the grid.""" if label is None: - label = self._y_var + label = label_from_attrs(self.data[self._y_var]) for ax in self._left_axes: ax.set_ylabel(label, **kwargs) return self @@ -418,9 +512,12 @@ def map(self, func, *args, **kwargs): data = self.data.loc[namedict] plt.sca(ax) innerargs = [data[a].values for a in args] - # TODO: is it possible to verify that an artist is mappable? - mappable = func(*innerargs, **kwargs) - self._mappables.append(mappable) + maybe_mappable = func(*innerargs, **kwargs) + # TODO: better way to verify that an artist is mappable? + # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 + if (maybe_mappable and + hasattr(maybe_mappable, 'autoscale_None')): + self._mappables.append(maybe_mappable) self._finalize_grid(*args[:2]) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index d17ceb84e16..8d21e084946 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -5,21 +5,26 @@ Or use the methods on a DataArray: DataArray.plot._____ """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import functools import warnings +from datetime import datetime import numpy as np import pandas as pd -from datetime import datetime -from .utils import (ROBUST_PERCENTILE, _determine_cmap_params, - _infer_xy_labels, get_axis, import_matplotlib_pyplot) -from .facetgrid import FacetGrid +from xarray.core.alignment import align +from xarray.core.common import contains_cftime_datetimes from xarray.core.pycompat import basestring +from .facetgrid import FacetGrid +from .utils import ( + ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, + _interval_to_double_bound_points, _interval_to_mid_points, + _resolve_intervals_2dplot, _valid_other_type, get_axis, + import_matplotlib_pyplot, label_from_attrs) + def _valid_numpy_subdtype(x, numpy_types): """ @@ -33,26 +38,20 @@ def _valid_numpy_subdtype(x, numpy_types): return any(np.issubdtype(x.dtype, t) for t in numpy_types) -def _valid_other_type(x, types): - """ - Do all elements of x have a type from types? - """ - return all(any(isinstance(el, t) for t in types) for el in np.ravel(x)) - - def _ensure_plottable(*args): """ Raise exception if there is anything in args that can't be plotted on an - axis. + axis by matplotlib. """ numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] other_types = [datetime] for x in args: - if not (_valid_numpy_subdtype(np.array(x), numpy_types) or - _valid_other_type(np.array(x), other_types)): + if not (_valid_numpy_subdtype(np.array(x), numpy_types) + or _valid_other_type(np.array(x), other_types)): raise TypeError('Plotting requires coordinates to be numeric ' - 'or dates.') + 'or dates of type np.datetime64 or ' + 'datetime.datetime or pd.Interval.') def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, @@ -80,8 +79,32 @@ def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, return g.map_dataarray(plotfunc, x, y, **kwargs) -def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, - subplot_kws=None, **kwargs): +def _line_facetgrid(darray, row=None, col=None, hue=None, + col_wrap=None, sharex=True, sharey=True, aspect=None, + size=None, subplot_kws=None, **kwargs): + """ + Convenience method to call xarray.plot.FacetGrid for line plots + kwargs are the arguments to pyplot.plot() + """ + ax = kwargs.pop('ax', None) + figsize = kwargs.pop('figsize', None) + if ax is not None: + raise ValueError("Can't use axes when making faceted plots.") + if aspect is None: + aspect = 1 + if size is None: + size = 3 + elif figsize is not None: + raise ValueError('cannot provide both `figsize` and `size` arguments') + + g = FacetGrid(data=darray, col=col, row=row, col_wrap=col_wrap, + sharex=sharex, sharey=sharey, figsize=figsize, + aspect=aspect, size=size, subplot_kws=subplot_kws) + return g.map_dataarray_line(hue=hue, **kwargs) + + +def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, + rtol=0.01, subplot_kws=None, **kwargs): """ Default plot of DataArray using matplotlib.pyplot. @@ -103,6 +126,8 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, If passed, make row faceted plots on this dimension name col : string, optional If passed, make column faceted plots on this dimension name + hue : string, optional + If passed, make faceted line plots with hue on this dimension name col_wrap : integer, optional Use together with ``col`` to wrap faceted plots ax : matplotlib axes, optional @@ -119,29 +144,42 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, """ darray = darray.squeeze() + if contains_cftime_datetimes(darray): + raise NotImplementedError( + 'Built-in plotting of arrays of cftime.datetime objects or arrays ' + 'indexed by cftime.datetime objects is currently not implemented ' + 'within xarray. A possible workaround is to use the ' + 'nc-time-axis package ' + '(https://github.com/SciTools/nc-time-axis) to convert the dates ' + 'to a plottable type and plot your data directly with matplotlib.') + plot_dims = set(darray.dims) plot_dims.discard(row) plot_dims.discard(col) + plot_dims.discard(hue) ndims = len(plot_dims) - error_msg = ('Only 2d plots are supported for facets in xarray. ' + error_msg = ('Only 1d and 2d plots are supported for facets in xarray. ' 'See the package `Seaborn` for more options.') - if ndims == 1: + if ndims in [1, 2]: if row or col: - raise ValueError(error_msg) - plotfunc = line - elif ndims == 2: - # Only 2d can FacetGrid - kwargs['row'] = row - kwargs['col'] = col - kwargs['col_wrap'] = col_wrap - kwargs['subplot_kws'] = subplot_kws - - plotfunc = pcolormesh + kwargs['row'] = row + kwargs['col'] = col + kwargs['col_wrap'] = col_wrap + kwargs['subplot_kws'] = subplot_kws + if ndims == 1: + plotfunc = line + kwargs['hue'] = hue + elif ndims == 2: + if hue: + plotfunc = line + kwargs['hue'] = hue + else: + plotfunc = pcolormesh else: - if row or col: + if row or col or hue: raise ValueError(error_msg) plotfunc = hist @@ -150,6 +188,80 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, return plotfunc(darray, **kwargs) +def _infer_line_data(darray, x, y, hue): + error_msg = ('must be either None or one of ({0:s})' + .format(', '.join([repr(dd) for dd in darray.dims]))) + ndims = len(darray.dims) + + if x is not None and x not in darray.dims and x not in darray.coords: + raise ValueError('x ' + error_msg) + + if y is not None and y not in darray.dims and y not in darray.coords: + raise ValueError('y ' + error_msg) + + if x is not None and y is not None: + raise ValueError('You cannot specify both x and y kwargs' + 'for line plots.') + + if ndims == 1: + dim, = darray.dims # get the only dimension name + huename = None + hueplt = None + huelabel = '' + + if (x is None and y is None) or x == dim: + xplt = darray[dim] + yplt = darray + + else: + yplt = darray[dim] + xplt = darray + + else: + if x is None and y is None and hue is None: + raise ValueError('For 2D inputs, please' + 'specify either hue, x or y.') + + if y is None: + xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) + xplt = darray[xname] + if xplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + yplt = darray.transpose(otherdim, huename) + xplt = xplt.transpose(otherdim, huename) + else: + raise ValueError('For 2D inputs, hue must be a dimension' + + ' i.e. one of ' + repr(darray.dims)) + + else: + yplt = darray.transpose(xname, huename) + + else: + yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) + yplt = darray[yname] + if yplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + xplt = darray.transpose(otherdim, huename) + else: + raise ValueError('For 2D inputs, hue must be a dimension' + + ' i.e. one of ' + repr(darray.dims)) + + else: + xplt = darray.transpose(yname, huename) + + huelabel = label_from_attrs(darray[huename]) + hueplt = darray[huename] + + + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) + + return xplt, yplt, hueplt, xlabel, ylabel, huelabel + + # This function signature should not change so that it can use # matplotlib format strings def line(darray, *args, **kwargs): @@ -175,9 +287,23 @@ def line(darray, *args, **kwargs): Axis on which to plot this figure. By default, use the current axis. Mutually exclusive with ``size`` and ``figsize``. hue : string, optional - Coordinate for which you want multiple lines plotted (2D inputs only). - x : string, optional - Coordinate for x axis. + Dimension or coordinate for which you want multiple lines plotted. + If plotting against a 2D coordinate, ``hue`` must be a dimension. + x, y : string, optional + Dimensions or coordinates for x, y axis. + Only one of these may be specified. + The other coordinate plots values from the DataArray on which this + plot method is called. + xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional + Specifies scaling for the x- and y-axes respectively + xticks, yticks : Specify tick locations for x- and y-axes + xlim, ylim : Specify x- and y-axes limits + xincrease : None, True, or False, optional + Should the values on the x axes be increasing from left to right? + if None, use the default for the matplotlib function. + yincrease : None, True, or False, optional + Should the values on the y axes be increasing from top to bottom? + if None, use the default for the matplotlib function. add_legend : boolean, optional Add legend with y axis coordinates (2D inputs only). *args, **kwargs : optional @@ -185,6 +311,14 @@ def line(darray, *args, **kwargs): """ + # Handle facetgrids first + row = kwargs.pop('row', None) + col = kwargs.pop('col', None) + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop('kwargs')) + return _line_facetgrid(**allargs) + ndims = len(darray.dims) if ndims > 2: raise ValueError('Line plots are for 1- or 2-dimensional DataArrays. ' @@ -198,48 +332,118 @@ def line(darray, *args, **kwargs): ax = kwargs.pop('ax', None) hue = kwargs.pop('hue', None) x = kwargs.pop('x', None) + y = kwargs.pop('y', None) + xincrease = kwargs.pop('xincrease', None) # default needs to be None + yincrease = kwargs.pop('yincrease', None) + xscale = kwargs.pop('xscale', None) # default needs to be None + yscale = kwargs.pop('yscale', None) + xticks = kwargs.pop('xticks', None) + yticks = kwargs.pop('yticks', None) + xlim = kwargs.pop('xlim', None) + ylim = kwargs.pop('ylim', None) add_legend = kwargs.pop('add_legend', True) + _labels = kwargs.pop('_labels', True) + if args is (): + args = kwargs.pop('args', ()) ax = get_axis(figsize, size, aspect, ax) - - if ndims == 1: - xlabel, = darray.dims - if x is not None and xlabel != x: - raise ValueError('Input does not have specified dimension' - ' {!r}'.format(x)) - - x = darray.coords[xlabel] - + xplt, yplt, hueplt, xlabel, ylabel, huelabel = \ + _infer_line_data(darray, x, y, hue) + + # Remove pd.Intervals if contained in xplt.values. + if _valid_other_type(xplt.values, [pd.Interval]): + # Is it a step plot? (see matplotlib.Axes.step) + if kwargs.get('linestyle', '').startswith('steps-'): + xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values, + yplt.values) + # Remove steps-* to be sure that matplotlib is not confused + kwargs['linestyle'] = (kwargs['linestyle'] + .replace('steps-pre', '') + .replace('steps-post', '') + .replace('steps-mid', '')) + if kwargs['linestyle'] == '': + kwargs.pop('linestyle') + else: + xplt_val = _interval_to_mid_points(xplt.values) + yplt_val = yplt.values + xlabel += '_center' else: - if x is None and hue is None: - raise ValueError('For 2D inputs, please specify either hue or x.') + xplt_val = xplt.values + yplt_val = yplt.values - xlabel, huelabel = _infer_xy_labels(darray=darray, x=x, y=hue) - x = darray.coords[xlabel] - darray = darray.transpose(xlabel, huelabel) + _ensure_plottable(xplt_val, yplt_val) - _ensure_plottable(x) + primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) - primitive = ax.plot(x, darray, *args, **kwargs) + if _labels: + if xlabel is not None: + ax.set_xlabel(xlabel) - ax.set_xlabel(xlabel) - ax.set_title(darray._title_for_slice()) + if ylabel is not None: + ax.set_ylabel(ylabel) - if darray.name is not None: - ax.set_ylabel(darray.name) + ax.set_title(darray._title_for_slice()) if darray.ndim == 2 and add_legend: ax.legend(handles=primitive, - labels=list(darray.coords[huelabel].values), + labels=list(hueplt.values), title=huelabel) # Rotate dates on xlabels - if np.issubdtype(x.dtype, np.datetime64): - ax.get_figure().autofmt_xdate() + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + if np.issubdtype(xplt.dtype, np.datetime64): + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha('right') + + _update_axes(ax, xincrease, yincrease, xscale, yscale, + xticks, yticks, xlim, ylim) return primitive +def step(darray, *args, **kwargs): + """ + Step plot of DataArray index against values + + Similar to :func:`matplotlib:matplotlib.pyplot.step` + + Parameters + ---------- + where : {'pre', 'post', 'mid'}, optional, default 'pre' + Define where the steps should be placed: + - 'pre': The y value is continued constantly to the left from + every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the + value ``y[i]``. + - 'post': The y value is continued constantly to the right from + every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the + value ``y[i]``. + - 'mid': Steps occur half-way between the *x* positions. + Note that this parameter is ignored if the x coordinate consists of + :py:func:`pandas.Interval` values, e.g. as a result of + :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual + boundaries of the interval are used. + + *args, **kwargs : optional + Additional arguments following :py:func:`xarray.plot.line` + + """ + if ('ls' in kwargs.keys()) and ('linestyle' not in kwargs.keys()): + kwargs['linestyle'] = kwargs.pop('ls') + + where = kwargs.pop('where', 'pre') + + if where not in ('pre', 'post', 'mid'): + raise ValueError("'where' argument to step must be " + "'pre', 'post' or 'mid'") + + kwargs['linestyle'] = 'steps-' + where + kwargs.get('linestyle', '') + + return line(darray, *args, **kwargs) + + def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): """ Histogram of DataArray @@ -270,37 +474,69 @@ def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): """ ax = get_axis(figsize, size, aspect, ax) + xincrease = kwargs.pop('xincrease', None) # default needs to be None + yincrease = kwargs.pop('yincrease', None) + xscale = kwargs.pop('xscale', None) # default needs to be None + yscale = kwargs.pop('yscale', None) + xticks = kwargs.pop('xticks', None) + yticks = kwargs.pop('yticks', None) + xlim = kwargs.pop('xlim', None) + ylim = kwargs.pop('ylim', None) + no_nan = np.ravel(darray.values) no_nan = no_nan[pd.notnull(no_nan)] primitive = ax.hist(no_nan, **kwargs) - ax.set_ylabel('Count') + ax.set_title('Histogram') + ax.set_xlabel(label_from_attrs(darray)) - if darray.name is not None: - ax.set_title('Histogram of {0}'.format(darray.name)) + _update_axes(ax, xincrease, yincrease, xscale, yscale, + xticks, yticks, xlim, ylim) return primitive -def _update_axes_limits(ax, xincrease, yincrease): +def _update_axes(ax, xincrease, yincrease, + xscale=None, yscale=None, + xticks=None, yticks=None, + xlim=None, ylim=None): """ - Update axes in place to increase or decrease - For use in _plot2d + Update axes with provided parameters """ if xincrease is None: pass - elif xincrease: - ax.set_xlim(sorted(ax.get_xlim())) - elif not xincrease: - ax.set_xlim(sorted(ax.get_xlim(), reverse=True)) + elif xincrease and ax.xaxis_inverted(): + ax.invert_xaxis() + elif not xincrease and not ax.xaxis_inverted(): + ax.invert_xaxis() if yincrease is None: pass - elif yincrease: - ax.set_ylim(sorted(ax.get_ylim())) - elif not yincrease: - ax.set_ylim(sorted(ax.get_ylim(), reverse=True)) + elif yincrease and ax.yaxis_inverted(): + ax.invert_yaxis() + elif not yincrease and not ax.yaxis_inverted(): + ax.invert_yaxis() + + # The default xscale, yscale needs to be None. + # If we set a scale it resets the axes formatters, + # This means that set_xscale('linear') on a datetime axis + # will remove the date labels. So only set the scale when explicitly + # asked to. https://github.com/matplotlib/matplotlib/issues/8740 + if xscale is not None: + ax.set_xscale(xscale) + if yscale is not None: + ax.set_yscale(yscale) + + if xticks is not None: + ax.set_xticks(xticks) + if yticks is not None: + ax.set_yticks(yticks) + + if xlim is not None: + ax.set_xlim(xlim) + if ylim is not None: + ax.set_ylim(ylim) # MUST run before any 2d plotting functions are defined since @@ -325,12 +561,18 @@ def hist(self, ax=None, **kwargs): def line(self, *args, **kwargs): return line(self._da, *args, **kwargs) + @functools.wraps(step) + def step(self, *args, **kwargs): + return step(self._da, *args, **kwargs) + def _rescale_imshow_rgb(darray, vmin, vmax, robust): assert robust or vmin is not None or vmax is not None + # TODO: remove when min numpy version is bumped to 1.13 # There's a cyclic dependency via DataArray, so we can't import from # xarray.ufuncs in global scope. from xarray.ufuncs import maximum, minimum + # Calculate vmin and vmax automatically for `robust=True` if robust: if vmax is None: @@ -356,7 +598,10 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): # After scaling, downcast to 32-bit float. This substantially reduces # memory usage after we hand `darray` off to matplotlib. darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4') - return minimum(maximum(darray, 0), 1) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'xarray.ufuncs', + PendingDeprecationWarning) + return minimum(maximum(darray, 0), 1) def _plot2d(plotfunc): @@ -392,16 +637,23 @@ def _plot2d(plotfunc): If passed, make column faceted plots on this dimension name col_wrap : integer, optional Use together with ``col`` to wrap faceted plots + xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional + Specifies scaling for the x- and y-axes respectively + xticks, yticks : Specify tick locations for x- and y-axes + xlim, ylim : Specify x- and y-axes limits xincrease : None, True, or False, optional Should the values on the x axes be increasing from left to right? - if None, use the default for the matplotlib function + if None, use the default for the matplotlib function. yincrease : None, True, or False, optional Should the values on the y axes be increasing from top to bottom? - if None, use the default for the matplotlib function + if None, use the default for the matplotlib function. add_colorbar : Boolean, optional Adds colorbar to axis add_labels : Boolean, optional Use xarray metadata to label axes + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. vmin, vmax : floats, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, @@ -469,7 +721,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, cmap=None, center=None, robust=False, extend=None, levels=None, infer_intervals=None, colors=None, subplot_kws=None, cbar_ax=None, cbar_kwargs=None, - **kwargs): + xscale=None, yscale=None, xticks=None, yticks=None, + xlim=None, ylim=None, norm=None, **kwargs): # All 2d plots in xarray share this function signature. # Method signature below should be consistent. @@ -552,7 +805,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # Pass the data as a masked ndarray too zval = darray.to_masked_array(copy=False) - _ensure_plottable(xval, yval) + # Replace pd.Intervals if contained in xval or yval. + xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) + yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__) + + _ensure_plottable(xplt, yplt) if 'contour' in plotfunc.__name__ and levels is None: levels = 7 # this is the matplotlib default @@ -566,6 +823,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, 'extend': extend, 'levels': levels, 'filled': plotfunc.__name__ != 'contour', + 'norm': norm, } cmap_params = _determine_cmap_params(**cmap_kwargs) @@ -576,28 +834,31 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # pcolormesh kwargs['extend'] = cmap_params['extend'] kwargs['levels'] = cmap_params['levels'] + # if colors == a single color, matplotlib draws dashed negative + # contours. we lose this feature if we pass cmap and not colors + if isinstance(colors, basestring): + cmap_params['cmap'] = None + kwargs['colors'] = colors if 'pcolormesh' == plotfunc.__name__: kwargs['infer_intervals'] = infer_intervals - # This allows the user to pass in a custom norm coming via kwargs - kwargs.setdefault('norm', cmap_params['norm']) - if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring): # forbid usage of mpl strings raise ValueError("plt.imshow's `aspect` kwarg is not available " "in xarray") ax = get_axis(figsize, size, aspect, ax) - primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'], + primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'], vmin=cmap_params['vmin'], vmax=cmap_params['vmax'], + norm=cmap_params['norm'], **kwargs) # Label the plot with metadata if add_labels: - ax.set_xlabel(xlab) - ax.set_ylabel(ylab) + ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) + ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) if add_colorbar: @@ -608,18 +869,28 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, else: cbar_kwargs.setdefault('cax', cbar_ax) cbar = plt.colorbar(primitive, **cbar_kwargs) - if darray.name and add_labels and 'label' not in cbar_kwargs: - cbar.set_label(darray.name, rotation=90) + if add_labels and 'label' not in cbar_kwargs: + cbar.set_label(label_from_attrs(darray)) elif cbar_ax is not None or cbar_kwargs is not None: # inform the user about keywords which aren't used raise ValueError("cbar_ax and cbar_kwargs can't be used with " "add_colorbar=False.") - _update_axes_limits(ax, xincrease, yincrease) + # origin kwarg overrides yincrease + if 'origin' in kwargs: + yincrease = None + + _update_axes(ax, xincrease, yincrease, xscale, yscale, + xticks, yticks, xlim, ylim) # Rotate dates on xlabels - if np.issubdtype(xval.dtype, np.datetime64): - ax.get_figure().autofmt_xdate() + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + if np.issubdtype(xplt.dtype, np.datetime64): + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha('right') return primitive @@ -631,7 +902,9 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, add_labels=True, vmin=None, vmax=None, cmap=None, colors=None, center=None, robust=False, extend=None, levels=None, infer_intervals=None, subplot_kws=None, - cbar_ax=None, cbar_kwargs=None, **kwargs): + cbar_ax=None, cbar_kwargs=None, + xscale=None, yscale=None, xticks=None, yticks=None, + xlim=None, ylim=None, norm=None, **kwargs): """ The method should have the same signature as the function. @@ -693,10 +966,8 @@ def imshow(x, y, z, ax, **kwargs): left, right = x[0] - xstep, x[-1] + xstep bottom, top = y[-1] + ystep, y[0] - ystep - defaults = {'extent': [left, right, bottom, top], - 'origin': 'upper', - 'interpolation': 'nearest', - } + defaults = {'origin': 'upper', + 'interpolation': 'nearest'} if not hasattr(ax, 'projection'): # not for cartopy geoaxes @@ -705,13 +976,22 @@ def imshow(x, y, z, ax, **kwargs): # Allow user to override these defaults defaults.update(kwargs) + if defaults['origin'] == 'upper': + defaults['extent'] = [left, right, bottom, top] + else: + defaults['extent'] = [left, right, top, bottom] + if z.ndim == 3: # matplotlib imshow uses black for missing data, but Xarray makes # missing data transparent. We therefore add an alpha channel if # there isn't one, and set it to transparent where data is masked. if z.shape[-1] == 3: - z = np.ma.concatenate((z, np.ma.ones(z.shape[:2] + (1,))), 2) - z = z.copy() + alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype) + if np.issubdtype(z.dtype, np.integer): + alpha *= 255 + z = np.ma.concatenate((z, alpha), axis=2) + else: + z = z.copy() z[np.any(z.mask, axis=-1), -1] = 0 primitive = ax.imshow(z, **defaults) @@ -741,7 +1021,27 @@ def contourf(x, y, z, ax, **kwargs): return primitive -def _infer_interval_breaks(coord, axis=0): +def _is_monotonic(coord, axis=0): + """ + >>> _is_monotonic(np.array([0, 1, 2])) + True + >>> _is_monotonic(np.array([2, 1, 0])) + True + >>> _is_monotonic(np.array([0, 2, 1])) + False + """ + if coord.shape[axis] < 3: + return True + else: + n = coord.shape[axis] + delta_pos = (coord.take(np.arange(1, n), axis=axis) >= + coord.take(np.arange(0, n - 1), axis=axis)) + delta_neg = (coord.take(np.arange(1, n), axis=axis) <= + coord.take(np.arange(0, n - 1), axis=axis)) + return np.all(delta_pos) or np.all(delta_neg) + + +def _infer_interval_breaks(coord, axis=0, check_monotonic=False): """ >>> _infer_interval_breaks(np.arange(5)) array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) @@ -750,6 +1050,15 @@ def _infer_interval_breaks(coord, axis=0): [ 2.5, 3.5, 4.5]]) """ coord = np.asarray(coord) + + if check_monotonic and not _is_monotonic(coord, axis=axis): + raise ValueError("The input coordinate is not sorted in increasing " + "order along axis %d. This can lead to unexpected " + "results. Consider calling the `sortby` method on " + "the input DataArray. To plot data with categorical " + "axes, consider using the `heatmap` function from " + "the `seaborn` statistical plotting library." % axis) + deltas = 0.5 * np.diff(coord, axis=axis) if deltas.size == 0: deltas = np.array(0.0) @@ -779,14 +1088,22 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): else: infer_intervals = True - if infer_intervals: + if (infer_intervals and + ((np.shape(x)[0] == np.shape(z)[1]) or + ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])))): if len(x.shape) == 1: - x = _infer_interval_breaks(x) - y = _infer_interval_breaks(y) + x = _infer_interval_breaks(x, check_monotonic=True) else: # we have to infer the intervals on both axes x = _infer_interval_breaks(x, axis=1) x = _infer_interval_breaks(x, axis=0) + + if (infer_intervals and + (np.shape(y)[0] == np.shape(z)[0])): + if len(y.shape) == 1: + y = _infer_interval_breaks(y, check_monotonic=True) + else: + # we have to infer the intervals on both axes y = _infer_interval_breaks(y, axis=1) y = _infer_interval_breaks(y, axis=0) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index c194b9dd8d8..41f61554739 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -1,32 +1,19 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import pkg_resources +from __future__ import absolute_import, division, print_function + +import itertools +import textwrap import warnings import numpy as np import pandas as pd +from ..core.options import OPTIONS from ..core.pycompat import basestring from ..core.utils import is_scalar - ROBUST_PERCENTILE = 2.0 -def _load_default_cmap(fname='default_colormap.csv'): - """ - Returns viridis color map - """ - from matplotlib.colors import LinearSegmentedColormap - - # Not sure what the first arg here should be - f = pkg_resources.resource_stream(__name__, fname) - cm_data = pd.read_csv(f, header=None).values - - return LinearSegmentedColormap.from_list('viridis', cm_data) - - def import_seaborn(): '''import seaborn and handle deprecation of apionly module''' with warnings.catch_warnings(record=True) as w: @@ -115,32 +102,23 @@ def _color_palette(cmap, n_colors): colors_i = np.linspace(0, 1., n_colors) if isinstance(cmap, (list, tuple)): # we have a list of colors - try: - sns = import_seaborn() - except ImportError: - # if that fails, use matplotlib - # in this case, is there any difference between mpl and seaborn? - cmap = ListedColormap(cmap, N=n_colors) - pal = cmap(colors_i) - else: - # first try to turn it into a palette with seaborn - pal = sns.color_palette(cmap, n_colors=n_colors) + cmap = ListedColormap(cmap, N=n_colors) + pal = cmap(colors_i) elif isinstance(cmap, basestring): # we have some sort of named palette try: - # first try to turn it into a palette with seaborn - from seaborn.apionly import color_palette - pal = color_palette(cmap, n_colors=n_colors) - except (ImportError, ValueError): - # ValueError is raised when seaborn doesn't like a colormap - # (e.g. jet). If that fails, use matplotlib + # is this a matplotlib cmap? + cmap = plt.get_cmap(cmap) + pal = cmap(colors_i) + except ValueError: + # ValueError happens when mpl doesn't like a colormap, try seaborn try: - # is this a matplotlib cmap? - cmap = plt.get_cmap(cmap) - except ValueError: + from seaborn.apionly import color_palette + pal = color_palette(cmap, n_colors=n_colors) + except (ValueError, ImportError): # or maybe we just got a single color as a string cmap = ListedColormap([cmap], N=n_colors) - pal = cmap(colors_i) + pal = cmap(colors_i) else: # cmap better be a LinearSegmentedColormap (e.g. viridis) pal = cmap(colors_i) @@ -170,7 +148,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, """ import matplotlib as mpl - calc_data = np.ravel(plot_data[~pd.isnull(plot_data)]) + calc_data = np.ravel(plot_data[np.isfinite(plot_data)]) # Handle all-NaN input data gracefully if calc_data.size == 0: @@ -196,6 +174,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, # vlim might be computed below vlim = None + # save state; needed later + vmin_was_none = vmin is None + vmax_was_none = vmax is None + if vmin is None: if robust: vmin = np.percentile(calc_data, ROBUST_PERCENTILE) @@ -228,22 +210,42 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, vmin += center vmax += center + # now check norm and harmonize with vmin, vmax + if norm is not None: + if norm.vmin is None: + norm.vmin = vmin + else: + if not vmin_was_none and vmin != norm.vmin: + raise ValueError('Cannot supply vmin and a norm' + + ' with a different vmin.') + vmin = norm.vmin + + if norm.vmax is None: + norm.vmax = vmax + else: + if not vmax_was_none and vmax != norm.vmax: + raise ValueError('Cannot supply vmax and a norm' + + ' with a different vmax.') + vmax = norm.vmax + + # if BoundaryNorm, then set levels + if isinstance(norm, mpl.colors.BoundaryNorm): + levels = norm.boundaries + # Choose default colormaps if not provided if cmap is None: if divergent: - cmap = "RdBu_r" + cmap = OPTIONS['cmap_divergent'] else: - cmap = "viridis" - - # Allow viridis before matplotlib 1.5 - if cmap == "viridis": - cmap = _load_default_cmap() + cmap = OPTIONS['cmap_sequential'] # Handle discrete levels - if levels is not None: + if levels is not None and norm is None: if is_scalar(levels): - if user_minmax or levels == 1: + if user_minmax: levels = np.linspace(vmin, vmax, levels) + elif levels == 1: + levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks ticker = mpl.ticker.MaxNLocator(levels - 1) @@ -253,8 +255,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, if extend is None: extend = _determine_extend(calc_data, vmin, vmax) - if levels is not None: - cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled) + if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm): + cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) + norm = newnorm if norm is None else norm return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm) @@ -307,7 +310,7 @@ def _infer_xy_labels_3d(darray, x, y, rgb): assert rgb is not None # Finally, we pick out the red slice and delegate to the 2D version: - return _infer_xy_labels(darray.isel(**{rgb: 0}).squeeze(), x, y) + return _infer_xy_labels(darray.isel(**{rgb: 0}), x, y) def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): @@ -325,11 +328,11 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): raise ValueError('DataArray must be 2d') y, x = darray.dims elif x is None: - if y not in darray.dims: + if y not in darray.dims and y not in darray.coords: raise ValueError('y must be a dimension name if x is not supplied') x = darray.dims[0] if y == darray.dims[1] else darray.dims[1] elif y is None: - if x not in darray.dims: + if x not in darray.dims and x not in darray.coords: raise ValueError('x must be a dimension name if y is not supplied') y = darray.dims[0] if x == darray.dims[1] else darray.dims[1] elif any(k not in darray.coords and k not in darray.dims for k in (x, y)): @@ -364,3 +367,86 @@ def get_axis(figsize, size, aspect, ax): ax = plt.gca() return ax + + +def label_from_attrs(da, extra=''): + ''' Makes informative labels if variable metadata (attrs) follows + CF conventions. ''' + + if da.attrs.get('long_name'): + name = da.attrs['long_name'] + elif da.attrs.get('standard_name'): + name = da.attrs['standard_name'] + elif da.name is not None: + name = da.name + else: + name = '' + + if da.attrs.get('units'): + units = ' [{}]'.format(da.attrs['units']) + else: + units = '' + + return '\n'.join(textwrap.wrap(name + extra + units, 30)) + + +def _interval_to_mid_points(array): + """ + Helper function which returns an array + with the Intervals' mid points. + """ + + return np.array([x.mid for x in array]) + + +def _interval_to_bound_points(array): + """ + Helper function which returns an array + with the Intervals' boundaries. + """ + + array_boundaries = np.array([x.left for x in array]) + array_boundaries = np.concatenate( + (array_boundaries, np.array([array[-1].right]))) + + return array_boundaries + + +def _interval_to_double_bound_points(xarray, yarray): + """ + Helper function to deal with a xarray consisting of pd.Intervals. Each + interval is replaced with both boundaries. I.e. the length of xarray + doubles. yarray is modified so it matches the new shape of xarray. + """ + + xarray1 = np.array([x.left for x in xarray]) + xarray2 = np.array([x.right for x in xarray]) + + xarray = list(itertools.chain.from_iterable(zip(xarray1, xarray2))) + yarray = list(itertools.chain.from_iterable(zip(yarray, yarray))) + + return xarray, yarray + + +def _resolve_intervals_2dplot(val, func_name): + """ + Helper function to replace the values of a coordinate array containing + pd.Interval with their mid-points or - for pcolormesh - boundaries which + increases length by 1. + """ + label_extra = '' + if _valid_other_type(val, [pd.Interval]): + if func_name == 'pcolormesh': + val = _interval_to_bound_points(val) + else: + val = _interval_to_mid_points(val) + label_extra = '_center' + + return val, label_extra + + +def _valid_other_type(x, types): + """ + Do all elements of x have a type from types? + """ + return all(any(isinstance(el, t) for t in types) for el in np.ravel(x)) diff --git a/xarray/testing.py b/xarray/testing.py index f51e474405f..ee5a54cd7dc 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -1,7 +1,5 @@ """Testing functions exposed to the user API""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import numpy as np @@ -50,7 +48,7 @@ def assert_equal(a, b): """ import xarray as xr __tracebackhide__ = True # noqa: F841 - assert type(a) == type(b) + assert type(a) == type(b) # noqa if isinstance(a, (xr.Variable, xr.DataArray, xr.Dataset)): assert a.equals(b), '{}\n{}'.format(a, b) else: @@ -77,7 +75,7 @@ def assert_identical(a, b): """ import xarray as xr __tracebackhide__ = True # noqa: F841 - assert type(a) == type(b) + assert type(a) == type(b) # noqa if isinstance(a, xr.DataArray): assert a.name == b.name assert_identical(a._to_temp_dataset(), b._to_temp_dataset()) @@ -115,7 +113,7 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): """ import xarray as xr __tracebackhide__ = True # noqa: F841 - assert type(a) == type(b) + assert type(a) == type(b) # noqa kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes) if isinstance(a, xr.Variable): assert a.dims == b.dims diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index dadcdeff640..a45f71bbc3b 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -3,17 +3,16 @@ from __future__ import print_function import warnings from contextlib import contextmanager -from distutils.version import LooseVersion +from distutils import version import re import importlib import numpy as np from numpy.testing import assert_array_equal # noqa: F401 -from xarray.core.duck_array_ops import allclose_or_equiv +from xarray.core.duck_array_ops import allclose_or_equiv # noqa import pytest from xarray.core import utils -from xarray.core.pycompat import PY3 from xarray.core.indexing import ExplicitlyIndexed from xarray.testing import (assert_equal, assert_identical, # noqa: F401 assert_allclose) @@ -25,10 +24,6 @@ # old location, for pandas < 0.20 from pandas.util.testing import assert_frame_equal # noqa: F401 -try: - import unittest2 as unittest -except ImportError: - import unittest try: from unittest import mock @@ -54,42 +49,58 @@ def _importorskip(modname, minversion=None): raise ImportError('Minimum version not satisfied') except ImportError: has = False - # TODO: use pytest.skipif instead of unittest.skipUnless - # Using `unittest.skipUnless` is a temporary workaround for pytest#568, - # wherein class decorators stain inherited classes. - # xref: xarray#1531, implemented in xarray #1557. - func = unittest.skipUnless(has, reason='requires {}'.format(modname)) + func = pytest.mark.skipif(not has, reason='requires {}'.format(modname)) return has, func +def LooseVersion(vstring): + # Our development version is something like '0.10.9+aac7bfc' + # This function just ignored the git commit id. + vstring = vstring.split('+')[0] + return version.LooseVersion(vstring) + + has_matplotlib, requires_matplotlib = _importorskip('matplotlib') +has_matplotlib2, requires_matplotlib2 = _importorskip('matplotlib', + minversion='2') has_scipy, requires_scipy = _importorskip('scipy') has_pydap, requires_pydap = _importorskip('pydap.client') has_netCDF4, requires_netCDF4 = _importorskip('netCDF4') has_h5netcdf, requires_h5netcdf = _importorskip('h5netcdf') has_pynio, requires_pynio = _importorskip('Nio') +has_pseudonetcdf, requires_pseudonetcdf = _importorskip('PseudoNetCDF') +has_cftime, requires_cftime = _importorskip('cftime') has_dask, requires_dask = _importorskip('dask') has_bottleneck, requires_bottleneck = _importorskip('bottleneck') has_rasterio, requires_rasterio = _importorskip('rasterio') has_pathlib, requires_pathlib = _importorskip('pathlib') has_zarr, requires_zarr = _importorskip('zarr', minversion='2.2') -has_np112, requires_np112 = _importorskip('numpy', minversion='1.12.0') +has_np113, requires_np113 = _importorskip('numpy', minversion='1.13.0') +has_iris, requires_iris = _importorskip('iris') +has_cfgrib, requires_cfgrib = _importorskip('cfgrib') # some special cases has_scipy_or_netCDF4 = has_scipy or has_netCDF4 -requires_scipy_or_netCDF4 = unittest.skipUnless( - has_scipy_or_netCDF4, reason='requires scipy or netCDF4') +requires_scipy_or_netCDF4 = pytest.mark.skipif( + not has_scipy_or_netCDF4, reason='requires scipy or netCDF4') +has_cftime_or_netCDF4 = has_cftime or has_netCDF4 +requires_cftime_or_netCDF4 = pytest.mark.skipif( + not has_cftime_or_netCDF4, reason='requires cftime or netCDF4') if not has_pathlib: has_pathlib, requires_pathlib = _importorskip('pathlib2') if has_dask: import dask - dask.set_options(get=dask.get) + if LooseVersion(dask.__version__) < '0.18': + dask.set_options(get=dask.get) + else: + dask.config.set(scheduler='single-threaded') try: import_seaborn() has_seaborn = True except ImportError: has_seaborn = False -requires_seaborn = unittest.skipUnless(has_seaborn, reason='requires seaborn') +requires_seaborn = pytest.mark.skipif(not has_seaborn, + reason='requires seaborn') try: _SKIP_FLAKY = not pytest.config.getoption("--run-flaky") @@ -109,39 +120,6 @@ def _importorskip(modname, minversion=None): "internet connection") -class TestCase(unittest.TestCase): - """ - These functions are all deprecated. Instead, use functions in xr.testing - """ - if PY3: - # Python 3 assertCountEqual is roughly equivalent to Python 2 - # assertItemsEqual - def assertItemsEqual(self, first, second, msg=None): - __tracebackhide__ = True # noqa: F841 - return self.assertCountEqual(first, second, msg) - - @contextmanager - def assertWarns(self, message): - __tracebackhide__ = True # noqa: F841 - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings('always', message) - yield - assert len(w) > 0 - assert any(message in str(wi.message) for wi in w) - - def assertVariableNotEqual(self, v1, v2): - __tracebackhide__ = True # noqa: F841 - assert not v1.equals(v2) - - def assertEqual(self, a1, a2): - __tracebackhide__ = True # noqa: F841 - assert a1 == a2 or (a1 != a1 and a2 != a2) - - def assertAllClose(self, a1, a2, rtol=1e-05, atol=1e-8): - __tracebackhide__ = True # noqa: F841 - assert allclose_or_equiv(a1, a2, rtol=rtol, atol=atol) - - @contextmanager def raises_regex(error, pattern): __tracebackhide__ = True # noqa: F841 @@ -187,7 +165,10 @@ def source_ndarray(array): """Given an ndarray, return the base object which holds its memory, or the object itself. """ - base = getattr(array, 'base', np.asarray(array).base) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'DatetimeIndex.base') + warnings.filterwarnings('ignore', 'TimedeltaIndex.base') + base = getattr(array, 'base', np.asarray(array).base) if base is None: base = array return base diff --git a/xarray/tests/data/example.grib b/xarray/tests/data/example.grib new file mode 100644 index 00000000000..596a54d98a0 Binary files /dev/null and b/xarray/tests/data/example.grib differ diff --git a/xarray/tests/data/example.ict b/xarray/tests/data/example.ict new file mode 100644 index 00000000000..bc04888fb80 --- /dev/null +++ b/xarray/tests/data/example.ict @@ -0,0 +1,31 @@ +27, 1001 +Henderson, Barron +U.S. EPA +Example file with artificial data +JUST_A_TEST +1, 1 +2018, 04, 27, 2018, 04, 27 +0 +Start_UTC +7 +1, 1, 1, 1, 1 +-9999, -9999, -9999, -9999, -9999 +lat, degrees_north +lon, degrees_east +elev, meters +TEST_ppbv, ppbv +TESTM_ppbv, ppbv +0 +8 +ULOD_FLAG: -7777 +ULOD_VALUE: N/A +LLOD_FLAG: -8888 +LLOD_VALUE: N/A, N/A, N/A, N/A, 0.025 +OTHER_COMMENTS: www-air.larc.nasa.gov/missions/etc/IcarttDataFormat.htm +REVISION: R0 +R0: No comments for this revision. +Start_UTC, lat, lon, elev, TEST_ppbv, TESTM_ppbv +43200, 41.00000, -71.00000, 5, 1.2345, 2.220 +46800, 42.00000, -72.00000, 15, 2.3456, -9999 +50400, 42.00000, -73.00000, 20, 3.4567, -7777 +50400, 42.00000, -74.00000, 25, 4.5678, -8888 \ No newline at end of file diff --git a/xarray/tests/data/example.uamiv b/xarray/tests/data/example.uamiv new file mode 100644 index 00000000000..fcedcd53097 Binary files /dev/null and b/xarray/tests/data/example.uamiv differ diff --git a/xarray/tests/test_accessors.py b/xarray/tests/test_accessors.py index 30ea1a88c7a..38038fc8f65 100644 --- a/xarray/tests/test_accessors.py +++ b/xarray/tests/test_accessors.py @@ -1,17 +1,19 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function -import xarray as xr import numpy as np import pandas as pd +import pytest + +import xarray as xr -from . import (TestCase, requires_dask, raises_regex, assert_equal, - assert_array_equal) +from . import ( + assert_array_equal, assert_equal, has_cftime, has_cftime_or_netCDF4, + has_dask, raises_regex, requires_dask) -class TestDatetimeAccessor(TestCase): - def setUp(self): +class TestDatetimeAccessor(object): + @pytest.fixture(autouse=True) + def setup(self): nt = 100 data = np.random.rand(10, 10, nt) lons = np.linspace(0, 11, 10) @@ -57,6 +59,9 @@ def test_dask_field_access(self): months = self.times_data.dt.month hours = self.times_data.dt.hour days = self.times_data.dt.day + floor = self.times_data.dt.floor('D') + ceil = self.times_data.dt.ceil('D') + round = self.times_data.dt.round('D') dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) dask_times_2d = xr.DataArray(dask_times_arr, @@ -67,6 +72,9 @@ def test_dask_field_access(self): dask_month = dask_times_2d.dt.month dask_day = dask_times_2d.dt.day dask_hour = dask_times_2d.dt.hour + dask_floor = dask_times_2d.dt.floor('D') + dask_ceil = dask_times_2d.dt.ceil('D') + dask_round = dask_times_2d.dt.round('D') # Test that the data isn't eagerly evaluated assert isinstance(dask_year.data, da.Array) @@ -86,6 +94,9 @@ def test_dask_field_access(self): assert_equal(months, dask_month.compute()) assert_equal(days, dask_day.compute()) assert_equal(hours, dask_hour.compute()) + assert_equal(floor, dask_floor.compute()) + assert_equal(ceil, dask_ceil.compute()) + assert_equal(round, dask_round.compute()) def test_seasons(self): dates = pd.date_range(start="2000/01/01", freq="M", periods=12) @@ -95,3 +106,128 @@ def test_seasons(self): seasons = xr.DataArray(seasons) assert_array_equal(seasons.values, dates.dt.season.values) + + def test_rounders(self): + dates = pd.date_range("2014-01-01", "2014-05-01", freq='H') + xdates = xr.DataArray(np.arange(len(dates)), + dims=['time'], coords=[dates]) + assert_array_equal(dates.floor('D').values, + xdates.time.dt.floor('D').values) + assert_array_equal(dates.ceil('D').values, + xdates.time.dt.ceil('D').values) + assert_array_equal(dates.round('D').values, + xdates.time.dt.round('D').values) + + +_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', + '366_day', 'gregorian', 'proleptic_gregorian'] +_NT = 100 + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def calendar(request): + return request.param + + +@pytest.fixture() +def times(calendar): + import cftime + + return cftime.num2date( + np.arange(_NT), units='hours since 2000-01-01', calendar=calendar, + only_use_cftime_datetimes=True) + + +@pytest.fixture() +def data(times): + data = np.random.rand(10, 10, _NT) + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + return xr.DataArray(data, coords=[lons, lats, times], + dims=['lon', 'lat', 'time'], name='data') + + +@pytest.fixture() +def times_3d(times): + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + times_arr = np.random.choice(times, size=(10, 10, _NT)) + return xr.DataArray(times_arr, coords=[lons, lats, times], + dims=['lon', 'lat', 'time'], + name='data') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour']) +def test_field_access(data, field): + result = getattr(data.time.dt, field) + expected = xr.DataArray( + getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field), + name=field, coords=data.time.coords, dims=data.time.dims) + + assert_equal(result, expected) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour']) +def test_dask_field_access_1d(data, field): + import dask.array as da + + expected = xr.DataArray( + getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field), + name=field, dims=['time']) + times = xr.DataArray(data.time.values, dims=['time']).chunk({'time': 50}) + result = getattr(times.dt, field) + assert isinstance(result.data, da.Array) + assert result.chunks == times.chunks + assert_equal(result.compute(), expected) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour']) +def test_dask_field_access(times_3d, data, field): + import dask.array as da + + expected = xr.DataArray( + getattr(xr.coding.cftimeindex.CFTimeIndex(times_3d.values.ravel()), + field).reshape(times_3d.shape), + name=field, coords=times_3d.coords, dims=times_3d.dims) + times_3d = times_3d.chunk({'lon': 5, 'lat': 5, 'time': 50}) + result = getattr(times_3d.dt, field) + assert isinstance(result.data, da.Array) + assert result.chunks == times_3d.chunks + assert_equal(result.compute(), expected) + + +@pytest.fixture() +def cftime_date_type(calendar): + from .test_coding_times import _all_cftime_date_types + + return _all_cftime_date_types()[calendar] + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_seasons(cftime_date_type): + dates = np.array([cftime_date_type(2000, month, 15) + for month in range(1, 13)]) + dates = xr.DataArray(dates) + seasons = ['DJF', 'DJF', 'MAM', 'MAM', 'MAM', 'JJA', + 'JJA', 'JJA', 'SON', 'SON', 'SON', 'DJF'] + seasons = xr.DataArray(seasons) + + assert_array_equal(seasons.values, dates.dt.season.values) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, + reason='cftime or netCDF4 not installed') +def test_dt_accessor_error_netCDF4(cftime_date_type): + da = xr.DataArray( + [cftime_date_type(1, 1, 1), cftime_date_type(2, 1, 1)], + dims=['time']) + if not has_cftime: + with pytest.raises(TypeError): + da.dt.month + else: + da.dt.month diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 73eb49b863b..fb9c43c0165 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1,42 +1,41 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from io import BytesIO +from __future__ import absolute_import, division, print_function + import contextlib import itertools +import math import os.path import pickle import shutil -import tempfile -import unittest import sys +import tempfile import warnings +from io import BytesIO import numpy as np import pandas as pd import pytest import xarray as xr -from xarray import (Dataset, DataArray, open_dataset, open_dataarray, - open_mfdataset, backends, save_mfdataset) +from xarray import ( + DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, + save_mfdataset) from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing -from xarray.core.pycompat import (iteritems, PY2, ExitStack, basestring, - dask_array_type) - -from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap, - requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf, - requires_pynio, requires_pathlib, requires_zarr, - requires_rasterio, has_netCDF4, has_scipy, assert_allclose, - flaky, network, assert_identical, raises_regex, assert_equal, - assert_array_equal) +from xarray.core.pycompat import ( + ExitStack, basestring, dask_array_type, iteritems) +from xarray.core.options import set_options +from xarray.tests import mock +from . import ( + assert_allclose, assert_array_equal, assert_equal, assert_identical, + has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_cftime, + requires_dask, requires_h5netcdf, requires_netCDF4, requires_pathlib, + requires_pseudonetcdf, requires_pydap, requires_pynio, requires_rasterio, + requires_scipy, requires_scipy_or_netCDF4, requires_zarr, requires_cfgrib) from .test_dataset import create_test_data -from xarray.tests import mock - try: import netCDF4 as nc4 except ImportError: @@ -64,6 +63,13 @@ def open_example_dataset(name, *args, **kwargs): *args, **kwargs) +def open_example_mfdataset(names, *args, **kwargs): + return open_mfdataset( + [os.path.join(os.path.dirname(__file__), 'data', name) + for name in names], + *args, **kwargs) + + def create_masked_and_scaled_data(): x = np.array([np.nan, np.nan, 10, 10.1, 10.2], dtype=np.float32) encoding = {'_FillValue': -1, 'add_offset': 10, @@ -99,7 +105,7 @@ def create_boolean_data(): return Dataset({'x': ('t', [True, False, False, True], attributes)}) -class TestCommon(TestCase): +class TestCommon(object): def test_robust_getitem(self): class UnreliableArrayFailure(Exception): @@ -119,19 +125,18 @@ def __getitem__(self, key): array = UnreliableArray([0]) with pytest.raises(UnreliableArrayFailure): array[0] - self.assertEqual(array[0], 0) + assert array[0] == 0 actual = robust_getitem(array, 0, catch=UnreliableArrayFailure, initial_delay=0) - self.assertEqual(actual, 0) + assert actual == 0 class NetCDF3Only(object): pass -class DatasetIOTestCases(object): - autoclose = False +class DatasetIOBase(object): engine = None file_format = None @@ -160,13 +165,12 @@ def roundtrip_append(self, data, save_kwargs={}, open_kwargs={}, # The save/open methods may be overwritten below def save(self, dataset, path, **kwargs): - dataset.to_netcdf(path, engine=self.engine, format=self.file_format, - **kwargs) + return dataset.to_netcdf(path, engine=self.engine, + format=self.file_format, **kwargs) @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine=self.engine, autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine=self.engine, **kwargs) as ds: yield ds def test_zero_dimensional_variable(self): @@ -215,11 +219,11 @@ def assert_loads(vars=None): with self.roundtrip(expected) as actual: for k, v in actual.variables.items(): # IndexVariables are eagerly loaded into memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) yield actual for k, v in actual.variables.items(): if k in vars: - self.assertTrue(v._in_memory) + assert v._in_memory assert_identical(expected, actual) with pytest.raises(AssertionError): @@ -245,14 +249,14 @@ def test_dataset_compute(self): # Test Dataset.compute() for k, v in actual.variables.items(): # IndexVariables are eagerly cached - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) computed = actual.compute() for k, v in actual.variables.items(): - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) for v in computed.variables.values(): - self.assertTrue(v._in_memory) + assert v._in_memory assert_identical(expected, actual) assert_identical(expected, computed) @@ -336,14 +340,14 @@ def test_roundtrip_string_encoded_characters(self): expected['x'].encoding['dtype'] = 'S1' with self.roundtrip(expected) as actual: assert_identical(expected, actual) - self.assertEqual(actual['x'].encoding['_Encoding'], 'utf-8') + assert actual['x'].encoding['_Encoding'] == 'utf-8' expected['x'].encoding['_Encoding'] = 'ascii' with self.roundtrip(expected) as actual: assert_identical(expected, actual) - self.assertEqual(actual['x'].encoding['_Encoding'], 'ascii') + assert actual['x'].encoding['_Encoding'] == 'ascii' - def test_roundtrip_datetime_data(self): + def test_roundtrip_numpy_datetime_data(self): times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) expected = Dataset({'t': ('t', times), 't0': times[0]}) kwds = {'encoding': {'t0': {'units': 'days since 1950-01-01'}}} @@ -351,6 +355,39 @@ def test_roundtrip_datetime_data(self): assert_identical(expected, actual) assert actual.t0.encoding['units'] == 'days since 1950-01-01' + @requires_cftime + def test_roundtrip_cftime_datetime_data(self): + from .test_coding_times import _all_cftime_date_types + + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + times = [date_type(1, 1, 1), date_type(1, 1, 2)] + expected = Dataset({'t': ('t', times), 't0': times[0]}) + kwds = {'encoding': {'t0': {'units': 'days since 0001-01-01'}}} + expected_decoded_t = np.array(times) + expected_decoded_t0 = np.array([date_type(1, 1, 1)]) + expected_calendar = times[0].calendar + + with warnings.catch_warnings(): + if expected_calendar in {'proleptic_gregorian', 'gregorian'}: + warnings.filterwarnings( + 'ignore', 'Unable to decode time axis') + + with self.roundtrip(expected, save_kwargs=kwds) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (actual.t.encoding['units'] == + 'days since 0001-01-01 00:00:00.000000') + assert (actual.t.encoding['calendar'] == + expected_calendar) + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (actual.t0.encoding['units'] == + 'days since 0001-01-01') + assert (actual.t.encoding['calendar'] == + expected_calendar) + def test_roundtrip_timedelta_data(self): time_deltas = pd.to_timedelta(['1h', '2h', 'NaT']) expected = Dataset({'td': ('td', time_deltas), 'td0': time_deltas[0]}) @@ -393,10 +430,10 @@ def test_roundtrip_coordinates_with_space(self): def test_roundtrip_boolean_dtype(self): original = create_boolean_data() - self.assertEqual(original['x'].dtype, 'bool') + assert original['x'].dtype == 'bool' with self.roundtrip(original) as actual: assert_identical(original, actual) - self.assertEqual(actual['x'].dtype, 'bool') + assert actual['x'].dtype == 'bool' def test_orthogonal_indexing(self): in_memory = create_test_data() @@ -405,34 +442,86 @@ def test_orthogonal_indexing(self): 'dim3': np.arange(5)} expected = in_memory.isel(**indexers) actual = on_disk.isel(**indexers) + # make sure the array is not yet loaded into memory + assert not actual['var1'].variable._in_memory assert_identical(expected, actual) # do it twice, to make sure we're switched from orthogonal -> numpy # when we cached the values actual = on_disk.isel(**indexers) assert_identical(expected, actual) - def _test_vectorized_indexing(self, vindex_support=True): - # Make sure vectorized_indexing works or at least raises - # NotImplementedError + def test_vectorized_indexing(self): in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: indexers = {'dim1': DataArray([0, 2, 0], dims='a'), 'dim2': DataArray([0, 2, 3], dims='a')} expected = in_memory.isel(**indexers) - if vindex_support: - actual = on_disk.isel(**indexers) - assert_identical(expected, actual) - # do it twice, to make sure we're switched from - # orthogonal -> numpy when we cached the values - actual = on_disk.isel(**indexers) - assert_identical(expected, actual) - else: - with raises_regex(NotImplementedError, 'Vectorized indexing '): - actual = on_disk.isel(**indexers) + actual = on_disk.isel(**indexers) + # make sure the array is not yet loaded into memory + assert not actual['var1'].variable._in_memory + assert_identical(expected, actual.load()) + # do it twice, to make sure we're switched from + # vectorized -> numpy when we cached the values + actual = on_disk.isel(**indexers) + assert_identical(expected, actual) - def test_vectorized_indexing(self): - # This test should be overwritten if vindex is supported - self._test_vectorized_indexing(vindex_support=False) + def multiple_indexing(indexers): + # make sure a sequence of lazy indexings certainly works. + with self.roundtrip(in_memory) as on_disk: + actual = on_disk['var3'] + expected = in_memory['var3'] + for ind in indexers: + actual = actual.isel(**ind) + expected = expected.isel(**ind) + # make sure the array is not yet loaded into memory + assert not actual.variable._in_memory + assert_identical(expected, actual.load()) + + # two-staged vectorized-indexing + indexers = [ + {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b']), + 'dim3': DataArray([[0, 4], [1, 3], [2, 2]], dims=['a', 'b'])}, + {'a': DataArray([0, 1], dims=['c']), + 'b': DataArray([0, 1], dims=['c'])} + ] + multiple_indexing(indexers) + + # vectorized-slice mixed + indexers = [ + {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b']), + 'dim3': slice(None, 10)} + ] + multiple_indexing(indexers) + + # vectorized-integer mixed + indexers = [ + {'dim3': 0}, + {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b'])}, + {'a': slice(None, None, 2)} + ] + multiple_indexing(indexers) + + # vectorized-integer mixed + indexers = [ + {'dim3': 0}, + {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b'])}, + {'a': 1, 'b': 0} + ] + multiple_indexing(indexers) + + # with negative step slice. + indexers = [ + {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b']), + 'dim3': slice(-1, 1, -1)}, + ] + multiple_indexing(indexers) + + # with negative step slice. + indexers = [ + {'dim1': DataArray([[0, 7], [2, 6], [3, 5]], dims=['a', 'b']), + 'dim3': slice(-1, 1, -2)}, + ] + multiple_indexing(indexers) def test_isel_dataarray(self): # Make sure isel works lazily. GH:issue:1688 @@ -503,14 +592,13 @@ def test_ondisk_after_print(self): assert not on_disk['var1']._in_memory -class CFEncodedDataTest(DatasetIOTestCases): +class CFEncodedBase(DatasetIOBase): def test_roundtrip_bytes_with_fill_value(self): values = np.array([b'ab', b'cdef', np.nan], dtype=object) encoding = {'_FillValue': b'X', 'dtype': 'S1'} original = Dataset({'x': ('t', values, {}, encoding)}) expected = original.copy(deep=True) - print(original) with self.roundtrip(original) as actual: assert_identical(expected, actual) @@ -534,20 +622,20 @@ def test_unsigned_roundtrip_mask_and_scale(self): encoded = create_encoded_unsigned_masked_scaled_data() with self.roundtrip(decoded) as actual: for k in decoded.variables: - self.assertEqual(decoded.variables[k].dtype, - actual.variables[k].dtype) + assert (decoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(decoded, actual, decode_bytes=False) with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert (encoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(encoded, actual, decode_bytes=False) with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert (encoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(encoded, actual, decode_bytes=False) # make sure roundtrip encoding didn't change the # original dataset. @@ -555,14 +643,14 @@ def test_unsigned_roundtrip_mask_and_scale(self): encoded, create_encoded_unsigned_masked_scaled_data()) with self.roundtrip(encoded) as actual: for k in decoded.variables: - self.assertEqual(decoded.variables[k].dtype, - actual.variables[k].dtype) + assert decoded.variables[k].dtype == \ + actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert encoded.variables[k].dtype == \ + actual.variables[k].dtype assert_allclose(encoded, actual, decode_bytes=False) def test_roundtrip_mask_and_scale(self): @@ -600,12 +688,11 @@ def equals_latlon(obj): with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - self.assertTrue(equals_latlon(ds['temp'].attrs['coordinates'])) - self.assertTrue( - equals_latlon(ds['precip'].attrs['coordinates'])) - self.assertNotIn('coordinates', ds.attrs) - self.assertNotIn('coordinates', ds['lat'].attrs) - self.assertNotIn('coordinates', ds['lon'].attrs) + assert equals_latlon(ds['temp'].attrs['coordinates']) + assert equals_latlon(ds['precip'].attrs['coordinates']) + assert 'coordinates' not in ds.attrs + assert 'coordinates' not in ds['lat'].attrs + assert 'coordinates' not in ds['lon'].attrs modified = original.drop(['temp', 'precip']) with self.roundtrip(modified) as actual: @@ -613,9 +700,9 @@ def equals_latlon(obj): with create_tmp_file() as tmp_file: modified.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - self.assertTrue(equals_latlon(ds.attrs['coordinates'])) - self.assertNotIn('coordinates', ds['lat'].attrs) - self.assertNotIn('coordinates', ds['lon'].attrs) + assert equals_latlon(ds.attrs['coordinates']) + assert 'coordinates' not in ds['lat'].attrs + assert 'coordinates' not in ds['lon'].attrs def test_roundtrip_endian(self): ds = Dataset({'x': np.arange(3, 10, dtype='>i2'), @@ -630,7 +717,7 @@ def test_roundtrip_endian(self): # should still pass though. assert_identical(ds, actual) - if type(self) is NetCDF4DataTest: + if self.engine == 'netcdf4': ds['z'].encoding['endian'] = 'big' with pytest.raises(NotImplementedError): with self.roundtrip(ds) as actual: @@ -651,8 +738,8 @@ def test_encoding_kwarg(self): ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(ds.x.encoding, {}) + assert actual.x.encoding['dtype'] == 'f4' + assert ds.x.encoding == {} kwargs = dict(encoding={'x': {'foo': 'bar'}}) with raises_regex(ValueError, 'unexpected encoding'): @@ -669,34 +756,49 @@ def test_encoding_kwarg(self): with self.roundtrip(ds, save_kwargs=kwargs) as actual: pass + def test_encoding_kwarg_dates(self): ds = Dataset({'t': pd.date_range('2000-01-01', periods=3)}) units = 'days since 1900-01-01' kwargs = dict(encoding={'t': {'units': units}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.t.encoding['units'], units) + assert actual.t.encoding['units'] == units assert_identical(actual, ds) + def test_encoding_kwarg_fixed_width_string(self): + # regression test for GH2149 + for strings in [ + [b'foo', b'bar', b'baz'], + [u'foo', u'bar', u'baz'], + ]: + ds = Dataset({'x': strings}) + kwargs = dict(encoding={'x': {'dtype': 'S1'}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert actual['x'].encoding['dtype'] == 'S1' + assert_identical(actual, ds) + def test_default_fill_value(self): # Test default encoding for float: ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['_FillValue'], - np.nan) - self.assertEqual(ds.x.encoding, {}) + assert math.isnan(actual.x.encoding['_FillValue']) + assert ds.x.encoding == {} # Test default encoding for int: ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'int16'}}) - with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) - self.assertEqual(ds.x.encoding, {}) + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', '.*floating point data as an integer') + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert '_FillValue' not in actual.x.encoding + assert ds.x.encoding == {} # Test default encoding for implicit int: ds = Dataset({'x': ('y', np.arange(10, dtype='int16'))}) with self.roundtrip(ds) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) - self.assertEqual(ds.x.encoding, {}) + assert '_FillValue' not in actual.x.encoding + assert ds.x.encoding == {} def test_explicitly_omit_fill_value(self): ds = Dataset({'x': ('y', [np.pi, -np.pi])}) @@ -704,12 +806,32 @@ def test_explicitly_omit_fill_value(self): with self.roundtrip(ds) as actual: assert '_FillValue' not in actual.x.encoding + def test_explicitly_omit_fill_value_via_encoding_kwarg(self): + ds = Dataset({'x': ('y', [np.pi, -np.pi])}) + kwargs = dict(encoding={'x': {'_FillValue': None}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert '_FillValue' not in actual.x.encoding + assert ds.y.encoding == {} + + def test_explicitly_omit_fill_value_in_coord(self): + ds = Dataset({'x': ('y', [np.pi, -np.pi])}, coords={'y': [0.0, 1.0]}) + ds.y.encoding['_FillValue'] = None + with self.roundtrip(ds) as actual: + assert '_FillValue' not in actual.y.encoding + + def test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg(self): + ds = Dataset({'x': ('y', [np.pi, -np.pi])}, coords={'y': [0.0, 1.0]}) + kwargs = dict(encoding={'y': {'_FillValue': None}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert '_FillValue' not in actual.y.encoding + assert ds.y.encoding == {} + def test_encoding_same_dtype(self): ds = Dataset({'x': ('y', np.arange(10.0, dtype='f4'))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(ds.x.encoding, {}) + assert actual.x.encoding['dtype'] == 'f4' + assert ds.x.encoding == {} def test_append_write(self): # regression for GH1215 @@ -738,9 +860,6 @@ def test_append_with_invalid_dim_raises(self): 'Unable to update size for existing dimension'): self.save(data, tmp_file, mode='a') - def test_vectorized_indexing(self): - self._test_vectorized_indexing(vindex_support=False) - def test_multiindex_not_implemented(self): ds = (Dataset(coords={'y': ('x', [1, 2]), 'z': ('x', ['a', 'b'])}) .set_index(x=['y', 'z'])) @@ -775,8 +894,8 @@ def create_tmp_files(nfiles, suffix='.nc', allow_cleanup_failure=False): yield files -@requires_netCDF4 -class BaseNetCDF4Test(CFEncodedDataTest): +class NetCDF4Base(CFEncodedBase): + """Tests for both netCDF4-python and h5netcdf.""" engine = 'netcdf4' @@ -796,7 +915,7 @@ def test_open_group(self): # check equivalent ways to specify group for group in 'foo', '/foo', 'foo/', '/foo/': - with open_dataset(tmp_file, group=group) as actual: + with self.open(tmp_file, group=group) as actual: assert_equal(actual['x'], expected['x']) # check that missing group raises appropriate exception @@ -806,7 +925,8 @@ def test_open_group(self): open_dataset(tmp_file, group=(1, 2, 3)) def test_open_subgroup(self): - # Create a netCDF file with a dataset within a group within a group + # Create a netCDF file with a dataset stored within a group within a + # group with create_tmp_file() as tmp_file: rootgrp = nc4.Dataset(tmp_file, 'w') foogrp = rootgrp.createGroup('foo') @@ -823,20 +943,32 @@ def test_open_subgroup(self): # check equivalent ways to specify group for group in 'foo/bar', '/foo/bar', 'foo/bar/', '/foo/bar/': - with open_dataset(tmp_file, group=group) as actual: + with self.open(tmp_file, group=group) as actual: assert_equal(actual['x'], expected['x']) def test_write_groups(self): data1 = create_test_data() data2 = data1 * 2 with create_tmp_file() as tmp_file: - data1.to_netcdf(tmp_file, group='data/1') - data2.to_netcdf(tmp_file, group='data/2', mode='a') - with open_dataset(tmp_file, group='data/1') as actual1: + self.save(data1, tmp_file, group='data/1') + self.save(data2, tmp_file, group='data/2', mode='a') + with self.open(tmp_file, group='data/1') as actual1: assert_identical(data1, actual1) - with open_dataset(tmp_file, group='data/2') as actual2: + with self.open(tmp_file, group='data/2') as actual2: assert_identical(data2, actual2) + def test_encoding_kwarg_vlen_string(self): + for input_strings in [ + [b'foo', b'bar', b'baz'], + [u'foo', u'bar', u'baz'], + ]: + original = Dataset({'x': input_strings}) + expected = Dataset({'x': [u'foo', u'bar', u'baz']}) + kwargs = dict(encoding={'x': {'dtype': str}}) + with self.roundtrip(original, save_kwargs=kwargs) as actual: + assert actual['x'].encoding['dtype'] is str + assert_identical(actual, expected) + def test_roundtrip_string_with_fill_value_vlen(self): values = np.array([u'ab', u'cdef', np.nan], dtype=object) expected = Dataset({'x': ('t', values)}) @@ -877,7 +1009,7 @@ def test_default_to_char_arrays(self): data = Dataset({'x': np.array(['foo', 'zzzz'], dtype='S')}) with self.roundtrip(data) as actual: assert_identical(data, actual) - self.assertEqual(actual['x'].dtype, np.dtype('S4')) + assert actual['x'].dtype == np.dtype('S4') def test_open_encodings(self): # Create a netCDF file with explicit time units @@ -902,15 +1034,15 @@ def test_open_encodings(self): actual_encoding = dict((k, v) for k, v in iteritems(actual['time'].encoding) if k in expected['time'].encoding) - self.assertDictEqual(actual_encoding, - expected['time'].encoding) + assert actual_encoding == \ + expected['time'].encoding def test_dump_encodings(self): # regression test for #709 ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'zlib': True}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue(actual.x.encoding['zlib']) + assert actual.x.encoding['zlib'] def test_dump_and_open_encodings(self): # Create a netCDF file with explicit time units @@ -928,8 +1060,7 @@ def test_dump_and_open_encodings(self): with create_tmp_file() as tmp_file2: xarray_dataset.to_netcdf(tmp_file2) with nc4.Dataset(tmp_file2, 'r') as ds: - self.assertEqual( - ds.variables['time'].getncattr('units'), units) + assert ds.variables['time'].getncattr('units') == units assert_array_equal( ds.variables['time'], np.arange(10) + 4) @@ -942,13 +1073,30 @@ def test_compression_encoding(self): 'original_shape': data.var2.shape}) with self.roundtrip(data) as actual: for k, v in iteritems(data['var2'].encoding): - self.assertEqual(v, actual['var2'].encoding[k]) + assert v == actual['var2'].encoding[k] # regression test for #156 expected = data.isel(dim1=0) with self.roundtrip(expected) as actual: assert_equal(expected, actual) + def test_encoding_kwarg_compression(self): + ds = Dataset({'x': np.arange(10.0)}) + encoding = dict(dtype='f4', zlib=True, complevel=9, fletcher32=True, + chunksizes=(5,), shuffle=True) + kwargs = dict(encoding=dict(x=encoding)) + + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert_equal(actual, ds) + assert actual.x.encoding['dtype'] == 'f4' + assert actual.x.encoding['zlib'] + assert actual.x.encoding['complevel'] == 9 + assert actual.x.encoding['fletcher32'] + assert actual.x.encoding['chunksizes'] == (5,) + assert actual.x.encoding['shuffle'] + + assert ds.x.encoding == {} + def test_encoding_chunksizes_unlimited(self): # regression test for GH1225 ds = Dataset({'x': [1, 2, 3], 'y': ('x', [2, 3, 4])}) @@ -1007,12 +1155,12 @@ def test_already_open_dataset(self): v[...] = 42 nc = nc4.Dataset(tmp_file, mode='r') - with backends.NetCDF4DataStore(nc, autoclose=False) as store: - with open_dataset(store) as ds: - expected = Dataset({'x': ((), 42)}) - assert_identical(expected, ds) + store = backends.NetCDF4DataStore(nc) + with open_dataset(store) as ds: + expected = Dataset({'x': ((), 42)}) + assert_identical(expected, ds) - def test_variable_len_strings(self): + def test_read_variable_len_strings(self): with create_tmp_file() as tmp_file: values = np.array(['foo', 'bar', 'baz'], dtype=object) @@ -1028,8 +1176,7 @@ def test_variable_len_strings(self): @requires_netCDF4 -class NetCDF4DataTest(BaseNetCDF4Test, TestCase): - autoclose = False +class TestNetCDF4Data(NetCDF4Base): @contextlib.contextmanager def create_store(self): @@ -1046,7 +1193,7 @@ def test_variable_order(self): ds.coords['c'] = 4 with self.roundtrip(ds) as actual: - self.assertEqual(list(ds.variables), list(actual.variables)) + assert list(ds.variables) == list(actual.variables) def test_unsorted_index_raises(self): # should be fixed in netcdf4 v1.2.1 @@ -1065,29 +1212,52 @@ def test_unsorted_index_raises(self): try: ds2.randovar.values except IndexError as err: - self.assertIn('first by calling .load', str(err)) + assert 'first by calling .load' in str(err) def test_88_character_filename_segmentation_fault(self): # should be fixed in netcdf4 v1.3.1 with mock.patch('netCDF4.__version__', '1.2.4'): with warnings.catch_warnings(): - warnings.simplefilter("error") - with raises_regex(Warning, 'segmentation fault'): + message = ('A segmentation fault may occur when the ' + 'file path has exactly 88 characters') + warnings.filterwarnings('error', message) + with pytest.raises(Warning): # Need to construct 88 character filepath xr.Dataset().to_netcdf('a' * (88 - len(os.getcwd()) - 1)) + def test_setncattr_string(self): + list_of_strings = ['list', 'of', 'strings'] + one_element_list_of_strings = ['one element'] + one_string = 'one string' + attrs = {'foo': list_of_strings, + 'bar': one_element_list_of_strings, + 'baz': one_string} + ds = Dataset({'x': ('y', [1, 2, 3], attrs)}, + attrs=attrs) -class NetCDF4DataStoreAutocloseTrue(NetCDF4DataTest): - autoclose = True + with self.roundtrip(ds) as actual: + for totest in [actual, actual['x']]: + assert_array_equal(list_of_strings, totest.attrs['foo']) + assert_array_equal(one_element_list_of_strings, + totest.attrs['bar']) + assert one_string == totest.attrs['baz'] + + def test_autoclose_future_warning(self): + data = create_test_data() + with create_tmp_file() as tmp_file: + self.save(data, tmp_file) + with pytest.warns(FutureWarning): + with self.open(tmp_file, autoclose=True) as actual: + assert_identical(data, actual) @requires_netCDF4 @requires_dask -class NetCDF4ViaDaskDataTest(NetCDF4DataTest): +class TestNetCDF4ViaDaskData(TestNetCDF4Data): @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=False): - with NetCDF4DataTest.roundtrip( + with TestNetCDF4Data.roundtrip( self, data, save_kwargs, open_kwargs, allow_cleanup_failure) as ds: yield ds.chunk() @@ -1101,16 +1271,26 @@ def test_dataset_caching(self): # caching behavior differs for dask pass - def test_vectorized_indexing(self): - self._test_vectorized_indexing(vindex_support=True) - - -class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest): - autoclose = True + def test_write_inconsistent_chunks(self): + # Construct two variables with the same dimensions, but different + # chunk sizes. + x = da.zeros((100, 100), dtype='f4', chunks=(50, 100)) + x = DataArray(data=x, dims=('lat', 'lon'), name='x') + x.encoding['chunksizes'] = (50, 100) + x.encoding['original_shape'] = (100, 100) + y = da.ones((100, 100), dtype='f4', chunks=(100, 50)) + y = DataArray(data=y, dims=('lat', 'lon'), name='y') + y.encoding['chunksizes'] = (100, 50) + y.encoding['original_shape'] = (100, 100) + # Put them both into the same dataset + ds = Dataset({'x': x, 'y': y}) + with self.roundtrip(ds) as actual: + assert actual['x'].encoding['chunksizes'] == (50, 100) + assert actual['y'].encoding['chunksizes'] == (100, 50) @requires_zarr -class BaseZarrTest(CFEncodedDataTest): +class ZarrBase(CFEncodedBase): DIMENSION_KEY = '_ARRAY_DIMENSIONS' @@ -1120,7 +1300,7 @@ def create_store(self): yield backends.ZarrStore.open_group(store_target, mode='w') def save(self, dataset, store_target, **kwargs): - dataset.to_zarr(store=store_target, **kwargs) + return dataset.to_zarr(store=store_target, **kwargs) @contextlib.contextmanager def open(self, store_target, **kwargs): @@ -1147,17 +1327,27 @@ def test_auto_chunk(self): original, open_kwargs={'auto_chunk': False}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) # there should be no chunks - self.assertEqual(v.chunks, None) + assert v.chunks is None with self.roundtrip( original, open_kwargs={'auto_chunk': True}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) # chunk size should be the same as original - self.assertEqual(v.chunks, original[k].chunks) + assert v.chunks == original[k].chunks + + def test_write_uneven_dask_chunks(self): + # regression for GH#2225 + original = create_test_data().chunk({'dim1': 3, 'dim2': 4, 'dim3': 3}) + + with self.roundtrip( + original, open_kwargs={'auto_chunk': True}) as actual: + for k, v in actual.data_vars.items(): + print(k) + assert v.chunks == actual[k].chunks def test_chunk_encoding(self): # These datasets have no dask chunks. All chunking specified in @@ -1167,7 +1357,7 @@ def test_chunk_encoding(self): data['var2'].encoding.update({'chunks': chunks}) with self.roundtrip(data) as actual: - self.assertEqual(chunks, actual['var2'].encoding['chunks']) + assert chunks == actual['var2'].encoding['chunks'] # expect an error with non-integer chunks data['var2'].encoding.update({'chunks': (5, 4.5)}) @@ -1184,7 +1374,7 @@ def test_chunk_encoding_with_dask(self): # zarr automatically gets chunk information from dask chunks ds_chunk4 = ds.chunk({'x': 4}) with self.roundtrip(ds_chunk4) as actual: - self.assertEqual((4,), actual['var1'].encoding['chunks']) + assert (4,) == actual['var1'].encoding['chunks'] # should fail if dask_chunks are irregular... ds_chunk_irreg = ds.chunk({'x': (5, 4, 3)}) @@ -1197,21 +1387,18 @@ def test_chunk_encoding_with_dask(self): # ... except if the last chunk is smaller than the first ds_chunk_irreg = ds.chunk({'x': (5, 5, 2)}) with self.roundtrip(ds_chunk_irreg) as actual: - self.assertEqual((5,), actual['var1'].encoding['chunks']) + assert (5,) == actual['var1'].encoding['chunks'] + # re-save Zarr arrays + with self.roundtrip(ds_chunk_irreg) as original: + with self.roundtrip(original) as actual: + assert_identical(original, actual) # - encoding specified - # specify compatible encodings for chunk_enc in 4, (4, ): ds_chunk4['var1'].encoding.update({'chunks': chunk_enc}) with self.roundtrip(ds_chunk4) as actual: - self.assertEqual((4,), actual['var1'].encoding['chunks']) - - # specify incompatible encoding - ds_chunk4['var1'].encoding.update({'chunks': (5, 5)}) - with pytest.raises(ValueError) as e_info: - with self.roundtrip(ds_chunk4) as actual: - pass - assert e_info.match('chunks') + assert (4,) == actual['var1'].encoding['chunks'] # TODO: remove this failure once syncronized overlapping writes are # supported by xarray @@ -1220,9 +1407,6 @@ def test_chunk_encoding_with_dask(self): with self.roundtrip(ds_chunk4) as actual: pass - def test_vectorized_indexing(self): - self._test_vectorized_indexing(vindex_support=True) - def test_hidden_zarr_keys(self): expected = create_test_data() with self.create_store() as store: @@ -1280,8 +1464,10 @@ def test_compressor_encoding(self): import zarr blosc_comp = zarr.Blosc(cname='zstd', clevel=3, shuffle=2) save_kwargs = dict(encoding={'var1': {'compressor': blosc_comp}}) - with self.roundtrip(original, save_kwargs=save_kwargs) as actual: - assert repr(actual.var1.encoding['compressor']) == repr(blosc_comp) + with self.roundtrip(original, save_kwargs=save_kwargs) as ds: + actual = ds['var1'].encoding['compressor'] + # get_config returns a dictionary of compressor attributes + assert actual.get_config() == blosc_comp.get_config() def test_group(self): original = create_test_data() @@ -1290,81 +1476,89 @@ def test_group(self): open_kwargs={'group': group}) as actual: assert_identical(original, actual) - # TODO: implement zarr object encoding and make these tests pass - @pytest.mark.xfail(reason="Zarr object encoding not implemented") - def test_multiindex_not_implemented(self): - super(CFEncodedDataTest, self).test_multiindex_not_implemented() - - @pytest.mark.xfail(reason="Zarr object encoding not implemented") - def test_roundtrip_bytes_with_fill_value(self): - super(CFEncodedDataTest, self).test_roundtrip_bytes_with_fill_value() - - @pytest.mark.xfail(reason="Zarr object encoding not implemented") - def test_roundtrip_object_dtype(self): - super(CFEncodedDataTest, self).test_roundtrip_object_dtype() - - @pytest.mark.xfail(reason="Zarr object encoding not implemented") - def test_roundtrip_string_encoded_characters(self): - super(CFEncodedDataTest, - self).test_roundtrip_string_encoded_characters() + def test_encoding_kwarg_fixed_width_string(self): + # not relevant for zarr, since we don't use EncodedStringCoder + pass # TODO: someone who understand caching figure out whether chaching # makes sense for Zarr backend @pytest.mark.xfail(reason="Zarr caching not implemented") def test_dataset_caching(self): - super(CFEncodedDataTest, self).test_dataset_caching() + super(CFEncodedBase, self).test_dataset_caching() @pytest.mark.xfail(reason="Zarr stores can not be appended to") def test_append_write(self): - super(CFEncodedDataTest, self).test_append_write() + super(CFEncodedBase, self).test_append_write() @pytest.mark.xfail(reason="Zarr stores can not be appended to") def test_append_overwrite_values(self): - super(CFEncodedDataTest, self).test_append_overwrite_values() + super(CFEncodedBase, self).test_append_overwrite_values() @pytest.mark.xfail(reason="Zarr stores can not be appended to") def test_append_with_invalid_dim_raises(self): - super(CFEncodedDataTest, self).test_append_with_invalid_dim_raises() + super(CFEncodedBase, self).test_append_with_invalid_dim_raises() + + def test_to_zarr_compute_false_roundtrip(self): + from dask.delayed import Delayed + + original = create_test_data().chunk() + + with self.create_zarr_target() as store: + delayed_obj = self.save(original, store, compute=False) + assert isinstance(delayed_obj, Delayed) + delayed_obj.compute() + + with self.open(store) as actual: + assert_identical(original, actual) + + def test_encoding_chunksizes(self): + # regression test for GH2278 + # see also test_encoding_chunksizes_unlimited + nx, ny, nt = 4, 4, 5 + original = xr.Dataset({}, coords={'x': np.arange(nx), + 'y': np.arange(ny), + 't': np.arange(nt)}) + original['v'] = xr.Variable(('x', 'y', 't'), np.zeros((nx, ny, nt))) + original = original.chunk({'t': 1, 'x': 2, 'y': 2}) + + with self.roundtrip(original) as ds1: + assert_equal(ds1, original) + with self.roundtrip(ds1.isel(t=0)) as ds2: + assert_equal(ds2, original.isel(t=0)) @requires_zarr -class ZarrDictStoreTest(BaseZarrTest, TestCase): +class TestZarrDictStore(ZarrBase): @contextlib.contextmanager def create_zarr_target(self): yield {} @requires_zarr -class ZarrDirectoryStoreTest(BaseZarrTest, TestCase): +class TestZarrDirectoryStore(ZarrBase): @contextlib.contextmanager def create_zarr_target(self): with create_tmp_file(suffix='.zarr') as tmp: yield tmp -def test_replace_slices_with_arrays(): - (actual,) = xr.backends.zarr._replace_slices_with_arrays( - key=(slice(None),), shape=(5,)) - np.testing.assert_array_equal(actual, np.arange(5)) - - actual = xr.backends.zarr._replace_slices_with_arrays( - key=(np.arange(5),) * 3, shape=(8, 10, 12)) - expected = np.stack([np.arange(5)] * 3) - np.testing.assert_array_equal(np.stack(actual), expected) +class ScipyWriteBase(CFEncodedBase, NetCDF3Only): - a, b = xr.backends.zarr._replace_slices_with_arrays( - key=(np.arange(5), slice(None)), shape=(8, 10)) - np.testing.assert_array_equal(a, np.arange(5)[:, np.newaxis]) - np.testing.assert_array_equal(b, np.arange(10)[np.newaxis, :]) + def test_append_write(self): + import scipy + if scipy.__version__ == '1.0.1': + pytest.xfail('https://github.com/scipy/scipy/issues/8625') + super(ScipyWriteBase, self).test_append_write() - a, b = xr.backends.zarr._replace_slices_with_arrays( - key=(slice(None), np.arange(5)), shape=(8, 10)) - np.testing.assert_array_equal(a, np.arange(8)[np.newaxis, :]) - np.testing.assert_array_equal(b, np.arange(5)[:, np.newaxis]) + def test_append_overwrite_values(self): + import scipy + if scipy.__version__ == '1.0.1': + pytest.xfail('https://github.com/scipy/scipy/issues/8625') + super(ScipyWriteBase, self).test_append_overwrite_values() @requires_scipy -class ScipyInMemoryDataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class TestScipyInMemoryData(ScipyWriteBase): engine = 'scipy' @contextlib.contextmanager @@ -1376,21 +1570,16 @@ def test_to_netcdf_explicit_engine(self): # regression test for GH1321 Dataset({'foo': 42}).to_netcdf(engine='scipy') - @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') - def test_bytesio_pickle(self): + def test_bytes_pickle(self): data = Dataset({'foo': ('x', [1, 2, 3])}) - fobj = BytesIO(data.to_netcdf()) - with open_dataset(fobj, autoclose=self.autoclose) as ds: + fobj = data.to_netcdf() + with self.open(fobj) as ds: unpickled = pickle.loads(pickle.dumps(ds)) assert_identical(unpickled, data) -class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): - autoclose = True - - @requires_scipy -class ScipyFileObjectTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class TestScipyFileObject(ScipyWriteBase): engine = 'scipy' @contextlib.contextmanager @@ -1418,7 +1607,7 @@ def test_pickle_dataarray(self): @requires_scipy -class ScipyFilePathTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class TestScipyFilePath(ScipyWriteBase): engine = 'scipy' @contextlib.contextmanager @@ -1442,7 +1631,7 @@ def test_netcdf3_endianness(self): # regression test for GH416 expected = open_example_dataset('bears.nc', engine='scipy') for var in expected.variables.values(): - self.assertTrue(var.dtype.isnative) + assert var.dtype.isnative @requires_netCDF4 def test_nc4_scipy(self): @@ -1454,12 +1643,8 @@ def test_nc4_scipy(self): open_dataset(tmp_file, engine='scipy') -class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest): - autoclose = True - - @requires_netCDF4 -class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class TestNetCDF3ViaNetCDF4Data(CFEncodedBase, NetCDF3Only): engine = 'netcdf4' file_format = 'NETCDF3_CLASSIC' @@ -1470,14 +1655,16 @@ def create_store(self): tmp_file, mode='w', format='NETCDF3_CLASSIC') as store: yield store - -class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): - autoclose = True + def test_encoding_kwarg_vlen_string(self): + original = Dataset({'x': [u'foo', u'bar', u'baz']}) + kwargs = dict(encoding={'x': {'dtype': str}}) + with raises_regex(ValueError, 'encoding dtype=str for vlen'): + with self.roundtrip(original, save_kwargs=kwargs): + pass @requires_netCDF4 -class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, - TestCase): +class TestNetCDF4ClassicViaNetCDF4Data(CFEncodedBase, NetCDF3Only): engine = 'netcdf4' file_format = 'NETCDF4_CLASSIC' @@ -1489,13 +1676,8 @@ def create_store(self): yield store -class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( - NetCDF4ClassicViaNetCDF4DataTest): - autoclose = True - - @requires_scipy_or_netCDF4 -class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class TestGenericNetCDFData(CFEncodedBase, NetCDF3Only): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed file_format = 'netcdf3_64bit' @@ -1520,6 +1702,7 @@ def test_engine(self): with raises_regex(ValueError, 'can only read'): open_dataset(BytesIO(netcdf_bytes), engine='foobar') + @pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/2050') def test_cross_engine_read_write_netcdf3(self): data = create_test_data() valid_engines = set() @@ -1548,22 +1731,30 @@ def test_encoding_unlimited_dims(self): ds = Dataset({'x': ('y', np.arange(10.0))}) with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=['y'])) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') + assert_equal(ds, actual) + + # Regression test for https://github.com/pydata/xarray/issues/2134 + with self.roundtrip(ds, + save_kwargs=dict(unlimited_dims='y')) as actual: + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) - -class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): - autoclose = True + # Regression test for https://github.com/pydata/xarray/issues/2134 + ds.encoding = {'unlimited_dims': 'y'} + with self.roundtrip(ds) as actual: + assert actual.encoding['unlimited_dims'] == set('y') + assert_equal(ds, actual) @requires_h5netcdf @requires_netCDF4 -class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): +class TestH5NetCDFData(NetCDF4Base): engine = 'h5netcdf' @contextlib.contextmanager @@ -1571,23 +1762,14 @@ def create_store(self): with create_tmp_file() as tmp_file: yield backends.H5NetCDFStore(tmp_file, 'w') - def test_orthogonal_indexing(self): - # simplified version for h5netcdf - in_memory = create_test_data() - with self.roundtrip(in_memory) as on_disk: - indexers = {'dim3': np.arange(5)} - expected = in_memory.isel(**indexers) - actual = on_disk.isel(**indexers) - assert_identical(expected, actual.load()) - - def test_array_type_after_indexing(self): - # h5netcdf does not support multiple list-like indexers - pass - + @pytest.mark.filterwarnings('ignore:complex dtypes are supported by h5py') def test_complex(self): expected = Dataset({'x': ('y', np.ones(5) + 1j * np.ones(5))}) - with self.roundtrip(expected) as actual: - assert_equal(expected, actual) + with pytest.warns(FutureWarning): + # TODO: make it possible to write invalid netCDF files from xarray + # without a warning + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) @pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/535') def test_cross_engine_read_write_netcdf4(self): @@ -1616,106 +1798,166 @@ def test_encoding_unlimited_dims(self): ds = Dataset({'x': ('y', np.arange(10.0))}) with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=['y'])) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) + def test_compression_encoding_h5py(self): + ENCODINGS = ( + # h5py style compression with gzip codec will be converted to + # NetCDF4-Python style on round-trip + ({'compression': 'gzip', 'compression_opts': 9}, + {'zlib': True, 'complevel': 9}), + # What can't be expressed in NetCDF4-Python style is + # round-tripped unaltered + ({'compression': 'lzf', 'compression_opts': None}, + {'compression': 'lzf', 'compression_opts': None}), + # If both styles are used together, h5py format takes precedence + ({'compression': 'lzf', 'compression_opts': None, + 'zlib': True, 'complevel': 9}, + {'compression': 'lzf', 'compression_opts': None})) + + for compr_in, compr_out in ENCODINGS: + data = create_test_data() + compr_common = { + 'chunksizes': (5, 5), + 'fletcher32': True, + 'shuffle': True, + 'original_shape': data.var2.shape + } + data['var2'].encoding.update(compr_in) + data['var2'].encoding.update(compr_common) + compr_out.update(compr_common) + with self.roundtrip(data) as actual: + for k, v in compr_out.items(): + assert v == actual['var2'].encoding[k] + + def test_compression_check_encoding_h5py(self): + """When mismatched h5py and NetCDF4-Python encodings are expressed + in to_netcdf(encoding=...), must raise ValueError + """ + data = Dataset({'x': ('y', np.arange(10.0))}) + # Compatible encodings are graciously supported + with create_tmp_file() as tmp_file: + data.to_netcdf( + tmp_file, engine='h5netcdf', + encoding={'x': {'compression': 'gzip', 'zlib': True, + 'compression_opts': 6, 'complevel': 6}}) + with open_dataset(tmp_file, engine='h5netcdf') as actual: + assert actual.x.encoding['zlib'] is True + assert actual.x.encoding['complevel'] == 6 + + # Incompatible encodings cause a crash + with create_tmp_file() as tmp_file: + with raises_regex(ValueError, + "'zlib' and 'compression' encodings mismatch"): + data.to_netcdf( + tmp_file, engine='h5netcdf', + encoding={'x': {'compression': 'lzf', 'zlib': True}}) -# tests pending h5netcdf fix -@unittest.skip -class H5NetCDFDataTestAutocloseTrue(H5NetCDFDataTest): - autoclose = True - - -class OpenMFDatasetManyFilesTest(TestCase): - def validate_open_mfdataset_autoclose(self, engine, nfiles=10): - randdata = np.random.randn(nfiles) - original = Dataset({'foo': ('x', randdata)}) - # test standard open_mfdataset approach with too many files - with create_tmp_files(nfiles) as tmpfiles: - for readengine in engine: - writeengine = (readengine if readengine != 'pynio' - else 'netcdf4') - # split into multiple sets of temp files - for ii in original.x.values: - subds = original.isel(x=slice(ii, ii + 1)) - subds.to_netcdf(tmpfiles[ii], engine=writeengine) - - # check that calculation on opened datasets works properly - ds = open_mfdataset(tmpfiles, engine=readengine, - autoclose=True) - self.assertAllClose(ds.x.sum().values, - (nfiles * (nfiles - 1)) / 2) - self.assertAllClose(ds.foo.sum().values, np.sum(randdata)) - self.assertAllClose(ds.sum().foo.values, np.sum(randdata)) - ds.close() - - def validate_open_mfdataset_large_num_files(self, engine): - self.validate_open_mfdataset_autoclose(engine, nfiles=2000) + with create_tmp_file() as tmp_file: + with raises_regex( + ValueError, + "'complevel' and 'compression_opts' encodings mismatch"): + data.to_netcdf( + tmp_file, engine='h5netcdf', + encoding={'x': {'compression': 'gzip', + 'compression_opts': 5, 'complevel': 6}}) + + def test_dump_encodings_h5py(self): + # regression test for #709 + ds = Dataset({'x': ('y', np.arange(10.0))}) - @requires_dask - @requires_netCDF4 - def test_1_autoclose_netcdf4(self): - self.validate_open_mfdataset_autoclose(engine=['netcdf4']) + kwargs = {'encoding': {'x': { + 'compression': 'gzip', 'compression_opts': 9}}} + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert actual.x.encoding['zlib'] + assert actual.x.encoding['complevel'] == 9 - @requires_dask - @requires_scipy - def test_2_autoclose_scipy(self): - self.validate_open_mfdataset_autoclose(engine=['scipy']) + kwargs = {'encoding': {'x': { + 'compression': 'lzf', 'compression_opts': None}}} + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert actual.x.encoding['compression'] == 'lzf' + assert actual.x.encoding['compression_opts'] is None - @requires_dask - @requires_pynio - def test_3_autoclose_pynio(self): - self.validate_open_mfdataset_autoclose(engine=['pynio']) - # use of autoclose=True with h5netcdf broken because of - # probable h5netcdf error - @requires_dask - @requires_h5netcdf - @pytest.mark.xfail - def test_4_autoclose_h5netcdf(self): - self.validate_open_mfdataset_autoclose(engine=['h5netcdf']) +@pytest.fixture(params=['scipy', 'netcdf4', 'h5netcdf', 'pynio']) +def readengine(request): + return request.param - # These tests below are marked as flaky (and skipped by default) because - # they fail sometimes on Travis-CI, for no clear reason. - @requires_dask - @requires_netCDF4 - @flaky - @pytest.mark.slow - def test_1_open_large_num_files_netcdf4(self): - self.validate_open_mfdataset_large_num_files(engine=['netcdf4']) +@pytest.fixture(params=[1, 20]) +def nfiles(request): + return request.param - @requires_dask - @requires_scipy - @flaky - @pytest.mark.slow - def test_2_open_large_num_files_scipy(self): - self.validate_open_mfdataset_large_num_files(engine=['scipy']) - @requires_dask - @requires_pynio - @flaky - @pytest.mark.slow - def test_3_open_large_num_files_pynio(self): - self.validate_open_mfdataset_large_num_files(engine=['pynio']) - - # use of autoclose=True with h5netcdf broken because of - # probable h5netcdf error - @requires_dask - @requires_h5netcdf - @flaky - @pytest.mark.xfail - @pytest.mark.slow - def test_4_open_large_num_files_h5netcdf(self): - self.validate_open_mfdataset_large_num_files(engine=['h5netcdf']) +@pytest.fixture(params=[5, None]) +def file_cache_maxsize(request): + maxsize = request.param + if maxsize is not None: + with set_options(file_cache_maxsize=maxsize): + yield maxsize + else: + yield maxsize + + +@pytest.fixture(params=[True, False]) +def parallel(request): + return request.param + + +@pytest.fixture(params=[None, 5]) +def chunks(request): + return request.param + + +# using pytest.mark.skipif does not work so this a work around +def skip_if_not_engine(engine): + if engine == 'netcdf4': + pytest.importorskip('netCDF4') + elif engine == 'pynio': + pytest.importorskip('Nio') + else: + pytest.importorskip(engine) + + +def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks, + file_cache_maxsize): + + # skip certain combinations + skip_if_not_engine(readengine) + + if not has_dask and parallel: + pytest.skip('parallel requires dask') + + if ON_WINDOWS: + pytest.skip('Skipping on Windows') + + randdata = np.random.randn(nfiles) + original = Dataset({'foo': ('x', randdata)}) + # test standard open_mfdataset approach with too many files + with create_tmp_files(nfiles) as tmpfiles: + writeengine = (readengine if readengine != 'pynio' else 'netcdf4') + # split into multiple sets of temp files + for ii in original.x.values: + subds = original.isel(x=slice(ii, ii + 1)) + subds.to_netcdf(tmpfiles[ii], engine=writeengine) + + # check that calculation on opened datasets works properly + actual = open_mfdataset(tmpfiles, engine=readengine, parallel=parallel, + chunks=chunks) + + # check that using open_mfdataset returns dask arrays for variables + assert isinstance(actual['foo'].data, dask_array_type) + + assert_identical(original, actual) @requires_scipy_or_netCDF4 -class OpenMFDatasetWithDataVarsAndCoordsKwTest(TestCase): +class TestOpenMFDatasetWithDataVarsAndCoordsKw(object): coord_name = 'lon' var_name = 'v1' @@ -1786,9 +2028,9 @@ def test_common_coord_when_datavars_all(self): var_shape = ds[self.var_name].shape - self.assertEqual(var_shape, coord_shape) - self.assertNotEqual(coord_shape1, coord_shape) - self.assertNotEqual(coord_shape2, coord_shape) + assert var_shape == coord_shape + assert coord_shape1 != coord_shape + assert coord_shape2 != coord_shape def test_common_coord_when_datavars_minimal(self): opt = 'minimal' @@ -1803,9 +2045,9 @@ def test_common_coord_when_datavars_minimal(self): var_shape = ds[self.var_name].shape - self.assertNotEqual(var_shape, coord_shape) - self.assertEqual(coord_shape1, coord_shape) - self.assertEqual(coord_shape2, coord_shape) + assert var_shape != coord_shape + assert coord_shape1 == coord_shape + assert coord_shape2 == coord_shape def test_invalid_data_vars_value_should_fail(self): @@ -1823,7 +2065,7 @@ def test_invalid_data_vars_value_should_fail(self): @requires_dask @requires_scipy @requires_netCDF4 -class DaskTest(TestCase, DatasetIOTestCases): +class TestDask(DatasetIOBase): @contextlib.contextmanager def create_store(self): yield Dataset() @@ -1833,23 +2075,42 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=False): yield data.chunk() - # Override methods in DatasetIOTestCases - not applicable to dask + # Override methods in DatasetIOBase - not applicable to dask def test_roundtrip_string_encoded_characters(self): pass def test_roundtrip_coordinates_with_space(self): pass - def test_roundtrip_datetime_data(self): - # Override method in DatasetIOTestCases - remove not applicable + def test_roundtrip_numpy_datetime_data(self): + # Override method in DatasetIOBase - remove not applicable # save_kwds times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) expected = Dataset({'t': ('t', times), 't0': times[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) + def test_roundtrip_cftime_datetime_data(self): + # Override method in DatasetIOBase - remove not applicable + # save_kwds + from .test_coding_times import _all_cftime_date_types + + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + times = [date_type(1, 1, 1), date_type(1, 1, 2)] + expected = Dataset({'t': ('t', times), 't0': times[0]}) + expected_decoded_t = np.array(times) + expected_decoded_t0 = np.array([date_type(1, 1, 1)]) + + with self.roundtrip(expected) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + assert (abs_diff <= np.timedelta64(1, 's')).all() + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + assert (abs_diff <= np.timedelta64(1, 's')).all() + def test_write_store(self): - # Override method in DatasetIOTestCases - not applicable to dask + # Override method in DatasetIOBase - not applicable to dask pass def test_dataset_caching(self): @@ -1865,19 +2126,20 @@ def test_open_mfdataset(self): with create_tmp_file() as tmp2: original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: - self.assertIsInstance(actual.foo.variable.data, da.Array) - self.assertEqual(actual.foo.variable.data.chunks, - ((5, 5),)) + with open_mfdataset([tmp1, tmp2]) as actual: + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == \ + ((5, 5),) assert_identical(original, actual) - with open_mfdataset([tmp1, tmp2], chunks={'x': 3}, - autoclose=self.autoclose) as actual: - self.assertEqual(actual.foo.variable.data.chunks, - ((3, 2, 3, 2),)) + with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual: + assert actual.foo.variable.data.chunks == \ + ((3, 2, 3, 2),) with raises_regex(IOError, 'no files to open'): - open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose) + open_mfdataset('foo-bar-baz-*.nc') + + with raises_regex(ValueError, 'wild-card'): + open_mfdataset('http://some/remote/uri') @requires_pathlib def test_open_mfdataset_pathlib(self): @@ -1888,8 +2150,7 @@ def test_open_mfdataset_pathlib(self): tmp2 = Path(tmp2) original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(original, actual) def test_attrs_mfdataset(self): @@ -1905,7 +2166,7 @@ def test_attrs_mfdataset(self): with open_mfdataset([tmp1, tmp2]) as actual: # presumes that attributes inherited from # first dataset loaded - self.assertEqual(actual.test1, ds1.test1) + assert actual.test1 == ds1.test1 # attributes from ds2 are not retained, e.g., with raises_regex(AttributeError, 'no attribute'): @@ -1920,8 +2181,7 @@ def preprocess(ds): return ds.assign_coords(z=0) expected = preprocess(original) - with open_mfdataset(tmp, preprocess=preprocess, - autoclose=self.autoclose) as actual: + with open_mfdataset(tmp, preprocess=preprocess) as actual: assert_identical(expected, actual) def test_save_mfdataset_roundtrip(self): @@ -1931,8 +2191,7 @@ def test_save_mfdataset_roundtrip(self): with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_save_mfdataset_invalid(self): @@ -1958,15 +2217,14 @@ def test_save_mfdataset_pathlib_roundtrip(self): tmp1 = Path(tmp1) tmp2 = Path(tmp2) save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_open_and_do_math(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: actual = 1.0 * ds assert_allclose(original, actual, decode_bytes=False) @@ -1976,8 +2234,7 @@ def test_open_mfdataset_concat_dim_none(self): data = Dataset({'x': 0}) data.to_netcdf(tmp1) Dataset({'x': np.nan}).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], concat_dim=None, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], concat_dim=None) as actual: assert_identical(data, actual) def test_open_dataset(self): @@ -1985,15 +2242,28 @@ def test_open_dataset(self): with create_tmp_file() as tmp: original.to_netcdf(tmp) with open_dataset(tmp, chunks={'x': 5}) as actual: - self.assertIsInstance(actual.foo.variable.data, da.Array) - self.assertEqual(actual.foo.variable.data.chunks, ((5, 5),)) + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == ((5, 5),) assert_identical(original, actual) with open_dataset(tmp, chunks=5) as actual: assert_identical(original, actual) with open_dataset(tmp) as actual: - self.assertIsInstance(actual.foo.variable.data, np.ndarray) + assert isinstance(actual.foo.variable.data, np.ndarray) assert_identical(original, actual) + def test_open_single_dataset(self): + # Test for issue GH #1988. This makes sure that the + # concat_dim is utilized when specified in open_mfdataset(). + rnddata = np.random.randn(10) + original = Dataset({'foo': ('x', rnddata)}) + dim = DataArray([100], name='baz', dims='baz') + expected = Dataset({'foo': (('baz', 'x'), rnddata[np.newaxis, :])}, + {'baz': [100]}) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + with open_mfdataset([tmp], concat_dim=dim) as actual: + assert_identical(expected, actual) + def test_dask_roundtrip(self): with create_tmp_file() as tmp: data = create_test_data() @@ -2010,38 +2280,46 @@ def test_deterministic_names(self): with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: original_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: repeat_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) for var_name, dask_name in original_names.items(): - self.assertIn(var_name, dask_name) - self.assertEqual(dask_name[:13], 'open_dataset-') - self.assertEqual(original_names, repeat_names) + assert var_name in dask_name + assert dask_name[:13] == 'open_dataset-' + assert original_names == repeat_names def test_dataarray_compute(self): # Test DataArray.compute() on dask backend. - # The test for Dataset.compute() is already in DatasetIOTestCases; + # The test for Dataset.compute() is already in DatasetIOBase; # however dask is the only tested backend which supports DataArrays actual = DataArray([1, 2]).chunk() computed = actual.compute() - self.assertFalse(actual._in_memory) - self.assertTrue(computed._in_memory) + assert not actual._in_memory + assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - def test_vectorized_indexing(self): - self._test_vectorized_indexing(vindex_support=True) - + def test_save_mfdataset_compute_false_roundtrip(self): + from dask.delayed import Delayed -class DaskTestAutocloseTrue(DaskTest): - autoclose = True + original = Dataset({'foo': ('x', np.random.randn(10))}).chunk() + datasets = [original.isel(x=slice(5)), + original.isel(x=slice(5, 10))] + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2: + delayed_obj = save_mfdataset(datasets, [tmp1, tmp2], + engine=self.engine, compute=False) + assert isinstance(delayed_obj, Delayed) + delayed_obj.compute() + with open_mfdataset([tmp1, tmp2]) as actual: + assert_identical(actual, original) @requires_scipy_or_netCDF4 @requires_pydap -class PydapTest(TestCase): +class TestPydap(object): def convert_to_pydap_dataset(self, original): from pydap.model import GridType, BaseType, DatasetType ds = DatasetType('bears', **original.attrs) @@ -2073,8 +2351,8 @@ def test_cmp_local_file(self): assert_equal(actual, expected) # global attributes should be global attributes on the dataset - self.assertNotIn('NC_GLOBAL', actual.attrs) - self.assertIn('history', actual.attrs) + assert 'NC_GLOBAL' not in actual.attrs + assert 'history' in actual.attrs # we don't check attributes exactly with assertDatasetIdentical() # because the test DAP server seems to insert some extra @@ -2082,8 +2360,7 @@ def test_cmp_local_file(self): assert actual.attrs.keys() == expected.attrs.keys() with self.create_datasets() as (actual, expected): - assert_equal( - actual.isel(l=2), expected.isel(l=2)) # noqa: E741 + assert_equal(actual.isel(l=2), expected.isel(l=2)) # noqa with self.create_datasets() as (actual, expected): assert_equal(actual.isel(i=0, j=-1), @@ -2093,6 +2370,17 @@ def test_cmp_local_file(self): assert_equal(actual.isel(j=slice(1, 2)), expected.isel(j=slice(1, 2))) + with self.create_datasets() as (actual, expected): + indexers = {'i': [1, 0, 0], 'j': [1, 2, 0, 1]} + assert_equal(actual.isel(**indexers), + expected.isel(**indexers)) + + with self.create_datasets() as (actual, expected): + indexers = {'i': DataArray([0, 1, 0], dims='a'), + 'j': DataArray([0, 2, 1], dims='a')} + assert_equal(actual.isel(**indexers), + expected.isel(**indexers)) + def test_compatible_to_netcdf(self): # make sure it can be saved as a netcdf with self.create_datasets() as (actual, expected): @@ -2111,7 +2399,7 @@ def test_dask(self): @network @requires_scipy_or_netCDF4 @requires_pydap -class PydapOnlineTest(PydapTest): +class TestPydapOnline(TestPydap): @contextlib.contextmanager def create_datasets(self, **kwargs): url = 'http://test.opendap.org/opendap/hyrax/data/nc/bears.nc' @@ -2132,32 +2420,18 @@ def test_session(self): @requires_scipy @requires_pynio -class TestPyNio(CFEncodedDataTest, NetCDF3Only, TestCase): +class TestPyNio(ScipyWriteBase): def test_write_store(self): # pynio is read-only for now pass - def test_orthogonal_indexing(self): - # pynio also does not support list-like indexing - with raises_regex(NotImplementedError, 'Outer indexing'): - super(TestPyNio, self).test_orthogonal_indexing() - - def test_isel_dataarray(self): - with raises_regex(NotImplementedError, 'Outer indexing'): - super(TestPyNio, self).test_isel_dataarray() - - def test_array_type_after_indexing(self): - # pynio also does not support list-like indexing - pass - @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine='pynio', autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine='pynio', **kwargs) as ds: yield ds def save(self, dataset, path, **kwargs): - dataset.to_netcdf(path, engine='scipy', **kwargs) + return dataset.to_netcdf(path, engine='scipy', **kwargs) def test_weakrefs(self): example = Dataset({'foo': ('x', np.arange(5.0))}) @@ -2171,20 +2445,250 @@ def test_weakrefs(self): assert_identical(actual, expected) -class TestPyNioAutocloseTrue(TestPyNio): - autoclose = True +@requires_cfgrib +class TestCfGrib(object): + + def test_read(self): + expected = {'number': 2, 'time': 3, 'isobaricInhPa': 2, 'latitude': 3, + 'longitude': 4} + with open_example_dataset('example.grib', engine='cfgrib') as ds: + assert ds.dims == expected + assert list(ds.data_vars) == ['z', 't'] + assert ds['z'].min() == 12660. + + def test_read_filter_by_keys(self): + kwargs = {'filter_by_keys': {'shortName': 't'}} + expected = {'number': 2, 'time': 3, 'isobaricInhPa': 2, 'latitude': 3, + 'longitude': 4} + with open_example_dataset('example.grib', engine='cfgrib', + backend_kwargs=kwargs) as ds: + assert ds.dims == expected + assert list(ds.data_vars) == ['t'] + assert ds['t'].min() == 231. + + +@requires_pseudonetcdf +@pytest.mark.filterwarnings('ignore:IOAPI_ISPH is assumed to be 6370000') +class TestPseudoNetCDFFormat(object): + + def open(self, path, **kwargs): + return open_dataset(path, engine='pseudonetcdf', **kwargs) + + @contextlib.contextmanager + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as path: + self.save(data, path, **save_kwargs) + with self.open(path, **open_kwargs) as ds: + yield ds + + def test_ict_format(self): + """ + Open a CAMx file and test data variables + """ + ictfile = open_example_dataset('example.ict', + engine='pseudonetcdf', + backend_kwargs={'format': 'ffi1001'}) + stdattr = { + 'fill_value': -9999.0, + 'missing_value': -9999, + 'scale': 1, + 'llod_flag': -8888, + 'llod_value': 'N/A', + 'ulod_flag': -7777, + 'ulod_value': 'N/A' + } + + def myatts(**attrs): + outattr = stdattr.copy() + outattr.update(attrs) + return outattr + + input = { + 'coords': {}, + 'attrs': { + 'fmt': '1001', 'n_header_lines': 27, + 'PI_NAME': 'Henderson, Barron', + 'ORGANIZATION_NAME': 'U.S. EPA', + 'SOURCE_DESCRIPTION': 'Example file with artificial data', + 'MISSION_NAME': 'JUST_A_TEST', + 'VOLUME_INFO': '1, 1', + 'SDATE': '2018, 04, 27', 'WDATE': '2018, 04, 27', + 'TIME_INTERVAL': '0', + 'INDEPENDENT_VARIABLE': 'Start_UTC', + 'ULOD_FLAG': '-7777', 'ULOD_VALUE': 'N/A', + 'LLOD_FLAG': '-8888', + 'LLOD_VALUE': ('N/A, N/A, N/A, N/A, 0.025'), + 'OTHER_COMMENTS': ('www-air.larc.nasa.gov/missions/etc/' + + 'IcarttDataFormat.htm'), + 'REVISION': 'R0', + 'R0': 'No comments for this revision.', + 'TFLAG': 'Start_UTC' + }, + 'dims': {'POINTS': 4}, + 'data_vars': { + 'Start_UTC': { + 'data': [43200.0, 46800.0, 50400.0, 50400.0], + 'dims': ('POINTS',), + 'attrs': myatts( + units='Start_UTC', + standard_name='Start_UTC', + ) + }, + 'lat': { + 'data': [41.0, 42.0, 42.0, 42.0], + 'dims': ('POINTS',), + 'attrs': myatts( + units='degrees_north', + standard_name='lat', + ) + }, + 'lon': { + 'data': [-71.0, -72.0, -73.0, -74.], + 'dims': ('POINTS',), + 'attrs': myatts( + units='degrees_east', + standard_name='lon', + ) + }, + 'elev': { + 'data': [5.0, 15.0, 20.0, 25.0], + 'dims': ('POINTS',), + 'attrs': myatts( + units='meters', + standard_name='elev', + ) + }, + 'TEST_ppbv': { + 'data': [1.2345, 2.3456, 3.4567, 4.5678], + 'dims': ('POINTS',), + 'attrs': myatts( + units='ppbv', + standard_name='TEST_ppbv', + ) + }, + 'TESTM_ppbv': { + 'data': [2.22, -9999.0, -7777.0, -8888.0], + 'dims': ('POINTS',), + 'attrs': myatts( + units='ppbv', + standard_name='TESTM_ppbv', + llod_value=0.025 + ) + } + } + } + chkfile = Dataset.from_dict(input) + assert_identical(ictfile, chkfile) + + def test_ict_format_write(self): + fmtkw = {'format': 'ffi1001'} + expected = open_example_dataset('example.ict', + engine='pseudonetcdf', + backend_kwargs=fmtkw) + with self.roundtrip(expected, save_kwargs=fmtkw, + open_kwargs={'backend_kwargs': fmtkw}) as actual: + assert_identical(expected, actual) + + def test_uamiv_format_read(self): + """ + Open a CAMx file and test data variables + """ + + camxfile = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + backend_kwargs={'format': 'uamiv'}) + data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) + expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, + dict(units='ppm', long_name='O3'.ljust(16), + var_desc='O3'.ljust(80))) + actual = camxfile.variables['O3'] + assert_allclose(expected, actual) + + data = np.array(['2002-06-03'], 'datetime64[ns]') + expected = xr.Variable(('TSTEP',), data, + dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes'))) + actual = camxfile.variables['time'] + assert_allclose(expected, actual) + camxfile.close() + + def test_uamiv_format_mfread(self): + """ + Open a CAMx file and test data variables + """ + + camxfile = open_example_mfdataset( + ['example.uamiv', + 'example.uamiv'], + engine='pseudonetcdf', + concat_dim='TSTEP', + backend_kwargs={'format': 'uamiv'}) + + data1 = np.arange(20, dtype='f').reshape(1, 1, 4, 5) + data = np.concatenate([data1] * 2, axis=0) + expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, + dict(units='ppm', long_name='O3'.ljust(16), + var_desc='O3'.ljust(80))) + actual = camxfile.variables['O3'] + assert_allclose(expected, actual) + + data1 = np.array(['2002-06-03'], 'datetime64[ns]') + data = np.concatenate([data1] * 2, axis=0) + attrs = dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes')) + expected = xr.Variable(('TSTEP',), data, attrs) + actual = camxfile.variables['time'] + assert_allclose(expected, actual) + camxfile.close() + + def test_uamiv_format_write(self): + fmtkw = {'format': 'uamiv'} + + expected = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + backend_kwargs=fmtkw) + with self.roundtrip(expected, + save_kwargs=fmtkw, + open_kwargs={'backend_kwargs': fmtkw}) as actual: + assert_identical(expected, actual) + + expected.close() + + def save(self, dataset, path, **save_kwargs): + import PseudoNetCDF as pnc + pncf = pnc.PseudoNetCDFFile() + pncf.dimensions = {k: pnc.PseudoNetCDFDimension(pncf, k, v) + for k, v in dataset.dims.items()} + pncf.variables = {k: pnc.PseudoNetCDFVariable(pncf, k, v.dtype.char, + v.dims, + values=v.data[...], + **v.attrs) + for k, v in dataset.variables.items()} + for pk, pv in dataset.attrs.items(): + setattr(pncf, pk, pv) + + pnc.pncwrite(pncf, path, **save_kwargs) @requires_rasterio @contextlib.contextmanager def create_tmp_geotiff(nx=4, ny=3, nz=3, + transform=None, transform_args=[5000, 80000, 1000, 2000.], crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84', - 'proj': 'utm', 'zone': 18}): + 'proj': 'utm', 'zone': 18}, + open_kwargs={}): # yields a temporary geotiff file and a corresponding expected DataArray import rasterio from rasterio.transform import from_origin - with create_tmp_file(suffix='.tif') as tmp_file: + with create_tmp_file(suffix='.tif', + allow_cleanup_failure=ON_WINDOWS) as tmp_file: # allow 2d or 3d shapes if nz == 1: data_shape = ny, nx @@ -2192,15 +2696,19 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3, else: data_shape = nz, ny, nx write_kwargs = {} - data = np.arange(nz*ny*nx, - dtype=rasterio.float32).reshape(*data_shape) - transform = from_origin(*transform_args) + data = np.arange( + nz * ny * nx, + dtype=rasterio.float32).reshape( + *data_shape) + if transform is None: + transform = from_origin(*transform_args) with rasterio.open( tmp_file, 'w', driver='GTiff', height=ny, width=nx, count=nz, crs=crs, transform=transform, - dtype=rasterio.float32) as s: + dtype=rasterio.float32, + **open_kwargs) as s: s.write(data, **write_kwargs) dx, dy = s.res[0], -s.res[1] @@ -2208,15 +2716,15 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3, data = data[np.newaxis, ...] if nz == 1 else data expected = DataArray(data, dims=('band', 'y', 'x'), coords={ - 'band': np.arange(nz)+1, - 'y': -np.arange(ny) * d + b + dy/2, - 'x': np.arange(nx) * c + a + dx/2, - }) + 'band': np.arange(nz) + 1, + 'y': -np.arange(ny) * d + b + dy / 2, + 'x': np.arange(nx) * c + a + dx / 2, + }) yield tmp_file, expected @requires_rasterio -class TestRasterio(TestCase): +class TestRasterio(object): @requires_scipy_or_netCDF4 def test_serialization(self): @@ -2236,6 +2744,9 @@ def test_utm(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 + np.testing.assert_array_equal(rioda.attrs['nodatavals'], + [np.NaN, np.NaN, np.NaN]) # Check no parse coords with xr.open_rasterio(tmp_file, parse_coordinates=False) as rioda: @@ -2243,23 +2754,10 @@ def test_utm(self): assert 'y' not in rioda.coords def test_non_rectilinear(self): - import rasterio from rasterio.transform import from_origin - # Create a geotiff file with 2d coordinates - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny, nz = 4, 3, 3 - data = np.arange(nx*ny*nz, - dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(0, 3, 1, 1).rotation(45) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=nz, - transform=transform, - dtype=rasterio.float32) as s: - s.write(data) - + with create_tmp_geotiff(transform=from_origin(0, 3, 1, 1).rotation(45), + crs=None) as (tmp_file, _): # Default is to not parse coords with xr.open_rasterio(tmp_file) as rioda: assert 'x' not in rioda.coords @@ -2268,9 +2766,11 @@ def test_non_rectilinear(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 # See if a warning is raised if we force it - with self.assertWarns("transformation isn't rectilinear"): + with pytest.warns(Warning, + match="transformation isn't rectilinear"): with xr.open_rasterio(tmp_file, parse_coordinates=True) as rioda: assert 'x' not in rioda.coords @@ -2278,7 +2778,8 @@ def test_non_rectilinear(self): def test_platecarree(self): with create_tmp_geotiff(8, 10, 1, transform_args=[1, 2, 0.5, 2.], - crs='+proj=latlong') \ + crs='+proj=latlong', + open_kwargs={'nodata': -9765}) \ as (tmp_file, expected): with xr.open_rasterio(tmp_file) as rioda: assert_allclose(rioda, expected) @@ -2286,6 +2787,9 @@ def test_platecarree(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 + np.testing.assert_array_equal(rioda.attrs['nodatavals'], + [-9765.]) def test_notransform(self): # regression test for https://github.com/pydata/xarray/issues/1686 @@ -2301,7 +2805,7 @@ def test_notransform(self): with create_tmp_file(suffix='.tif') as tmp_file: # data nx, ny, nz = 4, 3, 3 - data = np.arange(nx*ny*nz, + data = np.arange(nx * ny * nz, dtype=rasterio.float32).reshape(nz, ny, nx) with rasterio.open( tmp_file, 'w', @@ -2321,6 +2825,7 @@ def test_notransform(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 def test_indexing(self): with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.], @@ -2330,17 +2835,70 @@ def test_indexing(self): # tests # assert_allclose checks all data + coordinates assert_allclose(actual, expected) - - # Slicing - ex = expected.isel(x=slice(2, 5), y=slice(5, 7)) - ac = actual.isel(x=slice(2, 5), y=slice(5, 7)) - assert_allclose(ac, ex) - - ex = expected.isel(band=slice(1, 2), x=slice(2, 5), - y=slice(5, 7)) - ac = actual.isel(band=slice(1, 2), x=slice(2, 5), - y=slice(5, 7)) - assert_allclose(ac, ex) + assert not actual.variable._in_memory + + # Basic indexer + ind = {'x': slice(2, 5), 'y': slice(5, 7)} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + ind = {'band': slice(1, 2), 'x': slice(2, 5), 'y': slice(5, 7)} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + ind = {'band': slice(1, 2), 'x': slice(2, 5), 'y': 0} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + # orthogonal indexer + ind = {'band': np.array([2, 1, 0]), + 'x': np.array([1, 0]), 'y': np.array([0, 2])} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + ind = {'band': np.array([2, 1, 0]), + 'x': np.array([1, 0]), 'y': 0} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + ind = {'band': 0, 'x': np.array([0, 0]), 'y': np.array([1, 1, 1])} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + # minus-stepped slice + ind = {'band': np.array([2, 1, 0]), + 'x': slice(-1, None, -1), 'y': 0} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + ind = {'band': np.array([2, 1, 0]), + 'x': 1, 'y': slice(-1, 1, -2)} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + # empty selection + ind = {'band': np.array([2, 1, 0]), + 'x': 1, 'y': slice(2, 2, 1)} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + ind = {'band': slice(0, 0), 'x': 1, 'y': 2} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + # vectorized indexer + ind = {'band': DataArray([2, 1, 0], dims='a'), + 'x': DataArray([1, 0, 0], dims='a'), + 'y': np.array([0, 2])} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + + ind = { + 'band': DataArray([[2, 1, 0], [1, 0, 2]], dims=['a', 'b']), + 'x': DataArray([[1, 0, 0], [0, 1, 0]], dims=['a', 'b']), + 'y': 0} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory # Selecting lists of bands is fine ex = expected.isel(band=[1, 2]) @@ -2350,15 +2908,6 @@ def test_indexing(self): ac = actual.isel(band=[0, 2]) assert_allclose(ac, ex) - # but on x and y only windowed operations are allowed, more - # exotic slicing should raise an error - err_msg = 'not valid on rasterio' - with raises_regex(IndexError, err_msg): - actual.isel(x=[2, 4], y=[1, 3]).values - with raises_regex(IndexError, err_msg): - actual.isel(x=[4, 2]).values - with raises_regex(IndexError, err_msg): - actual.isel(x=slice(5, 2, -1)).values # Integer indexing ex = expected.isel(band=1) ac = actual.isel(band=1) @@ -2396,11 +2945,6 @@ def test_caching(self): # Cache is the default with xr.open_rasterio(tmp_file) as actual: - # Without cache an error is raised - err_msg = 'not valid on rasterio' - with raises_regex(IndexError, err_msg): - actual.isel(x=[2, 4]).values - # This should cache everything assert_allclose(actual, expected) @@ -2417,7 +2961,7 @@ def test_chunks(self): with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) assert 'open_rasterio' in actual.data.name # do some arithmetic @@ -2429,6 +2973,14 @@ def test_chunks(self): ex = expected.sel(band=1).mean(dim='x') assert_allclose(ac, ex) + def test_pickle_rasterio(self): + # regression test for https://github.com/pydata/xarray/issues/2121 + with create_tmp_geotiff() as (tmp_file, expected): + with xr.open_rasterio(tmp_file) as rioda: + temp = pickle.dumps(rioda) + with pickle.loads(temp) as actual: + assert_equal(actual, rioda) + def test_ENVI_tags(self): rasterio = pytest.importorskip('rasterio', minversion='1.0a') from rasterio.transform import from_origin @@ -2472,6 +3024,7 @@ def test_ENVI_tags(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 # from ENVI tags assert isinstance(rioda.attrs['description'], basestring) assert isinstance(rioda.attrs['map_info'], basestring) @@ -2489,7 +3042,7 @@ def test_no_mftime(self): with mock.patch('os.path.getmtime', side_effect=OSError): with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) assert_allclose(actual, expected) @network @@ -2502,10 +3055,10 @@ def test_http_url(self): # make sure chunking works with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) -class TestEncodingInvalid(TestCase): +class TestEncodingInvalid(object): def test_extract_nc4_variable_encoding(self): var = xr.Variable(('x',), [1, 2, 3], {}, {'foo': 'bar'}) @@ -2514,12 +3067,12 @@ def test_extract_nc4_variable_encoding(self): var = xr.Variable(('x',), [1, 2, 3], {}, {'chunking': (2, 1)}) encoding = _extract_nc4_variable_encoding(var) - self.assertEqual({}, encoding) + assert {} == encoding # regression test var = xr.Variable(('x',), [1, 2, 3], {}, {'shuffle': True}) encoding = _extract_nc4_variable_encoding(var, raise_on_invalid=True) - self.assertEqual({'shuffle': True}, encoding) + assert {'shuffle': True} == encoding def test_extract_h5nc_encoding(self): # not supported with h5netcdf (yet) @@ -2534,7 +3087,7 @@ class MiscObject: @requires_netCDF4 -class TestValidateAttrs(TestCase): +class TestValidateAttrs(object): def test_validating_attrs(self): def new_dataset(): return Dataset({'data': ('y', np.arange(10.0))}, @@ -2634,7 +3187,7 @@ def new_dataset_and_coord_attrs(): @requires_scipy_or_netCDF4 -class TestDataArrayToNetCDF(TestCase): +class TestDataArrayToNetCDF(object): def test_dataarray_to_netcdf_no_name(self): original_da = DataArray(np.arange(12).reshape((3, 4))) @@ -2693,3 +3246,12 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): with open_dataarray(tmp) as loaded_da: assert_identical(original_da, loaded_da) + + +@requires_scipy_or_netCDF4 +def test_no_warning_from_dask_effective_get(): + with create_tmp_file() as tmpfile: + with pytest.warns(None) as record: + ds = Dataset() + ds.to_netcdf(tmpfile) + assert len(record) == 0 diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py new file mode 100644 index 00000000000..ed49dd721d2 --- /dev/null +++ b/xarray/tests/test_backends_api.py @@ -0,0 +1,22 @@ + +import pytest + +from xarray.backends.api import _get_default_engine +from . import requires_netCDF4, requires_scipy + + +@requires_netCDF4 +@requires_scipy +def test__get_default_engine(): + engine_remote = _get_default_engine('http://example.org/test.nc', + allow_remote=True) + assert engine_remote == 'netcdf4' + + engine_gz = _get_default_engine('/example.gz') + assert engine_gz == 'scipy' + + with pytest.raises(ValueError): + _get_default_engine('/example.grib') + + engine_default = _get_default_engine('/example') + assert engine_default == 'netcdf4' diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py new file mode 100644 index 00000000000..591c981cd45 --- /dev/null +++ b/xarray/tests/test_backends_file_manager.py @@ -0,0 +1,114 @@ +import pickle +import threading +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.file_manager import CachingFileManager +from xarray.backends.lru_cache import LRUCache + + +@pytest.fixture(params=[1, 2, 3, None]) +def file_cache(request): + maxsize = request.param + if maxsize is None: + yield {} + else: + yield LRUCache(maxsize) + + +def test_file_manager_mock_write(file_cache): + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + lock = mock.MagicMock(spec=threading.Lock()) + + manager = CachingFileManager( + opener, 'filename', lock=lock, cache=file_cache) + f = manager.acquire() + f.write('contents') + manager.close() + + assert not file_cache + opener.assert_called_once_with('filename') + mock_file.write.assert_called_once_with('contents') + mock_file.close.assert_called_once_with() + lock.__enter__.assert_has_calls([mock.call(), mock.call()]) + + +def test_file_manager_write_consecutive(tmpdir, file_cache): + path1 = str(tmpdir.join('testing1.txt')) + path2 = str(tmpdir.join('testing2.txt')) + manager1 = CachingFileManager(open, path1, mode='w', cache=file_cache) + manager2 = CachingFileManager(open, path2, mode='w', cache=file_cache) + f1a = manager1.acquire() + f1a.write('foo') + f1a.flush() + f2 = manager2.acquire() + f2.write('bar') + f2.flush() + f1b = manager1.acquire() + f1b.write('baz') + assert (getattr(file_cache, 'maxsize', float('inf')) > 1) == (f1a is f1b) + manager1.close() + manager2.close() + + with open(path1, 'r') as f: + assert f.read() == 'foobaz' + with open(path2, 'r') as f: + assert f.read() == 'bar' + + +def test_file_manager_write_concurrent(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f1 = manager.acquire() + f2 = manager.acquire() + f3 = manager.acquire() + assert f1 is f2 + assert f2 is f3 + f1.write('foo') + f1.flush() + f2.write('bar') + f2.flush() + f3.write('baz') + f3.flush() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobarbaz' + + +def test_file_manager_write_pickle(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f = manager.acquire() + f.write('foo') + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write('bar') + manager2.close() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobar' + + +def test_file_manager_read(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + + with open(path, 'w') as f: + f.write('foobar') + + manager = CachingFileManager(open, path, cache=file_cache) + f = manager.acquire() + assert f.read() == 'foobar' + manager.close() + + +def test_file_manager_invalid_kwargs(): + with pytest.raises(TypeError): + CachingFileManager(open, 'dummy', mode='w', invalid=True) diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py new file mode 100644 index 00000000000..5f83321802e --- /dev/null +++ b/xarray/tests/test_backends_locks.py @@ -0,0 +1,13 @@ +import threading + +from xarray.backends import locks + + +def test_threaded_lock(): + lock1 = locks._get_threaded_lock('foo') + assert isinstance(lock1, type(threading.Lock())) + lock2 = locks._get_threaded_lock('foo') + assert lock1 is lock2 + + lock3 = locks._get_threaded_lock('bar') + assert lock1 is not lock3 diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py new file mode 100644 index 00000000000..03eb6dcf208 --- /dev/null +++ b/xarray/tests/test_backends_lru_cache.py @@ -0,0 +1,91 @@ +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.lru_cache import LRUCache + + +def test_simple(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + + assert cache['x'] == 1 + assert cache['y'] == 2 + assert len(cache) == 2 + assert dict(cache) == {'x': 1, 'y': 2} + assert list(cache.keys()) == ['x', 'y'] + assert list(cache.items()) == [('x', 1), ('y', 2)] + + cache['z'] = 3 + assert len(cache) == 2 + assert list(cache.items()) == [('y', 2), ('z', 3)] + + +def test_trivial(): + cache = LRUCache(maxsize=0) + cache['x'] = 1 + assert len(cache) == 0 + + +def test_invalid(): + with pytest.raises(TypeError): + LRUCache(maxsize=None) + with pytest.raises(ValueError): + LRUCache(maxsize=-1) + + +def test_update_priority(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + assert list(cache) == ['x', 'y'] + assert 'x' in cache # contains + assert list(cache) == ['y', 'x'] + assert cache['y'] == 2 # getitem + assert list(cache) == ['x', 'y'] + cache['x'] = 3 # setitem + assert list(cache.items()) == [('y', 2), ('x', 3)] + + +def test_del(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + del cache['x'] + assert dict(cache) == {'y': 2} + + +def test_on_evict(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=1, on_evict=on_evict) + cache['x'] = 1 + cache['y'] = 2 + on_evict.assert_called_once_with('x', 1) + + +def test_on_evict_trivial(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=0, on_evict=on_evict) + cache['x'] = 1 + on_evict.assert_called_once_with('x', 1) + + +def test_resize(): + cache = LRUCache(maxsize=2) + assert cache.maxsize == 2 + cache['w'] = 0 + cache['x'] = 1 + cache['y'] = 2 + assert list(cache.items()) == [('x', 1), ('y', 2)] + cache.maxsize = 10 + cache['z'] = 3 + assert list(cache.items()) == [('x', 1), ('y', 2), ('z', 3)] + cache.maxsize = 1 + assert list(cache.items()) == [('z', 3)] + + with pytest.raises(ValueError): + cache.maxsize = -1 diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py new file mode 100644 index 00000000000..7acd764cab3 --- /dev/null +++ b/xarray/tests/test_cftime_offsets.py @@ -0,0 +1,799 @@ +from itertools import product + +import numpy as np +import pytest + +from xarray import CFTimeIndex +from xarray.coding.cftime_offsets import ( + _MONTH_ABBREVIATIONS, BaseCFTimeOffset, Day, Hour, Minute, MonthBegin, + MonthEnd, Second, YearBegin, YearEnd, _days_in_month, cftime_range, + get_date_type, to_cftime_datetime, to_offset) + +cftime = pytest.importorskip('cftime') + + +_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', + '366_day', 'gregorian', 'proleptic_gregorian', 'standard'] + + +def _id_func(param): + """Called on each parameter passed to pytest.mark.parametrize""" + return str(param) + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def calendar(request): + return request.param + + +@pytest.mark.parametrize( + ('offset', 'expected_n'), + [(BaseCFTimeOffset(), 1), + (YearBegin(), 1), + (YearEnd(), 1), + (BaseCFTimeOffset(n=2), 2), + (YearBegin(n=2), 2), + (YearEnd(n=2), 2)], + ids=_id_func +) +def test_cftime_offset_constructor_valid_n(offset, expected_n): + assert offset.n == expected_n + + +@pytest.mark.parametrize( + ('offset', 'invalid_n'), + [(BaseCFTimeOffset, 1.5), + (YearBegin, 1.5), + (YearEnd, 1.5)], + ids=_id_func +) +def test_cftime_offset_constructor_invalid_n(offset, invalid_n): + with pytest.raises(TypeError): + offset(n=invalid_n) + + +@pytest.mark.parametrize( + ('offset', 'expected_month'), + [(YearBegin(), 1), + (YearEnd(), 12), + (YearBegin(month=5), 5), + (YearEnd(month=5), 5)], + ids=_id_func +) +def test_year_offset_constructor_valid_month(offset, expected_month): + assert offset.month == expected_month + + +@pytest.mark.parametrize( + ('offset', 'invalid_month', 'exception'), + [(YearBegin, 0, ValueError), + (YearEnd, 0, ValueError), + (YearBegin, 13, ValueError,), + (YearEnd, 13, ValueError), + (YearBegin, 1.5, TypeError), + (YearEnd, 1.5, TypeError)], + ids=_id_func +) +def test_year_offset_constructor_invalid_month( + offset, invalid_month, exception): + with pytest.raises(exception): + offset(month=invalid_month) + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), None), + (MonthBegin(), 'MS'), + (YearBegin(), 'AS-JAN')], + ids=_id_func +) +def test_rule_code(offset, expected): + assert offset.rule_code() == expected + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), ''), + (YearBegin(), '')], + ids=_id_func +) +def test_str_and_repr(offset, expected): + assert str(offset) == expected + assert repr(offset) == expected + + +@pytest.mark.parametrize( + 'offset', + [BaseCFTimeOffset(), MonthBegin(), YearBegin()], + ids=_id_func +) +def test_to_offset_offset_input(offset): + assert to_offset(offset) == offset + + +@pytest.mark.parametrize( + ('freq', 'expected'), + [('M', MonthEnd()), + ('2M', MonthEnd(n=2)), + ('MS', MonthBegin()), + ('2MS', MonthBegin(n=2)), + ('D', Day()), + ('2D', Day(n=2)), + ('H', Hour()), + ('2H', Hour(n=2)), + ('T', Minute()), + ('2T', Minute(n=2)), + ('min', Minute()), + ('2min', Minute(n=2)), + ('S', Second()), + ('2S', Second(n=2))], + ids=_id_func +) +def test_to_offset_sub_annual(freq, expected): + assert to_offset(freq) == expected + + +_ANNUAL_OFFSET_TYPES = { + 'A': YearEnd, + 'AS': YearBegin +} + + +@pytest.mark.parametrize(('month_int', 'month_label'), + list(_MONTH_ABBREVIATIONS.items()) + [('', '')]) +@pytest.mark.parametrize('multiple', [None, 2]) +@pytest.mark.parametrize('offset_str', ['AS', 'A']) +def test_to_offset_annual(month_label, month_int, multiple, offset_str): + freq = offset_str + offset_type = _ANNUAL_OFFSET_TYPES[offset_str] + if month_label: + freq = '-'.join([freq, month_label]) + if multiple: + freq = '{}'.format(multiple) + freq + result = to_offset(freq) + + if multiple and month_int: + expected = offset_type(n=multiple, month=month_int) + elif multiple: + expected = offset_type(n=multiple) + elif month_int: + expected = offset_type(month=month_int) + else: + expected = offset_type() + assert result == expected + + +@pytest.mark.parametrize('freq', ['Z', '7min2', 'AM', 'M-', 'AS-', '1H1min']) +def test_invalid_to_offset_str(freq): + with pytest.raises(ValueError): + to_offset(freq) + + +@pytest.mark.parametrize( + ('argument', 'expected_date_args'), + [('2000-01-01', (2000, 1, 1)), + ((2000, 1, 1), (2000, 1, 1))], + ids=_id_func +) +def test_to_cftime_datetime(calendar, argument, expected_date_args): + date_type = get_date_type(calendar) + expected = date_type(*expected_date_args) + if isinstance(argument, tuple): + argument = date_type(*argument) + result = to_cftime_datetime(argument, calendar=calendar) + assert result == expected + + +def test_to_cftime_datetime_error_no_calendar(): + with pytest.raises(ValueError): + to_cftime_datetime('2000') + + +def test_to_cftime_datetime_error_type_error(): + with pytest.raises(TypeError): + to_cftime_datetime(1) + + +_EQ_TESTS_A = [ + BaseCFTimeOffset(), YearBegin(), YearEnd(), YearBegin(month=2), + YearEnd(month=2), MonthBegin(), MonthEnd(), Day(), Hour(), Minute(), + Second() +] +_EQ_TESTS_B = [ + BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), + YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2), + MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2) +] + + +@pytest.mark.parametrize( + ('a', 'b'), product(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func +) +def test_neq(a, b): + assert a != b + + +_EQ_TESTS_B_COPY = [ + BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), + YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2), + MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2) +] + + +@pytest.mark.parametrize( + ('a', 'b'), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func +) +def test_eq(a, b): + assert a == b + + +_MUL_TESTS = [ + (BaseCFTimeOffset(), BaseCFTimeOffset(n=3)), + (YearEnd(), YearEnd(n=3)), + (YearBegin(), YearBegin(n=3)), + (MonthEnd(), MonthEnd(n=3)), + (MonthBegin(), MonthBegin(n=3)), + (Day(), Day(n=3)), + (Hour(), Hour(n=3)), + (Minute(), Minute(n=3)), + (Second(), Second(n=3)) +] + + +@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +def test_mul(offset, expected): + assert offset * 3 == expected + + +@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +def test_rmul(offset, expected): + assert 3 * offset == expected + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), BaseCFTimeOffset(n=-1)), + (YearEnd(), YearEnd(n=-1)), + (YearBegin(), YearBegin(n=-1)), + (MonthEnd(), MonthEnd(n=-1)), + (MonthBegin(), MonthBegin(n=-1)), + (Day(), Day(n=-1)), + (Hour(), Hour(n=-1)), + (Minute(), Minute(n=-1)), + (Second(), Second(n=-1))], + ids=_id_func) +def test_neg(offset, expected): + assert -offset == expected + + +_ADD_TESTS = [ + (Day(n=2), (1, 1, 3)), + (Hour(n=2), (1, 1, 1, 2)), + (Minute(n=2), (1, 1, 1, 0, 2)), + (Second(n=2), (1, 1, 1, 0, 0, 2)) +] + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + _ADD_TESTS, + ids=_id_func +) +def test_add_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = offset + initial + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + _ADD_TESTS, + ids=_id_func +) +def test_radd_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = initial + offset + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + [(Day(n=2), (1, 1, 1)), + (Hour(n=2), (1, 1, 2, 22)), + (Minute(n=2), (1, 1, 2, 23, 58)), + (Second(n=2), (1, 1, 2, 23, 59, 58))], + ids=_id_func +) +def test_rsub_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 3) + expected = date_type(*expected_date_args) + result = initial - offset + assert result == expected + + +@pytest.mark.parametrize('offset', _EQ_TESTS_A, ids=_id_func) +def test_sub_error(offset, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + with pytest.raises(TypeError): + offset - initial + + +@pytest.mark.parametrize( + ('a', 'b'), + zip(_EQ_TESTS_A, _EQ_TESTS_B), + ids=_id_func +) +def test_minus_offset(a, b): + result = b - a + expected = a + assert result == expected + + +@pytest.mark.parametrize( + ('a', 'b'), + list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) + + [(YearEnd(month=1), YearEnd(month=2))], + ids=_id_func +) +def test_minus_offset_error(a, b): + with pytest.raises(TypeError): + b - a + + +def test_days_in_month_non_december(calendar): + date_type = get_date_type(calendar) + reference = date_type(1, 4, 1) + assert _days_in_month(reference) == 30 + + +def test_days_in_month_december(calendar): + if calendar == '360_day': + expected = 30 + else: + expected = 31 + date_type = get_date_type(calendar) + reference = date_type(1, 12, 5) + assert _days_in_month(reference) == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_date_args'), + [((1, 1, 1), MonthBegin(), (1, 2, 1)), + ((1, 1, 1), MonthBegin(n=2), (1, 3, 1)), + ((1, 1, 7), MonthBegin(), (1, 2, 1)), + ((1, 1, 7), MonthBegin(n=2), (1, 3, 1)), + ((1, 3, 1), MonthBegin(n=-1), (1, 2, 1)), + ((1, 3, 1), MonthBegin(n=-2), (1, 1, 1)), + ((1, 3, 3), MonthBegin(n=-1), (1, 3, 1)), + ((1, 3, 3), MonthBegin(n=-2), (1, 2, 1)), + ((1, 2, 1), MonthBegin(n=14), (2, 4, 1)), + ((2, 4, 1), MonthBegin(n=-14), (1, 2, 1)), + ((1, 1, 1, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_begin( + calendar, initial_date_args, offset, expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1, 1), MonthEnd(), (1, 1), ()), + ((1, 1, 1), MonthEnd(n=2), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-1), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-2), (1, 1), ()), + ((1, 2, 1), MonthEnd(n=14), (2, 3), ()), + ((2, 4, 1), MonthEnd(n=-14), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), MonthEnd(), (1, 1), (5, 5, 5, 5)), + ((1, 2, 1, 5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_end( + calendar, initial_date_args, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1), (), MonthEnd(), (1, 2), ()), + ((1, 1), (), MonthEnd(n=2), (1, 3), ()), + ((1, 3), (), MonthEnd(n=-1), (1, 2), ()), + ((1, 3), (), MonthEnd(n=-2), (1, 1), ()), + ((1, 2), (), MonthEnd(n=14), (2, 4), ()), + ((2, 4), (), MonthEnd(n=-14), (1, 2), ()), + ((1, 1), (5, 5, 5, 5), MonthEnd(), (1, 2), (5, 5, 5, 5)), + ((1, 2), (5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_end_onOffset( + calendar, initial_year_month, initial_sub_day, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = (initial_year_month + (_days_in_month(reference),) + + initial_sub_day) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_date_args'), + [((1, 1, 1), YearBegin(), (2, 1, 1)), + ((1, 1, 1), YearBegin(n=2), (3, 1, 1)), + ((1, 1, 1), YearBegin(month=2), (1, 2, 1)), + ((1, 1, 7), YearBegin(n=2), (3, 1, 1)), + ((2, 2, 1), YearBegin(n=-1), (2, 1, 1)), + ((1, 1, 2), YearBegin(n=-1), (1, 1, 1)), + ((1, 1, 1, 5, 5, 5, 5), YearBegin(), (2, 1, 1, 5, 5, 5, 5)), + ((2, 1, 1, 5, 5, 5, 5), YearBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_begin(calendar, initial_date_args, offset, + expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1, 1), YearEnd(), (1, 12), ()), + ((1, 1, 1), YearEnd(n=2), (2, 12), ()), + ((1, 1, 1), YearEnd(month=1), (1, 1), ()), + ((2, 3, 1), YearEnd(n=-1), (1, 12), ()), + ((1, 3, 1), YearEnd(n=-1, month=2), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(), (1, 12), (5, 5, 5, 5)), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(n=2), (2, 12), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_end( + calendar, initial_date_args, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 12), (), YearEnd(), (2, 12), ()), + ((1, 12), (), YearEnd(n=2), (3, 12), ()), + ((2, 12), (), YearEnd(n=-1), (1, 12), ()), + ((3, 12), (), YearEnd(n=-2), (1, 12), ()), + ((1, 1), (), YearEnd(month=2), (1, 2), ()), + ((1, 12), (5, 5, 5, 5), YearEnd(), (2, 12), (5, 5, 5, 5)), + ((2, 12), (5, 5, 5, 5), YearEnd(n=-1), (1, 12), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_end_onOffset( + calendar, initial_year_month, initial_sub_day, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = (initial_year_month + (_days_in_month(reference),) + + initial_sub_day) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +# Note for all sub-monthly offsets, pandas always returns True for onOffset +@pytest.mark.parametrize( + ('date_args', 'offset', 'expected'), + [((1, 1, 1), MonthBegin(), True), + ((1, 1, 1, 1), MonthBegin(), True), + ((1, 1, 5), MonthBegin(), False), + ((1, 1, 5), MonthEnd(), False), + ((1, 1, 1), YearBegin(), True), + ((1, 1, 1, 1), YearBegin(), True), + ((1, 1, 5), YearBegin(), False), + ((1, 12, 1), YearEnd(), False), + ((1, 1, 1), Day(), True), + ((1, 1, 1, 1), Day(), True), + ((1, 1, 1), Hour(), True), + ((1, 1, 1), Minute(), True), + ((1, 1, 1), Second(), True)], + ids=_id_func +) +def test_onOffset(calendar, date_args, offset, expected): + date_type = get_date_type(calendar) + date = date_type(*date_args) + result = offset.onOffset(date) + assert result == expected + + +@pytest.mark.parametrize( + ('year_month_args', 'sub_day_args', 'offset'), + [((1, 1), (), MonthEnd()), + ((1, 1), (1,), MonthEnd()), + ((1, 12), (), YearEnd()), + ((1, 1), (), YearEnd(month=1))], + ids=_id_func +) +def test_onOffset_month_or_year_end( + calendar, year_month_args, sub_day_args, offset): + date_type = get_date_type(calendar) + reference_args = year_month_args + (1,) + reference = date_type(*reference_args) + date_args = year_month_args + (_days_in_month(reference),) + sub_day_args + date = date_type(*date_args) + result = offset.onOffset(date) + assert result + + +@pytest.mark.parametrize( + ('offset', 'initial_date_args', 'partial_expected_date_args'), + [(YearBegin(), (1, 3, 1), (2, 1)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (2, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(), (1, 3, 1), (1, 12)), + (YearEnd(n=2), (1, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 4)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 4)), + (MonthEnd(), (1, 3, 2), (1, 3)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (MonthEnd(n=2), (1, 3, 2), (1, 3)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], + ids=_id_func +) +def test_rollforward(calendar, offset, initial_date_args, + partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = (partial_expected_date_args + + (_days_in_month(reference),)) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollforward(initial) + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'initial_date_args', 'partial_expected_date_args'), + [(YearBegin(), (1, 3, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (1, 2)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 2, 1), (1, 2)), + (YearEnd(), (2, 3, 1), (1, 12)), + (YearEnd(n=2), (2, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (2, 3, 1), (2, 2)), + (YearEnd(month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 3)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthEnd(), (1, 3, 2), (1, 2)), + (MonthEnd(n=2), (1, 3, 2), (1, 2)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], + ids=_id_func +) +def test_rollback(calendar, offset, initial_date_args, + partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = (partial_expected_date_args + + (_days_in_month(reference),)) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollback(initial) + assert result == expected + + +_CFTIME_RANGE_TESTS = [ + ('0001-01-01', '0001-01-04', None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01', '0001-01-04', None, 'D', 'left', False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3)]), + ('0001-01-01', '0001-01-04', None, 'D', 'right', False, + [(1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, False, + [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)]), + ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, True, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01', None, 4, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + (None, '0001-01-04', 4, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ((1, 1, 1), '0001-01-04', None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ((1, 1, 1), (1, 1, 4), None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-30', '0011-02-01', None, '3AS-JUN', None, False, + [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)]), + ('0001-01-04', '0001-01-01', None, 'D', None, False, + []), + ('0010', None, 4, YearBegin(n=-2), None, False, + [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)]), + ('0001-01-01', '0001-01-04', 4, None, None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]) +] + + +@pytest.mark.parametrize( + ('start', 'end', 'periods', 'freq', 'closed', 'normalize', + 'expected_date_args'), + _CFTIME_RANGE_TESTS, ids=_id_func +) +def test_cftime_range( + start, end, periods, freq, closed, normalize, calendar, + expected_date_args): + date_type = get_date_type(calendar) + expected_dates = [date_type(*args) for args in expected_date_args] + + if isinstance(start, tuple): + start = date_type(*start) + if isinstance(end, tuple): + end = date_type(*end) + + result = cftime_range( + start=start, end=end, periods=periods, freq=freq, closed=closed, + normalize=normalize, calendar=calendar) + resulting_dates = result.values + + assert isinstance(result, CFTimeIndex) + + if freq is not None: + np.testing.assert_equal(resulting_dates, expected_dates) + else: + # If we create a linear range of dates using cftime.num2date + # we will not get exact round number dates. This is because + # datetime arithmetic in cftime is accurate approximately to + # 1 millisecond (see https://unidata.github.io/cftime/api.html). + deltas = resulting_dates - expected_dates + deltas = np.array([delta.total_seconds() for delta in deltas]) + assert np.max(np.abs(deltas)) < 0.001 + + +def test_cftime_range_name(): + result = cftime_range(start='2000', periods=4, name='foo') + assert result.name == 'foo' + + result = cftime_range(start='2000', periods=4) + assert result.name is None + + +@pytest.mark.parametrize( + ('start', 'end', 'periods', 'freq', 'closed'), + [(None, None, 5, 'A', None), + ('2000', None, None, 'A', None), + (None, '2000', None, 'A', None), + ('2000', '2001', None, None, None), + (None, None, None, None, None), + ('2000', '2001', None, 'A', 'up'), + ('2000', '2001', 5, 'A', None)] +) +def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): + with pytest.raises(ValueError): + cftime_range(start, end, periods, freq, closed=closed) + + +_CALENDAR_SPECIFIC_MONTH_END_TESTS = [ + ('2M', 'noleap', + [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'all_leap', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', '360_day', + [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), + ('2M', 'standard', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'gregorian', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'julian', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]) +] + + +@pytest.mark.parametrize( + ('freq', 'calendar', 'expected_month_day'), + _CALENDAR_SPECIFIC_MONTH_END_TESTS, ids=_id_func +) +def test_calendar_specific_month_end(freq, calendar, expected_month_day): + year = 2000 # Use a leap-year to highlight calendar differences + result = cftime_range( + start='2000-02', end='2001', freq=freq, calendar=calendar).values + date_type = get_date_type(calendar) + expected = [date_type(year, *args) for args in expected_month_day] + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize( + ('calendar', 'start', 'end', 'expected_number_of_days'), + [('noleap', '2000', '2001', 365), + ('all_leap', '2000', '2001', 366), + ('360_day', '2000', '2001', 360), + ('standard', '2000', '2001', 366), + ('gregorian', '2000', '2001', 366), + ('julian', '2000', '2001', 366), + ('noleap', '2001', '2002', 365), + ('all_leap', '2001', '2002', 366), + ('360_day', '2001', '2002', 360), + ('standard', '2001', '2002', 365), + ('gregorian', '2001', '2002', 365), + ('julian', '2001', '2002', 365)] +) +def test_calendar_year_length( + calendar, start, end, expected_number_of_days): + result = cftime_range(start, end, freq='D', closed='left', + calendar=calendar) + assert len(result) == expected_number_of_days diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py new file mode 100644 index 00000000000..5e710827ff8 --- /dev/null +++ b/xarray/tests/test_cftimeindex.py @@ -0,0 +1,781 @@ +from __future__ import absolute_import + +from datetime import timedelta + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.coding.cftimeindex import ( + CFTimeIndex, _parse_array_of_cftime_strings, _parse_iso8601_with_reso, + _parsed_string_to_bounds, assert_all_valid_date_type, parse_iso8601) +from xarray.tests import assert_array_equal, assert_identical + +from . import has_cftime, has_cftime_or_netCDF4, requires_cftime +from .test_coding_times import (_all_cftime_date_types, _ALL_CALENDARS, + _NON_STANDARD_CALENDARS) + + +def date_dict(year=None, month=None, day=None, + hour=None, minute=None, second=None): + return dict(year=year, month=month, day=day, hour=hour, + minute=minute, second=second) + + +ISO8601_STRING_TESTS = { + 'year': ('1999', date_dict(year='1999')), + 'month': ('199901', date_dict(year='1999', month='01')), + 'month-dash': ('1999-01', date_dict(year='1999', month='01')), + 'day': ('19990101', date_dict(year='1999', month='01', day='01')), + 'day-dash': ('1999-01-01', date_dict(year='1999', month='01', day='01')), + 'hour': ('19990101T12', date_dict( + year='1999', month='01', day='01', hour='12')), + 'hour-dash': ('1999-01-01T12', date_dict( + year='1999', month='01', day='01', hour='12')), + 'minute': ('19990101T1234', date_dict( + year='1999', month='01', day='01', hour='12', minute='34')), + 'minute-dash': ('1999-01-01T12:34', date_dict( + year='1999', month='01', day='01', hour='12', minute='34')), + 'second': ('19990101T123456', date_dict( + year='1999', month='01', day='01', hour='12', minute='34', + second='56')), + 'second-dash': ('1999-01-01T12:34:56', date_dict( + year='1999', month='01', day='01', hour='12', minute='34', + second='56')) +} + + +@pytest.mark.parametrize(('string', 'expected'), + list(ISO8601_STRING_TESTS.values()), + ids=list(ISO8601_STRING_TESTS.keys())) +def test_parse_iso8601(string, expected): + result = parse_iso8601(string) + assert result == expected + + with pytest.raises(ValueError): + parse_iso8601(string + '3') + parse_iso8601(string + '.3') + + +_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', + '366_day', 'gregorian', 'proleptic_gregorian'] + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def date_type(request): + return _all_cftime_date_types()[request.param] + + +@pytest.fixture +def index(date_type): + dates = [date_type(1, 1, 1), date_type(1, 2, 1), + date_type(2, 1, 1), date_type(2, 2, 1)] + return CFTimeIndex(dates) + + +@pytest.fixture +def monotonic_decreasing_index(date_type): + dates = [date_type(2, 2, 1), date_type(2, 1, 1), + date_type(1, 2, 1), date_type(1, 1, 1)] + return CFTimeIndex(dates) + + +@pytest.fixture +def length_one_index(date_type): + dates = [date_type(1, 1, 1)] + return CFTimeIndex(dates) + + +@pytest.fixture +def da(index): + return xr.DataArray([1, 2, 3, 4], coords=[index], + dims=['time']) + + +@pytest.fixture +def series(index): + return pd.Series([1, 2, 3, 4], index=index) + + +@pytest.fixture +def df(index): + return pd.DataFrame([1, 2, 3, 4], index=index) + + +@pytest.fixture +def feb_days(date_type): + import cftime + if date_type is cftime.DatetimeAllLeap: + return 29 + elif date_type is cftime.Datetime360Day: + return 30 + else: + return 28 + + +@pytest.fixture +def dec_days(date_type): + import cftime + if date_type is cftime.Datetime360Day: + return 30 + else: + return 31 + + +@pytest.fixture +def index_with_name(date_type): + dates = [date_type(1, 1, 1), date_type(1, 2, 1), + date_type(2, 1, 1), date_type(2, 2, 1)] + return CFTimeIndex(dates, name='foo') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize( + ('name', 'expected_name'), + [('bar', 'bar'), + (None, 'foo')]) +def test_constructor_with_name(index_with_name, name, expected_name): + result = CFTimeIndex(index_with_name, name=name).name + assert result == expected_name + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_assert_all_valid_date_type(date_type, index): + import cftime + if date_type is cftime.DatetimeNoLeap: + mixed_date_types = np.array( + [date_type(1, 1, 1), + cftime.DatetimeAllLeap(1, 2, 1)]) + else: + mixed_date_types = np.array( + [date_type(1, 1, 1), + cftime.DatetimeNoLeap(1, 2, 1)]) + with pytest.raises(TypeError): + assert_all_valid_date_type(mixed_date_types) + + with pytest.raises(TypeError): + assert_all_valid_date_type(np.array([1, date_type(1, 1, 1)])) + + assert_all_valid_date_type( + np.array([date_type(1, 1, 1), date_type(1, 2, 1)])) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize(('field', 'expected'), [ + ('year', [1, 1, 2, 2]), + ('month', [1, 2, 1, 2]), + ('day', [1, 1, 1, 1]), + ('hour', [0, 0, 0, 0]), + ('minute', [0, 0, 0, 0]), + ('second', [0, 0, 0, 0]), + ('microsecond', [0, 0, 0, 0])]) +def test_cftimeindex_field_accessors(index, field, expected): + result = getattr(index, field) + assert_array_equal(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize(('string', 'date_args', 'reso'), [ + ('1999', (1999, 1, 1), 'year'), + ('199902', (1999, 2, 1), 'month'), + ('19990202', (1999, 2, 2), 'day'), + ('19990202T01', (1999, 2, 2, 1), 'hour'), + ('19990202T0101', (1999, 2, 2, 1, 1), 'minute'), + ('19990202T010156', (1999, 2, 2, 1, 1, 56), 'second')]) +def test_parse_iso8601_with_reso(date_type, string, date_args, reso): + expected_date = date_type(*date_args) + expected_reso = reso + result_date, result_reso = _parse_iso8601_with_reso(date_type, string) + assert result_date == expected_date + assert result_reso == expected_reso + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_parse_string_to_bounds_year(date_type, dec_days): + parsed = date_type(2, 2, 10, 6, 2, 8, 1) + expected_start = date_type(2, 1, 1) + expected_end = date_type(2, 12, dec_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds( + date_type, 'year', parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_parse_string_to_bounds_month_feb(date_type, feb_days): + parsed = date_type(2, 2, 10, 6, 2, 8, 1) + expected_start = date_type(2, 2, 1) + expected_end = date_type(2, 2, feb_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds( + date_type, 'month', parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_parse_string_to_bounds_month_dec(date_type, dec_days): + parsed = date_type(2, 12, 1) + expected_start = date_type(2, 12, 1) + expected_end = date_type(2, 12, dec_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds( + date_type, 'month', parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize(('reso', 'ex_start_args', 'ex_end_args'), [ + ('day', (2, 2, 10), (2, 2, 10, 23, 59, 59, 999999)), + ('hour', (2, 2, 10, 6), (2, 2, 10, 6, 59, 59, 999999)), + ('minute', (2, 2, 10, 6, 2), (2, 2, 10, 6, 2, 59, 999999)), + ('second', (2, 2, 10, 6, 2, 8), (2, 2, 10, 6, 2, 8, 999999))]) +def test_parsed_string_to_bounds_sub_monthly(date_type, reso, + ex_start_args, ex_end_args): + parsed = date_type(2, 2, 10, 6, 2, 8, 123456) + expected_start = date_type(*ex_start_args) + expected_end = date_type(*ex_end_args) + + result_start, result_end = _parsed_string_to_bounds( + date_type, reso, parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_parsed_string_to_bounds_raises(date_type): + with pytest.raises(KeyError): + _parsed_string_to_bounds(date_type, 'a', date_type(1, 1, 1)) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_get_loc(date_type, index): + result = index.get_loc('0001') + expected = [0, 1] + assert_array_equal(result, expected) + + result = index.get_loc(date_type(1, 2, 1)) + expected = 1 + assert result == expected + + result = index.get_loc('0001-02-01') + expected = 1 + assert result == expected + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('kind', ['loc', 'getitem']) +def test_get_slice_bound(date_type, index, kind): + result = index.get_slice_bound('0001', 'left', kind) + expected = 0 + assert result == expected + + result = index.get_slice_bound('0001', 'right', kind) + expected = 2 + assert result == expected + + result = index.get_slice_bound( + date_type(1, 3, 1), 'left', kind) + expected = 2 + assert result == expected + + result = index.get_slice_bound( + date_type(1, 3, 1), 'right', kind) + expected = 2 + assert result == expected + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('kind', ['loc', 'getitem']) +def test_get_slice_bound_decreasing_index( + date_type, monotonic_decreasing_index, kind): + result = monotonic_decreasing_index.get_slice_bound('0001', 'left', kind) + expected = 2 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound('0001', 'right', kind) + expected = 4 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound( + date_type(1, 3, 1), 'left', kind) + expected = 2 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound( + date_type(1, 3, 1), 'right', kind) + expected = 2 + assert result == expected + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('kind', ['loc', 'getitem']) +def test_get_slice_bound_length_one_index( + date_type, length_one_index, kind): + result = length_one_index.get_slice_bound('0001', 'left', kind) + expected = 0 + assert result == expected + + result = length_one_index.get_slice_bound('0001', 'right', kind) + expected = 1 + assert result == expected + + result = length_one_index.get_slice_bound( + date_type(1, 3, 1), 'left', kind) + expected = 1 + assert result == expected + + result = length_one_index.get_slice_bound( + date_type(1, 3, 1), 'right', kind) + expected = 1 + assert result == expected + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_string_slice_length_one_index(length_one_index): + da = xr.DataArray([1], coords=[length_one_index], dims=['time']) + result = da.sel(time=slice('0001', '0001')) + assert_identical(result, da) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_date_type_property(date_type, index): + assert index.date_type is date_type + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains(date_type, index): + assert '0001-01-01' in index + assert '0001' in index + assert '0003' not in index + assert date_type(1, 1, 1) in index + assert date_type(3, 1, 1) not in index + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_groupby(da): + result = da.groupby('time.month').sum('time') + expected = xr.DataArray([4, 6], coords=[[1, 2]], dims=['month']) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_resample_error(da): + with pytest.raises(NotImplementedError, match='to_datetimeindex'): + da.resample(time='Y') + + +SEL_STRING_OR_LIST_TESTS = { + 'string': '0001', + 'string-slice': slice('0001-01-01', '0001-12-30'), + 'bool-list': [True, True, False, False] +} + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_arg', list(SEL_STRING_OR_LIST_TESTS.values()), + ids=list(SEL_STRING_OR_LIST_TESTS.keys())) +def test_sel_string_or_list(da, index, sel_arg): + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + result = da.sel(time=sel_arg) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_sel_date_slice_or_list(da, index, date_type): + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + result = da.sel(time=slice(date_type(1, 1, 1), date_type(1, 12, 30))) + assert_identical(result, expected) + + result = da.sel(time=[date_type(1, 1, 1), date_type(1, 2, 1)]) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_sel_date_scalar(da, date_type, index): + expected = xr.DataArray(1).assign_coords(time=index[0]) + result = da.sel(time=date_type(1, 1, 1)) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'nearest'}, + {'method': 'nearest', 'tolerance': timedelta(days=70)} +]) +def test_sel_date_scalar_nearest(da, date_type, index, sel_kwargs): + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'pad'}, + {'method': 'pad', 'tolerance': timedelta(days=365)} +]) +def test_sel_date_scalar_pad(da, date_type, index, sel_kwargs): + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'backfill'}, + {'method': 'backfill', 'tolerance': timedelta(days=365)} +]) +def test_sel_date_scalar_backfill(da, date_type, index, sel_kwargs): + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'pad', 'tolerance': timedelta(days=20)}, + {'method': 'backfill', 'tolerance': timedelta(days=20)}, + {'method': 'nearest', 'tolerance': timedelta(days=20)}, +]) +def test_sel_date_scalar_tolerance_raises(da, date_type, sel_kwargs): + with pytest.raises(KeyError): + da.sel(time=date_type(1, 5, 1), **sel_kwargs) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'nearest'}, + {'method': 'nearest', 'tolerance': timedelta(days=70)} +]) +def test_sel_date_list_nearest(da, date_type, index, sel_kwargs): + expected = xr.DataArray( + [2, 2], coords=[[index[1], index[1]]], dims=['time']) + result = da.sel( + time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray( + [2, 3], coords=[[index[1], index[2]]], dims=['time']) + result = da.sel( + time=[date_type(1, 3, 1), date_type(1, 12, 1)], **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray( + [3, 3], coords=[[index[2], index[2]]], dims=['time']) + result = da.sel( + time=[date_type(1, 11, 1), date_type(1, 12, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'pad'}, + {'method': 'pad', 'tolerance': timedelta(days=365)} +]) +def test_sel_date_list_pad(da, date_type, index, sel_kwargs): + expected = xr.DataArray( + [2, 2], coords=[[index[1], index[1]]], dims=['time']) + result = da.sel( + time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'backfill'}, + {'method': 'backfill', 'tolerance': timedelta(days=365)} +]) +def test_sel_date_list_backfill(da, date_type, index, sel_kwargs): + expected = xr.DataArray( + [3, 3], coords=[[index[2], index[2]]], dims=['time']) + result = da.sel( + time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'pad', 'tolerance': timedelta(days=20)}, + {'method': 'backfill', 'tolerance': timedelta(days=20)}, + {'method': 'nearest', 'tolerance': timedelta(days=20)}, +]) +def test_sel_date_list_tolerance_raises(da, date_type, sel_kwargs): + with pytest.raises(KeyError): + da.sel(time=[date_type(1, 2, 1), date_type(1, 5, 1)], **sel_kwargs) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_isel(da, index): + expected = xr.DataArray(1).assign_coords(time=index[0]) + result = da.isel(time=0) + assert_identical(result, expected) + + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + result = da.isel(time=[0, 1]) + assert_identical(result, expected) + + +@pytest.fixture +def scalar_args(date_type): + return [date_type(1, 1, 1)] + + +@pytest.fixture +def range_args(date_type): + return ['0001', slice('0001-01-01', '0001-12-30'), + slice(None, '0001-12-30'), + slice(date_type(1, 1, 1), date_type(1, 12, 30)), + slice(None, date_type(1, 12, 30))] + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_series_getitem(series, index, scalar_args, range_args): + for arg in scalar_args: + assert series[arg] == 1 + + expected = pd.Series([1, 2], index=index[:2]) + for arg in range_args: + assert series[arg].equals(expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_series_loc(series, index, scalar_args, range_args): + for arg in scalar_args: + assert series.loc[arg] == 1 + + expected = pd.Series([1, 2], index=index[:2]) + for arg in range_args: + assert series.loc[arg].equals(expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_series_iloc(series, index): + expected = 1 + assert series.iloc[0] == expected + + expected = pd.Series([1, 2], index=index[:2]) + assert series.iloc[:2].equals(expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_dataframe_loc(df, index, scalar_args, range_args): + expected = pd.Series([1], name=index[0]) + for arg in scalar_args: + result = df.loc[arg] + assert result.equals(expected) + + expected = pd.DataFrame([1, 2], index=index[:2]) + for arg in range_args: + result = df.loc[arg] + assert result.equals(expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_dataframe_iloc(df, index): + expected = pd.Series([1], name=index[0]) + result = df.iloc[0] + assert result.equals(expected) + assert result.equals(expected) + + expected = pd.DataFrame([1, 2], index=index[:2]) + result = df.iloc[:2] + assert result.equals(expected) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +def test_concat_cftimeindex(date_type): + da1 = xr.DataArray( + [1., 2.], coords=[[date_type(1, 1, 1), date_type(1, 2, 1)]], + dims=['time']) + da2 = xr.DataArray( + [3., 4.], coords=[[date_type(1, 3, 1), date_type(1, 4, 1)]], + dims=['time']) + da = xr.concat([da1, da2], dim='time') + + if has_cftime: + assert isinstance(da.indexes['time'], CFTimeIndex) + else: + assert isinstance(da.indexes['time'], pd.Index) + assert not isinstance(da.indexes['time'], CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_empty_cftimeindex(): + index = CFTimeIndex([]) + assert index.date_type is None + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_add(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +def test_cftimeindex_add_timedeltaindex(calendar): + a = xr.cftime_range('2000', periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = a + deltas + expected = a.shift(2, 'D') + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_radd(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = timedelta(days=1) + index + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +def test_timedeltaindex_add_cftimeindex(calendar): + a = xr.cftime_range('2000', periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = deltas + a + expected = a.shift(2, 'D') + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_sub(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=2) + result = result - timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +def test_cftimeindex_sub_cftimeindex(calendar): + a = xr.cftime_range('2000', periods=5, calendar=calendar) + b = a.shift(2, 'D') + result = b - a + expected = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + assert result.equals(expected) + assert isinstance(result, pd.TimedeltaIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +def test_cftimeindex_sub_timedeltaindex(calendar): + a = xr.cftime_range('2000', periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = a - deltas + expected = a.shift(-2, 'D') + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_rsub(index): + with pytest.raises(TypeError): + timedelta(days=1) - index + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('freq', ['D', timedelta(days=1)]) +def test_cftimeindex_shift(index, freq): + date_type = index.date_type + expected_dates = [date_type(1, 1, 3), date_type(1, 2, 3), + date_type(2, 1, 3), date_type(2, 2, 3)] + expected = CFTimeIndex(expected_dates) + result = index.shift(2, freq) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_shift_invalid_n(): + index = xr.cftime_range('2000', periods=3) + with pytest.raises(TypeError): + index.shift('a', 'D') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_shift_invalid_freq(): + index = xr.cftime_range('2000', periods=3) + with pytest.raises(TypeError): + index.shift(1, 1) + + +@requires_cftime +def test_parse_array_of_cftime_strings(): + from cftime import DatetimeNoLeap + + strings = np.array([['2000-01-01', '2000-01-02'], + ['2000-01-03', '2000-01-04']]) + expected = np.array( + [[DatetimeNoLeap(2000, 1, 1), DatetimeNoLeap(2000, 1, 2)], + [DatetimeNoLeap(2000, 1, 3), DatetimeNoLeap(2000, 1, 4)]]) + + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) + + # Test scalar array case + strings = np.array('2000-01-01') + expected = np.array(DatetimeNoLeap(2000, 1, 1)) + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +@pytest.mark.parametrize('unsafe', [False, True]) +def test_to_datetimeindex(calendar, unsafe): + index = xr.cftime_range('2000', periods=5, calendar=calendar) + expected = pd.date_range('2000', periods=5) + + if calendar in _NON_STANDARD_CALENDARS and not unsafe: + with pytest.warns(RuntimeWarning, match='non-standard'): + result = index.to_datetimeindex() + else: + result = index.to_datetimeindex() + + assert result.equals(expected) + np.testing.assert_array_equal(result, expected) + assert isinstance(result, pd.DatetimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +def test_to_datetimeindex_out_of_range(calendar): + index = xr.cftime_range('0001', periods=5, calendar=calendar) + with pytest.raises(ValueError, match='0001'): + index.to_datetimeindex() + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', ['all_leap', '360_day']) +def test_to_datetimeindex_feb_29(calendar): + index = xr.cftime_range('2001-02-28', periods=2, calendar=calendar) + with pytest.raises(ValueError, match='29'): + index.to_datetimeindex() diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index a6faea8749b..6300a1957f8 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -1,12 +1,11 @@ import numpy as np - import pytest import xarray as xr -from xarray.core.pycompat import suppress from xarray.coding import variables +from xarray.core.pycompat import suppress -from . import requires_dask, assert_identical +from . import assert_identical, requires_dask with suppress(ImportError): import dask.array as da diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py new file mode 100644 index 00000000000..ca138ca8362 --- /dev/null +++ b/xarray/tests/test_coding_strings.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, division, print_function + +import numpy as np +import pytest + +from xarray import Variable +from xarray.coding import strings +from xarray.core import indexing +from xarray.core.pycompat import bytes_type, suppress, unicode_type + +from . import ( + IndexerMaker, assert_array_equal, assert_identical, raises_regex, + requires_dask) + +with suppress(ImportError): + import dask.array as da + + +def test_vlen_dtype(): + dtype = strings.create_vlen_dtype(unicode_type) + assert dtype.metadata['element_type'] == unicode_type + assert strings.is_unicode_dtype(dtype) + assert not strings.is_bytes_dtype(dtype) + assert strings.check_vlen_dtype(dtype) is unicode_type + + dtype = strings.create_vlen_dtype(bytes_type) + assert dtype.metadata['element_type'] == bytes_type + assert not strings.is_unicode_dtype(dtype) + assert strings.is_bytes_dtype(dtype) + assert strings.check_vlen_dtype(dtype) is bytes_type + + assert strings.check_vlen_dtype(np.dtype(object)) is None + + +def test_EncodedStringCoder_decode(): + coder = strings.EncodedStringCoder() + + raw_data = np.array([b'abc', u'ß∂µ∆'.encode('utf-8')]) + raw = Variable(('x',), raw_data, {'_Encoding': 'utf-8'}) + actual = coder.decode(raw) + + expected = Variable( + ('x',), np.array([u'abc', u'ß∂µ∆'], dtype=object)) + assert_identical(actual, expected) + + assert_identical(coder.decode(actual[0]), expected[0]) + + +@requires_dask +def test_EncodedStringCoder_decode_dask(): + coder = strings.EncodedStringCoder() + + raw_data = np.array([b'abc', u'ß∂µ∆'.encode('utf-8')]) + raw = Variable(('x',), raw_data, {'_Encoding': 'utf-8'}).chunk() + actual = coder.decode(raw) + assert isinstance(actual.data, da.Array) + + expected = Variable(('x',), np.array([u'abc', u'ß∂µ∆'], dtype=object)) + assert_identical(actual, expected) + + actual_indexed = coder.decode(actual[0]) + assert isinstance(actual_indexed.data, da.Array) + assert_identical(actual_indexed, expected[0]) + + +def test_EncodedStringCoder_encode(): + dtype = strings.create_vlen_dtype(unicode_type) + raw_data = np.array([u'abc', u'ß∂µ∆'], dtype=dtype) + expected_data = np.array([r.encode('utf-8') for r in raw_data], + dtype=object) + + coder = strings.EncodedStringCoder(allows_unicode=True) + raw = Variable(('x',), raw_data, encoding={'dtype': 'S1'}) + actual = coder.encode(raw) + expected = Variable(('x',), expected_data, attrs={'_Encoding': 'utf-8'}) + assert_identical(actual, expected) + + raw = Variable(('x',), raw_data) + assert_identical(coder.encode(raw), raw) + + coder = strings.EncodedStringCoder(allows_unicode=False) + assert_identical(coder.encode(raw), expected) + + +@pytest.mark.parametrize('original', [ + Variable(('x',), [b'ab', b'cdef']), + Variable((), b'ab'), + Variable(('x',), [b'a', b'b']), + Variable((), b'a'), +]) +def test_CharacterArrayCoder_roundtrip(original): + coder = strings.CharacterArrayCoder() + roundtripped = coder.decode(coder.encode(original)) + assert_identical(original, roundtripped) + + +@pytest.mark.parametrize('data', [ + np.array([b'a', b'bc']), + np.array([b'a', b'bc'], dtype=strings.create_vlen_dtype(bytes_type)), +]) +def test_CharacterArrayCoder_encode(data): + coder = strings.CharacterArrayCoder() + raw = Variable(('x',), data) + actual = coder.encode(raw) + expected = Variable(('x', 'string2'), + np.array([[b'a', b''], [b'b', b'c']])) + assert_identical(actual, expected) + + +def test_StackedBytesArray(): + array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']], dtype='S') + actual = strings.StackedBytesArray(array) + expected = np.array([b'abc', b'def'], dtype='S') + assert actual.dtype == expected.dtype + assert actual.shape == expected.shape + assert actual.size == expected.size + assert actual.ndim == expected.ndim + assert len(actual) == len(expected) + assert_array_equal(expected, actual) + + B = IndexerMaker(indexing.BasicIndexer) + assert_array_equal(expected[:1], actual[B[:1]]) + with pytest.raises(IndexError): + actual[B[:, :2]] + + +def test_StackedBytesArray_scalar(): + array = np.array([b'a', b'b', b'c'], dtype='S') + actual = strings.StackedBytesArray(array) + + expected = np.array(b'abc') + assert actual.dtype == expected.dtype + assert actual.shape == expected.shape + assert actual.size == expected.size + assert actual.ndim == expected.ndim + with pytest.raises(TypeError): + len(actual) + np.testing.assert_array_equal(expected, actual) + + B = IndexerMaker(indexing.BasicIndexer) + with pytest.raises(IndexError): + actual[B[:2]] + + +def test_StackedBytesArray_vectorized_indexing(): + array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']], dtype='S') + stacked = strings.StackedBytesArray(array) + expected = np.array([[b'abc', b'def'], [b'def', b'abc']]) + + V = IndexerMaker(indexing.VectorizedIndexer) + indexer = V[np.array([[0, 1], [1, 0]])] + actual = stacked[indexer] + assert_array_equal(actual, expected) + + +def test_char_to_bytes(): + array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']]) + expected = np.array([b'abc', b'def']) + actual = strings.char_to_bytes(array) + assert_array_equal(actual, expected) + + expected = np.array([b'ad', b'be', b'cf']) + actual = strings.char_to_bytes(array.T) # non-contiguous + assert_array_equal(actual, expected) + + +def test_char_to_bytes_ndim_zero(): + expected = np.array(b'a') + actual = strings.char_to_bytes(expected) + assert_array_equal(actual, expected) + + +def test_char_to_bytes_size_zero(): + array = np.zeros((3, 0), dtype='S1') + expected = np.array([b'', b'', b'']) + actual = strings.char_to_bytes(array) + assert_array_equal(actual, expected) + + +@requires_dask +def test_char_to_bytes_dask(): + numpy_array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']]) + array = da.from_array(numpy_array, ((2,), (3,))) + expected = np.array([b'abc', b'def']) + actual = strings.char_to_bytes(array) + assert isinstance(actual, da.Array) + assert actual.chunks == ((2,),) + assert actual.dtype == 'S3' + assert_array_equal(np.array(actual), expected) + + with raises_regex(ValueError, 'stacked dask character array'): + strings.char_to_bytes(array.rechunk(1)) + + +def test_bytes_to_char(): + array = np.array([[b'ab', b'cd'], [b'ef', b'gh']]) + expected = np.array([[[b'a', b'b'], [b'c', b'd']], + [[b'e', b'f'], [b'g', b'h']]]) + actual = strings.bytes_to_char(array) + assert_array_equal(actual, expected) + + expected = np.array([[[b'a', b'b'], [b'e', b'f']], + [[b'c', b'd'], [b'g', b'h']]]) + actual = strings.bytes_to_char(array.T) # non-contiguous + assert_array_equal(actual, expected) + + +@requires_dask +def test_bytes_to_char_dask(): + numpy_array = np.array([b'ab', b'cd']) + array = da.from_array(numpy_array, ((1, 1),)) + expected = np.array([[b'a', b'b'], [b'c', b'd']]) + actual = strings.bytes_to_char(array) + assert isinstance(actual, da.Array) + assert actual.chunks == ((1, 1), ((2,))) + assert actual.dtype == 'S1' + assert_array_equal(np.array(actual), expected) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index f8ac3d3b58b..0ca57f98a6d 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1,323 +1,703 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import warnings +from itertools import product import numpy as np import pandas as pd - -from xarray import Variable, coding -from . import ( - TestCase, requires_netCDF4, assert_array_equal) import pytest +from xarray import DataArray, Variable, coding, decode_cf +from xarray.coding.times import (_import_cftime, cftime_to_nptime, + decode_cf_datetime, encode_cf_datetime) +from xarray.core.common import contains_cftime_datetimes -@np.vectorize -def _ensure_naive_tz(dt): - if hasattr(dt, 'tzinfo'): - return dt.replace(tzinfo=None) +from . import ( + assert_array_equal, has_cftime, has_cftime_or_netCDF4, has_dask, + requires_cftime_or_netCDF4) + +_NON_STANDARD_CALENDARS_SET = {'noleap', '365_day', '360_day', + 'julian', 'all_leap', '366_day'} +_ALL_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET.union( + coding.times._STANDARD_CALENDARS)) +_NON_STANDARD_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET) +_STANDARD_CALENDARS = sorted(coding.times._STANDARD_CALENDARS) +_CF_DATETIME_NUM_DATES_UNITS = [ + (np.arange(10), 'days since 2000-01-01'), + (np.arange(10).astype('float64'), 'days since 2000-01-01'), + (np.arange(10).astype('float32'), 'days since 2000-01-01'), + (np.arange(10).reshape(2, 5), 'days since 2000-01-01'), + (12300 + np.arange(5), 'hours since 1680-01-01 00:00:00'), + # here we add a couple minor formatting errors to test + # the robustness of the parsing algorithm. + (12300 + np.arange(5), 'hour since 1680-01-01 00:00:00'), + (12300 + np.arange(5), u'Hour since 1680-01-01 00:00:00'), + (12300 + np.arange(5), ' Hour since 1680-01-01 00:00:00 '), + (10, 'days since 2000-01-01'), + ([10], 'daYs since 2000-01-01'), + ([[10]], 'days since 2000-01-01'), + ([10, 10], 'days since 2000-01-01'), + (np.array(10), 'days since 2000-01-01'), + (0, 'days since 1000-01-01'), + ([0], 'days since 1000-01-01'), + ([[0]], 'days since 1000-01-01'), + (np.arange(2), 'days since 1000-01-01'), + (np.arange(0, 100000, 20000), 'days since 1900-01-01'), + (17093352.0, 'hours since 1-1-1 00:00:0.0'), + ([0.5, 1.5], 'hours since 1900-01-01T00:00:00'), + (0, 'milliseconds since 2000-01-01T00:00:00'), + (0, 'microseconds since 2000-01-01T00:00:00'), + (np.int32(788961600), 'seconds since 1981-01-01'), # GH2002 + (12300 + np.arange(5), 'hour since 1680-01-01 00:00:00.500000') +] +_CF_DATETIME_TESTS = [num_dates_units + (calendar,) for num_dates_units, + calendar in product(_CF_DATETIME_NUM_DATES_UNITS, + _STANDARD_CALENDARS)] + + +def _all_cftime_date_types(): + try: + import cftime + except ImportError: + import netcdftime as cftime + return {'noleap': cftime.DatetimeNoLeap, + '365_day': cftime.DatetimeNoLeap, + '360_day': cftime.Datetime360Day, + 'julian': cftime.DatetimeJulian, + 'all_leap': cftime.DatetimeAllLeap, + '366_day': cftime.DatetimeAllLeap, + 'gregorian': cftime.DatetimeGregorian, + 'proleptic_gregorian': cftime.DatetimeProlepticGregorian} + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize(['num_dates', 'units', 'calendar'], + _CF_DATETIME_TESTS) +def test_cf_datetime(num_dates, units, calendar): + cftime = _import_cftime() + if cftime.__name__ == 'cftime': + expected = cftime.num2date(num_dates, units, calendar, + only_use_cftime_datetimes=True) else: - return dt - - -class TestDatetime(TestCase): - @requires_netCDF4 - def test_cf_datetime(self): - import netCDF4 as nc4 - for num_dates, units in [ - (np.arange(10), 'days since 2000-01-01'), - (np.arange(10).reshape(2, 5), 'days since 2000-01-01'), - (12300 + np.arange(5), 'hours since 1680-01-01 00:00:00'), - # here we add a couple minor formatting errors to test - # the robustness of the parsing algorithm. - (12300 + np.arange(5), 'hour since 1680-01-01 00:00:00'), - (12300 + np.arange(5), u'Hour since 1680-01-01 00:00:00'), - (12300 + np.arange(5), ' Hour since 1680-01-01 00:00:00 '), - (10, 'days since 2000-01-01'), - ([10], 'daYs since 2000-01-01'), - ([[10]], 'days since 2000-01-01'), - ([10, 10], 'days since 2000-01-01'), - (np.array(10), 'days since 2000-01-01'), - (0, 'days since 1000-01-01'), - ([0], 'days since 1000-01-01'), - ([[0]], 'days since 1000-01-01'), - (np.arange(2), 'days since 1000-01-01'), - (np.arange(0, 100000, 20000), 'days since 1900-01-01'), - (17093352.0, 'hours since 1-1-1 00:00:0.0'), - ([0.5, 1.5], 'hours since 1900-01-01T00:00:00'), - (0, 'milliseconds since 2000-01-01T00:00:00'), - (0, 'microseconds since 2000-01-01T00:00:00'), - ]: - for calendar in ['standard', 'gregorian', 'proleptic_gregorian']: - expected = _ensure_naive_tz( - nc4.num2date(num_dates, units, calendar)) - print(num_dates, units, calendar) - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', - 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(num_dates, units, - calendar) - if (isinstance(actual, np.ndarray) and - np.issubdtype(actual.dtype, np.datetime64)): - # self.assertEqual(actual.dtype.kind, 'M') - # For some reason, numpy 1.8 does not compare ns precision - # datetime64 arrays as equal to arrays of datetime objects, - # but it works for us precision. Thus, convert to us - # precision for the actual array equal comparison... - actual_cmp = actual.astype('M8[us]') - else: - actual_cmp = actual - assert_array_equal(expected, actual_cmp) - encoded, _, _ = coding.times.encode_cf_datetime(actual, units, - calendar) - if '1-1-1' not in units: - # pandas parses this date very strangely, so the original - # units/encoding cannot be preserved in this case: - # (Pdb) pd.to_datetime('1-1-1 00:00:0.0') - # Timestamp('2001-01-01 00:00:00') - assert_array_equal(num_dates, np.around(encoded, 1)) - if (hasattr(num_dates, 'ndim') and num_dates.ndim == 1 and - '1000' not in units): - # verify that wrapping with a pandas.Index works - # note that it *does not* currently work to even put - # non-datetime64 compatible dates into a pandas.Index - encoded, _, _ = coding.times.encode_cf_datetime( - pd.Index(actual), units, calendar) - assert_array_equal(num_dates, np.around(encoded, 1)) - - @requires_netCDF4 - def test_decode_cf_datetime_overflow(self): - # checks for - # https://github.com/pydata/pandas/issues/14068 - # https://github.com/pydata/xarray/issues/975 - - from datetime import datetime - units = 'days since 2000-01-01 00:00:00' - - # date after 2262 and before 1678 - days = (-117608, 95795) - expected = (datetime(1677, 12, 31), datetime(2262, 4, 12)) - - for i, day in enumerate(days): - result = coding.times.decode_cf_datetime(day, units) - assert result == expected[i] - - def test_decode_cf_datetime_non_standard_units(self): - expected = pd.date_range(periods=100, start='1970-01-01', freq='h') - # netCDFs from madis.noaa.gov use this format for their time units - # they cannot be parsed by netcdftime, but pd.Timestamp works - units = 'hours since 1-1-1970' - actual = coding.times.decode_cf_datetime(np.arange(100), units) - assert_array_equal(actual, expected) - - @requires_netCDF4 - def test_decode_cf_datetime_non_iso_strings(self): - # datetime strings that are _almost_ ISO compliant but not quite, - # but which netCDF4.num2date can still parse correctly - expected = pd.date_range(periods=100, start='2000-01-01', freq='h') - cases = [(np.arange(100), 'hours since 2000-01-01 0'), - (np.arange(100), 'hours since 2000-1-1 0'), - (np.arange(100), 'hours since 2000-01-01 0:00')] - for num_dates, units in cases: - actual = coding.times.decode_cf_datetime(num_dates, units) - assert_array_equal(actual, expected) - - @requires_netCDF4 - def test_decode_non_standard_calendar(self): - import netCDF4 as nc4 - - for calendar in ['noleap', '365_day', '360_day', 'julian', 'all_leap', - '366_day']: - units = 'days since 0001-01-01' - times = pd.date_range('2001-04-01-00', end='2001-04-30-23', - freq='H') - noleap_time = nc4.date2num(times.to_pydatetime(), units, - calendar=calendar) - expected = times.values - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(noleap_time, units, - calendar=calendar) - assert actual.dtype == np.dtype('M8[ns]') - abs_diff = abs(actual - expected) - # once we no longer support versions of netCDF4 older than 1.1.5, - # we could do this check with near microsecond accuracy: - # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff <= np.timedelta64(1, 's')).all() - - @requires_netCDF4 - def test_decode_non_standard_calendar_single_element(self): - units = 'days since 0001-01-01' - for calendar in ['noleap', '365_day', '360_day', 'julian', 'all_leap', - '366_day']: - for num_time in [735368, [735368], [[735368]]]: - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', - 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(num_time, units, - calendar=calendar) - assert actual.dtype == np.dtype('M8[ns]') - - @requires_netCDF4 - def test_decode_non_standard_calendar_single_element_fallback(self): - import netCDF4 as nc4 - - units = 'days since 0001-01-01' - dt = nc4.netcdftime.datetime(2001, 2, 29) - for calendar in ['360_day', 'all_leap', '366_day']: - num_time = nc4.date2num(dt, units, calendar) - with pytest.warns(Warning, match='Unable to decode time axis'): - actual = coding.times.decode_cf_datetime(num_time, units, - calendar=calendar) - expected = np.asarray(nc4.num2date(num_time, units, calendar)) - print(num_time, calendar, actual, expected) - assert actual.dtype == np.dtype('O') - assert expected == actual - - @requires_netCDF4 - def test_decode_non_standard_calendar_multidim_time(self): - import netCDF4 as nc4 - - calendar = 'noleap' - units = 'days since 0001-01-01' - times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') - times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') - noleap_time1 = nc4.date2num(times1.to_pydatetime(), units, - calendar=calendar) - noleap_time2 = nc4.date2num(times2.to_pydatetime(), units, - calendar=calendar) - mdim_time = np.empty((len(noleap_time1), 2), ) - mdim_time[:, 0] = noleap_time1 - mdim_time[:, 1] = noleap_time2 - - expected1 = times1.values - expected2 = times2.values + expected = cftime.num2date(num_dates, units, calendar) + min_y = np.ravel(np.atleast_1d(expected))[np.nanargmin(num_dates)].year + max_y = np.ravel(np.atleast_1d(expected))[np.nanargmax(num_dates)].year + if min_y >= 1678 and max_y < 2262: + expected = cftime_to_nptime(expected) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', + 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime(num_dates, units, + calendar) + + abs_diff = np.atleast_1d(abs(actual - expected)).astype(np.timedelta64) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() + encoded, _, _ = coding.times.encode_cf_datetime(actual, units, + calendar) + if '1-1-1' not in units: + # pandas parses this date very strangely, so the original + # units/encoding cannot be preserved in this case: + # (Pdb) pd.to_datetime('1-1-1 00:00:0.0') + # Timestamp('2001-01-01 00:00:00') + assert_array_equal(num_dates, np.around(encoded, 1)) + if (hasattr(num_dates, 'ndim') and num_dates.ndim == 1 and + '1000' not in units): + # verify that wrapping with a pandas.Index works + # note that it *does not* currently work to even put + # non-datetime64 compatible dates into a pandas.Index + encoded, _, _ = coding.times.encode_cf_datetime( + pd.Index(actual), units, calendar) + assert_array_equal(num_dates, np.around(encoded, 1)) + + +@requires_cftime_or_netCDF4 +def test_decode_cf_datetime_overflow(): + # checks for + # https://github.com/pydata/pandas/issues/14068 + # https://github.com/pydata/xarray/issues/975 + try: + from cftime import DatetimeGregorian + except ImportError: + from netcdftime import DatetimeGregorian + + datetime = DatetimeGregorian + units = 'days since 2000-01-01 00:00:00' + + # date after 2262 and before 1678 + days = (-117608, 95795) + expected = (datetime(1677, 12, 31), datetime(2262, 4, 12)) + + for i, day in enumerate(days): with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(mdim_time, units, - calendar=calendar) + result = coding.times.decode_cf_datetime(day, units) + assert result == expected[i] + + +def test_decode_cf_datetime_non_standard_units(): + expected = pd.date_range(periods=100, start='1970-01-01', freq='h') + # netCDFs from madis.noaa.gov use this format for their time units + # they cannot be parsed by cftime, but pd.Timestamp works + units = 'hours since 1-1-1970' + actual = coding.times.decode_cf_datetime(np.arange(100), units) + assert_array_equal(actual, expected) + + +@requires_cftime_or_netCDF4 +def test_decode_cf_datetime_non_iso_strings(): + # datetime strings that are _almost_ ISO compliant but not quite, + # but which cftime.num2date can still parse correctly + expected = pd.date_range(periods=100, start='2000-01-01', freq='h') + cases = [(np.arange(100), 'hours since 2000-01-01 0'), + (np.arange(100), 'hours since 2000-1-1 0'), + (np.arange(100), 'hours since 2000-01-01 0:00')] + for num_dates, units in cases: + actual = coding.times.decode_cf_datetime(num_dates, units) + abs_diff = abs(actual - expected.values) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) +def test_decode_standard_calendar_inside_timestamp_range(calendar): + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times = pd.date_range('2001-04-01-00', end='2001-04-30-23', freq='H') + time = cftime.date2num(times.to_pydatetime(), units, calendar=calendar) + expected = times.values + expected_dtype = np.dtype('M8[ns]') + + actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) + assert actual.dtype == expected_dtype + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) +def test_decode_non_standard_calendar_inside_timestamp_range( + calendar): + cftime = _import_cftime() + units = 'days since 0001-01-01' + times = pd.date_range('2001-04-01-00', end='2001-04-30-23', + freq='H') + non_standard_time = cftime.date2num( + times.to_pydatetime(), units, calendar=calendar) + + if cftime.__name__ == 'cftime': + expected = cftime.num2date( + non_standard_time, units, calendar=calendar, + only_use_cftime_datetimes=True) + else: + expected = cftime.num2date(non_standard_time, units, + calendar=calendar) + + expected_dtype = np.dtype('O') + + actual = coding.times.decode_cf_datetime( + non_standard_time, units, calendar=calendar) + assert actual.dtype == expected_dtype + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +def test_decode_dates_outside_timestamp_range(calendar): + from datetime import datetime + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times = [datetime(1, 4, 1, h) for h in range(1, 5)] + time = cftime.date2num(times, units, calendar=calendar) + + if cftime.__name__ == 'cftime': + expected = cftime.num2date(time, units, calendar=calendar, + only_use_cftime_datetimes=True) + else: + expected = cftime.num2date(time, units, calendar=calendar) + + expected_date_type = type(expected[0]) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + time, units, calendar=calendar) + assert all(isinstance(value, expected_date_type) for value in actual) + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) +def test_decode_standard_calendar_single_element_inside_timestamp_range( + calendar): + units = 'days since 0001-01-01' + for num_time in [735368, [735368], [[735368]]]: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', + 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar) assert actual.dtype == np.dtype('M8[ns]') - assert_array_equal(actual[:, 0], expected1) - assert_array_equal(actual[:, 1], expected2) - - @requires_netCDF4 - def test_decode_non_standard_calendar_fallback(self): - import netCDF4 as nc4 - # ensure leap year doesn't matter - for year in [2010, 2011, 2012, 2013, 2014]: - for calendar in ['360_day', '366_day', 'all_leap']: - calendar = '360_day' - units = 'days since {0}-01-01'.format(year) - num_times = np.arange(100) - expected = nc4.num2date(num_times, units, calendar) - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - actual = coding.times.decode_cf_datetime(num_times, units, - calendar=calendar) - assert len(w) == 1 - assert 'Unable to decode time axis' in \ - str(w[0].message) - - assert actual.dtype == np.dtype('O') - assert_array_equal(actual, expected) - - @requires_netCDF4 - def test_cf_datetime_nan(self): - for num_dates, units, expected_list in [ - ([np.nan], 'days since 2000-01-01', ['NaT']), - ([np.nan, 0], 'days since 2000-01-01', - ['NaT', '2000-01-01T00:00:00Z']), - ([np.nan, 0, 1], 'days since 2000-01-01', - ['NaT', '2000-01-01T00:00:00Z', '2000-01-02T00:00:00Z']), - ]: + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) +def test_decode_non_standard_calendar_single_element_inside_timestamp_range( + calendar): + units = 'days since 0001-01-01' + for num_time in [735368, [735368], [[735368]]]: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', + 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar) + assert actual.dtype == np.dtype('O') + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) +def test_decode_single_element_outside_timestamp_range( + calendar): + cftime = _import_cftime() + units = 'days since 0001-01-01' + for days in [1, 1470376]: + for num_time in [days, [days], [[days]]]: with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'All-NaN') - actual = coding.times.decode_cf_datetime(num_dates, units) - expected = np.array(expected_list, dtype='datetime64[ns]') - assert_array_equal(expected, actual) - - @requires_netCDF4 - def test_decoded_cf_datetime_array_2d(self): - # regression test for GH1229 - variable = Variable(('x', 'y'), np.array([[0, 1], [2, 3]]), - {'units': 'days since 2000-01-01'}) - result = coding.times.CFDatetimeCoder().decode(variable) - assert result.dtype == 'datetime64[ns]' - expected = pd.date_range('2000-01-01', periods=4).values.reshape(2, 2) - assert_array_equal(np.asarray(result), expected) - - def test_infer_datetime_units(self): - for dates, expected in [(pd.date_range('1900-01-01', periods=5), - 'days since 1900-01-01 00:00:00'), - (pd.date_range('1900-01-01 12:00:00', freq='H', - periods=2), - 'hours since 1900-01-01 12:00:00'), - (['1900-01-01', '1900-01-02', - '1900-01-02 00:00:01'], - 'seconds since 1900-01-01 00:00:00'), - (pd.to_datetime( - ['1900-01-01', '1900-01-02', 'NaT']), - 'days since 1900-01-01 00:00:00'), - (pd.to_datetime(['1900-01-01', - '1900-01-02T00:00:00.005']), - 'seconds since 1900-01-01 00:00:00'), - (pd.to_datetime(['NaT', '1900-01-01']), - 'days since 1900-01-01 00:00:00'), - (pd.to_datetime(['NaT']), - 'days since 1970-01-01 00:00:00'), - ]: - assert expected == coding.times.infer_datetime_units(dates) - - def test_cf_timedelta(self): - examples = [ - ('1D', 'days', np.int64(1)), - (['1D', '2D', '3D'], 'days', np.array([1, 2, 3], 'int64')), - ('1h', 'hours', np.int64(1)), - ('1ms', 'milliseconds', np.int64(1)), - ('1us', 'microseconds', np.int64(1)), - (['NaT', '0s', '1s'], None, [np.nan, 0, 1]), - (['30m', '60m'], 'hours', [0.5, 1.0]), - (np.timedelta64('NaT', 'ns'), 'days', np.nan), - (['NaT', 'NaT'], 'days', [np.nan, np.nan]), - ] - - for timedeltas, units, numbers in examples: - timedeltas = pd.to_timedelta(timedeltas, box=False) - numbers = np.array(numbers) - - expected = numbers - actual, _ = coding.times.encode_cf_timedelta(timedeltas, units) - assert_array_equal(expected, actual) - assert expected.dtype == actual.dtype - - if units is not None: - expected = timedeltas - actual = coding.times.decode_cf_timedelta(numbers, units) - assert_array_equal(expected, actual) - assert expected.dtype == actual.dtype - - expected = np.timedelta64('NaT', 'ns') - actual = coding.times.decode_cf_timedelta(np.array(np.nan), 'days') - assert_array_equal(expected, actual) + warnings.filterwarnings('ignore', + 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar) + + if cftime.__name__ == 'cftime': + expected = cftime.num2date(days, units, calendar, + only_use_cftime_datetimes=True) + else: + expected = cftime.num2date(days, units, calendar) + + assert isinstance(actual.item(), type(expected)) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) +def test_decode_standard_calendar_multidim_time_inside_timestamp_range( + calendar): + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') + times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') + time1 = cftime.date2num(times1.to_pydatetime(), + units, calendar=calendar) + time2 = cftime.date2num(times2.to_pydatetime(), + units, calendar=calendar) + mdim_time = np.empty((len(time1), 2), ) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 + + expected1 = times1.values + expected2 = times2.values + + actual = coding.times.decode_cf_datetime( + mdim_time, units, calendar=calendar) + assert actual.dtype == np.dtype('M8[ns]') + + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, 's')).all() + assert (abs_diff2 <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) +def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( + calendar): + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') + times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') + time1 = cftime.date2num(times1.to_pydatetime(), + units, calendar=calendar) + time2 = cftime.date2num(times2.to_pydatetime(), + units, calendar=calendar) + mdim_time = np.empty((len(time1), 2), ) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 + + if cftime.__name__ == 'cftime': + expected1 = cftime.num2date(time1, units, calendar, + only_use_cftime_datetimes=True) + expected2 = cftime.num2date(time2, units, calendar, + only_use_cftime_datetimes=True) + else: + expected1 = cftime.num2date(time1, units, calendar) + expected2 = cftime.num2date(time2, units, calendar) + + expected_dtype = np.dtype('O') + + actual = coding.times.decode_cf_datetime( + mdim_time, units, calendar=calendar) + + assert actual.dtype == expected_dtype + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, 's')).all() + assert (abs_diff2 <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +def test_decode_multidim_time_outside_timestamp_range( + calendar): + from datetime import datetime + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times1 = [datetime(1, 4, day) for day in range(1, 6)] + times2 = [datetime(1, 5, day) for day in range(1, 6)] + time1 = cftime.date2num(times1, units, calendar=calendar) + time2 = cftime.date2num(times2, units, calendar=calendar) + mdim_time = np.empty((len(time1), 2), ) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 + + if cftime.__name__ == 'cftime': + expected1 = cftime.num2date(time1, units, calendar, + only_use_cftime_datetimes=True) + expected2 = cftime.num2date(time2, units, calendar, + only_use_cftime_datetimes=True) + else: + expected1 = cftime.num2date(time1, units, calendar) + expected2 = cftime.num2date(time2, units, calendar) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + mdim_time, units, calendar=calendar) + + assert actual.dtype == np.dtype('O') + + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, 's')).all() + assert (abs_diff2 <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('calendar', ['360_day', 'all_leap', '366_day']) +def test_decode_non_standard_calendar_single_element( + calendar): + cftime = _import_cftime() + units = 'days since 0001-01-01' + + try: + dt = cftime.netcdftime.datetime(2001, 2, 29) + except AttributeError: + # Must be using the standalone cftime library + dt = cftime.datetime(2001, 2, 29) + + num_time = cftime.date2num(dt, units, calendar) + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar) + + if cftime.__name__ == 'cftime': + expected = np.asarray(cftime.num2date( + num_time, units, calendar, only_use_cftime_datetimes=True)) + else: + expected = np.asarray(cftime.num2date(num_time, units, calendar)) + assert actual.dtype == np.dtype('O') + assert expected == actual + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +def test_decode_360_day_calendar(): + cftime = _import_cftime() + calendar = '360_day' + # ensure leap year doesn't matter + for year in [2010, 2011, 2012, 2013, 2014]: + units = 'days since {0}-01-01'.format(year) + num_times = np.arange(100) + + if cftime.__name__ == 'cftime': + expected = cftime.num2date(num_times, units, calendar, + only_use_cftime_datetimes=True) + else: + expected = cftime.num2date(num_times, units, calendar) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + actual = coding.times.decode_cf_datetime( + num_times, units, calendar=calendar) + assert len(w) == 0 + + assert actual.dtype == np.dtype('O') + assert_array_equal(actual, expected) - def test_cf_timedelta_2d(self): - timedeltas = ['1D', '2D', '3D'] - units = 'days' - numbers = np.atleast_2d([1, 2, 3]) - timedeltas = np.atleast_2d(pd.to_timedelta(timedeltas, box=False)) +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['num_dates', 'units', 'expected_list'], + [([np.nan], 'days since 2000-01-01', ['NaT']), + ([np.nan, 0], 'days since 2000-01-01', + ['NaT', '2000-01-01T00:00:00Z']), + ([np.nan, 0, 1], 'days since 2000-01-01', + ['NaT', '2000-01-01T00:00:00Z', '2000-01-02T00:00:00Z'])]) +def test_cf_datetime_nan(num_dates, units, expected_list): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'All-NaN') + actual = coding.times.decode_cf_datetime(num_dates, units) + # use pandas because numpy will deprecate timezone-aware conversions + expected = pd.to_datetime(expected_list) + assert_array_equal(expected, actual) + + +@requires_cftime_or_netCDF4 +def test_decoded_cf_datetime_array_2d(): + # regression test for GH1229 + variable = Variable(('x', 'y'), np.array([[0, 1], [2, 3]]), + {'units': 'days since 2000-01-01'}) + result = coding.times.CFDatetimeCoder().decode(variable) + assert result.dtype == 'datetime64[ns]' + expected = pd.date_range('2000-01-01', periods=4).values.reshape(2, 2) + assert_array_equal(np.asarray(result), expected) + + +@pytest.mark.parametrize( + ['dates', 'expected'], + [(pd.date_range('1900-01-01', periods=5), + 'days since 1900-01-01 00:00:00'), + (pd.date_range('1900-01-01 12:00:00', freq='H', + periods=2), + 'hours since 1900-01-01 12:00:00'), + (pd.to_datetime( + ['1900-01-01', '1900-01-02', 'NaT']), + 'days since 1900-01-01 00:00:00'), + (pd.to_datetime(['1900-01-01', + '1900-01-02T00:00:00.005']), + 'seconds since 1900-01-01 00:00:00'), + (pd.to_datetime(['NaT', '1900-01-01']), + 'days since 1900-01-01 00:00:00'), + (pd.to_datetime(['NaT']), + 'days since 1970-01-01 00:00:00')]) +def test_infer_datetime_units(dates, expected): + assert expected == coding.times.infer_datetime_units(dates) + + +_CFTIME_DATETIME_UNITS_TESTS = [ + ([(1900, 1, 1), (1900, 1, 1)], 'days since 1900-01-01 00:00:00.000000'), + ([(1900, 1, 1), (1900, 1, 2), (1900, 1, 2, 0, 0, 1)], + 'seconds since 1900-01-01 00:00:00.000000'), + ([(1900, 1, 1), (1900, 1, 8), (1900, 1, 16)], + 'days since 1900-01-01 00:00:00.000000') +] + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + 'calendar', _NON_STANDARD_CALENDARS + ['gregorian', 'proleptic_gregorian']) +@pytest.mark.parametrize(('date_args', 'expected'), + _CFTIME_DATETIME_UNITS_TESTS) +def test_infer_cftime_datetime_units(calendar, date_args, expected): + date_type = _all_cftime_date_types()[calendar] + dates = [date_type(*args) for args in date_args] + assert expected == coding.times.infer_datetime_units(dates) + + +@pytest.mark.parametrize( + ['timedeltas', 'units', 'numbers'], + [('1D', 'days', np.int64(1)), + (['1D', '2D', '3D'], 'days', np.array([1, 2, 3], 'int64')), + ('1h', 'hours', np.int64(1)), + ('1ms', 'milliseconds', np.int64(1)), + ('1us', 'microseconds', np.int64(1)), + (['NaT', '0s', '1s'], None, [np.nan, 0, 1]), + (['30m', '60m'], 'hours', [0.5, 1.0]), + (np.timedelta64('NaT', 'ns'), 'days', np.nan), + (['NaT', 'NaT'], 'days', [np.nan, np.nan])]) +def test_cf_timedelta(timedeltas, units, numbers): + timedeltas = pd.to_timedelta(timedeltas, box=False) + numbers = np.array(numbers) + + expected = numbers + actual, _ = coding.times.encode_cf_timedelta(timedeltas, units) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + if units is not None: expected = timedeltas - actual = coding.times.decode_cf_timedelta(numbers, units) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype - def test_infer_timedelta_units(self): - for deltas, expected in [ - (pd.to_timedelta(['1 day', '2 days']), 'days'), - (pd.to_timedelta(['1h', '1 day 1 hour']), 'hours'), - (pd.to_timedelta(['1m', '2m', np.nan]), 'minutes'), - (pd.to_timedelta(['1m3s', '1m4s']), 'seconds')]: - assert expected == coding.times.infer_timedelta_units(deltas) + expected = np.timedelta64('NaT', 'ns') + actual = coding.times.decode_cf_timedelta(np.array(np.nan), 'days') + assert_array_equal(expected, actual) + + +def test_cf_timedelta_2d(): + timedeltas = ['1D', '2D', '3D'] + units = 'days' + numbers = np.atleast_2d([1, 2, 3]) + + timedeltas = np.atleast_2d(pd.to_timedelta(timedeltas, box=False)) + expected = timedeltas + + actual = coding.times.decode_cf_timedelta(numbers, units) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +@pytest.mark.parametrize( + ['deltas', 'expected'], + [(pd.to_timedelta(['1 day', '2 days']), 'days'), + (pd.to_timedelta(['1h', '1 day 1 hour']), 'hours'), + (pd.to_timedelta(['1m', '2m', np.nan]), 'minutes'), + (pd.to_timedelta(['1m3s', '1m4s']), 'seconds')]) +def test_infer_timedelta_units(deltas, expected): + assert expected == coding.times.infer_timedelta_units(deltas) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize(['date_args', 'expected'], + [((1, 2, 3, 4, 5, 6), + '0001-02-03 04:05:06.000000'), + ((10, 2, 3, 4, 5, 6), + '0010-02-03 04:05:06.000000'), + ((100, 2, 3, 4, 5, 6), + '0100-02-03 04:05:06.000000'), + ((1000, 2, 3, 4, 5, 6), + '1000-02-03 04:05:06.000000')]) +def test_format_cftime_datetime(date_args, expected): + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + result = coding.times.format_cftime_datetime(date_type(*date_args)) + assert result == expected + + +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +def test_decode_cf(calendar): + days = [1., 2., 3.] + da = DataArray(days, coords=[days], dims=['time'], name='test') + ds = da.to_dataset() + + for v in ['test', 'time']: + ds[v].attrs['units'] = 'days since 2001-01-01' + ds[v].attrs['calendar'] = calendar + + if not has_cftime_or_netCDF4 and calendar not in _STANDARD_CALENDARS: + with pytest.raises(ValueError): + ds = decode_cf(ds) + else: + ds = decode_cf(ds) + + if calendar not in _STANDARD_CALENDARS: + assert ds.test.dtype == np.dtype('O') + else: + assert ds.test.dtype == np.dtype('M8[ns]') + + +@pytest.fixture(params=_ALL_CALENDARS) +def calendar(request): + return request.param + + +@pytest.fixture() +def times(calendar): + cftime = _import_cftime() + + return cftime.num2date( + np.arange(4), units='hours since 2000-01-01', calendar=calendar, + only_use_cftime_datetimes=True) + + +@pytest.fixture() +def data(times): + data = np.random.rand(2, 2, 4) + lons = np.linspace(0, 11, 2) + lats = np.linspace(0, 20, 2) + return DataArray(data, coords=[lons, lats, times], + dims=['lon', 'lat', 'time'], name='data') + + +@pytest.fixture() +def times_3d(times): + lons = np.linspace(0, 11, 2) + lats = np.linspace(0, 20, 2) + times_arr = np.random.choice(times, size=(2, 2, 4)) + return DataArray(times_arr, coords=[lons, lats, times], + dims=['lon', 'lat', 'time'], + name='data') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains_cftime_datetimes_1d(data): + assert contains_cftime_datetimes(data.time) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains_cftime_datetimes_dask_1d(data): + assert contains_cftime_datetimes(data.time.chunk()) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains_cftime_datetimes_3d(times_3d): + assert contains_cftime_datetimes(times_3d) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains_cftime_datetimes_dask_3d(times_3d): + assert contains_cftime_datetimes(times_3d.chunk()) + + +@pytest.mark.parametrize('non_cftime_data', [DataArray([]), DataArray([1, 2])]) +def test_contains_cftime_datetimes_non_cftimes(non_cftime_data): + assert not contains_cftime_datetimes(non_cftime_data) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.parametrize('non_cftime_data', [DataArray([]), DataArray([1, 2])]) +def test_contains_cftime_datetimes_non_cftimes_dask(non_cftime_data): + assert not contains_cftime_datetimes(non_cftime_data.chunk()) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('shape', [(24,), (8, 3), (2, 4, 3)]) +def test_encode_cf_datetime_overflow(shape): + # Test for fix to GH 2272 + dates = pd.date_range('2100', periods=24).values.reshape(shape) + units = 'days since 1800-01-01' + calendar = 'standard' + + num, _, _ = encode_cf_datetime(dates, units, calendar) + roundtrip = decode_cf_datetime(num, units, calendar) + np.testing.assert_array_equal(dates, roundtrip) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 365e274a191..2004b1e660f 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -1,22 +1,21 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from copy import deepcopy +from __future__ import absolute_import, division, print_function -import pytest +from copy import deepcopy import numpy as np import pandas as pd +import pytest -from xarray import Dataset, DataArray, auto_combine, concat, Variable -from xarray.core.pycompat import iteritems, OrderedDict +from xarray import DataArray, Dataset, Variable, auto_combine, concat +from xarray.core.pycompat import OrderedDict, iteritems -from . import (TestCase, InaccessibleArray, requires_dask, raises_regex, - assert_equal, assert_identical, assert_array_equal) +from . import ( + InaccessibleArray, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask) from .test_dataset import create_test_data -class TestConcatDataset(TestCase): +class TestConcatDataset(object): def test_concat(self): # TODO: simplify and split this test case @@ -236,7 +235,7 @@ def test_concat_multiindex(self): assert isinstance(actual.x.to_index(), pd.MultiIndex) -class TestConcatDataArray(TestCase): +class TestConcatDataArray(object): def test_concat(self): ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))), 'bar': (['x', 'y'], np.random.random((2, 3)))}, @@ -296,7 +295,7 @@ def test_concat_lazy(self): assert combined.dims == ('z', 'x', 'y') -class TestAutoCombine(TestCase): +class TestAutoCombine(object): @requires_dask # only for toolz def test_auto_combine(self): @@ -378,3 +377,22 @@ def test_auto_combine_no_concat(self): data = Dataset({'x': 0}) actual = auto_combine([data, data, data], concat_dim=None) assert_identical(data, actual) + + # Single object, with a concat_dim explicitly provided + # Test the issue reported in GH #1988 + objs = [Dataset({'x': 0, 'y': 1})] + dim = DataArray([100], name='baz', dims='baz') + actual = auto_combine(objs, concat_dim=dim) + expected = Dataset({'x': ('baz', [0]), 'y': ('baz', [1])}, + {'baz': [100]}) + assert_identical(expected, actual) + + # Just making sure that auto_combine is doing what is + # expected for non-scalar values, too. + objs = [Dataset({'x': ('z', [0, 1]), 'y': ('z', [1, 2])})] + dim = DataArray([100], name='baz', dims='baz') + actual = auto_combine(objs, concat_dim=dim) + expected = Dataset({'x': (('baz', 'z'), [[0, 1]]), + 'y': (('baz', 'z'), [[1, 2]])}, + {'baz': [100]}) + assert_identical(expected, actual) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 23e77b83455..1003c531018 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1,21 +1,21 @@ import functools import operator +import pickle from collections import OrderedDict - from distutils.version import LooseVersion + import numpy as np -from numpy.testing import assert_array_equal import pandas as pd - import pytest +from numpy.testing import assert_array_equal import xarray as xr from xarray.core.computation import ( - _UFuncSignature, result_name, broadcast_compat_data, collect_dict_values, - join_dict_keys, ordered_set_intersection, ordered_set_union, - unified_dim_sizes, apply_ufunc) + _UFuncSignature, apply_ufunc, broadcast_compat_data, collect_dict_values, + join_dict_keys, ordered_set_intersection, ordered_set_union, result_name, + unified_dim_sizes) -from . import requires_dask, raises_regex +from . import has_dask, raises_regex, requires_dask def assert_identical(a, b): @@ -274,6 +274,22 @@ def func(x): assert_identical(expected_dataset_x, first_element(dataset.groupby('y'), 'x')) + def multiply(*args): + val = args[0] + for arg in args[1:]: + val = val * arg + return val + + # regression test for GH:2341 + with pytest.raises(ValueError): + apply_ufunc(multiply, data_array, data_array['y'].values, + input_core_dims=[['y']], output_core_dims=[['y']]) + expected = xr.DataArray(multiply(data_array, data_array['y']), + dims=['x', 'y'], coords=data_array.coords) + actual = apply_ufunc(multiply, data_array, data_array['y'].values, + input_core_dims=[['y'], []], output_core_dims=[['y']]) + assert_identical(expected, actual) + def test_apply_output_core_dimension(): @@ -481,12 +497,19 @@ def add(a, b, keep_attrs): a = xr.DataArray([0, 1], [('x', [0, 1])]) a.attrs['attr'] = 'da' + a['x'].attrs['attr'] = 'da_coord' b = xr.DataArray([1, 2], [('x', [0, 1])]) actual = add(a, b, keep_attrs=False) assert not actual.attrs actual = add(a, b, keep_attrs=True) assert_identical(actual.attrs, a.attrs) + assert_identical(actual['x'].attrs, a['x'].attrs) + + actual = add(a.variable, b.variable, keep_attrs=False) + assert not actual.attrs + actual = add(a.variable, b.variable, keep_attrs=True) + assert_identical(actual.attrs, a.attrs) a = xr.Dataset({'x': [0, 1]}) a.attrs['attr'] = 'ds' @@ -546,7 +569,7 @@ def test_apply_dask(): array = da.ones((2,), chunks=2) variable = xr.Variable('x', array) coords = xr.DataArray(variable).coords.variables - data_array = xr.DataArray(variable, coords, fastpath=True) + data_array = xr.DataArray(variable, dims=['x'], coords=coords) dataset = xr.Dataset({'y': variable}) # encountered dask array, but did not set dask='allowed' @@ -719,9 +742,6 @@ def pandas_median(x): def test_vectorize(): - if LooseVersion(np.__version__) < LooseVersion('1.12.0'): - pytest.skip('numpy 1.12 or later to support vectorize=True.') - data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=('x', 'y')) expected = xr.DataArray([1, 2], dims=['x']) actual = apply_ufunc(pandas_median, data_array, @@ -732,9 +752,6 @@ def test_vectorize(): @requires_dask def test_vectorize_dask(): - if LooseVersion(np.__version__) < LooseVersion('1.12.0'): - pytest.skip('numpy 1.12 or later to support vectorize=True.') - data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=('x', 'y')) expected = xr.DataArray([1, 2], dims=['x']) actual = apply_ufunc(pandas_median, data_array.chunk({'x': 1}), @@ -745,6 +762,215 @@ def test_vectorize_dask(): assert_identical(expected, actual) +def test_output_wrong_number(): + variable = xr.Variable('x', np.arange(10)) + + def identity(x): + return x + + def tuple3x(x): + return (x, x, x) + + with raises_regex(ValueError, 'number of outputs'): + apply_ufunc(identity, variable, output_core_dims=[(), ()]) + + with raises_regex(ValueError, 'number of outputs'): + apply_ufunc(tuple3x, variable, output_core_dims=[(), ()]) + + +def test_output_wrong_dims(): + variable = xr.Variable('x', np.arange(10)) + + def add_dim(x): + return x[..., np.newaxis] + + def remove_dim(x): + return x[..., 0] + + with raises_regex(ValueError, 'unexpected number of dimensions'): + apply_ufunc(add_dim, variable, output_core_dims=[('y', 'z')]) + + with raises_regex(ValueError, 'unexpected number of dimensions'): + apply_ufunc(add_dim, variable) + + with raises_regex(ValueError, 'unexpected number of dimensions'): + apply_ufunc(remove_dim, variable) + + +def test_output_wrong_dim_size(): + array = np.arange(10) + variable = xr.Variable('x', array) + data_array = xr.DataArray(variable, [('x', -array)]) + dataset = xr.Dataset({'y': variable}, {'x': -array}) + + def truncate(array): + return array[:5] + + def apply_truncate_broadcast_invalid(obj): + return apply_ufunc(truncate, obj) + + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_broadcast_invalid(variable) + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_broadcast_invalid(data_array) + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_broadcast_invalid(dataset) + + def apply_truncate_x_x_invalid(obj): + return apply_ufunc(truncate, obj, input_core_dims=[['x']], + output_core_dims=[['x']]) + + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_x_x_invalid(variable) + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_x_x_invalid(data_array) + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_x_x_invalid(dataset) + + def apply_truncate_x_z(obj): + return apply_ufunc(truncate, obj, input_core_dims=[['x']], + output_core_dims=[['z']]) + + assert_identical(xr.Variable('z', array[:5]), + apply_truncate_x_z(variable)) + assert_identical(xr.DataArray(array[:5], dims=['z']), + apply_truncate_x_z(data_array)) + assert_identical(xr.Dataset({'y': ('z', array[:5])}), + apply_truncate_x_z(dataset)) + + def apply_truncate_x_x_valid(obj): + return apply_ufunc(truncate, obj, input_core_dims=[['x']], + output_core_dims=[['x']], exclude_dims={'x'}) + + assert_identical(xr.Variable('x', array[:5]), + apply_truncate_x_x_valid(variable)) + assert_identical(xr.DataArray(array[:5], dims=['x']), + apply_truncate_x_x_valid(data_array)) + assert_identical(xr.Dataset({'y': ('x', array[:5])}), + apply_truncate_x_x_valid(dataset)) + + +@pytest.mark.parametrize('use_dask', [True, False]) +def test_dot(use_dask): + if use_dask: + if not has_dask: + pytest.skip('test for dask.') + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + c = np.arange(5 * 60).reshape(5, 60) + da_a = xr.DataArray(a, dims=['a', 'b'], + coords={'a': np.linspace(0, 1, 30)}) + da_b = xr.DataArray(b, dims=['a', 'b', 'c'], + coords={'a': np.linspace(0, 1, 30)}) + da_c = xr.DataArray(c, dims=['c', 'e']) + if use_dask: + da_a = da_a.chunk({'a': 3}) + da_b = da_b.chunk({'a': 3}) + da_c = da_c.chunk({'c': 3}) + + actual = xr.dot(da_a, da_b, dims=['a', 'b']) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + actual = xr.dot(da_a, da_b) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + if use_dask: + import dask + if LooseVersion(dask.__version__) < LooseVersion('0.17.3'): + pytest.skip("needs dask.array.einsum") + + # for only a single array is passed without dims argument, just return + # as is + actual = xr.dot(da_a) + assert da_a.identical(actual) + + # test for variable + actual = xr.dot(da_a.variable, da_b.variable) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.data, type(da_a.variable.data)) + + if use_dask: + da_a = da_a.chunk({'a': 3}) + da_b = da_b.chunk({'a': 3}) + actual = xr.dot(da_a, da_b, dims=['b']) + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + actual = xr.dot(da_a, da_b, dims=['b']) + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + + actual = xr.dot(da_a, da_b, dims='b') + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + + actual = xr.dot(da_a, da_b, dims='a') + assert actual.dims == ('b', 'c') + assert (actual.data == np.einsum('ij,ijk->jk', a, b)).all() + + actual = xr.dot(da_a, da_b, dims='c') + assert actual.dims == ('a', 'b') + assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() + + actual = xr.dot(da_a, da_b, da_c, dims=['a', 'b']) + assert actual.dims == ('c', 'e') + assert (actual.data == np.einsum('ij,ijk,kl->kl ', a, b, c)).all() + + # should work with tuple + actual = xr.dot(da_a, da_b, dims=('c', )) + assert actual.dims == ('a', 'b') + assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() + + # default dims + actual = xr.dot(da_a, da_b, da_c) + assert actual.dims == ('e', ) + assert (actual.data == np.einsum('ij,ijk,kl->l ', a, b, c)).all() + + # 1 array summation + actual = xr.dot(da_a, dims='a') + assert actual.dims == ('b', ) + assert (actual.data == np.einsum('ij->j ', a)).all() + + # empty dim + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims='a') + assert actual.dims == ('b', ) + assert (actual.data == np.zeros(actual.shape)).all() + + # Invalid cases + if not use_dask or LooseVersion(dask.__version__) > LooseVersion('0.17.4'): + with pytest.raises(TypeError): + xr.dot(da_a, dims='a', invalid=None) + with pytest.raises(TypeError): + xr.dot(da_a.to_dataset(name='da'), dims='a') + with pytest.raises(TypeError): + xr.dot(dims='a') + + # einsum parameters + actual = xr.dot(da_a, da_b, dims=['b'], order='C') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + assert actual.values.flags['C_CONTIGUOUS'] + assert not actual.values.flags['F_CONTIGUOUS'] + actual = xr.dot(da_a, da_b, dims=['b'], order='F') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + # dask converts Fortran arrays to C order when merging the final array + if not use_dask: + assert not actual.values.flags['C_CONTIGUOUS'] + assert actual.values.flags['F_CONTIGUOUS'] + + # einsum has a constant string as of the first parameter, which makes + # it hard to pass to xarray.apply_ufunc. + # make sure dot() uses functools.partial(einsum, subscripts), which + # can be pickled, and not a lambda, which can't. + pickle.loads(pickle.dumps(xr.dot(da_a))) + + def test_where(): cond = xr.DataArray([True, False], dims='x') actual = xr.where(cond, 1, 0) diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 6a509368017..5fa518f5112 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -1,143 +1,27 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import contextlib import warnings + import numpy as np import pandas as pd import pytest -from xarray import conventions, Variable, Dataset, open_dataset -from xarray.core import utils, indexing -from xarray.testing import assert_identical -from . import ( - TestCase, requires_netCDF4, unittest, raises_regex, IndexerMaker, - assert_array_equal) -from .test_backends import CFEncodedDataTest -from xarray.core.pycompat import iteritems -from xarray.backends.memory import InMemoryDataStore +from xarray import ( + Dataset, SerializationWarning, Variable, coding, conventions, open_dataset) from xarray.backends.common import WritableCFDataStore +from xarray.backends.memory import InMemoryDataStore from xarray.conventions import decode_cf +from xarray.testing import assert_identical + +from . import ( + assert_array_equal, raises_regex, requires_cftime_or_netCDF4, + requires_dask, requires_netCDF4) +from .test_backends import CFEncodedBase -B = IndexerMaker(indexing.BasicIndexer) -V = IndexerMaker(indexing.VectorizedIndexer) - - -class TestStackedBytesArray(TestCase): - def test_wrapper_class(self): - array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']], dtype='S') - actual = conventions.StackedBytesArray(array) - expected = np.array([b'abc', b'def'], dtype='S') - assert actual.dtype == expected.dtype - assert actual.shape == expected.shape - assert actual.size == expected.size - assert actual.ndim == expected.ndim - assert len(actual) == len(expected) - assert_array_equal(expected, actual) - assert_array_equal(expected[:1], actual[B[:1]]) - with pytest.raises(IndexError): - actual[B[:, :2]] - - def test_scalar(self): - array = np.array([b'a', b'b', b'c'], dtype='S') - actual = conventions.StackedBytesArray(array) - - expected = np.array(b'abc') - assert actual.dtype == expected.dtype - assert actual.shape == expected.shape - assert actual.size == expected.size - assert actual.ndim == expected.ndim - with pytest.raises(TypeError): - len(actual) - np.testing.assert_array_equal(expected, actual) - with pytest.raises(IndexError): - actual[B[:2]] - assert str(actual) == str(expected) - - def test_char_to_bytes(self): - array = np.array([['a', 'b', 'c'], ['d', 'e', 'f']]) - expected = np.array(['abc', 'def']) - actual = conventions.char_to_bytes(array) - assert_array_equal(actual, expected) - - expected = np.array(['ad', 'be', 'cf']) - actual = conventions.char_to_bytes(array.T) # non-contiguous - assert_array_equal(actual, expected) - - def test_char_to_bytes_ndim_zero(self): - expected = np.array('a') - actual = conventions.char_to_bytes(expected) - assert_array_equal(actual, expected) - - def test_char_to_bytes_size_zero(self): - array = np.zeros((3, 0), dtype='S1') - expected = np.array([b'', b'', b'']) - actual = conventions.char_to_bytes(array) - assert_array_equal(actual, expected) - - def test_bytes_to_char(self): - array = np.array([['ab', 'cd'], ['ef', 'gh']]) - expected = np.array([[['a', 'b'], ['c', 'd']], - [['e', 'f'], ['g', 'h']]]) - actual = conventions.bytes_to_char(array) - assert_array_equal(actual, expected) - - expected = np.array([[['a', 'b'], ['e', 'f']], - [['c', 'd'], ['g', 'h']]]) - actual = conventions.bytes_to_char(array.T) - assert_array_equal(actual, expected) - - def test_vectorized_indexing(self): - array = np.array([[b'a', b'b', b'c'], [b'd', b'e', b'f']], dtype='S') - stacked = conventions.StackedBytesArray(array) - expected = np.array([[b'abc', b'def'], [b'def', b'abc']]) - indexer = V[np.array([[0, 1], [1, 0]])] - actual = stacked[indexer] - assert_array_equal(actual, expected) - - -class TestBytesToStringArray(TestCase): - - def test_encoding(self): - encoding = 'utf-8' - raw_array = np.array([b'abc', u'ß∂µ∆'.encode(encoding)]) - actual = conventions.BytesToStringArray(raw_array, encoding=encoding) - expected = np.array([u'abc', u'ß∂µ∆'], dtype=object) - - assert actual.dtype == expected.dtype - assert actual.shape == expected.shape - assert actual.size == expected.size - assert actual.ndim == expected.ndim - assert_array_equal(expected, actual) - assert_array_equal(expected[0], actual[B[0]]) - - def test_scalar(self): - expected = np.array(u'abc', dtype=object) - actual = conventions.BytesToStringArray( - np.array(b'abc'), encoding='utf-8') - assert actual.dtype == expected.dtype - assert actual.shape == expected.shape - assert actual.size == expected.size - assert actual.ndim == expected.ndim - with pytest.raises(TypeError): - len(actual) - np.testing.assert_array_equal(expected, actual) - with pytest.raises(IndexError): - actual[B[:2]] - assert str(actual) == str(expected) - - def test_decode_bytes_array(self): - encoding = 'utf-8' - raw_array = np.array([b'abc', u'ß∂µ∆'.encode(encoding)]) - expected = np.array([u'abc', u'ß∂µ∆'], dtype=object) - actual = conventions.decode_bytes_array(raw_array, encoding) - np.testing.assert_array_equal(actual, expected) - - -class TestBoolTypeArray(TestCase): +class TestBoolTypeArray(object): def test_booltype_array(self): x = np.array([1, 0, 1, 1, 0], dtype='i1') bx = conventions.BoolTypeArray(x) @@ -146,7 +30,7 @@ def test_booltype_array(self): dtype=np.bool)) -class TestNativeEndiannessArray(TestCase): +class TestNativeEndiannessArray(object): def test(self): x = np.arange(5, dtype='>i8') expected = np.arange(5, dtype='int64') @@ -157,12 +41,15 @@ def test(self): def test_decode_cf_with_conflicting_fill_missing_value(): - var = Variable(['t'], np.arange(10), + expected = Variable(['t'], [np.nan, np.nan, 2], {'units': 'foobar'}) + var = Variable(['t'], np.arange(3), {'units': 'foobar', 'missing_value': 0, '_FillValue': 1}) - with raises_regex(ValueError, "_FillValue and missing_value"): - conventions.decode_cf_variable('t', var) + with warnings.catch_warnings(record=True) as w: + actual = conventions.decode_cf_variable('t', var) + assert_identical(actual, expected) + assert 'has multiple fill' in str(w[0].message) expected = Variable(['t'], np.arange(10), {'units': 'foobar'}) @@ -181,8 +68,8 @@ def test_decode_cf_with_conflicting_fill_missing_value(): assert_identical(actual, expected) -@requires_netCDF4 -class TestEncodeCFVariable(TestCase): +@requires_cftime_or_netCDF4 +class TestEncodeCFVariable(object): def test_incompatible_attributes(self): invalid_vars = [ Variable(['t'], pd.date_range('2000-01-01', periods=3), @@ -236,9 +123,18 @@ def test_multidimensional_coordinates(self): # Should not have any global coordinates. assert 'coordinates' not in attrs + @requires_dask + def test_string_object_warning(self): + original = Variable( + ('x',), np.array([u'foo', u'bar'], dtype=object)).chunk() + with pytest.warns(SerializationWarning, + match='dask array with dtype=object'): + encoded = conventions.encode_cf_variable(original) + assert_identical(original, encoded) -@requires_netCDF4 -class TestDecodeCF(TestCase): + +@requires_cftime_or_netCDF4 +class TestDecodeCF(object): def test_dataset(self): original = Dataset({ 't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}), @@ -303,7 +199,7 @@ def test_invalid_time_units_raises_eagerly(self): with raises_regex(ValueError, 'unable to decode time'): decode_cf(ds) - @requires_netCDF4 + @requires_cftime_or_netCDF4 def test_dataset_repr_with_netcdf4_datetimes(self): # regression test for #347 attrs = {'units': 'days since 0001-01-01', 'calendar': 'noleap'} @@ -316,50 +212,50 @@ def test_dataset_repr_with_netcdf4_datetimes(self): ds = decode_cf(Dataset({'time': ('time', [0, 1], attrs)})) assert '(time) datetime64[ns]' in repr(ds) - @requires_netCDF4 + @requires_cftime_or_netCDF4 def test_decode_cf_datetime_transition_to_invalid(self): # manually create dataset with not-decoded date from datetime import datetime ds = Dataset(coords={'time': [0, 266 * 365]}) units = 'days since 2000-01-01 00:00:00' ds.time.attrs = dict(units=units) - ds_decoded = conventions.decode_cf(ds) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'unable to decode time') + ds_decoded = conventions.decode_cf(ds) expected = [datetime(2000, 1, 1, 0, 0), datetime(2265, 10, 28, 0, 0)] assert_array_equal(ds_decoded.time.values, expected) - -class CFEncodedInMemoryStore(WritableCFDataStore, InMemoryDataStore): - pass - - -class NullWrapper(utils.NDArrayMixin): - """ - Just for testing, this lets us create a numpy array directly - but make it look like its not in memory yet. - """ - - def __init__(self, array): - self.array = array - - def __getitem__(self, key): - return self.array[indexing.orthogonal_indexer(key, self.shape)] + @requires_dask + def test_decode_cf_with_dask(self): + import dask.array as da + original = Dataset({ + 't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}), + 'foo': ('t', [0, 0, 0], {'coordinates': 'y', 'units': 'bar'}), + 'bar': ('string2', [b'a', b'b']), + 'baz': (('x'), [b'abc'], {'_Encoding': 'utf-8'}), + 'y': ('t', [5, 10, -999], {'_FillValue': -999}) + }).chunk() + decoded = conventions.decode_cf(original) + print(decoded) + assert all(isinstance(var.data, da.Array) + for name, var in decoded.variables.items() + if name not in decoded.indexes) + assert_identical(decoded, conventions.decode_cf(original).compute()) -def null_wrap(ds): - """ - Given a data store this wraps each variable in a NullWrapper so that - it appears to be out of memory. - """ - variables = dict((k, Variable(v.dims, NullWrapper(v.values), v.attrs)) - for k, v in iteritems(ds)) - return InMemoryDataStore(variables=variables, attributes=ds.attrs) +class CFEncodedInMemoryStore(WritableCFDataStore, InMemoryDataStore): + def encode_variable(self, var): + """encode one variable""" + coder = coding.strings.EncodedStringCoder(allows_unicode=True) + var = coder.encode(var) + return var @requires_netCDF4 -class TestCFEncodedDataStore(CFEncodedDataTest, TestCase): +class TestCFEncodedDataStore(CFEncodedBase): @contextlib.contextmanager def create_store(self): yield CFEncodedInMemoryStore() @@ -371,12 +267,20 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, data.dump_to_store(store, **save_kwargs) yield open_dataset(store, **open_kwargs) + @pytest.mark.skip('cannot roundtrip coordinates yet for ' + 'CFEncodedInMemoryStore') def test_roundtrip_coordinates(self): - raise unittest.SkipTest('cannot roundtrip coordinates yet for ' - 'CFEncodedInMemoryStore') + pass def test_invalid_dataarray_names_raise(self): + # only relevant for on-disk file formats pass def test_encoding_kwarg(self): + # we haven't bothered to raise errors yet for unexpected encodings in + # this test dummy + pass + + def test_encoding_kwarg_fixed_width_string(self): + # CFEncodedInMemoryStore doesn't support explicit string encodings. pass diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 7833a43a894..62ce7d074fa 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1,34 +1,36 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import pickle +from distutils.version import LooseVersion from textwrap import dedent -from distutils.version import LooseVersion import numpy as np import pandas as pd import pytest import xarray as xr -from xarray import Variable, DataArray, Dataset import xarray.ufuncs as xu -from xarray.core.pycompat import suppress, OrderedDict -from . import ( - TestCase, assert_frame_equal, raises_regex, assert_equal, assert_identical, - assert_array_equal, assert_allclose) - +from xarray import DataArray, Dataset, Variable +from xarray.core.pycompat import OrderedDict, suppress from xarray.tests import mock +from . import ( + assert_allclose, assert_array_equal, assert_equal, assert_frame_equal, + assert_identical, raises_regex) + dask = pytest.importorskip('dask') -import dask.array as da # noqa: E402 # allow importorskip call above this -import dask.dataframe as dd # noqa: E402 +da = pytest.importorskip('dask.array') +dd = pytest.importorskip('dask.dataframe') -class DaskTestCase(TestCase): +class DaskTestCase(object): def assertLazyAnd(self, expected, actual, test): - with dask.set_options(get=dask.get): + + with (dask.config.set(scheduler='single-threaded') + if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') + else dask.set_options(get=dask.get)): test(actual, expected) + if isinstance(actual, Dataset): for k, v in actual.variables.items(): if k in actual.dims: @@ -55,6 +57,7 @@ def assertLazyAndIdentical(self, expected, actual): def assertLazyAndAllClose(self, expected, actual): self.assertLazyAnd(expected, actual, assert_allclose) + @pytest.fixture(autouse=True) def setUp(self): self.values = np.random.RandomState(0).randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) @@ -199,19 +202,19 @@ def test_missing_methods(self): except NotImplementedError as err: assert 'dask' in str(err) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_univariate_ufunc(self): u = self.eager_var v = self.lazy_var self.assertLazyAndAllClose(np.sin(u), xu.sin(v)) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_bivariate_ufunc(self): u = self.eager_var v = self.lazy_var self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0)) self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v)) - @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', - reason='Need dask 0.16 for new interface') def test_compute(self): u = self.eager_var v = self.lazy_var @@ -222,8 +225,6 @@ def test_compute(self): assert ((u + 1).data == v2.data).all() - @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', - reason='Need dask 0.16 for new interface') def test_persist(self): u = self.eager_var v = self.lazy_var + 1 @@ -249,6 +250,7 @@ def assertLazyAndAllClose(self, expected, actual): def assertLazyAndEqual(self, expected, actual): self.assertLazyAnd(expected, actual, assert_equal) + @pytest.fixture(autouse=True) def setUp(self): self.values = np.random.randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) @@ -283,8 +285,6 @@ def test_lazy_array(self): actual = xr.concat([v[:2], v[2:]], 'x') self.assertLazyAndAllClose(u, actual) - @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', - reason='Need dask 0.16 for new interface') def test_compute(self): u = self.eager_array v = self.lazy_array @@ -295,8 +295,6 @@ def test_compute(self): assert ((u + 1).data == v2.data).all() - @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', - reason='Need dask 0.16 for new interface') def test_persist(self): u = self.eager_array v = self.lazy_array + 1 @@ -386,15 +384,11 @@ def test_concat_loads_variables(self): assert ds3['c'].data is c3 def test_groupby(self): - if LooseVersion(dask.__version__) == LooseVersion('0.15.3'): - pytest.xfail('upstream bug in dask: ' - 'https://github.com/dask/dask/issues/2718') - u = self.eager_array v = self.lazy_array - expected = u.groupby('x').mean() - actual = v.groupby('x').mean() + expected = u.groupby('x').mean(xr.ALL_DIMS) + actual = v.groupby('x').mean(xr.ALL_DIMS) self.assertLazyAndAllClose(expected, actual) def test_groupby_first(self): @@ -436,6 +430,7 @@ def duplicate_and_merge(array): actual = duplicate_and_merge(self.lazy_array) self.assertLazyAndEqual(expected, actual) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_ufuncs(self): u = self.eager_array v = self.lazy_array @@ -461,8 +456,11 @@ def counting_get(*args, **kwargs): count[0] += 1 return dask.get(*args, **kwargs) - with dask.set_options(get=counting_get): - ds.load() + if dask.__version__ < '0.19.4': + ds.load(get=counting_get) + else: + ds.load(scheduler=counting_get) + assert count[0] == 1 def test_stack(self): @@ -589,7 +587,7 @@ def test_from_dask_variable(self): self.assertLazyAndIdentical(self.lazy_array, a) -class TestToDaskDataFrame(TestCase): +class TestToDaskDataFrame(object): def test_to_dask_dataframe(self): # Test conversion of Datasets to dask DataFrames @@ -782,12 +780,8 @@ def build_dask_array(name): # test both the perist method and the dask.persist function # the dask.persist function requires a new version of dask -@pytest.mark.parametrize('persist', [ - lambda x: x.persist(), - pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', - lambda x: dask.persist(x)[0], - reason='Need Dask 0.16') -]) +@pytest.mark.parametrize('persist', [lambda x: x.persist(), + lambda x: dask.persist(x)[0]]) def test_persist_Dataset(persist): ds = Dataset({'foo': ('x', range(5)), 'bar': ('x', range(5))}).chunk() @@ -800,12 +794,8 @@ def test_persist_Dataset(persist): assert len(ds.foo.data.dask) == n # doesn't mutate in place -@pytest.mark.parametrize('persist', [ - lambda x: x.persist(), - pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', - lambda x: dask.persist(x)[0], - reason='Need Dask 0.16') -]) +@pytest.mark.parametrize('persist', [lambda x: x.persist(), + lambda x: dask.persist(x)[0]]) def test_persist_DataArray(persist): x = da.arange(10, chunks=(5,)) y = DataArray(x) @@ -818,8 +808,6 @@ def test_persist_DataArray(persist): assert len(zz.data.dask) == zz.data.npartitions -@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', - reason='Need dask 0.16 for new interface') def test_dataarray_with_dask_coords(): import toolz x = xr.Variable('x', da.arange(8, chunks=(4,))) @@ -847,7 +835,11 @@ def test_basic_compute(): dask.multiprocessing.get, dask.local.get_sync, None]: - with dask.set_options(get=get): + with (dask.config.set(scheduler=get) + if LooseVersion(dask.__version__) >= LooseVersion('0.19.4') + else dask.config.set(scheduler=get) + if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') + else dask.set_options(get=get)): ds.compute() ds.foo.compute() ds.foo.variable.compute() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 39b3109c295..87ee60715a1 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1,29 +1,31 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import numpy as np -import pandas as pd +from __future__ import absolute_import, division, print_function + import pickle -import pytest +import warnings from copy import deepcopy from textwrap import dedent -from distutils.version import LooseVersion -import xarray as xr +import numpy as np +import pandas as pd +import pytest -from xarray import (align, broadcast, Dataset, DataArray, - IndexVariable, Variable) -from xarray.coding.times import CFDatetimeCoder -from xarray.core.pycompat import iteritems, OrderedDict -from xarray.core.common import full_like +import xarray as xr +from xarray import ( + DataArray, Dataset, IndexVariable, Variable, align, broadcast, set_options) +from xarray.coding.times import CFDatetimeCoder, _import_cftime +from xarray.convert import from_cdms2 +from xarray.core.common import ALL_DIMS, full_like +from xarray.core.pycompat import OrderedDict, iteritems from xarray.tests import ( - TestCase, ReturnItem, source_ndarray, unittest, requires_dask, - assert_identical, assert_equal, assert_allclose, assert_array_equal, - raises_regex, requires_scipy, requires_bottleneck) + LooseVersion, ReturnItem, assert_allclose, assert_array_equal, + assert_equal, assert_identical, raises_regex, requires_bottleneck, + requires_cftime, requires_dask, requires_iris, requires_np113, + requires_scipy, source_ndarray) -class TestDataArray(TestCase): - def setUp(self): +class TestDataArray(object): + @pytest.fixture(autouse=True) + def setup(self): self.attrs = {'attr1': 'value1', 'attr2': 2929} self.x = np.random.random((10, 20)) self.v = Variable(['x', 'y'], self.x) @@ -322,11 +324,14 @@ def test_constructor_from_self_described(self): actual = DataArray(series) assert_equal(expected[0].reset_coords('x', drop=True), actual) - panel = pd.Panel({0: frame}) - actual = DataArray(panel) - expected = DataArray([data], expected.coords, ['dim_0', 'x', 'y']) - expected['dim_0'] = [0] - assert_identical(expected, actual) + if hasattr(pd, 'Panel'): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', r'\W*Panel is deprecated') + panel = pd.Panel({0: frame}) + actual = DataArray(panel) + expected = DataArray([data], expected.coords, ['dim_0', 'x', 'y']) + expected['dim_0'] = [0] + assert_identical(expected, actual) expected = DataArray(data, coords={'x': ['a', 'b'], 'y': [-1, -2], @@ -436,7 +441,7 @@ def test_getitem(self): assert_identical(self.ds['x'], x) assert_identical(self.ds['y'], y) - I = ReturnItem() # noqa: E741 # allow ambiguous name + I = ReturnItem() # noqa for i in [I[:], I[...], I[x.values], I[x.variable], I[x], I[x, y], I[x.values > -1], I[x.variable > -1], I[x > -1], I[x > -1, y > -1]]: @@ -613,9 +618,9 @@ def get_data(): da[dict(x=ind)] = value # should not raise def test_contains(self): - data_array = DataArray(1, coords={'x': 2}) - with pytest.warns(FutureWarning): - assert 'x' in data_array + data_array = DataArray([1, 2]) + assert 1 in data_array + assert 3 not in data_array def test_attr_sources_multiindex(self): # make sure attr-style access for multi-index levels @@ -668,6 +673,7 @@ def test_isel_types(self): assert_identical(da.isel(x=np.array([0], dtype="int64")), da.isel(x=np.array([0]))) + @pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_isel_fancy(self): shape = (10, 7, 6) np_array = np.random.random(shape) @@ -780,6 +786,22 @@ def test_sel_dataarray(self): assert 'new_dim' in actual.coords assert_equal(actual['new_dim'].drop('x'), ind['new_dim']) + def test_sel_invalid_slice(self): + array = DataArray(np.arange(10), [('x', np.arange(10))]) + with raises_regex(ValueError, 'cannot use non-scalar arrays'): + array.sel(x=slice(array.x)) + + def test_sel_dataarray_datetime(self): + # regression test for GH1240 + times = pd.date_range('2000-01-01', freq='D', periods=365) + array = DataArray(np.arange(365), [('time', times)]) + result = array.sel(time=slice(array.time[0], array.time[-1])) + assert_equal(result, array) + + array = DataArray(np.arange(365), [('delta', times - times[0])]) + result = array.sel(delta=slice(array.delta[0], array.delta[-1])) + assert_equal(result, array) + def test_sel_no_index(self): array = DataArray(np.arange(10), dims='x') assert_identical(array[0], array.sel(x=0)) @@ -825,6 +847,7 @@ def test_isel_drop(self): selected = data.isel(x=0, drop=False) assert_identical(expected, selected) + @pytest.mark.filterwarnings("ignore:Dataset.isel_points") def test_isel_points(self): shape = (10, 5, 6) np_array = np.random.random(shape) @@ -979,7 +1002,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, assert da.dims[0] == renamed_dim da = da.rename({renamed_dim: 'x'}) assert_identical(da.variable, expected_da.variable) - self.assertVariableNotEqual(da['x'], expected_da['x']) + assert not da['x'].equals(expected_da['x']) test_sel(('a', 1, -1), 0) test_sel(('b', 2, -2), -1) @@ -1132,7 +1155,7 @@ def test_reset_coords(self): assert_identical(actual, expected) actual = data.copy() - actual.reset_coords(drop=True, inplace=True) + actual = actual.reset_coords(drop=True) assert_identical(actual, expected) actual = data.reset_coords('bar', drop=True) @@ -1141,8 +1164,9 @@ def test_reset_coords(self): dims=['x', 'y'], name='foo') assert_identical(actual, expected) - with raises_regex(ValueError, 'cannot reset coord'): - data.reset_coords(inplace=True) + with pytest.warns(FutureWarning, message='The inplace argument'): + with raises_regex(ValueError, 'cannot reset coord'): + data = data.reset_coords(inplace=True) with raises_regex(ValueError, 'cannot be found'): data.reset_coords('foo', drop=True) with raises_regex(ValueError, 'cannot be found'): @@ -1165,6 +1189,13 @@ def test_assign_coords(self): with raises_regex(ValueError, 'conflicting MultiIndex'): self.mda.assign_coords(level_1=range(4)) + # GH: 2112 + da = xr.DataArray([0, 1, 2], dims='x') + with pytest.raises(ValueError): + da['x'] = [0, 1, 2, 3] # size conflict + with pytest.raises(ValueError): + da.coords['x'] = [0, 1, 2, 3] # size conflict + def test_coords_alignment(self): lhs = DataArray([1, 2, 3], [('x', [0, 1, 2])]) rhs = DataArray([2, 3, 4], [('x', [1, 2, 3])]) @@ -1210,6 +1241,7 @@ def test_reindex_like_no_index(self): ValueError, 'different size for unlabeled'): foo.reindex_like(bar) + @pytest.mark.filterwarnings('ignore:Indexer has dimensions') def test_reindex_regressions(self): # regression test for #279 expected = DataArray(np.random.randn(5), coords=[("time", range(5))]) @@ -1248,6 +1280,9 @@ def test_rename(self): assert renamed.name == 'z' assert renamed.dims == ('z',) + renamed_kwargs = self.dv.x.rename(x='z').rename('z') + assert_identical(renamed, renamed_kwargs) + def test_swap_dims(self): array = DataArray(np.random.randn(3), {'y': ('x', list('abc'))}, 'x') expected = DataArray(array.values, {'y': list('abc')}, dims='y') @@ -1256,7 +1291,7 @@ def test_swap_dims(self): def test_expand_dims_error(self): array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3.0)}, + coords={'x': np.linspace(0.0, 1.0, 3)}, attrs={'key': 'entry'}) with raises_regex(ValueError, 'dim should be str or'): @@ -1364,7 +1399,7 @@ def test_set_index(self): expected = array.set_index(x=['level_1', 'level_2', 'level_3']) assert_identical(obj, expected) - array.set_index(x=['level_1', 'level_2', 'level_3'], inplace=True) + array = array.set_index(x=['level_1', 'level_2', 'level_3']) assert_identical(array, expected) array2d = DataArray(np.random.rand(2, 2), @@ -1397,7 +1432,7 @@ def test_reset_index(self): assert_identical(obj, expected) array = self.mda.copy() - array.reset_index(['x'], drop=True, inplace=True) + array = array.reset_index(['x'], drop=True) assert_identical(array, expected) # single index @@ -1413,9 +1448,10 @@ def test_reorder_levels(self): obj = self.mda.reorder_levels(x=['level_2', 'level_1']) assert_identical(obj, expected) - array = self.mda.copy() - array.reorder_levels(x=['level_2', 'level_1'], inplace=True) - assert_identical(array, expected) + with pytest.warns(FutureWarning, message='The inplace argument'): + array = self.mda.copy() + array.reorder_levels(x=['level_2', 'level_1'], inplace=True) + assert_identical(array, expected) array = DataArray([1, 2], dims='x') with pytest.raises(KeyError): @@ -1630,9 +1666,23 @@ def test_dataset_math(self): def test_stack_unstack(self): orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], attrs={'foo': 2}) + assert_identical(orig, orig.unstack()) + actual = orig.stack(z=['x', 'y']).unstack('z').drop(['x', 'y']) assert_identical(orig, actual) + dims = ['a', 'b', 'c', 'd', 'e'] + orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims) + stacked = orig.stack(ab=['a', 'b'], cd=['c', 'd']) + + unstacked = stacked.unstack(['ab', 'cd']) + roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + assert_identical(orig, roundtripped) + + unstacked = stacked.unstack() + roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + assert_identical(orig, roundtripped) + def test_stack_unstack_decreasing_coordinate(self): # regression test for GH980 orig = DataArray(np.random.rand(3, 4), dims=('y', 'x'), @@ -1651,6 +1701,13 @@ def test_unstack_pandas_consistency(self): actual = DataArray(s, dims='z').unstack('z') assert_identical(expected, actual) + def test_stack_nonunique_consistency(self): + orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], + coords={'x': [0, 1], 'y': [0, 0]}) + actual = orig.stack(z=['x', 'y']) + expected = DataArray(orig.to_pandas().stack(), dims='z') + assert_identical(expected, actual) + def test_transpose(self): assert_equal(self.dv.variable.transpose(), self.dv.transpose().variable) @@ -1668,6 +1725,21 @@ def test_squeeze_drop(self): actual = array.squeeze(drop=False) assert_identical(expected, actual) + array = DataArray([[[0., 1.]]], dims=['dim_0', 'dim_1', 'dim_2']) + expected = DataArray([[0., 1.]], dims=['dim_1', 'dim_2']) + actual = array.squeeze(axis=0) + assert_identical(expected, actual) + + array = DataArray([[[[0., 1.]]]], dims=[ + 'dim_0', 'dim_1', 'dim_2', 'dim_3']) + expected = DataArray([[0., 1.]], dims=['dim_1', 'dim_3']) + actual = array.squeeze(axis=(0, 2)) + assert_identical(expected, actual) + + array = DataArray([[[0., 1.]]], dims=['dim_0', 'dim_1', 'dim_2']) + with pytest.raises(ValueError): + array.squeeze(axis=0, dim='dim_1') + def test_drop_coordinates(self): expected = DataArray(np.random.randn(2, 3), dims=['x', 'y']) arr = expected.copy() @@ -1717,6 +1789,12 @@ def test_where(self): actual = arr.where(arr.x < 2, drop=True) assert_identical(actual, expected) + def test_where_string(self): + array = DataArray(['a', 'b']) + expected = DataArray(np.array(['a', np.nan], dtype=object)) + actual = array.where([True, False]) + assert_identical(actual, expected) + def test_cumops(self): coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), @@ -1724,6 +1802,11 @@ def test_cumops(self): orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + actual = orig.cumsum() + expected = DataArray([[-1, -1, 0], [-4, -4, 0]], coords, + dims=['x', 'y']) + assert_identical(expected, actual) + actual = orig.cumsum('x') expected = DataArray([[-1, 0, 1], [-4, 0, 4]], coords, dims=['x', 'y']) @@ -1920,15 +2003,15 @@ def test_groupby_sum(self): self.x[:, 10:].sum(), self.x[:, 9:10].sum()]).T), 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] - assert_allclose(expected_sum_all, grouped.reduce(np.sum)) - assert_allclose(expected_sum_all, grouped.sum()) + assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=ALL_DIMS)) + assert_allclose(expected_sum_all, grouped.sum(ALL_DIMS)) expected = DataArray([array['y'].values[idx].sum() for idx in [slice(9), slice(10, None), slice(9, 10)]], [['a', 'b', 'c']], ['abc']) actual = array['y'].groupby('abc').apply(np.sum) assert_allclose(expected, actual) - actual = array['y'].groupby('abc').sum() + actual = array['y'].groupby('abc').sum(ALL_DIMS) assert_allclose(expected, actual) expected_sum_axis1 = Dataset( @@ -1939,6 +2022,27 @@ def test_groupby_sum(self): assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, 'y')) assert_allclose(expected_sum_axis1, grouped.sum('y')) + def test_groupby_warning(self): + array = self.make_groupby_example_array() + grouped = array.groupby('y') + with pytest.warns(FutureWarning): + grouped.sum() + + @pytest.mark.skipif(LooseVersion(xr.__version__) < LooseVersion('0.12'), + reason="not to forget the behavior change") + def test_groupby_sum_default(self): + array = self.make_groupby_example_array() + grouped = array.groupby('abc') + + expected_sum_all = Dataset( + {'foo': Variable(['x', 'abc'], + np.array([self.x[:, :9].sum(axis=-1), + self.x[:, 10:].sum(axis=-1), + self.x[:, 9:10].sum(axis=-1)]).T), + 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] + + assert_allclose(expected_sum_all, grouped.sum()) + def test_groupby_count(self): array = DataArray( [0, 0, np.nan, np.nan, 0, 0], @@ -1948,7 +2052,7 @@ def test_groupby_count(self): expected = DataArray([1, 1, 2], coords=[('cat', ['a', 'b', 'c'])]) assert_identical(actual, expected) - @unittest.skip('needs to be fixed for shortcut=False, keep_attrs=False') + @pytest.mark.skip('needs to be fixed for shortcut=False, keep_attrs=False') def test_groupby_reduce_attrs(self): array = self.make_groupby_example_array() array.attrs['foo'] = 'bar' @@ -2010,7 +2114,7 @@ def test_groupby_math(self): actual = array.coords['x'] + grouped assert_identical(expected, actual) - ds = array.coords['x'].to_dataset('X') + ds = array.coords['x'].to_dataset(name='X') expected = array + ds actual = grouped + ds assert_identical(expected, actual) @@ -2019,9 +2123,9 @@ def test_groupby_math(self): assert_identical(expected, actual) grouped = array.groupby('abc') - expected_agg = (grouped.mean() - np.arange(3)).rename(None) + expected_agg = (grouped.mean(ALL_DIMS) - np.arange(3)).rename(None) actual = grouped - DataArray(range(3), [('abc', ['a', 'b', 'c'])]) - actual_agg = actual.groupby('abc').mean() + actual_agg = actual.groupby('abc').mean(ALL_DIMS) assert_allclose(expected_agg, actual_agg) with raises_regex(TypeError, 'only support binary ops'): @@ -2095,7 +2199,7 @@ def test_groupby_multidim(self): ('lon', DataArray([5, 28, 23], coords=[('lon', [30., 40., 50.])])), ('lat', DataArray([16, 40], coords=[('lat', [10., 20.])]))]: - actual_sum = array.groupby(dim).sum() + actual_sum = array.groupby(dim).sum(ALL_DIMS) assert_identical(expected_sum, actual_sum) def test_groupby_multidim_apply(self): @@ -2172,6 +2276,16 @@ def test_resample(self): with raises_regex(ValueError, 'index must be monotonic'): array[[2, 0, 1]].resample(time='1D') + @requires_cftime + def test_resample_cftimeindex(self): + cftime = _import_cftime() + times = cftime.num2date(np.arange(12), units='hours since 0001-01-01', + calendar='noleap') + array = DataArray(np.arange(12), [('time', times)]) + + with raises_regex(NotImplementedError, 'to_datetimeindex'): + array.resample(time='6H').mean() + def test_resample_first(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) array = DataArray(np.arange(10), [('time', times)]) @@ -2244,53 +2358,24 @@ def test_resample_drop_nondim_coords(self): actual = array.resample(time="1H").interpolate('linear') assert 'tc' not in actual.coords - def test_resample_old_vs_new_api(self): + def test_resample_keep_attrs(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) array = DataArray(np.ones(10), [('time', times)]) + array.attrs['meta'] = 'data' - # Simple mean - with pytest.warns(DeprecationWarning): - old_mean = array.resample('1D', 'time', how='mean') - new_mean = array.resample(time='1D').mean() - assert_identical(old_mean, new_mean) - - # Mean, while keeping attributes - attr_array = array.copy() - attr_array.attrs['meta'] = 'data' - - with pytest.warns(DeprecationWarning): - old_mean = attr_array.resample('1D', dim='time', how='mean', - keep_attrs=True) - new_mean = attr_array.resample(time='1D').mean(keep_attrs=True) - assert old_mean.attrs == new_mean.attrs - assert_identical(old_mean, new_mean) - - # Mean, with NaN to skip - nan_array = array.copy() - nan_array[1] = np.nan - - with pytest.warns(DeprecationWarning): - old_mean = nan_array.resample('1D', 'time', how='mean', - skipna=False) - new_mean = nan_array.resample(time='1D').mean(skipna=False) + result = array.resample(time='1D').mean(keep_attrs=True) + expected = DataArray([1, 1, 1], [('time', times[::4])], + attrs=array.attrs) + assert_identical(result, expected) + + def test_resample_skipna(self): + times = pd.date_range('2000-01-01', freq='6H', periods=10) + array = DataArray(np.ones(10), [('time', times)]) + array[1] = np.nan + + result = array.resample(time='1D').mean(skipna=False) expected = DataArray([np.nan, 1, 1], [('time', times[::4])]) - assert_identical(old_mean, expected) - assert_identical(new_mean, expected) - - # Try other common resampling methods - resampler = array.resample(time='1D') - for method in ['mean', 'median', 'sum', 'first', 'last', 'count']: - # Discard attributes on the call using the new api to match - # convention from old api - new_api = getattr(resampler, method)(keep_attrs=False) - with pytest.warns(DeprecationWarning): - old_api = array.resample('1D', dim='time', how=method) - assert_identical(new_api, old_api) - for method in [np.mean, np.sum, np.max, np.min]: - new_api = resampler.reduce(method) - with pytest.warns(DeprecationWarning): - old_api = array.resample('1D', dim='time', how=method) - assert_identical(new_api, old_api) + assert_identical(result, expected) def test_upsample(self): times = pd.date_range('2000-01-01', freq='6H', periods=5) @@ -2418,6 +2503,7 @@ def test_upsample_interpolate_regression_1605(self): assert_allclose(actual, expected, rtol=1e-16) @requires_dask + @requires_scipy def test_upsample_interpolate_dask(self): import dask.array as da @@ -2642,9 +2728,13 @@ def test_to_pandas(self): # roundtrips for shape in [(3,), (3, 4), (3, 4, 5)]: + if len(shape) > 2 and not hasattr(pd, 'Panel'): + continue dims = list('abc')[:len(shape)] da = DataArray(np.random.randn(*shape), dims=dims) - roundtripped = DataArray(da.to_pandas()).drop(dims) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', r'\W*Panel is deprecated') + roundtripped = DataArray(da.to_pandas()).drop(dims) assert_identical(da, roundtripped) with raises_regex(ValueError, 'cannot convert'): @@ -2707,9 +2797,9 @@ def test_to_and_from_series(self): def test_series_categorical_index(self): # regression test for GH700 if not hasattr(pd, 'CategoricalIndex'): - raise unittest.SkipTest('requires pandas with CategoricalIndex') + pytest.skip('requires pandas with CategoricalIndex') - s = pd.Series(range(5), index=pd.CategoricalIndex(list('aabbc'))) + s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list('aabbc'))) arr = DataArray(s) assert "'a'" in repr(arr) # should not error @@ -2835,7 +2925,8 @@ def test_to_masked_array(self): ma = da.to_masked_array() assert len(ma.mask) == N - def test_to_and_from_cdms2(self): + def test_to_and_from_cdms2_classic(self): + """Classic with 1D axes""" pytest.importorskip('cdms2') original = DataArray( @@ -2846,9 +2937,9 @@ def test_to_and_from_cdms2(self): expected_coords = [IndexVariable('distance', [-2, 2]), IndexVariable('time', [0, 1, 2])] actual = original.to_cdms2() - assert_array_equal(actual, original) + assert_array_equal(actual.asma(), original) assert actual.id == original.name - self.assertItemsEqual(actual.getAxisIds(), original.dims) + assert tuple(actual.getAxisIds()) == original.dims for axis, coord in zip(actual.getAxisList(), expected_coords): assert axis.id == coord.name assert_array_equal(axis, coord.values) @@ -2861,148 +2952,62 @@ def test_to_and_from_cdms2(self): roundtripped = DataArray.from_cdms2(actual) assert_identical(original, roundtripped) - def test_to_and_from_iris(self): - try: - import iris - import cf_units - except ImportError: - raise unittest.SkipTest('iris not installed') - - coord_dict = OrderedDict() - coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) - coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) - coord_dict['height'] = 10 - coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) - coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) - - original = DataArray(np.arange(6, dtype='float').reshape(2, 3), - coord_dict, name='Temperature', - attrs={'baz': 123, 'units': 'Kelvin', - 'standard_name': 'fire_temperature', - 'long_name': 'Fire Temperature'}, - dims=('distance', 'time')) - - # Set a bad value to test the masking logic - original.data[0, 2] = np.NaN + back = from_cdms2(actual) + assert original.dims == back.dims + assert original.coords.keys() == back.coords.keys() + for coord_name in original.coords.keys(): + assert_array_equal(original.coords[coord_name], + back.coords[coord_name]) - original.attrs['cell_methods'] = \ - 'height: mean (comment: A cell method)' - actual = original.to_iris() - assert_array_equal(actual.data, original.data) - assert actual.var_name == original.name - self.assertItemsEqual([d.var_name for d in actual.dim_coords], - original.dims) - assert (actual.cell_methods == (iris.coords.CellMethod( - method='mean', - coords=('height', ), - intervals=(), - comments=('A cell method', )), )) + def test_to_and_from_cdms2_sgrid(self): + """Curvilinear (structured) grid - for coord, orginal_key in zip((actual.coords()), original.coords): - original_coord = original.coords[orginal_key] - assert coord.var_name == original_coord.name - assert_array_equal( - coord.points, CFDatetimeCoder().encode(original_coord).values) - assert (actual.coord_dims(coord) == - original.get_axis_num( - original.coords[coord.var_name].dims)) - - assert (actual.coord('distance2').attributes['foo'] == - original.coords['distance2'].attrs['foo']) - assert (actual.coord('distance').units == - cf_units.Unit(original.coords['distance'].units)) - assert actual.attributes['baz'] == original.attrs['baz'] - assert actual.standard_name == original.attrs['standard_name'] - - roundtripped = DataArray.from_iris(actual) - assert_identical(original, roundtripped) - - actual.remove_coord('time') - auto_time_dimension = DataArray.from_iris(actual) - assert auto_time_dimension.dims == ('distance', 'dim_1') - - actual.coord('distance').var_name = None - with raises_regex(ValueError, 'no var_name attribute'): - DataArray.from_iris(actual) - - @requires_dask - def test_to_and_from_iris_dask(self): - import dask.array as da - try: - import iris - import cf_units - except ImportError: - raise unittest.SkipTest('iris not installed') - - coord_dict = OrderedDict() - coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) - coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) - coord_dict['height'] = 10 - coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) - coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) - - original = DataArray( - da.from_array(np.arange(-1, 5, dtype='float').reshape(2, 3), 3), - coord_dict, - name='Temperature', - attrs=dict(baz=123, units='Kelvin', - standard_name='fire_temperature', - long_name='Fire Temperature'), - dims=('distance', 'time')) - - # Set a bad value to test the masking logic - original.data = da.ma.masked_less(original.data, 0) - - original.attrs['cell_methods'] = \ - 'height: mean (comment: A cell method)' - actual = original.to_iris() - - # Be careful not to trigger the loading of the iris data - actual_data = actual.core_data() if \ - hasattr(actual, 'core_data') else actual.data - assert_array_equal(actual_data, original.data) - assert actual.var_name == original.name - self.assertItemsEqual([d.var_name for d in actual.dim_coords], - original.dims) - assert (actual.cell_methods == (iris.coords.CellMethod( - method='mean', - coords=('height', ), - intervals=(), - comments=('A cell method', )), )) - - for coord, orginal_key in zip((actual.coords()), original.coords): - original_coord = original.coords[orginal_key] - assert coord.var_name == original_coord.name - assert_array_equal( - coord.points, CFDatetimeCoder().encode(original_coord).values) - assert (actual.coord_dims(coord) == - original.get_axis_num( - original.coords[coord.var_name].dims)) - - assert (actual.coord('distance2').attributes['foo'] == original.coords[ - 'distance2'].attrs['foo']) - assert (actual.coord('distance').units == - cf_units.Unit(original.coords['distance'].units)) - assert actual.attributes['baz'] == original.attrs['baz'] - assert actual.standard_name == original.attrs['standard_name'] - - roundtripped = DataArray.from_iris(actual) - assert_identical(original, roundtripped) - - # If the Iris version supports it then we should get a dask array back - if hasattr(actual, 'core_data'): - pass - # TODO This currently fails due to the decoding loading - # the data (#1372) - # self.assertEqual(type(original.data), type(roundtripped.data)) + The rectangular grid case is covered by the classic case + """ + pytest.importorskip('cdms2') - actual.remove_coord('time') - auto_time_dimension = DataArray.from_iris(actual) - assert auto_time_dimension.dims == ('distance', 'dim_1') + lonlat = np.mgrid[:3, :4] + lon = DataArray(lonlat[1], dims=['y', 'x'], name='lon') + lat = DataArray(lonlat[0], dims=['y', 'x'], name='lat') + x = DataArray(np.arange(lon.shape[1]), dims=['x'], name='x') + y = DataArray(np.arange(lon.shape[0]), dims=['y'], name='y') + original = DataArray(lonlat.sum(axis=0), dims=['y', 'x'], + coords=OrderedDict(x=x, y=y, lon=lon, lat=lat), + name='sst') + actual = original.to_cdms2() + assert tuple(actual.getAxisIds()) == original.dims + assert_array_equal(original.coords['lon'], + actual.getLongitude().asma()) + assert_array_equal(original.coords['lat'], + actual.getLatitude().asma()) + + back = from_cdms2(actual) + assert original.dims == back.dims + assert set(original.coords.keys()) == set(back.coords.keys()) + assert_array_equal(original.coords['lat'], back.coords['lat']) + assert_array_equal(original.coords['lon'], back.coords['lon']) + + def test_to_and_from_cdms2_ugrid(self): + """Unstructured grid""" + pytest.importorskip('cdms2') - actual.coord('distance').var_name = None - with raises_regex(ValueError, 'no var_name attribute'): - DataArray.from_iris(actual) + lon = DataArray(np.random.uniform(size=5), dims=['cell'], name='lon') + lat = DataArray(np.random.uniform(size=5), dims=['cell'], name='lat') + cell = DataArray(np.arange(5), dims=['cell'], name='cell') + original = DataArray(np.arange(5), dims=['cell'], + coords={'lon': lon, 'lat': lat, 'cell': cell}) + actual = original.to_cdms2() + assert tuple(actual.getAxisIds()) == original.dims + assert_array_equal(original.coords['lon'], + actual.getLongitude().getValue()) + assert_array_equal(original.coords['lat'], + actual.getLatitude().getValue()) + + back = from_cdms2(actual) + assert set(original.dims) == set(back.dims) + assert set(original.coords.keys()) == set(back.coords.keys()) + assert_array_equal(original.coords['lat'], back.coords['lat']) + assert_array_equal(original.coords['lon'], back.coords['lon']) def test_to_dataset_whole(self): unnamed = DataArray([1, 2], dims='x') @@ -3093,24 +3098,51 @@ def test_coordinate_diff(self): actual = lon.diff('lon') assert_equal(expected, actual) - def test_shift(self): + @pytest.mark.parametrize('offset', [-5, -2, -1, 0, 1, 2, 5]) + def test_shift(self, offset): arr = DataArray([1, 2, 3], dims='x') actual = arr.shift(x=1) expected = DataArray([np.nan, 1, 2], dims='x') assert_identical(expected, actual) arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])]) - for offset in [-5, -2, -1, 0, 1, 2, 5]: - expected = DataArray(arr.to_pandas().shift(offset)) - actual = arr.shift(x=offset) - assert_identical(expected, actual) + expected = DataArray(arr.to_pandas().shift(offset)) + actual = arr.shift(x=offset) + assert_identical(expected, actual) + + def test_roll_coords(self): + arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + actual = arr.roll(x=1, roll_coords=True) + expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) + assert_identical(expected, actual) + + def test_roll_no_coords(self): + arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + actual = arr.roll(x=1, roll_coords=False) + expected = DataArray([3, 1, 2], coords=[('x', [0, 1, 2])]) + assert_identical(expected, actual) - def test_roll(self): + def test_roll_coords_none(self): arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') - actual = arr.roll(x=1) + + with pytest.warns(FutureWarning): + actual = arr.roll(x=1, roll_coords=None) + expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) assert_identical(expected, actual) + def test_copy_with_data(self): + orig = DataArray(np.random.random(size=(2, 2)), + dims=('x', 'y'), + attrs={'attr1': 'value1'}, + coords={'x': [4, 3]}, + name='helloworld') + new_data = np.arange(4).reshape(2, 2) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + def test_real_and_imag(self): array = DataArray(1 + 2j) assert_identical(array.real, DataArray(1)) @@ -3180,8 +3212,6 @@ def test_dot(self): da.dot(dm.to_dataset(name='dm')) with pytest.raises(TypeError): da.dot(dm.values) - with raises_regex(ValueError, 'no shared dimensions'): - da.dot(DataArray(1)) def test_binary_op_join_setting(self): dim = 'x' @@ -3248,9 +3278,6 @@ def test_sortby(self): actual = da.sortby([day, dax]) assert_equal(actual, expected) - if LooseVersion(np.__version__) < LooseVersion('1.11.0'): - pytest.skip('numpy 1.11.0 or later to support object data-type.') - expected = sorted1d actual = da.sortby('x') assert_equal(actual, expected) @@ -3293,6 +3320,14 @@ def da(request): [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims='time') + if request.param == 'repeating_ints': + return DataArray( + np.tile(np.arange(12), 5).reshape(5, 4, 3), + coords={'x': list('abc'), + 'y': list('defg')}, + dims=list('zyx') + ) + @pytest.fixture def da_dask(seed=123): @@ -3305,9 +3340,36 @@ def da_dask(seed=123): return da +@pytest.mark.parametrize('da', ('repeating_ints', ), indirect=True) +def test_isin(da): + + expected = DataArray( + np.asarray([[0, 0, 0], [1, 0, 0]]), + dims=list('yx'), + coords={'x': list('abc'), + 'y': list('de')}, + ).astype('bool') + + result = da.isin([3]).sel(y=list('de'), z=0) + assert_equal(result, expected) + + expected = DataArray( + np.asarray([[0, 0, 1], [1, 0, 0]]), + dims=list('yx'), + coords={'x': list('abc'), + 'y': list('de')}, + ).astype('bool') + result = da.isin([2, 3]).sel(y=list('de'), z=0) + assert_equal(result, expected) + + +@pytest.mark.parametrize('da', (1, 2), indirect=True) def test_rolling_iter(da): rolling_obj = da.rolling(time=7) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + rolling_obj_mean = rolling_obj.mean() assert len(rolling_obj.window_labels) == len(da['time']) assert_identical(rolling_obj.window_labels, da['time']) @@ -3315,6 +3377,18 @@ def test_rolling_iter(da): for i, (label, window_da) in enumerate(rolling_obj): assert label == da['time'].isel(time=i) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + actual = rolling_obj_mean.isel(time=i) + expected = window_da.mean('time') + + # TODO add assert_allclose_with_nan, which compares nan position + # as well as the closeness of the values. + assert_array_equal(actual.isnull(), expected.isnull()) + if (~actual.isnull()).sum() > 0: + np.allclose(actual.values[actual.values.nonzero()], + expected.values[expected.values.nonzero()]) + def test_rolling_doc(da): rolling_obj = da.rolling(time=7) @@ -3362,29 +3436,49 @@ def test_rolling_wrapped_bottleneck(da, name, center, min_periods): assert_equal(actual, da['time']) -@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'min', 'max', - 'median')) +@pytest.mark.parametrize('name', ('mean', 'count')) @pytest.mark.parametrize('center', (True, False, None)) @pytest.mark.parametrize('min_periods', (1, None)) -def test_rolling_wrapped_bottleneck_dask(da_dask, name, center, min_periods): +@pytest.mark.parametrize('window', (7, 8)) +def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window): pytest.importorskip('dask.array') # dask version - rolling_obj = da_dask.rolling(time=7, min_periods=min_periods) + rolling_obj = da_dask.rolling(time=window, min_periods=min_periods, + center=center) actual = getattr(rolling_obj, name)().load() # numpy version - rolling_obj = da_dask.load().rolling(time=7, min_periods=min_periods) + rolling_obj = da_dask.load().rolling(time=window, min_periods=min_periods, + center=center) expected = getattr(rolling_obj, name)() # using all-close because rolling over ghost cells introduces some # precision errors assert_allclose(actual, expected) + # with zero chunked array GH:2113 + rolling_obj = da_dask.chunk().rolling(time=window, min_periods=min_periods, + center=center) + actual = getattr(rolling_obj, name)().load() + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('center', (True, None)) +def test_rolling_wrapped_dask_nochunk(center): + # GH:2113 + pytest.importorskip('dask.array') + + da_day_clim = xr.DataArray(np.arange(1, 367), + coords=[np.arange(1, 367)], dims='dayofyear') + expected = da_day_clim.rolling(dayofyear=31, center=center).mean() + actual = da_day_clim.chunk().rolling(dayofyear=31, center=center).mean() + assert_allclose(actual, expected) + @pytest.mark.parametrize('center', (True, False)) @pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) @pytest.mark.parametrize('window', (1, 2, 3, 4)) -def test_rolling_pandas_compat(da, center, window, min_periods): - s = pd.Series(range(10)) +def test_rolling_pandas_compat(center, window, min_periods): + s = pd.Series(np.arange(10)) da = DataArray.from_series(s) if min_periods is not None and window < min_periods: @@ -3394,12 +3488,39 @@ def test_rolling_pandas_compat(da, center, window, min_periods): min_periods=min_periods).mean() da_rolling = da.rolling(index=window, center=center, min_periods=min_periods).mean() - # pandas does some fancy stuff in the last position, - # we're not going to do that yet! - np.testing.assert_allclose(s_rolling.values[:-1], - da_rolling.values[:-1]) - np.testing.assert_allclose(s_rolling.index, - da_rolling['index']) + da_rolling_np = da.rolling(index=window, center=center, + min_periods=min_periods).reduce(np.nanmean) + + np.testing.assert_allclose(s_rolling.values, da_rolling.values) + np.testing.assert_allclose(s_rolling.index, da_rolling['index']) + np.testing.assert_allclose(s_rolling.values, da_rolling_np.values) + np.testing.assert_allclose(s_rolling.index, da_rolling_np['index']) + + +@pytest.mark.parametrize('center', (True, False)) +@pytest.mark.parametrize('window', (1, 2, 3, 4)) +def test_rolling_construct(center, window): + s = pd.Series(np.arange(10)) + da = DataArray.from_series(s) + + s_rolling = s.rolling(window, center=center, min_periods=1).mean() + da_rolling = da.rolling(index=window, center=center, min_periods=1) + + da_rolling_mean = da_rolling.construct('window').mean('window') + np.testing.assert_allclose(s_rolling.values, da_rolling_mean.values) + np.testing.assert_allclose(s_rolling.index, da_rolling_mean['index']) + + # with stride + da_rolling_mean = da_rolling.construct('window', + stride=2).mean('window') + np.testing.assert_allclose(s_rolling.values[::2], da_rolling_mean.values) + np.testing.assert_allclose(s_rolling.index[::2], da_rolling_mean['index']) + + # with fill_value + da_rolling_mean = da_rolling.construct( + 'window', stride=2, fill_value=0.0).mean('window') + assert da_rolling_mean.isnull().sum() == 0 + assert (da_rolling_mean == 0.0).sum() >= 0 @pytest.mark.parametrize('da', (1, 2), indirect=True) @@ -3412,6 +3533,10 @@ def test_rolling_reduce(da, center, min_periods, window, name): if min_periods is not None and window < min_periods: min_periods = window + if da.isnull().sum() > 1 and window == 1: + # this causes all nan slices + window = 2 + rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods) @@ -3422,29 +3547,264 @@ def test_rolling_reduce(da, center, min_periods, window, name): assert actual.dims == expected.dims +@requires_np113 +@pytest.mark.parametrize('center', (True, False)) +@pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) +@pytest.mark.parametrize('window', (1, 2, 3, 4)) +@pytest.mark.parametrize('name', ('sum', 'max')) +def test_rolling_reduce_nonnumeric(center, min_periods, window, name): + da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], + dims='time').isnull() + + if min_periods is not None and window < min_periods: + min_periods = window + + rolling_obj = da.rolling(time=window, center=center, + min_periods=min_periods) + + # add nan prefix to numpy methods to get similar behavior as bottleneck + actual = rolling_obj.reduce(getattr(np, 'nan%s' % name)) + expected = getattr(rolling_obj, name)() + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + def test_rolling_count_correct(): da = DataArray( [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims='time') - result = da.rolling(time=11, min_periods=1).count() - expected = DataArray( - [1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8], dims='time') - assert_equal(result, expected) - - result = da.rolling(time=11, min_periods=None).count() - expected = DataArray( + kwargs = [{'time': 11, 'min_periods': 1}, + {'time': 11, 'min_periods': None}, + {'time': 7, 'min_periods': 2}] + expecteds = [DataArray( + [1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8], dims='time'), + DataArray( [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, 8], dims='time') - assert_equal(result, expected) + np.nan, np.nan, np.nan, np.nan, np.nan], dims='time'), + DataArray( + [np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims='time')] - result = da.rolling(time=7, min_periods=2).count() - expected = DataArray( - [np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims='time') - assert_equal(result, expected) + for kwarg, expected in zip(kwargs, expecteds): + result = da.rolling(**kwarg).count() + assert_equal(result, expected) + + result = da.to_dataset(name='var1').rolling(**kwarg).count()['var1'] + assert_equal(result, expected) def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 assert len(record) == 0 + + +class TestIrisConversion(object): + @requires_iris + def test_to_and_from_iris(self): + import iris + import cf_units # iris requirement + + # to iris + coord_dict = OrderedDict() + coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) + coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) + coord_dict['height'] = 10 + coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) + coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) + + original = DataArray(np.arange(6, dtype='float').reshape(2, 3), + coord_dict, name='Temperature', + attrs={'baz': 123, 'units': 'Kelvin', + 'standard_name': 'fire_temperature', + 'long_name': 'Fire Temperature'}, + dims=('distance', 'time')) + + # Set a bad value to test the masking logic + original.data[0, 2] = np.NaN + + original.attrs['cell_methods'] = \ + 'height: mean (comment: A cell method)' + actual = original.to_iris() + assert_array_equal(actual.data, original.data) + assert actual.var_name == original.name + assert tuple(d.var_name for d in actual.dim_coords) == original.dims + assert (actual.cell_methods == (iris.coords.CellMethod( + method='mean', + coords=('height', ), + intervals=(), + comments=('A cell method', )), )) + + for coord, orginal_key in zip((actual.coords()), original.coords): + original_coord = original.coords[orginal_key] + assert coord.var_name == original_coord.name + assert_array_equal( + coord.points, CFDatetimeCoder().encode(original_coord).values) + assert (actual.coord_dims(coord) == + original.get_axis_num( + original.coords[coord.var_name].dims)) + + assert (actual.coord('distance2').attributes['foo'] == + original.coords['distance2'].attrs['foo']) + assert (actual.coord('distance').units == + cf_units.Unit(original.coords['distance'].units)) + assert actual.attributes['baz'] == original.attrs['baz'] + assert actual.standard_name == original.attrs['standard_name'] + + roundtripped = DataArray.from_iris(actual) + assert_identical(original, roundtripped) + + actual.remove_coord('time') + auto_time_dimension = DataArray.from_iris(actual) + assert auto_time_dimension.dims == ('distance', 'dim_1') + + @requires_iris + @requires_dask + def test_to_and_from_iris_dask(self): + import dask.array as da + import iris + import cf_units # iris requirement + + coord_dict = OrderedDict() + coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) + coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) + coord_dict['height'] = 10 + coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) + coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) + + original = DataArray( + da.from_array(np.arange(-1, 5, dtype='float').reshape(2, 3), 3), + coord_dict, + name='Temperature', + attrs=dict(baz=123, units='Kelvin', + standard_name='fire_temperature', + long_name='Fire Temperature'), + dims=('distance', 'time')) + + # Set a bad value to test the masking logic + original.data = da.ma.masked_less(original.data, 0) + + original.attrs['cell_methods'] = \ + 'height: mean (comment: A cell method)' + actual = original.to_iris() + + # Be careful not to trigger the loading of the iris data + actual_data = actual.core_data() if \ + hasattr(actual, 'core_data') else actual.data + assert_array_equal(actual_data, original.data) + assert actual.var_name == original.name + assert tuple(d.var_name for d in actual.dim_coords) == original.dims + assert (actual.cell_methods == (iris.coords.CellMethod( + method='mean', + coords=('height', ), + intervals=(), + comments=('A cell method', )), )) + + for coord, orginal_key in zip((actual.coords()), original.coords): + original_coord = original.coords[orginal_key] + assert coord.var_name == original_coord.name + assert_array_equal( + coord.points, CFDatetimeCoder().encode(original_coord).values) + assert (actual.coord_dims(coord) == + original.get_axis_num( + original.coords[coord.var_name].dims)) + + assert (actual.coord('distance2').attributes['foo'] == original.coords[ + 'distance2'].attrs['foo']) + assert (actual.coord('distance').units == + cf_units.Unit(original.coords['distance'].units)) + assert actual.attributes['baz'] == original.attrs['baz'] + assert actual.standard_name == original.attrs['standard_name'] + + roundtripped = DataArray.from_iris(actual) + assert_identical(original, roundtripped) + + # If the Iris version supports it then we should have a dask array + # at each stage of the conversion + if hasattr(actual, 'core_data'): + assert isinstance(original.data, type(actual.core_data())) + assert isinstance(original.data, type(roundtripped.data)) + + actual.remove_coord('time') + auto_time_dimension = DataArray.from_iris(actual) + assert auto_time_dimension.dims == ('distance', 'dim_1') + + @requires_iris + @pytest.mark.parametrize('var_name, std_name, long_name, name, attrs', [ + ('var_name', 'height', 'Height', + 'var_name', {'standard_name': 'height', 'long_name': 'Height'}), + (None, 'height', 'Height', + 'height', {'standard_name': 'height', 'long_name': 'Height'}), + (None, None, 'Height', + 'Height', {'long_name': 'Height'}), + (None, None, None, + None, {}), + ]) + def test_da_name_from_cube(self, std_name, long_name, var_name, name, + attrs): + from iris.cube import Cube + + data = [] + cube = Cube(data, var_name=var_name, standard_name=std_name, + long_name=long_name) + result = xr.DataArray.from_iris(cube) + expected = xr.DataArray(data, name=name, attrs=attrs) + xr.testing.assert_identical(result, expected) + + @requires_iris + @pytest.mark.parametrize('var_name, std_name, long_name, name, attrs', [ + ('var_name', 'height', 'Height', + 'var_name', {'standard_name': 'height', 'long_name': 'Height'}), + (None, 'height', 'Height', + 'height', {'standard_name': 'height', 'long_name': 'Height'}), + (None, None, 'Height', + 'Height', {'long_name': 'Height'}), + (None, None, None, + 'unknown', {}), + ]) + def test_da_coord_name_from_cube(self, std_name, long_name, var_name, + name, attrs): + from iris.cube import Cube + from iris.coords import DimCoord + + latitude = DimCoord([-90, 0, 90], standard_name=std_name, + var_name=var_name, long_name=long_name) + data = [0, 0, 0] + cube = Cube(data, dim_coords_and_dims=[(latitude, 0)]) + result = xr.DataArray.from_iris(cube) + expected = xr.DataArray(data, coords=[(name, [-90, 0, 90], attrs)]) + xr.testing.assert_identical(result, expected) + + @requires_iris + def test_prevent_duplicate_coord_names(self): + from iris.cube import Cube + from iris.coords import DimCoord + + # Iris enforces unique coordinate names. Because we use a different + # name resolution order a valid iris Cube with coords that have the + # same var_name would lead to duplicate dimension names in the + # DataArray + longitude = DimCoord([0, 360], standard_name='longitude', + var_name='duplicate') + latitude = DimCoord([-90, 0, 90], standard_name='latitude', + var_name='duplicate') + data = [[0, 0, 0], [0, 0, 0]] + cube = Cube(data, dim_coords_and_dims=[(longitude, 0), (latitude, 1)]) + with pytest.raises(ValueError): + xr.DataArray.from_iris(cube) + + @requires_iris + @pytest.mark.parametrize('coord_values', [ + ['IA', 'IL', 'IN'], # non-numeric values + [0, 2, 1], # non-monotonic values + ]) + def test_fallback_to_iris_AuxCoord(self, coord_values): + from iris.cube import Cube + from iris.coords import AuxCoord + + data = [0, 0, 0] + da = xr.DataArray(data, coords=[coord_values], dims=['space']) + result = xr.DataArray.to_iris(da) + expected = Cube(data, aux_coords_and_dims=[ + (AuxCoord(coord_values, var_name='space'), 0)]) + assert result == expected diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 09d67613007..89ea3ba78a0 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1,9 +1,31 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +import sys +import warnings from copy import copy, deepcopy +from io import StringIO from textwrap import dedent + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray import ( + ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align, + backends, broadcast, open_dataset, set_options) +from xarray.core import indexing, npcompat, utils +from xarray.core.common import full_like +from xarray.core.pycompat import ( + OrderedDict, integer_types, iteritems, unicode_type) + +from . import ( + InaccessibleArray, UnexpectedDataAccess, assert_allclose, + assert_array_equal, assert_equal, assert_identical, has_cftime, has_dask, + raises_regex, requires_bottleneck, requires_dask, requires_scipy, + source_ndarray) + try: import cPickle as pickle except ImportError: @@ -12,25 +34,6 @@ import dask.array as da except ImportError: pass -from io import StringIO -from distutils.version import LooseVersion - -import numpy as np -import pandas as pd -import xarray as xr -import pytest - -from xarray import (align, broadcast, backends, Dataset, DataArray, Variable, - IndexVariable, open_dataset, set_options, MergeError) -from xarray.core import indexing, utils -from xarray.core.pycompat import (iteritems, OrderedDict, unicode_type, - integer_types) -from xarray.core.common import full_like - -from . import (TestCase, raises_regex, InaccessibleArray, UnexpectedDataAccess, - requires_dask, source_ndarray, assert_array_equal, assert_equal, - assert_allclose, assert_identical, requires_bottleneck, - requires_scipy) def create_test_data(seed=None): @@ -61,8 +64,8 @@ def create_test_multiindex(): class InaccessibleVariableDataStore(backends.InMemoryDataStore): - def __init__(self, writer=None): - super(InaccessibleVariableDataStore, self).__init__(writer) + def __init__(self): + super(InaccessibleVariableDataStore, self).__init__() self._indexvars = set() def store(self, variables, *args, **kwargs): @@ -76,13 +79,14 @@ def get_variables(self): def lazy_inaccessible(k, v): if k in self._indexvars: return v - data = indexing.LazilyIndexedArray(InaccessibleArray(v.values)) + data = indexing.LazilyOuterIndexedArray( + InaccessibleArray(v.values)) return Variable(v.dims, data, v.attrs) return dict((k, lazy_inaccessible(k, v)) for k, v in iteritems(self._variables)) -class TestDataset(TestCase): +class TestDataset(object): def test_repr(self): data = create_test_data(seed=123) data.attrs['foo'] = 'bar' @@ -91,15 +95,15 @@ def test_repr(self): Dimensions: (dim1: 8, dim2: 9, dim3: 10, time: 20) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20 * dim2 (dim2) float64 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 * dim3 (dim3) %s 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' numbers (dim3) int64 0 1 2 0 0 1 1 2 2 3 Dimensions without coordinates: dim1 Data variables: - var1 (dim1, dim2) float64 -1.086 0.9973 0.283 -1.506 -0.5786 1.651 ... - var2 (dim1, dim2) float64 1.162 -1.097 -2.123 1.04 -0.4034 -0.126 ... - var3 (dim3, dim1) float64 0.5565 -0.2121 0.4563 1.545 -0.2397 0.1433 ... + var1 (dim1, dim2) float64 -1.086 0.9973 0.283 ... 0.1995 0.4684 -0.8312 + var2 (dim1, dim2) float64 1.162 -1.097 -2.123 ... 0.1302 1.267 0.3328 + var3 (dim3, dim1) float64 0.5565 -0.2121 0.4563 ... -0.2452 -0.3616 Attributes: foo: bar""") % data['dim3'].dtype # noqa: E501 actual = '\n'.join(x.rstrip() for x in repr(data).split('\n')) @@ -180,15 +184,16 @@ def test_unicode_data(self): data = Dataset({u'foø': [u'ba®']}, attrs={u'å': u'∑'}) repr(data) # should not raise + byteorder = '<' if sys.byteorder == 'little' else '>' expected = dedent(u"""\ Dimensions: (foø: 1) Coordinates: - * foø (foø) KeyError + # in pandas 0.23. + with pytest.raises((ValueError, KeyError)): # not contained in axis data.drop(['c'], dim='x') @@ -1871,6 +1889,27 @@ def test_copy(self): v1 = copied.variables[k] assert v0 is not v1 + def test_copy_with_data(self): + orig = create_test_data() + new_data = {k: np.random.randn(*v.shape) + for k, v in iteritems(orig.data_vars)} + actual = orig.copy(data=new_data) + + expected = orig.copy() + for k, v in new_data.items(): + expected[k].data = v + assert_identical(expected, actual) + + def test_copy_with_data_errors(self): + orig = create_test_data() + new_var1 = np.arange(orig['var1'].size).reshape(orig['var1'].shape) + with raises_regex(ValueError, 'Data must be dict-like'): + orig.copy(data=new_var1) + with raises_regex(ValueError, 'only contain variables in original'): + orig.copy(data={'not_in_original': new_var1}) + with raises_regex(ValueError, 'contain all variables in original'): + orig.copy(data={'var1': new_var1}) + def test_rename(self): data = create_test_data() newnames = {'var1': 'renamed_var1', 'dim2': 'renamed_dim2'} @@ -1907,6 +1946,9 @@ def test_rename(self): with pytest.raises(UnexpectedDataAccess): renamed['renamed_var1'].values + renamed_kwargs = data.rename(**newnames) + assert_identical(renamed, renamed_kwargs) + def test_rename_old_name(self): # regtest for GH1477 data = create_test_data() @@ -1923,6 +1965,7 @@ def test_rename_same_name(self): renamed = data.rename(newnames) assert_identical(renamed, data) + @pytest.mark.filterwarnings('ignore:The inplace argument') def test_rename_inplace(self): times = pd.date_range('2000-01-01', periods=3) data = Dataset({'z': ('x', [2, 3, 4]), 't': ('t', times)}) @@ -1948,7 +1991,7 @@ def test_swap_dims(self): assert_identical(original.set_coords('y'), roundtripped) actual = original.copy() - actual.swap_dims({'x': 'y'}, inplace=True) + actual = actual.swap_dims({'x': 'y'}) assert_identical(expected, actual) with raises_regex(ValueError, 'cannot swap'): @@ -1970,7 +2013,7 @@ def test_expand_dims_error(self): # Make sure it raises true error also for non-dimensional coordinates # which has dimension. - original.set_coords('z', inplace=True) + original = original.set_coords('z') with raises_regex(ValueError, 'already exists'): original.expand_dims(dim=['z']) @@ -2017,8 +2060,9 @@ def test_set_index(self): obj = ds.set_index(x=mindex.names) assert_identical(obj, expected) - ds.set_index(x=mindex.names, inplace=True) - assert_identical(ds, expected) + with pytest.warns(FutureWarning, message='The inplace argument'): + ds.set_index(x=mindex.names, inplace=True) + assert_identical(ds, expected) # ensure set_index with no existing index and a single data var given # doesn't return multi-index @@ -2036,8 +2080,9 @@ def test_reset_index(self): obj = ds.reset_index('x') assert_identical(obj, expected) - ds.reset_index('x', inplace=True) - assert_identical(ds, expected) + with pytest.warns(FutureWarning, message='The inplace argument'): + ds.reset_index('x', inplace=True) + assert_identical(ds, expected) def test_reorder_levels(self): ds = create_test_multiindex() @@ -2048,8 +2093,9 @@ def test_reorder_levels(self): reindexed = ds.reorder_levels(x=['level_2', 'level_1']) assert_identical(reindexed, expected) - ds.reorder_levels(x=['level_2', 'level_1'], inplace=True) - assert_identical(ds, expected) + with pytest.warns(FutureWarning, message='The inplace argument'): + ds.reorder_levels(x=['level_2', 'level_1'], inplace=True) + assert_identical(ds, expected) ds = Dataset({}, coords={'x': [1, 2]}) with raises_regex(ValueError, 'has no MultiIndex'): @@ -2083,17 +2129,18 @@ def test_unstack(self): expected = Dataset({'b': (('x', 'y'), [[0, 1], [2, 3]]), 'x': [0, 1], 'y': ['a', 'b']}) - actual = ds.unstack('z') - assert_identical(actual, expected) + for dim in ['z', ['z'], None]: + actual = ds.unstack(dim) + assert_identical(actual, expected) def test_unstack_errors(self): ds = Dataset({'x': [1, 2, 3]}) - with raises_regex(ValueError, 'invalid dimension'): + with raises_regex(ValueError, 'does not contain the dimensions'): ds.unstack('foo') - with raises_regex(ValueError, 'does not have a MultiIndex'): + with raises_regex(ValueError, 'do not have a MultiIndex'): ds.unstack('x') - def test_stack_unstack(self): + def test_stack_unstack_fast(self): ds = Dataset({'a': ('x', [0, 1]), 'b': (('x', 'y'), [[0, 1], [2, 3]]), 'x': [0, 1], @@ -2104,6 +2151,19 @@ def test_stack_unstack(self): actual = ds[['b']].stack(z=['x', 'y']).unstack('z') assert actual.identical(ds[['b']]) + def test_stack_unstack_slow(self): + ds = Dataset({'a': ('x', [0, 1]), + 'b': (('x', 'y'), [[0, 1], [2, 3]]), + 'x': [0, 1], + 'y': ['a', 'b']}) + stacked = ds.stack(z=['x', 'y']) + actual = stacked.isel(z=slice(None, None, -1)).unstack('z') + assert actual.broadcast_equals(ds) + + stacked = ds[['b']].stack(z=['x', 'y']) + actual = stacked.isel(z=slice(None, None, -1)).unstack('z') + assert actual.identical(ds[['b']]) + def test_update(self): data = create_test_data(seed=0) expected = data.copy() @@ -2113,20 +2173,37 @@ def test_update(self): assert_identical(expected, actual) actual = data.copy() - actual_result = actual.update(data, inplace=True) + actual_result = actual.update(data) assert actual_result is actual assert_identical(expected, actual) - actual = data.update(data, inplace=False) - expected = data - assert actual is not expected - assert_identical(expected, actual) + with pytest.warns(FutureWarning, message='The inplace argument'): + actual = data.update(data, inplace=False) + expected = data + assert actual is not expected + assert_identical(expected, actual) other = Dataset(attrs={'new': 'attr'}) actual = data.copy() actual.update(other) assert_identical(expected, actual) + def test_update_overwrite_coords(self): + data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) + data.update(Dataset(coords={'b': 4})) + expected = Dataset({'a': ('x', [1, 2])}, {'b': 4}) + assert_identical(data, expected) + + data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) + data.update(Dataset({'c': 5}, coords={'b': 4})) + expected = Dataset({'a': ('x', [1, 2]), 'c': 5}, {'b': 4}) + assert_identical(data, expected) + + data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) + data.update({'c': DataArray(5, coords={'b': 4})}) + expected = Dataset({'a': ('x', [1, 2]), 'c': 5}, {'b': 3}) + assert_identical(data, expected) + def test_update_auto_align(self): ds = Dataset({'x': ('t', [3, 4])}, {'t': [0, 1]}) @@ -2320,6 +2397,52 @@ def test_setitem_auto_align(self): expected = Dataset({'x': ('y', [4, 5, 6])}, {'y': range(3)}) assert_identical(ds, expected) + def test_setitem_with_coords(self): + # Regression test for GH:2068 + ds = create_test_data() + + other = DataArray(np.arange(10), dims='dim3', + coords={'numbers': ('dim3', np.arange(10))}) + expected = ds.copy() + expected['var3'] = other.drop('numbers') + actual = ds.copy() + actual['var3'] = other + assert_identical(expected, actual) + assert 'numbers' in other.coords # should not change other + + # with alignment + other = ds['var3'].isel(dim3=slice(1, -1)) + other['numbers'] = ('dim3', np.arange(8)) + actual = ds.copy() + actual['var3'] = other + assert 'numbers' in other.coords # should not change other + expected = ds.copy() + expected['var3'] = ds['var3'].isel(dim3=slice(1, -1)) + assert_identical(expected, actual) + + # with non-duplicate coords + other = ds['var3'].isel(dim3=slice(1, -1)) + other['numbers'] = ('dim3', np.arange(8)) + other['position'] = ('dim3', np.arange(8)) + actual = ds.copy() + actual['var3'] = other + assert 'position' in actual + assert 'position' in other.coords + + # assigning a coordinate-only dataarray + actual = ds.copy() + other = actual['numbers'] + other[0] = 10 + actual['numbers'] = other + assert actual['numbers'][0] == 10 + + # GH: 2099 + ds = Dataset({'var': ('x', [1, 2, 3])}, + coords={'x': [0, 1, 2], 'z1': ('x', [1, 2, 3]), + 'z2': ('x', [1, 2, 3])}) + ds['var'] = ds['var'] * 2 + assert np.allclose(ds['var'], [2, 4, 6]) + def test_setitem_align_new_indexes(self): ds = Dataset({'foo': ('x', [1, 2, 3])}, {'x': [0, 1, 2]}) ds['bar'] = DataArray([2, 3, 4], [('x', [1, 2, 3])]) @@ -2374,6 +2497,18 @@ def test_assign_multiindex_level(self): with raises_regex(ValueError, 'conflicting MultiIndex'): data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) + # raise an Error when any level name is used as dimension GH:2299 + with pytest.raises(ValueError): + data['y'] = ('level_1', [0, 1]) + + def test_merge_multiindex_level(self): + data = create_test_multiindex() + other = Dataset({'z': ('level_1', [0, 1])}) # conflict dimension + with pytest.raises(ValueError): + data.merge(other) + other = Dataset({'level_1': ('x', [0, 1])}) # conflict variable name + with pytest.raises(ValueError): + data.merge(other) def test_setitem_original_non_unique_index(self): # regression test for GH943 @@ -2412,12 +2547,11 @@ def test_setitem_multiindex_level(self): def test_delitem(self): data = create_test_data() all_items = set(data.variables) - self.assertItemsEqual(data.variables, all_items) + assert set(data.variables) == all_items del data['var1'] - self.assertItemsEqual(data.variables, all_items - set(['var1'])) + assert set(data.variables) == all_items - set(['var1']) del data['numbers'] - self.assertItemsEqual(data.variables, - all_items - set(['var1', 'numbers'])) + assert set(data.variables) == all_items - set(['var1', 'numbers']) assert 'numbers' not in data.coords def test_squeeze(self): @@ -2427,7 +2561,7 @@ def get_args(v): return [set(args[0]) & set(v.dims)] if args else [] expected = Dataset(dict((k, v.squeeze(*get_args(v))) for k, v in iteritems(data.variables))) - expected.set_coords(data.coords, inplace=True) + expected = expected.set_coords(data.coords) assert_identical(expected, data.squeeze(*args)) # invalid squeeze with raises_regex(ValueError, 'cannot select a dimension'): @@ -2515,20 +2649,28 @@ def test_groupby_reduce(self): expected = data.mean('y') expected['yonly'] = expected['yonly'].variable.set_dims({'x': 3}) - actual = data.groupby('x').mean() + actual = data.groupby('x').mean(ALL_DIMS) assert_allclose(expected, actual) actual = data.groupby('x').mean('y') assert_allclose(expected, actual) letters = data['letters'] - expected = Dataset({'xy': data['xy'].groupby(letters).mean(), + expected = Dataset({'xy': data['xy'].groupby(letters).mean(ALL_DIMS), 'xonly': (data['xonly'].mean().variable .set_dims({'letters': 2})), 'yonly': data['yonly'].groupby(letters).mean()}) - actual = data.groupby('letters').mean() + actual = data.groupby('letters').mean(ALL_DIMS) assert_allclose(expected, actual) + def test_groupby_warn(self): + data = Dataset({'xy': (['x', 'y'], np.random.randn(3, 4)), + 'xonly': ('x', np.random.randn(3)), + 'yonly': ('y', np.random.randn(4)), + 'letters': ('y', ['a', 'a', 'b', 'b'])}) + with pytest.warns(FutureWarning): + data.groupby('x').mean() + def test_groupby_math(self): def reorder_dims(x): return x.transpose('dim1', 'dim2', 'dim3', 'time') @@ -2583,7 +2725,7 @@ def test_groupby_math_virtual(self): ds = Dataset({'x': ('t', [1, 2, 3])}, {'t': pd.date_range('20100101', periods=3)}) grouped = ds.groupby('t.day') - actual = grouped - grouped.mean() + actual = grouped - grouped.mean(ALL_DIMS) expected = Dataset({'x': ('t', [0, 0, 0])}, ds[['t', 't.day']]) assert_identical(actual, expected) @@ -2592,18 +2734,17 @@ def test_groupby_nan(self): # nan should be excluded from groupby ds = Dataset({'foo': ('x', [1, 2, 3, 4])}, {'bar': ('x', [1, 1, 2, np.nan])}) - actual = ds.groupby('bar').mean() + actual = ds.groupby('bar').mean(ALL_DIMS) expected = Dataset({'foo': ('bar', [1.5, 3]), 'bar': [1, 2]}) assert_identical(actual, expected) def test_groupby_order(self): # groupby should preserve variables order - ds = Dataset() for vn in ['a', 'b', 'c']: ds[vn] = DataArray(np.arange(10), dims=['t']) data_vars_ref = list(ds.data_vars.keys()) - ds = ds.groupby('t').mean() + ds = ds.groupby('t').mean(ALL_DIMS) data_vars = list(ds.data_vars.keys()) assert data_vars == data_vars_ref # coords are now at the end of the list, so the test below fails @@ -2633,6 +2774,20 @@ def test_resample_and_first(self): result = actual.reduce(method) assert_equal(expected, result) + def test_resample_min_count(self): + times = pd.date_range('2000-01-01', freq='6H', periods=10) + ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), + 'bar': ('time', np.random.randn(10), {'meta': 'data'}), + 'time': times}) + # inject nan + ds['foo'] = xr.where(ds['foo'] > 2.0, np.nan, ds['foo']) + + actual = ds.resample(time='1D').sum(min_count=1) + expected = xr.concat([ + ds.isel(time=slice(i * 4, (i + 1) * 4)).sum('time', min_count=1) + for i in range(3)], dim=actual['time']) + assert_equal(expected, actual) + def test_resample_by_mean_with_keep_attrs(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), @@ -2703,22 +2858,21 @@ def test_resample_drop_nondim_coords(self): actual = ds.resample(time="1H").interpolate('linear') assert 'tc' not in actual.coords - def test_resample_old_vs_new_api(self): + def test_resample_old_api(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), 'bar': ('time', np.random.randn(10), {'meta': 'data'}), 'time': times}) - ds.attrs['dsmeta'] = 'dsdata' - for method in ['mean', 'sum', 'count', 'first', 'last']: - resampler = ds.resample(time='1D') - # Discard attributes on the call using the new api to match - # convention from old api - new_api = getattr(resampler, method)(keep_attrs=False) - with pytest.warns(DeprecationWarning): - old_api = ds.resample('1D', dim='time', how=method) - assert_identical(new_api, old_api) + with raises_regex(TypeError, r'resample\(\) no longer supports'): + ds.resample('1D', 'time') + + with raises_regex(TypeError, r'resample\(\) no longer supports'): + ds.resample('1D', dim='time', how='mean') + + with raises_regex(TypeError, r'resample\(\) no longer supports'): + ds.resample('1D', dim='time') def test_to_array(self): ds = Dataset(OrderedDict([('a', 1), ('b', ('x', [1, 2, 3]))]), @@ -3270,13 +3424,23 @@ def test_reduce(self): (['dim2', 'time'], ['dim1', 'dim3']), (('dim2', 'time'), ['dim1', 'dim3']), ((), ['dim1', 'dim2', 'dim3', 'time'])]: - actual = data.min(dim=reduct).dims - print(reduct, actual, expected) - self.assertItemsEqual(actual, expected) + actual = list(data.min(dim=reduct).dims) + assert actual == expected assert_equal(data.mean(dim=[]), data) - # uint support + def test_reduce_coords(self): + # regression test for GH1470 + data = xr.Dataset({'a': ('x', [1, 2, 3])}, coords={'b': 4}) + expected = xr.Dataset({'a': 2}, coords={'b': 4}) + actual = data.mean('x') + assert_identical(actual, expected) + + # should be consistent + actual = data['a'].mean('x').to_dataset() + assert_identical(actual, expected) + + def test_mean_uint_dtype(self): data = xr.Dataset({'a': (('x', 'y'), np.arange(6).reshape(3, 2).astype('uint')), 'b': (('x', ), np.array([0.1, 0.2, np.nan]))}) @@ -3290,15 +3454,20 @@ def test_reduce_bad_dim(self): with raises_regex(ValueError, 'Dataset does not contain'): data.mean(dim='bad_dim') + def test_reduce_cumsum(self): + data = xr.Dataset({'a': 1, + 'b': ('x', [1, 2]), + 'c': (('x', 'y'), [[np.nan, 3], [0, 4]])}) + assert_identical(data.fillna(0), data.cumsum('y')) + + expected = xr.Dataset({'a': 1, + 'b': ('x', [1, 3]), + 'c': (('x', 'y'), [[0, 3], [0, 7]])}) + assert_identical(expected, data.cumsum()) + def test_reduce_cumsum_test_dims(self): data = create_test_data() for cumfunc in ['cumsum', 'cumprod']: - with raises_regex(ValueError, - "must supply either single 'dim' or 'axis'"): - getattr(data, cumfunc)() - with raises_regex(ValueError, - "must supply either single 'dim' or 'axis'"): - getattr(data, cumfunc)(dim=['dim1', 'dim2']) with raises_regex(ValueError, 'Dataset does not contain'): getattr(data, cumfunc)(dim='bad_dim') @@ -3310,8 +3479,7 @@ def test_reduce_cumsum_test_dims(self): ('time', ['dim1', 'dim2', 'dim3']) ]: actual = getattr(data, cumfunc)(dim=reduct).dims - print(reduct, actual, expected) - self.assertItemsEqual(actual, expected) + assert list(actual) == expected def test_reduce_non_numeric(self): data1 = create_test_data(seed=44) @@ -3405,6 +3573,10 @@ def test_reduce_scalars(self): actual = ds.var() assert_identical(expected, actual) + expected = Dataset({'x': 0, 'y': 0, 'z': ('b', [0])}) + actual = ds.var('a') + assert_identical(expected, actual) + def test_reduce_only_one_axis(self): def mean_only_one_axis(x, axis): @@ -3445,14 +3617,14 @@ def test_rank(self): ds = create_test_data(seed=1234) # only ds.var3 depends on dim3 z = ds.rank('dim3') - self.assertItemsEqual(['var3'], list(z.data_vars)) + assert ['var3'] == list(z.data_vars) # same as dataarray version x = z.var3 y = ds.var3.rank('dim3') assert_equal(x, y) # coordinates stick - self.assertItemsEqual(list(z.coords), list(ds.coords)) - self.assertItemsEqual(list(x.coords), list(y.coords)) + assert list(z.coords) == list(ds.coords) + assert list(x.coords) == list(y.coords) # invalid dim with raises_regex(ValueError, 'does not contain'): x.rank('invalid_dim') @@ -3633,10 +3805,6 @@ def test_dataset_transpose(self): expected = ds.apply(lambda x: x.transpose()) assert_identical(expected, actual) - with pytest.warns(FutureWarning): - actual = ds.T - assert_identical(expected, actual) - actual = ds.transpose('x', 'y') expected = ds.apply(lambda x: x.transpose('x', 'y')) assert_identical(expected, actual) @@ -3735,18 +3903,52 @@ def test_shift(self): with raises_regex(ValueError, 'dimensions'): ds.shift(foo=123) - def test_roll(self): + def test_roll_coords(self): coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} attrs = {'meta': 'data'} ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) - actual = ds.roll(x=1) + actual = ds.roll(x=1, roll_coords=True) ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) assert_identical(expected, actual) with raises_regex(ValueError, 'dimensions'): - ds.roll(foo=123) + ds.roll(foo=123, roll_coords=True) + + def test_roll_no_coords(self): + coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} + attrs = {'meta': 'data'} + ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + actual = ds.roll(x=1, roll_coords=False) + + expected = Dataset({'foo': ('x', [3, 1, 2])}, coords, attrs) + assert_identical(expected, actual) + + with raises_regex(ValueError, 'dimensions'): + ds.roll(abc=321, roll_coords=False) + + def test_roll_coords_none(self): + coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} + attrs = {'meta': 'data'} + ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + + with pytest.warns(FutureWarning): + actual = ds.roll(x=1, roll_coords=None) + + ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} + expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) + assert_identical(expected, actual) + + def test_roll_multidim(self): + # regression test for 2445 + arr = xr.DataArray( + [[1, 2, 3], [4, 5, 6]], coords={'x': range(3), 'y': range(2)}, + dims=('y', 'x')) + actual = arr.roll(x=1, roll_coords=True) + expected = xr.DataArray([[3, 1, 2], [6, 4, 5]], + coords=[('y', [0, 1]), ('x', [2, 0, 1])]) + assert_identical(expected, actual) def test_real_and_imag(self): attrs = {'foo': 'bar'} @@ -3806,6 +4008,26 @@ def test_filter_by_attrs(self): for var in new_ds.data_vars: assert new_ds[var].height == '10 m' + # Test return empty Dataset due to conflicting filters + new_ds = ds.filter_by_attrs( + standard_name='convective_precipitation_flux', + height='0 m') + assert not bool(new_ds.data_vars) + + # Test return one DataArray with two filter conditions + new_ds = ds.filter_by_attrs( + standard_name='air_potential_temperature', + height='0 m') + for var in new_ds.data_vars: + assert new_ds[var].standard_name == 'air_potential_temperature' + assert new_ds[var].height == '0 m' + assert new_ds[var].height != '10 m' + + # Test return empty Dataset due to conflicting callables + new_ds = ds.filter_by_attrs(standard_name=lambda v: False, + height=lambda v: True) + assert not bool(new_ds.data_vars) + def test_binary_op_join_setting(self): # arithmetic_join applies to data array coordinates missing_2 = xr.Dataset({'x': [0, 1]}) @@ -3935,9 +4157,6 @@ def test_sortby(self): actual = ds.sortby(ds['A']) assert "DataArray is not 1-D" in str(excinfo.value) - if LooseVersion(np.__version__) < LooseVersion('1.11.0'): - pytest.skip('numpy 1.11.0 or later to support object data-type.') - expected = sorted1d actual = ds.sortby('x') assert_equal(actual, expected) @@ -4030,9 +4249,110 @@ def test_ipython_key_completion(self): # Py.test tests -@pytest.fixture() -def data_set(seed=None): - return create_test_data(seed) +@pytest.fixture(params=[None]) +def data_set(request): + return create_test_data(request.param) + + +@pytest.mark.parametrize('test_elements', ( + [1, 2], + np.array([1, 2]), + DataArray([1, 2]), +)) +def test_isin(test_elements): + expected = Dataset( + data_vars={ + 'var1': (('dim1',), [0, 1]), + 'var2': (('dim1',), [1, 1]), + 'var3': (('dim1',), [0, 1]), + } + ).astype('bool') + + result = Dataset( + data_vars={ + 'var1': (('dim1',), [0, 1]), + 'var2': (('dim1',), [1, 2]), + 'var3': (('dim1',), [0, 1]), + } + ).isin(test_elements) + + assert_equal(result, expected) + + +@pytest.mark.skipif(not has_dask, reason='requires dask') +@pytest.mark.parametrize('test_elements', ( + [1, 2], + np.array([1, 2]), + DataArray([1, 2]), +)) +def test_isin_dask(test_elements): + expected = Dataset( + data_vars={ + 'var1': (('dim1',), [0, 1]), + 'var2': (('dim1',), [1, 1]), + 'var3': (('dim1',), [0, 1]), + } + ).astype('bool') + + result = Dataset( + data_vars={ + 'var1': (('dim1',), [0, 1]), + 'var2': (('dim1',), [1, 2]), + 'var3': (('dim1',), [0, 1]), + } + ).chunk(1).isin(test_elements).compute() + + assert_equal(result, expected) + + +def test_isin_dataset(): + ds = Dataset({'x': [1, 2]}) + with pytest.raises(TypeError): + ds.isin(ds) + + +@pytest.mark.parametrize('unaligned_coords', ( + {'x': [2, 1, 0]}, + {'x': (['x'], np.asarray([2, 1, 0]))}, + {'x': (['x'], np.asarray([1, 2, 0]))}, + {'x': pd.Index([2, 1, 0])}, + {'x': Variable(dims='x', data=[0, 2, 1])}, + {'x': IndexVariable(dims='x', data=[0, 1, 2])}, + {'y': 42}, + {'y': ('x', [2, 1, 0])}, + {'y': ('x', np.asarray([2, 1, 0]))}, + {'y': (['x'], np.asarray([2, 1, 0]))}, +)) +@pytest.mark.parametrize('coords', ( + {'x': ('x', [0, 1, 2])}, + {'x': [0, 1, 2]}, +)) +def test_dataset_constructor_aligns_to_explicit_coords( + unaligned_coords, coords): + + a = xr.DataArray([1, 2, 3], dims=['x'], coords=unaligned_coords) + + expected = xr.Dataset(coords=coords) + expected['a'] = a + + result = xr.Dataset({'a': a}, coords=coords) + + assert_equal(expected, result) + + +def test_error_message_on_set_supplied(): + with pytest.raises(TypeError, message='has invalid type set'): + xr.Dataset(dict(date=[1, 2, 3], sec={4})) + + +@pytest.mark.parametrize('unaligned_coords', ( + {'y': ('b', np.asarray([2, 1, 0]))}, +)) +def test_constructor_raises_with_invalid_coords(unaligned_coords): + + with pytest.raises(ValueError, + message='not a subset of the DataArray dimensions'): + xr.DataArray([1, 2, 3], dims=['x'], coords=unaligned_coords) def test_dir_expected_attrs(data_set): @@ -4047,7 +4367,13 @@ def test_dir_non_string(data_set): # add a numbered key to ensure this doesn't break dir data_set[5] = 'foo' result = dir(data_set) - assert not (5 in result) + assert 5 not in result + + # GH2172 + sample_data = np.random.uniform(size=[2, 2000, 10000]) + x = xr.Dataset({"sample_data": (sample_data.shape, sample_data)}) + x2 = x["sample_data"] + dir(x2) def test_dir_unicode(data_set): @@ -4133,12 +4459,36 @@ def test_rolling_pandas_compat(center, window, min_periods): min_periods=min_periods).mean() ds_rolling = ds.rolling(index=window, center=center, min_periods=min_periods).mean() - # pandas does some fancy stuff in the last position, - # we're not going to do that yet! - np.testing.assert_allclose(df_rolling['x'].values[:-1], - ds_rolling['x'].values[:-1]) - np.testing.assert_allclose(df_rolling.index, - ds_rolling['index']) + + np.testing.assert_allclose(df_rolling['x'].values, ds_rolling['x'].values) + np.testing.assert_allclose(df_rolling.index, ds_rolling['index']) + + +@pytest.mark.parametrize('center', (True, False)) +@pytest.mark.parametrize('window', (1, 2, 3, 4)) +def test_rolling_construct(center, window): + df = pd.DataFrame({'x': np.random.randn(20), 'y': np.random.randn(20), + 'time': np.linspace(0, 1, 20)}) + + ds = Dataset.from_dataframe(df) + df_rolling = df.rolling(window, center=center, min_periods=1).mean() + ds_rolling = ds.rolling(index=window, center=center) + + ds_rolling_mean = ds_rolling.construct('window').mean('window') + np.testing.assert_allclose(df_rolling['x'].values, + ds_rolling_mean['x'].values) + np.testing.assert_allclose(df_rolling.index, ds_rolling_mean['index']) + + # with stride + ds_rolling_mean = ds_rolling.construct('window', stride=2).mean('window') + np.testing.assert_allclose(df_rolling['x'][::2].values, + ds_rolling_mean['x'].values) + np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean['index']) + # with fill_value + ds_rolling_mean = ds_rolling.construct( + 'window', stride=2, fill_value=0.0).mean('window') + assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim='vars').all() + assert (ds_rolling_mean['x'] == 0.0).sum() >= 0 @pytest.mark.slow @@ -4176,3 +4526,107 @@ def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: Dataset(data_vars={'x': ('y', [1, 2, np.NaN])}) > 0 assert len(record) == 0 + + +@pytest.mark.parametrize('dask', [True, False]) +@pytest.mark.parametrize('edge_order', [1, 2]) +def test_differentiate(dask, edge_order): + rs = np.random.RandomState(42) + coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] + + da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], + coords={'x': coord, + 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + if dask and has_dask: + da = da.chunk({'x': 4}) + + ds = xr.Dataset({'var': da}) + + # along x + actual = da.differentiate('x', edge_order) + expected_x = xr.DataArray( + npcompat.gradient(da, da['x'], axis=0, edge_order=edge_order), + dims=da.dims, coords=da.coords) + assert_equal(expected_x, actual) + assert_equal(ds['var'].differentiate('x', edge_order=edge_order), + ds.differentiate('x', edge_order=edge_order)['var']) + # coordinate should not change + assert_equal(da['x'], actual['x']) + + # along y + actual = da.differentiate('y', edge_order) + expected_y = xr.DataArray( + npcompat.gradient(da, da['y'], axis=1, edge_order=edge_order), + dims=da.dims, coords=da.coords) + assert_equal(expected_y, actual) + assert_equal(actual, ds.differentiate('y', edge_order=edge_order)['var']) + assert_equal(ds['var'].differentiate('y', edge_order=edge_order), + ds.differentiate('y', edge_order=edge_order)['var']) + + with pytest.raises(ValueError): + da.differentiate('x2d') + + +@pytest.mark.parametrize('dask', [True, False]) +def test_differentiate_datetime(dask): + rs = np.random.RandomState(42) + coord = np.array( + ['2004-07-13', '2006-01-13', '2010-08-13', '2010-09-13', + '2010-10-11', '2010-12-13', '2011-02-13', '2012-08-13'], + dtype='datetime64') + + da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], + coords={'x': coord, + 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + if dask and has_dask: + da = da.chunk({'x': 4}) + + # along x + actual = da.differentiate('x', edge_order=1, datetime_unit='D') + expected_x = xr.DataArray( + npcompat.gradient( + da, utils.datetime_to_numeric(da['x'], datetime_unit='D'), + axis=0, edge_order=1), dims=da.dims, coords=da.coords) + assert_equal(expected_x, actual) + + actual2 = da.differentiate('x', edge_order=1, datetime_unit='h') + assert np.allclose(actual, actual2 * 24) + + # for datetime variable + actual = da['x'].differentiate('x', edge_order=1, datetime_unit='D') + assert np.allclose(actual, 1.0) + + # with different date unit + da = xr.DataArray(coord.astype('datetime64[ms]'), dims=['x'], + coords={'x': coord}) + actual = da.differentiate('x', edge_order=1) + assert np.allclose(actual, 1.0) + + +@pytest.mark.skipif(not has_cftime, reason='Test requires cftime.') +@pytest.mark.parametrize('dask', [True, False]) +def test_differentiate_cftime(dask): + rs = np.random.RandomState(42) + coord = xr.cftime_range('2000', periods=8, freq='2M') + + da = xr.DataArray( + rs.randn(8, 6), + coords={'time': coord, 'z': 3, 't2d': (('time', 'y'), rs.randn(8, 6))}, + dims=['time', 'y']) + + if dask and has_dask: + da = da.chunk({'time': 4}) + + actual = da.differentiate('time', edge_order=1, datetime_unit='D') + expected_data = npcompat.gradient( + da, utils.datetime_to_numeric(da['time'], datetime_unit='D'), + axis=0, edge_order=1) + expected = xr.DataArray(expected_data, coords=da.coords, dims=da.dims) + assert_equal(expected, actual) + + actual2 = da.differentiate('time', edge_order=1, datetime_unit='h') + assert_allclose(actual, actual2 * 24) + + # Test the differentiation of datetimes themselves + actual = da['time'].differentiate('time', edge_order=1, datetime_unit='D') + assert_allclose(actual, xr.ones_like(da['time']).astype(float)) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 1d450ff51d4..1837a0fe4ef 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -1,20 +1,43 @@ +""" isort:skip_file """ +from __future__ import absolute_import, division, print_function +import os import sys +import pickle +import tempfile import pytest -import xarray as xr -distributed = pytest.importorskip('distributed') -da = pytest.importorskip('dask.array') -import dask +dask = pytest.importorskip('dask', minversion='0.18') # isort:skip +distributed = pytest.importorskip('distributed', minversion='1.21') # isort:skip + +from dask import array +from dask.distributed import Client, Lock from distributed.utils_test import cluster, gen_cluster from distributed.utils_test import loop # flake8: noqa from distributed.client import futures_of +import numpy as np -from xarray.tests.test_backends import create_tmp_file, ON_WINDOWS +import xarray as xr +from xarray.backends.locks import HDF5_LOCK, CombinedLock +from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, + create_tmp_geotiff, + open_example_dataset) from xarray.tests.test_dataset import create_test_data -from . import (assert_allclose, has_scipy, has_netCDF4, has_h5netcdf, - requires_zarr) +from . import ( + assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, + requires_zarr, requires_cfgrib, raises_regex) + +# this is to stop isort throwing errors. May have been easier to just use +# `isort:skip` in retrospect + + +da = pytest.importorskip('dask.array') + + +@pytest.fixture +def tmp_netcdf_filename(tmpdir): + return str(tmpdir.join('testfile.nc')) ENGINES = [] @@ -25,28 +48,83 @@ if has_h5netcdf: ENGINES.append('h5netcdf') +NC_FORMATS = {'netcdf4': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT_OFFSET', + 'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'], + 'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'], + 'h5netcdf': ['NETCDF4']} + +ENGINES_AND_FORMATS = [ + ('netcdf4', 'NETCDF3_CLASSIC'), + ('netcdf4', 'NETCDF4_CLASSIC'), + ('netcdf4', 'NETCDF4'), + ('h5netcdf', 'NETCDF4'), + ('scipy', 'NETCDF3_64BIT'), +] + + +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_netcdf_roundtrip( + loop, tmp_netcdf_filename, engine, nc_format): + + if engine not in ENGINES: + pytest.skip('engine not available') + + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + + original = create_test_data().chunk(chunks) + + if engine == 'scipy': + with pytest.raises(NotImplementedError): + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + return + + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) + + +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_read_netcdf_integration_test( + loop, tmp_netcdf_filename, engine, nc_format): + + if engine not in ENGINES: + pytest.skip('engine not available') + + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ENGINES) -def test_dask_distributed_netcdf_integration_test(loop, engine): - with cluster() as (s, _): - with distributed.Client(s['address'], loop=loop): original = create_test_data() - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - original.to_netcdf(filename, engine=engine) - with xr.open_dataset(filename, chunks=3, engine=engine) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, + engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) + @requires_zarr def test_dask_distributed_zarr_integration_test(loop): - with cluster() as (s, _): - with distributed.Client(s['address'], loop=loop): - original = create_test_data() - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 5} + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + original = create_test_data().chunk(chunks) + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS, + suffix='.zarr') as filename: original.to_zarr(filename) with xr.open_zarr(filename) as restored: assert isinstance(restored.var1.data, da.Array) @@ -54,6 +132,31 @@ def test_dask_distributed_zarr_integration_test(loop): assert_allclose(original, computed) +@requires_rasterio +def test_dask_distributed_rasterio_integration_test(loop): + with create_tmp_geotiff() as (tmp_file, expected): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + da_tiff = xr.open_rasterio(tmp_file, chunks={'band': 1}) + assert isinstance(da_tiff.data, da.Array) + actual = da_tiff.compute() + assert_allclose(actual, expected) + + +@requires_cfgrib +def test_dask_distributed_cfgrib_integration_test(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + with open_example_dataset('example.grib', + engine='cfgrib', + chunks={'time': 1}) as ds: + with open_example_dataset('example.grib', + engine='cfgrib') as expected: + assert isinstance(ds['t'].data, da.Array) + actual = ds.compute() + assert_allclose(actual, expected) + + @pytest.mark.skipif(distributed.__version__ <= '1.19.3', reason='Need recent distributed version to clean up get') @gen_cluster(client=True, timeout=None) @@ -81,4 +184,26 @@ def test_async(c, s, a, b): assert not dask.is_dask_collection(w) assert_allclose(x + 10, w) - assert s.task_state + assert s.tasks + + +def test_hdf5_lock(): + assert isinstance(HDF5_LOCK, dask.utils.SerializableLock) + + +@gen_cluster(client=True) +def test_serializable_locks(c, s, a, b): + def f(x, lock=None): + with lock: + return x + 1 + + # note, the creation of Lock needs to be done inside a cluster + for lock in [HDF5_LOCK, Lock(), Lock('filename.nc'), + CombinedLock([HDF5_LOCK]), + CombinedLock([HDF5_LOCK, Lock('filename.nc')])]: + + futures = c.map(f, list(range(10)), lock=lock) + yield c.gather(futures) + + lock2 = pickle.loads(pickle.dumps(lock)) + assert type(lock) == type(lock2) diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py new file mode 100644 index 00000000000..292c60b4d05 --- /dev/null +++ b/xarray/tests/test_dtypes.py @@ -0,0 +1,88 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +import pytest + +from xarray.core import dtypes + + +@pytest.mark.parametrize("args, expected", [ + ([np.bool], np.bool), + ([np.bool, np.string_], np.object_), + ([np.float32, np.float64], np.float64), + ([np.float32, np.string_], np.object_), + ([np.unicode_, np.int64], np.object_), + ([np.unicode_, np.unicode_], np.unicode_), + ([np.bytes_, np.unicode_], np.object_), +]) +def test_result_type(args, expected): + actual = dtypes.result_type(*args) + assert actual == expected + + +def test_result_type_scalar(): + actual = dtypes.result_type(np.arange(3, dtype=np.float32), np.nan) + assert actual == np.float32 + + +def test_result_type_dask_array(): + # verify it works without evaluating dask arrays + da = pytest.importorskip('dask.array') + dask = pytest.importorskip('dask') + + def error(): + raise RuntimeError + + array = da.from_delayed(dask.delayed(error)(), (), np.float64) + with pytest.raises(RuntimeError): + array.compute() + + actual = dtypes.result_type(array) + assert actual == np.float64 + + # note that this differs from the behavior for scalar numpy arrays, which + # would get promoted to float32 + actual = dtypes.result_type(array, np.array([0.5, 1.0], dtype=np.float32)) + assert actual == np.float64 + + +@pytest.mark.parametrize('obj', [1.0, np.inf, 'ab', 1.0 + 1.0j, True]) +def test_inf(obj): + assert dtypes.INF > obj + assert dtypes.NINF < obj + + +@pytest.mark.parametrize("kind, expected", [ + ('a', (np.dtype('O'), 'nan')), # dtype('S') + ('b', (np.float32, 'nan')), # dtype('int8') + ('B', (np.float32, 'nan')), # dtype('uint8') + ('c', (np.dtype('O'), 'nan')), # dtype('S1') + ('D', (np.complex128, '(nan+nanj)')), # dtype('complex128') + ('d', (np.float64, 'nan')), # dtype('float64') + ('e', (np.float16, 'nan')), # dtype('float16') + ('F', (np.complex64, '(nan+nanj)')), # dtype('complex64') + ('f', (np.float32, 'nan')), # dtype('float32') + ('h', (np.float32, 'nan')), # dtype('int16') + ('H', (np.float32, 'nan')), # dtype('uint16') + ('i', (np.float64, 'nan')), # dtype('int32') + ('I', (np.float64, 'nan')), # dtype('uint32') + ('l', (np.float64, 'nan')), # dtype('int64') + ('L', (np.float64, 'nan')), # dtype('uint64') + ('m', (np.timedelta64, 'NaT')), # dtype(' 0: + assert isinstance(da.data, dask_array_type) + + +@pytest.mark.parametrize('dim_num', [1, 2]) +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'min', 'max', 'mean', 'var']) +# TODO test cumsum, cumprod +@pytest.mark.parametrize('skipna', [False, True]) +@pytest.mark.parametrize('aggdim', [None, 'x']) +def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): + + if aggdim == 'y' and dim_num < 2: + pytest.skip('dim not in this test') + + if dtype == np.bool_ and func == 'mean': + pytest.skip('numpy does not support this') + + if dask and not has_dask: + pytest.skip('requires dask') + + if dask and skipna is False and dtype in [np.bool_]: + pytest.skip('dask does not compute object-typed array') + + rtol = 1e-04 if dtype == np.float32 else 1e-05 + + da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + axis = None if aggdim is None else da.get_axis_num(aggdim) + + # TODO: remove these after resolving + # https://github.com/dask/dask/issues/3245 + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + warnings.filterwarnings('ignore', 'All-NaN slice') + warnings.filterwarnings('ignore', 'invalid value encountered in') + + if has_np113 and da.dtype.kind == 'O' and skipna: + # Numpy < 1.13 does not handle object-type array. + try: + if skipna: + expected = getattr(np, 'nan{}'.format(func))(da.values, + axis=axis) + else: + expected = getattr(np, func)(da.values, axis=axis) + + actual = getattr(da, func)(skipna=skipna, dim=aggdim) + assert_dask_array(actual, dask) + assert np.allclose(actual.values, np.array(expected), + rtol=1.0e-4, equal_nan=True) + except (TypeError, AttributeError, ZeroDivisionError): + # TODO currently, numpy does not support some methods such as + # nanmean for object dtype + pass + + actual = getattr(da, func)(skipna=skipna, dim=aggdim) + + # for dask case, make sure the result is the same for numpy backend + expected = getattr(da.compute(), func)(skipna=skipna, dim=aggdim) + assert_allclose(actual, expected, rtol=rtol) + + # make sure the compatiblility with pandas' results. + if func in ['var', 'std']: + expected = series_reduce(da, func, skipna=skipna, dim=aggdim, + ddof=0) + assert_allclose(actual, expected, rtol=rtol) + # also check ddof!=0 case + actual = getattr(da, func)(skipna=skipna, dim=aggdim, ddof=5) + if dask: + assert isinstance(da.data, dask_array_type) + expected = series_reduce(da, func, skipna=skipna, dim=aggdim, + ddof=5) + assert_allclose(actual, expected, rtol=rtol) + else: + expected = series_reduce(da, func, skipna=skipna, dim=aggdim) + assert_allclose(actual, expected, rtol=rtol) + + # make sure the dtype argument + if func not in ['max', 'min']: + actual = getattr(da, func)(skipna=skipna, dim=aggdim, dtype=float) + assert_dask_array(actual, dask) + assert actual.dtype == float + + # without nan + da = construct_dataarray(dim_num, dtype, contains_nan=False, dask=dask) + actual = getattr(da, func)(skipna=skipna) + if dask: + assert isinstance(da.data, dask_array_type) + expected = getattr(np, 'nan{}'.format(func))(da.values) + if actual.dtype == object: + assert actual.values == np.array(expected) + else: + assert np.allclose(actual.values, np.array(expected), rtol=rtol) + + +@pytest.mark.parametrize('dim_num', [1, 2]) +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_, str]) +@pytest.mark.parametrize('contains_nan', [True, False]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['min', 'max']) +@pytest.mark.parametrize('skipna', [False, True]) +@pytest.mark.parametrize('aggdim', ['x', 'y']) +def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): + # pandas-dev/pandas#16830, we do not check consistency with pandas but + # just make sure da[da.argmin()] == da.min() + + if aggdim == 'y' and dim_num < 2: + pytest.skip('dim not in this test') + + if dask and not has_dask: + pytest.skip('requires dask') + + if contains_nan: + if not skipna: + pytest.skip("numpy's argmin (not nanargmin) does not handle " + "object-dtype") + if skipna and np.dtype(dtype).kind in 'iufc': + pytest.skip("numpy's nanargmin raises ValueError for all nan axis") + da = construct_dataarray(dim_num, dtype, contains_nan=contains_nan, + dask=dask) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'All-NaN slice') + + actual = da.isel(**{aggdim: getattr(da, 'arg' + func) + (dim=aggdim, skipna=skipna).compute()}) + expected = getattr(da, func)(dim=aggdim, skipna=skipna) + assert_allclose(actual.drop(actual.coords), + expected.drop(expected.coords)) + + +def test_argmin_max_error(): + da = construct_dataarray(2, np.bool_, contains_nan=True, dask=False) + da[0] = np.nan + with pytest.raises(ValueError): + da.argmin(dim='y') + + +@requires_dask +def test_isnull_with_dask(): + da = construct_dataarray(2, np.float32, contains_nan=True, dask=True) + assert isinstance(da.isnull().data, dask_array_type) + assert_equal(da.isnull().load(), da.load().isnull()) + + +@pytest.mark.skipif(not has_dask, reason='This is for dask.') +@pytest.mark.parametrize('axis', [0, -1]) +@pytest.mark.parametrize('window', [3, 8, 11]) +@pytest.mark.parametrize('center', [True, False]) +def test_dask_rolling(axis, window, center): + import dask.array as da + + x = np.array(np.random.randn(100, 40), dtype=float) + dx = da.from_array(x, chunks=[(6, 30, 30, 20, 14), 8]) + + expected = rolling_window(x, axis=axis, window=window, center=center, + fill_value=np.nan) + actual = rolling_window(dx, axis=axis, window=window, center=center, + fill_value=np.nan) + assert isinstance(actual, da.Array) + assert_array_equal(actual, expected) + assert actual.shape == expected.shape + + # we need to take care of window size if chunk size is small + # window/2 should be smaller than the smallest chunk size. + with pytest.raises(ValueError): + rolling_window(dx, axis=axis, window=100, center=center, + fill_value=np.nan) + + +@pytest.mark.skipif(not has_dask, reason='This is for dask.') +@pytest.mark.parametrize('axis', [0, -1, 1]) +@pytest.mark.parametrize('edge_order', [1, 2]) +def test_dask_gradient(axis, edge_order): + import dask.array as da + + array = np.array(np.random.randn(100, 5, 40)) + x = np.exp(np.linspace(0, 1, array.shape[axis])) + + darray = da.from_array(array, chunks=[(6, 30, 30, 20, 14), 5, 8]) + expected = gradient(array, x, axis=axis, edge_order=edge_order) + actual = gradient(darray, x, axis=axis, edge_order=edge_order) + + assert isinstance(actual, da.Array) + assert_array_equal(actual, expected) + + +@pytest.mark.parametrize('dim_num', [1, 2]) +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'prod']) +@pytest.mark.parametrize('aggdim', [None, 'x']) +def test_min_count(dim_num, dtype, dask, func, aggdim): + if dask and not has_dask: + pytest.skip('requires dask') + + da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + min_count = 3 + + actual = getattr(da, func)(dim=aggdim, skipna=True, min_count=min_count) + + if LooseVersion(pd.__version__) >= LooseVersion('0.22.0'): + # min_count is only implenented in pandas > 0.22 + expected = series_reduce(da, func, skipna=True, dim=aggdim, + min_count=min_count) + assert_allclose(actual, expected) + + assert_dask_array(actual, dask) + + +@pytest.mark.parametrize('func', ['sum', 'prod']) +def test_min_count_dataset(func): + da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False) + ds = Dataset({'var1': da}, coords={'scalar': 0}) + actual = getattr(ds, func)(dim='x', skipna=True, min_count=3)['var1'] + expected = getattr(ds['var1'], func)(dim='x', skipna=True, min_count=3) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'prod']) +def test_multiple_dims(dtype, dask, func): + if dask and not has_dask: + pytest.skip('requires dask') + da = construct_dataarray(3, dtype, contains_nan=True, dask=dask) + + actual = getattr(da, func)(('x', 'y')) + expected = getattr(getattr(da, func)('x'), func)('y') + assert_allclose(actual, expected) + + +def test_docs(): + # with min_count + actual = DataArray.sum.__doc__ + expected = dedent("""\ + Reduce this DataArray's data by applying `sum` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply `sum`. + axis : int or sequence of int, optional + Axis(es) over which to apply `sum`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `sum` is calculated over axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default None + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will + be NA. New in version 0.10.8: Added with the default being None. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `sum` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `sum` applied to its data and the + indicated dimension(s) removed. + """) + assert actual == expected + + # without min_count + actual = DataArray.std.__doc__ + expected = dedent("""\ + Reduce this DataArray's data by applying `std` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply `std`. + axis : int or sequence of int, optional + Axis(es) over which to apply `std`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `std` is calculated over axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `std` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `std` applied to its data and the + indicated dimension(s) removed. + """) + assert actual == expected diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 9456f335572..ffefa78aa34 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -1,16 +1,16 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +import pytest + +import xarray as xr + +from . import raises_regex + try: import cPickle as pickle except ImportError: import pickle -import xarray as xr - -from . import TestCase, raises_regex -import pytest - @xr.register_dataset_accessor('example_accessor') @xr.register_dataarray_accessor('example_accessor') @@ -21,7 +21,7 @@ def __init__(self, xarray_obj): self.obj = xarray_obj -class TestAccessor(TestCase): +class TestAccessor(object): def test_register(self): @xr.register_dataset_accessor('demo') diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 53342825dcd..024c669bed9 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -1,41 +1,61 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import numpy as np import pandas as pd from xarray.core import formatting from xarray.core.pycompat import PY3 -from . import TestCase, raises_regex +from . import raises_regex -class TestFormatting(TestCase): +class TestFormatting(object): def test_get_indexer_at_least_n_items(self): cases = [ - ((20,), (slice(10),)), - ((3, 20,), (0, slice(10))), - ((2, 10,), (0, slice(10))), - ((2, 5,), (slice(2), slice(None))), - ((1, 2, 5,), (0, slice(2), slice(None))), - ((2, 3, 5,), (0, slice(2), slice(None))), - ((1, 10, 1,), (0, slice(10), slice(None))), - ((2, 5, 1,), (slice(2), slice(None), slice(None))), - ((2, 5, 3,), (0, slice(4), slice(None))), - ((2, 3, 3,), (slice(2), slice(None), slice(None))), + ((20,), (slice(10),), (slice(-10, None),)), + ((3, 20,), (0, slice(10)), (-1, slice(-10, None))), + ((2, 10,), (0, slice(10)), (-1, slice(-10, None))), + ((2, 5,), (slice(2), slice(None)), + (slice(-2, None), slice(None))), + ((1, 2, 5,), (0, slice(2), slice(None)), + (-1, slice(-2, None), slice(None))), + ((2, 3, 5,), (0, slice(2), slice(None)), + (-1, slice(-2, None), slice(None))), + ((1, 10, 1,), (0, slice(10), slice(None)), + (-1, slice(-10, None), slice(None))), + ((2, 5, 1,), (slice(2), slice(None), slice(None)), + (slice(-2, None), slice(None), slice(None))), + ((2, 5, 3,), (0, slice(4), slice(None)), + (-1, slice(-4, None), slice(None))), + ((2, 3, 3,), (slice(2), slice(None), slice(None)), + (slice(-2, None), slice(None), slice(None))), ] - for shape, expected in cases: - actual = formatting._get_indexer_at_least_n_items(shape, 10) - assert expected == actual + for shape, start_expected, end_expected in cases: + actual = formatting._get_indexer_at_least_n_items(shape, 10, + from_end=False) + assert start_expected == actual + actual = formatting._get_indexer_at_least_n_items(shape, 10, + from_end=True) + assert end_expected == actual def test_first_n_items(self): array = np.arange(100).reshape(10, 5, 2) for n in [3, 10, 13, 100, 200]: actual = formatting.first_n_items(array, n) expected = array.flat[:n] - self.assertItemsEqual(expected, actual) + assert (expected == actual).all() + + with raises_regex(ValueError, 'at least one item'): + formatting.first_n_items(array, 0) + + def test_last_n_items(self): + array = np.arange(100).reshape(10, 5, 2) + for n in [3, 10, 13, 100, 200]: + actual = formatting.last_n_items(array, n) + expected = array.flat[-n:] + assert (expected == actual).all() with raises_regex(ValueError, 'at least one item'): formatting.first_n_items(array, 0) @@ -88,16 +108,32 @@ def test_format_items(self): assert expected == actual def test_format_array_flat(self): + actual = formatting.format_array_flat(np.arange(100), 2) + expected = '0 ... 99' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 9) + expected = '0 ... 99' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 10) + expected = '0 1 ... 99' + assert expected == actual + actual = formatting.format_array_flat(np.arange(100), 13) - expected = '0 1 2 3 4 ...' + expected = '0 1 ... 98 99' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 15) + expected = '0 1 2 ... 98 99' assert expected == actual actual = formatting.format_array_flat(np.arange(100.0), 11) - expected = '0.0 1.0 ...' + expected = '0.0 ... 99.0' assert expected == actual actual = formatting.format_array_flat(np.arange(100.0), 1) - expected = '0.0 ...' + expected = '0.0 ... 99.0' assert expected == actual actual = formatting.format_array_flat(np.arange(3), 5) @@ -105,11 +141,23 @@ def test_format_array_flat(self): assert expected == actual actual = formatting.format_array_flat(np.arange(4.0), 11) - expected = '0.0 1.0 ...' + expected = '0.0 ... 3.0' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(0), 0) + expected = '' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(1), 0) + expected = '0' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(2), 0) + expected = '0 1' assert expected == actual actual = formatting.format_array_flat(np.arange(4), 0) - expected = '0 ...' + expected = '0 ... 3' assert expected == actual def test_pretty_print(self): diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index f1d80954295..8ace55be66b 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1,12 +1,13 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import numpy as np import pandas as pd +import pytest + import xarray as xr from xarray.core.groupby import _consolidate_slices -import pytest +from . import assert_identical def test_consolidate_slices(): @@ -74,4 +75,14 @@ def test_groupby_duplicate_coordinate_labels(): assert expected.equals(actual) +def test_groupby_input_mutation(): + # regression test for GH2153 + array = xr.DataArray([1, 2, 3], [('x', [2, 2, 1])]) + array_copy = array.copy() + expected = xr.DataArray([3, 3], [('x', [1, 2])]) + actual = array.groupby('x').sum() + assert_identical(expected, actual) + assert_identical(array, array_copy) # should not modify inputs + + # TODO: move other groupby tests from test_dataset and test_dataarray over here diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 3d93afb26d4..701eefcb462 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -1,25 +1,21 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import itertools +from __future__ import absolute_import, division, print_function -import pytest +import itertools import numpy as np import pandas as pd +import pytest -from xarray import Dataset, DataArray, Variable -from xarray.core import indexing -from xarray.core import nputils +from xarray import DataArray, Dataset, Variable +from xarray.core import indexing, nputils from xarray.core.pycompat import native_int_types -from . import ( - TestCase, ReturnItem, raises_regex, IndexerMaker, assert_array_equal) +from . import IndexerMaker, ReturnItem, assert_array_equal, raises_regex B = IndexerMaker(indexing.BasicIndexer) -class TestIndexers(TestCase): +class TestIndexers(object): def set_to_zero(self, x, i): x = x.copy() x[i] = 0 @@ -28,7 +24,7 @@ def set_to_zero(self, x, i): def test_expanded_indexer(self): x = np.random.randn(10, 11, 12, 13, 14) y = np.arange(5) - I = ReturnItem() # noqa: E741 # allow ambiguous name + I = ReturnItem() # noqa for i in [I[:], I[...], I[0, :, 10], I[..., 10], I[:5, ..., 0], I[..., 0, :], I[y], I[y, y], I[..., y, y], I[..., 0, 1, 2, 3, 4]]: @@ -114,47 +110,49 @@ def test_indexer(data, x, expected_pos, expected_idx=None): test_indexer(data, Variable([], 1), 0) test_indexer(mdata, ('a', 1, -1), 0) test_indexer(mdata, ('a', 1), - [True, True, False, False, False, False, False, False], + [True, True, False, False, False, False, False, False], [-1, -2]) test_indexer(mdata, 'a', slice(0, 4, None), pd.MultiIndex.from_product([[1, 2], [-1, -2]])) test_indexer(mdata, ('a',), - [True, True, True, True, False, False, False, False], + [True, True, True, True, False, False, False, False], pd.MultiIndex.from_product([[1, 2], [-1, -2]])) test_indexer(mdata, [('a', 1, -1), ('b', 2, -2)], [0, 7]) test_indexer(mdata, slice('a', 'b'), slice(0, 8, None)) test_indexer(mdata, slice(('a', 1), ('b', 1)), slice(0, 6, None)) test_indexer(mdata, {'one': 'a', 'two': 1, 'three': -1}, 0) test_indexer(mdata, {'one': 'a', 'two': 1}, - [True, True, False, False, False, False, False, False], + [True, True, False, False, False, False, False, False], [-1, -2]) test_indexer(mdata, {'one': 'a', 'three': -1}, - [True, False, True, False, False, False, False, False], + [True, False, True, False, False, False, False, False], [1, 2]) test_indexer(mdata, {'one': 'a'}, - [True, True, True, True, False, False, False, False], + [True, True, True, True, False, False, False, False], pd.MultiIndex.from_product([[1, 2], [-1, -2]])) -class TestLazyArray(TestCase): +class TestLazyArray(object): def test_slice_slice(self): I = ReturnItem() # noqa: E741 # allow ambiguous name - x = np.arange(100) - slices = [I[:3], I[:4], I[2:4], I[:1], I[:-1], I[5:-1], I[-5:-1], - I[::-1], I[5::-1], I[:3:-1], I[:30:-1], I[10:4:], I[::4], - I[4:4:4], I[:4:-4]] - for i in slices: - for j in slices: - expected = x[i][j] - new_slice = indexing.slice_slice(i, j, size=100) - actual = x[new_slice] - assert_array_equal(expected, actual) + for size in [100, 99]: + # We test even/odd size cases + x = np.arange(size) + slices = [I[:3], I[:4], I[2:4], I[:1], I[:-1], I[5:-1], I[-5:-1], + I[::-1], I[5::-1], I[:3:-1], I[:30:-1], I[10:4:], I[::4], + I[4:4:4], I[:4:-4], I[::-2]] + for i in slices: + for j in slices: + expected = x[i][j] + new_slice = indexing.slice_slice(i, j, size=size) + actual = x[new_slice] + assert_array_equal(expected, actual) def test_lazily_indexed_array(self): original = np.random.rand(10, 20, 30) x = indexing.NumpyIndexingAdapter(original) v = Variable(['i', 'j', 'k'], original) - lazy = indexing.LazilyIndexedArray(x) + lazy = indexing.LazilyOuterIndexedArray(x) v_lazy = Variable(['i', 'j', 'k'], lazy) I = ReturnItem() # noqa: E741 # allow ambiguous name # test orthogonally applied indexers @@ -173,7 +171,7 @@ def test_lazily_indexed_array(self): assert expected.shape == actual.shape assert_array_equal(expected, actual) assert isinstance(actual._data, - indexing.LazilyIndexedArray) + indexing.LazilyOuterIndexedArray) # make sure actual.key is appropriate type if all(isinstance(k, native_int_types + (slice, )) @@ -188,16 +186,68 @@ def test_lazily_indexed_array(self): indexers = [(3, 2), (I[:], 0), (I[:2], -1), (I[:4], [0]), ([4, 5], 0), ([0, 1, 2], [0, 1]), ([0, 3, 5], I[:2])] for i, j in indexers: - expected = np.asarray(v[i][j]) + expected = v[i][j] actual = v_lazy[i][j] assert expected.shape == actual.shape assert_array_equal(expected, actual) - assert isinstance(actual._data, indexing.LazilyIndexedArray) + + # test transpose + if actual.ndim > 1: + order = np.random.choice(actual.ndim, actual.ndim) + order = np.array(actual.dims) + transposed = actual.transpose(*order) + assert_array_equal(expected.transpose(*order), transposed) + assert isinstance( + actual._data, (indexing.LazilyVectorizedIndexedArray, + indexing.LazilyOuterIndexedArray)) + + assert isinstance(actual._data, indexing.LazilyOuterIndexedArray) assert isinstance(actual._data.array, indexing.NumpyIndexingAdapter) + def test_vectorized_lazily_indexed_array(self): + original = np.random.rand(10, 20, 30) + x = indexing.NumpyIndexingAdapter(original) + v_eager = Variable(['i', 'j', 'k'], x) + lazy = indexing.LazilyOuterIndexedArray(x) + v_lazy = Variable(['i', 'j', 'k'], lazy) + I = ReturnItem() # noqa: E741 # allow ambiguous name + + def check_indexing(v_eager, v_lazy, indexers): + for indexer in indexers: + actual = v_lazy[indexer] + expected = v_eager[indexer] + assert expected.shape == actual.shape + assert isinstance(actual._data, + (indexing.LazilyVectorizedIndexedArray, + indexing.LazilyOuterIndexedArray)) + assert_array_equal(expected, actual) + v_eager = expected + v_lazy = actual + + # test orthogonal indexing + indexers = [(I[:], 0, 1), (Variable('i', [0, 1]), )] + check_indexing(v_eager, v_lazy, indexers) + + # vectorized indexing + indexers = [ + (Variable('i', [0, 1]), Variable('i', [0, 1]), slice(None)), + (slice(1, 3, 2), 0)] + check_indexing(v_eager, v_lazy, indexers) + + indexers = [ + (slice(None, None, 2), 0, slice(None, 10)), + (Variable('i', [3, 2, 4, 3]), Variable('i', [3, 2, 1, 0])), + (Variable(['i', 'j'], [[0, 1], [1, 2]]), )] + check_indexing(v_eager, v_lazy, indexers) + + indexers = [ + (Variable('i', [3, 2, 4, 3]), Variable('i', [3, 2, 1, 0])), + (Variable(['i', 'j'], [[0, 1], [1, 2]]), )] + check_indexing(v_eager, v_lazy, indexers) -class TestCopyOnWriteArray(TestCase): + +class TestCopyOnWriteArray(object): def test_setitem(self): original = np.arange(10) wrapped = indexing.CopyOnWriteArray(original) @@ -221,21 +271,21 @@ def test_index_scalar(self): assert np.array(x[B[0]][B[()]]) == 'foo' -class TestMemoryCachedArray(TestCase): +class TestMemoryCachedArray(object): def test_wrapper(self): - original = indexing.LazilyIndexedArray(np.arange(10)) + original = indexing.LazilyOuterIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) assert_array_equal(wrapped, np.arange(10)) assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter) def test_sub_array(self): - original = indexing.LazilyIndexedArray(np.arange(10)) + original = indexing.LazilyOuterIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) child = wrapped[B[:5]] assert isinstance(child, indexing.MemoryCachedArray) assert_array_equal(child, np.arange(5)) assert isinstance(child.array, indexing.NumpyIndexingAdapter) - assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + assert isinstance(wrapped.array, indexing.LazilyOuterIndexedArray) def test_setitem(self): original = np.arange(10) @@ -334,21 +384,127 @@ def test_vectorized_indexer(): np.arange(5, dtype=np.int64))) -def test_unwrap_explicit_indexer(): - indexer = indexing.BasicIndexer((1, 2)) - target = None - - unwrapped = indexing.unwrap_explicit_indexer( - indexer, target, allow=indexing.BasicIndexer) - assert unwrapped == (1, 2) - - with raises_regex(NotImplementedError, 'Load your data'): - indexing.unwrap_explicit_indexer( - indexer, target, allow=indexing.OuterIndexer) - - with raises_regex(TypeError, 'unexpected key type'): - indexing.unwrap_explicit_indexer( - indexer.tuple, target, allow=indexing.OuterIndexer) +class Test_vectorized_indexer(object): + @pytest.fixture(autouse=True) + def setup(self): + self.data = indexing.NumpyIndexingAdapter(np.random.randn(10, 12, 13)) + self.indexers = [np.array([[0, 3, 2], ]), + np.array([[0, 3, 3], [4, 6, 7]]), + slice(2, -2, 2), slice(2, -2, 3), slice(None)] + + def test_arrayize_vectorized_indexer(self): + for i, j, k in itertools.product(self.indexers, repeat=3): + vindex = indexing.VectorizedIndexer((i, j, k)) + vindex_array = indexing._arrayize_vectorized_indexer( + vindex, self.data.shape) + np.testing.assert_array_equal( + self.data[vindex], self.data[vindex_array],) + + actual = indexing._arrayize_vectorized_indexer( + indexing.VectorizedIndexer((slice(None),)), shape=(5,)) + np.testing.assert_array_equal(actual.tuple, [np.arange(5)]) + + actual = indexing._arrayize_vectorized_indexer( + indexing.VectorizedIndexer((np.arange(5),) * 3), shape=(8, 10, 12)) + expected = np.stack([np.arange(5)] * 3) + np.testing.assert_array_equal(np.stack(actual.tuple), expected) + + actual = indexing._arrayize_vectorized_indexer( + indexing.VectorizedIndexer((np.arange(5), slice(None))), + shape=(8, 10)) + a, b = actual.tuple + np.testing.assert_array_equal(a, np.arange(5)[:, np.newaxis]) + np.testing.assert_array_equal(b, np.arange(10)[np.newaxis, :]) + + actual = indexing._arrayize_vectorized_indexer( + indexing.VectorizedIndexer((slice(None), np.arange(5))), + shape=(8, 10)) + a, b = actual.tuple + np.testing.assert_array_equal(a, np.arange(8)[np.newaxis, :]) + np.testing.assert_array_equal(b, np.arange(5)[:, np.newaxis]) + + +def get_indexers(shape, mode): + if mode == 'vectorized': + indexed_shape = (3, 4) + indexer = tuple(np.random.randint(0, s, size=indexed_shape) + for s in shape) + return indexing.VectorizedIndexer(indexer) + + elif mode == 'outer': + indexer = tuple(np.random.randint(0, s, s + 2) for s in shape) + return indexing.OuterIndexer(indexer) + + elif mode == 'outer_scalar': + indexer = (np.random.randint(0, 3, 4), 0, slice(None, None, 2)) + return indexing.OuterIndexer(indexer[:len(shape)]) + + elif mode == 'outer_scalar2': + indexer = (np.random.randint(0, 3, 4), -2, slice(None, None, 2)) + return indexing.OuterIndexer(indexer[:len(shape)]) + + elif mode == 'outer1vec': + indexer = [slice(2, -3) for s in shape] + indexer[1] = np.random.randint(0, shape[1], shape[1] + 2) + return indexing.OuterIndexer(tuple(indexer)) + + elif mode == 'basic': # basic indexer + indexer = [slice(2, -3) for s in shape] + indexer[0] = 3 + return indexing.BasicIndexer(tuple(indexer)) + + elif mode == 'basic1': # basic indexer + return indexing.BasicIndexer((3, )) + + elif mode == 'basic2': # basic indexer + indexer = [0, 2, 4] + return indexing.BasicIndexer(tuple(indexer[:len(shape)])) + + elif mode == 'basic3': # basic indexer + indexer = [slice(None) for s in shape] + indexer[0] = slice(-2, 2, -2) + indexer[1] = slice(1, -1, 2) + return indexing.BasicIndexer(tuple(indexer[:len(shape)])) + + +@pytest.mark.parametrize('size', [100, 99]) +@pytest.mark.parametrize('sl', [slice(1, -1, 1), slice(None, -1, 2), + slice(-1, 1, -1), slice(-1, 1, -2)]) +def test_decompose_slice(size, sl): + x = np.arange(size) + slice1, slice2 = indexing._decompose_slice(sl, size) + expected = x[sl] + actual = x[slice1][slice2] + assert_array_equal(expected, actual) + + +@pytest.mark.parametrize('shape', [(10, 5, 8), (10, 3)]) +@pytest.mark.parametrize('indexer_mode', + ['vectorized', 'outer', 'outer_scalar', + 'outer_scalar2', 'outer1vec', + 'basic', 'basic1', 'basic2', 'basic3']) +@pytest.mark.parametrize('indexing_support', + [indexing.IndexingSupport.BASIC, + indexing.IndexingSupport.OUTER, + indexing.IndexingSupport.OUTER_1VECTOR, + indexing.IndexingSupport.VECTORIZED]) +def test_decompose_indexers(shape, indexer_mode, indexing_support): + data = np.random.randn(*shape) + indexer = get_indexers(shape, indexer_mode) + + backend_ind, np_ind = indexing.decompose_indexer( + indexer, shape, indexing_support) + + expected = indexing.NumpyIndexingAdapter(data)[indexer] + array = indexing.NumpyIndexingAdapter(data)[backend_ind] + if len(np_ind.tuple) > 0: + array = indexing.NumpyIndexingAdapter(array)[np_ind] + np.testing.assert_array_equal(expected, array) + + if not all(isinstance(k, indexing.integer_types) for k in np_ind.tuple): + combined_ind = indexing._combine_indexers(backend_ind, shape, np_ind) + array = indexing.NumpyIndexingAdapter(data)[combined_ind] + np.testing.assert_array_equal(expected, array) def test_implicit_indexing_adapter(): @@ -385,7 +541,8 @@ def nonzero(x): expected_data = np.moveaxis(expected_data, old_order, new_order) - outer_index = (nonzero(i), nonzero(j), nonzero(k)) + outer_index = indexing.OuterIndexer((nonzero(i), nonzero(j), + nonzero(k))) actual = indexing._outer_to_numpy_indexer(outer_index, v.shape) actual_data = v.data[actual] np.testing.assert_array_equal(actual_data, expected_data) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py new file mode 100644 index 00000000000..624879cce1f --- /dev/null +++ b/xarray/tests/test_interp.py @@ -0,0 +1,575 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.tests import ( + assert_allclose, assert_equal, requires_cftime, requires_scipy) + +from . import has_dask, has_scipy +from ..coding.cftimeindex import _parse_array_of_cftime_strings +from .test_dataset import create_test_data + +try: + import scipy +except ImportError: + pass + + +def get_example_data(case): + x = np.linspace(0, 1, 100) + y = np.linspace(0, 0.1, 30) + data = xr.DataArray( + np.sin(x[:, np.newaxis]) * np.cos(y), dims=['x', 'y'], + coords={'x': x, 'y': y, 'x2': ('x', x**2)}) + + if case == 0: + return data + elif case == 1: + return data.chunk({'y': 3}) + elif case == 2: + return data.chunk({'x': 25, 'y': 3}) + elif case == 3: + x = np.linspace(0, 1, 100) + y = np.linspace(0, 0.1, 30) + z = np.linspace(0.1, 0.2, 10) + return xr.DataArray( + np.sin(x[:, np.newaxis, np.newaxis]) * np.cos( + y[:, np.newaxis]) * z, + dims=['x', 'y', 'z'], + coords={'x': x, 'y': y, 'x2': ('x', x**2), 'z': z}) + elif case == 4: + return get_example_data(3).chunk({'z': 5}) + + +def test_keywargs(): + if not has_scipy: + pytest.skip('scipy is not installed.') + + da = get_example_data(0) + assert_equal(da.interp(x=[0.5, 0.8]), da.interp({'x': [0.5, 0.8]})) + + +@pytest.mark.parametrize('method', ['linear', 'cubic']) +@pytest.mark.parametrize('dim', ['x', 'y']) +@pytest.mark.parametrize('case', [0, 1]) +def test_interpolate_1d(method, dim, case): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and case in [1]: + pytest.skip('dask is not installed in the environment.') + + da = get_example_data(case) + xdest = np.linspace(0.0, 0.9, 80) + + if dim == 'y' and case == 1: + with pytest.raises(NotImplementedError): + actual = da.interp(method=method, **{dim: xdest}) + pytest.skip('interpolation along chunked dimension is ' + 'not yet supported') + + actual = da.interp(method=method, **{dim: xdest}) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da[dim], obj.data, axis=obj.get_axis_num(dim), bounds_error=False, + fill_value=np.nan, kind=method)(new_x) + + if dim == 'x': + coords = {'x': xdest, 'y': da['y'], 'x2': ('x', func(da['x2'], xdest))} + else: # y + coords = {'x': da['x'], 'y': xdest, 'x2': da['x2']} + + expected = xr.DataArray(func(da, xdest), dims=['x', 'y'], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('method', ['cubic', 'zero']) +def test_interpolate_1d_methods(method): + if not has_scipy: + pytest.skip('scipy is not installed.') + + da = get_example_data(0) + dim = 'x' + xdest = np.linspace(0.0, 0.9, 80) + + actual = da.interp(method=method, **{dim: xdest}) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da[dim], obj.data, axis=obj.get_axis_num(dim), bounds_error=False, + fill_value=np.nan, kind=method)(new_x) + + coords = {'x': xdest, 'y': da['y'], 'x2': ('x', func(da['x2'], xdest))} + expected = xr.DataArray(func(da, xdest), dims=['x', 'y'], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('use_dask', [False, True]) +def test_interpolate_vectorize(use_dask): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and use_dask: + pytest.skip('dask is not installed in the environment.') + + # scipy interpolation for the reference + def func(obj, dim, new_x): + shape = [s for i, s in enumerate(obj.shape) + if i != obj.get_axis_num(dim)] + for s in new_x.shape[::-1]: + shape.insert(obj.get_axis_num(dim), s) + + return scipy.interpolate.interp1d( + da[dim], obj.data, axis=obj.get_axis_num(dim), + bounds_error=False, fill_value=np.nan)(new_x).reshape(shape) + + da = get_example_data(0) + if use_dask: + da = da.chunk({'y': 5}) + + # xdest is 1d but has different dimension + xdest = xr.DataArray(np.linspace(0.1, 0.9, 30), dims='z', + coords={'z': np.random.randn(30), + 'z2': ('z', np.random.randn(30))}) + + actual = da.interp(x=xdest, method='linear') + + expected = xr.DataArray(func(da, 'x', xdest), dims=['z', 'y'], + coords={'z': xdest['z'], 'z2': xdest['z2'], + 'y': da['y'], + 'x': ('z', xdest.values), + 'x2': ('z', func(da['x2'], 'x', xdest))}) + assert_allclose(actual, expected.transpose('z', 'y')) + + # xdest is 2d + xdest = xr.DataArray(np.linspace(0.1, 0.9, 30).reshape(6, 5), + dims=['z', 'w'], + coords={'z': np.random.randn(6), + 'w': np.random.randn(5), + 'z2': ('z', np.random.randn(6))}) + + actual = da.interp(x=xdest, method='linear') + + expected = xr.DataArray( + func(da, 'x', xdest), + dims=['z', 'w', 'y'], + coords={'z': xdest['z'], 'w': xdest['w'], 'z2': xdest['z2'], + 'y': da['y'], 'x': (('z', 'w'), xdest), + 'x2': (('z', 'w'), func(da['x2'], 'x', xdest))}) + assert_allclose(actual, expected.transpose('z', 'w', 'y')) + + +@pytest.mark.parametrize('case', [3, 4]) +def test_interpolate_nd(case): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and case == 4: + pytest.skip('dask is not installed in the environment.') + + da = get_example_data(case) + + # grid -> grid + xdest = np.linspace(0.1, 1.0, 11) + ydest = np.linspace(0.0, 0.2, 10) + actual = da.interp(x=xdest, y=ydest, method='linear') + + # linear interpolation is separateable + expected = da.interp(x=xdest, method='linear') + expected = expected.interp(y=ydest, method='linear') + assert_allclose(actual.transpose('x', 'y', 'z'), + expected.transpose('x', 'y', 'z')) + + # grid -> 1d-sample + xdest = xr.DataArray(np.linspace(0.1, 1.0, 11), dims='y') + ydest = xr.DataArray(np.linspace(0.0, 0.2, 11), dims='y') + actual = da.interp(x=xdest, y=ydest, method='linear') + + # linear interpolation is separateable + expected_data = scipy.interpolate.RegularGridInterpolator( + (da['x'], da['y']), da.transpose('x', 'y', 'z').values, + method='linear', bounds_error=False, + fill_value=np.nan)(np.stack([xdest, ydest], axis=-1)) + expected = xr.DataArray( + expected_data, dims=['y', 'z'], + coords={'z': da['z'], 'y': ydest, 'x': ('y', xdest.values), + 'x2': da['x2'].interp(x=xdest)}) + assert_allclose(actual.transpose('y', 'z'), expected) + + # reversed order + actual = da.interp(y=ydest, x=xdest, method='linear') + assert_allclose(actual.transpose('y', 'z'), expected) + + +@pytest.mark.parametrize('method', ['linear']) +@pytest.mark.parametrize('case', [0, 1]) +def test_interpolate_scalar(method, case): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and case in [1]: + pytest.skip('dask is not installed in the environment.') + + da = get_example_data(case) + xdest = 0.4 + + actual = da.interp(x=xdest, method=method) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da['x'], obj.data, axis=obj.get_axis_num('x'), bounds_error=False, + fill_value=np.nan)(new_x) + + coords = {'x': xdest, 'y': da['y'], 'x2': func(da['x2'], xdest)} + expected = xr.DataArray(func(da, xdest), dims=['y'], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('method', ['linear']) +@pytest.mark.parametrize('case', [3, 4]) +def test_interpolate_nd_scalar(method, case): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and case in [4]: + pytest.skip('dask is not installed in the environment.') + + da = get_example_data(case) + xdest = 0.4 + ydest = 0.05 + + actual = da.interp(x=xdest, y=ydest, method=method) + # scipy interpolation for the reference + expected_data = scipy.interpolate.RegularGridInterpolator( + (da['x'], da['y']), da.transpose('x', 'y', 'z').values, + method='linear', bounds_error=False, + fill_value=np.nan)(np.stack([xdest, ydest], axis=-1)) + + coords = {'x': xdest, 'y': ydest, 'x2': da['x2'].interp(x=xdest), + 'z': da['z']} + expected = xr.DataArray(expected_data[0], dims=['z'], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('use_dask', [True, False]) +def test_nans(use_dask): + if not has_scipy: + pytest.skip('scipy is not installed.') + + da = xr.DataArray([0, 1, np.nan, 2], dims='x', coords={'x': range(4)}) + + if not has_dask and use_dask: + pytest.skip('dask is not installed in the environment.') + da = da.chunk() + + actual = da.interp(x=[0.5, 1.5]) + # not all values are nan + assert actual.count() > 0 + + +@pytest.mark.parametrize('use_dask', [True, False]) +def test_errors(use_dask): + if not has_scipy: + pytest.skip('scipy is not installed.') + + # akima and spline are unavailable + da = xr.DataArray([0, 1, np.nan, 2], dims='x', coords={'x': range(4)}) + if not has_dask and use_dask: + pytest.skip('dask is not installed in the environment.') + da = da.chunk() + + for method in ['akima', 'spline']: + with pytest.raises(ValueError): + da.interp(x=[0.5, 1.5], method=method) + + # not sorted + if use_dask: + da = get_example_data(3) + else: + da = get_example_data(1) + + result = da.interp(x=[-1, 1, 3], kwargs={'fill_value': 0.0}) + assert not np.isnan(result.values).any() + result = da.interp(x=[-1, 1, 3]) + assert np.isnan(result.values).any() + + # invalid method + with pytest.raises(ValueError): + da.interp(x=[2, 0], method='boo') + with pytest.raises(ValueError): + da.interp(x=[2, 0], y=2, method='cubic') + with pytest.raises(ValueError): + da.interp(y=[2, 0], method='boo') + + # object-type DataArray cannot be interpolated + da = xr.DataArray(['a', 'b', 'c'], dims='x', coords={'x': [0, 1, 2]}) + with pytest.raises(TypeError): + da.interp(x=0) + + +@requires_scipy +def test_dtype(): + ds = xr.Dataset({'var1': ('x', [0, 1, 2]), 'var2': ('x', ['a', 'b', 'c'])}, + coords={'x': [0.1, 0.2, 0.3], 'z': ('x', ['a', 'b', 'c'])}) + actual = ds.interp(x=[0.15, 0.25]) + assert 'var1' in actual + assert 'var2' not in actual + # object array should be dropped + assert 'z' not in actual.coords + + +@requires_scipy +def test_sorted(): + # unsorted non-uniform gridded data + x = np.random.randn(100) + y = np.random.randn(30) + z = np.linspace(0.1, 0.2, 10) * 3.0 + da = xr.DataArray( + np.cos(x[:, np.newaxis, np.newaxis]) * np.cos( + y[:, np.newaxis]) * z, + dims=['x', 'y', 'z'], + coords={'x': x, 'y': y, 'x2': ('x', x**2), 'z': z}) + + x_new = np.linspace(0, 1, 30) + y_new = np.linspace(0, 1, 20) + + da_sorted = da.sortby('x') + assert_allclose(da.interp(x=x_new), + da_sorted.interp(x=x_new, assume_sorted=True)) + da_sorted = da.sortby(['x', 'y']) + assert_allclose(da.interp(x=x_new, y=y_new), + da_sorted.interp(x=x_new, y=y_new, assume_sorted=True)) + + with pytest.raises(ValueError): + da.interp(x=[0, 1, 2], assume_sorted=True) + + +@requires_scipy +def test_dimension_wo_coords(): + da = xr.DataArray(np.arange(12).reshape(3, 4), dims=['x', 'y'], + coords={'y': [0, 1, 2, 3]}) + da_w_coord = da.copy() + da_w_coord['x'] = np.arange(3) + + assert_equal(da.interp(x=[0.1, 0.2, 0.3]), + da_w_coord.interp(x=[0.1, 0.2, 0.3])) + assert_equal(da.interp(x=[0.1, 0.2, 0.3], y=[0.5]), + da_w_coord.interp(x=[0.1, 0.2, 0.3], y=[0.5])) + + +@requires_scipy +def test_dataset(): + ds = create_test_data() + ds.attrs['foo'] = 'var' + ds['var1'].attrs['buz'] = 'var2' + new_dim2 = xr.DataArray([0.11, 0.21, 0.31], dims='z') + interpolated = ds.interp(dim2=new_dim2) + + assert_allclose(interpolated['var1'], ds['var1'].interp(dim2=new_dim2)) + assert interpolated['var3'].equals(ds['var3']) + + # make sure modifying interpolated does not affect the original dataset + interpolated['var1'][:, 1] = 1.0 + interpolated['var2'][:, 1] = 1.0 + interpolated['var3'][:, 1] = 1.0 + + assert not interpolated['var1'].equals(ds['var1']) + assert not interpolated['var2'].equals(ds['var2']) + assert not interpolated['var3'].equals(ds['var3']) + # attrs should be kept + assert interpolated.attrs['foo'] == 'var' + assert interpolated['var1'].attrs['buz'] == 'var2' + + +@pytest.mark.parametrize('case', [0, 3]) +def test_interpolate_dimorder(case): + """ Make sure the resultant dimension order is consistent with .sel() """ + if not has_scipy: + pytest.skip('scipy is not installed.') + + da = get_example_data(case) + + new_x = xr.DataArray([0, 1, 2], dims='x') + assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims + + new_y = xr.DataArray([0, 1, 2], dims='y') + actual = da.interp(x=new_x, y=new_y).dims + expected = da.sel(x=new_x, y=new_y, method='nearest').dims + assert actual == expected + # reversed order + actual = da.interp(y=new_y, x=new_x).dims + expected = da.sel(y=new_y, x=new_x, method='nearest').dims + assert actual == expected + + new_x = xr.DataArray([0, 1, 2], dims='a') + assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims + assert da.interp(y=new_x).dims == da.sel(y=new_x, method='nearest').dims + new_y = xr.DataArray([0, 1, 2], dims='a') + actual = da.interp(x=new_x, y=new_y).dims + expected = da.sel(x=new_x, y=new_y, method='nearest').dims + assert actual == expected + + new_x = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) + assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims + assert da.interp(y=new_x).dims == da.sel(y=new_x, method='nearest').dims + + if case == 3: + new_x = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) + new_z = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) + actual = da.interp(x=new_x, z=new_z).dims + expected = da.sel(x=new_x, z=new_z, method='nearest').dims + assert actual == expected + + actual = da.interp(z=new_z, x=new_x).dims + expected = da.sel(z=new_z, x=new_x, method='nearest').dims + assert actual == expected + + actual = da.interp(x=0.5, z=new_z).dims + expected = da.sel(x=0.5, z=new_z, method='nearest').dims + assert actual == expected + + +@requires_scipy +def test_interp_like(): + ds = create_test_data() + ds.attrs['foo'] = 'var' + ds['var1'].attrs['buz'] = 'var2' + + other = xr.DataArray(np.random.randn(3), dims=['dim2'], + coords={'dim2': [0, 1, 2]}) + interpolated = ds.interp_like(other) + + assert_allclose(interpolated['var1'], + ds['var1'].interp(dim2=other['dim2'])) + assert_allclose(interpolated['var1'], + ds['var1'].interp_like(other)) + assert interpolated['var3'].equals(ds['var3']) + + # attrs should be kept + assert interpolated.attrs['foo'] == 'var' + assert interpolated['var1'].attrs['buz'] == 'var2' + + other = xr.DataArray(np.random.randn(3), dims=['dim3'], + coords={'dim3': ['a', 'b', 'c']}) + + actual = ds.interp_like(other) + expected = ds.reindex_like(other) + assert_allclose(actual, expected) + + +@requires_scipy +@pytest.mark.parametrize('x_new, expected', [ + (pd.date_range('2000-01-02', periods=3), [1, 2, 3]), + (np.array([np.datetime64('2000-01-01T12:00'), + np.datetime64('2000-01-02T12:00')]), [0.5, 1.5]), + (['2000-01-01T12:00', '2000-01-02T12:00'], [0.5, 1.5]), + (['2000-01-01T12:00'], 0.5), + pytest.param('2000-01-01T12:00', 0.5, marks=pytest.mark.xfail) +]) +def test_datetime(x_new, expected): + da = xr.DataArray(np.arange(24), dims='time', + coords={'time': pd.date_range('2000-01-01', periods=24)}) + + actual = da.interp(time=x_new) + expected_da = xr.DataArray(np.atleast_1d(expected), dims=['time'], + coords={'time': (np.atleast_1d(x_new) + .astype('datetime64[ns]'))}) + + assert_allclose(actual, expected_da) + + +@requires_scipy +def test_datetime_single_string(): + da = xr.DataArray(np.arange(24), dims='time', + coords={'time': pd.date_range('2000-01-01', periods=24)}) + actual = da.interp(time='2000-01-01T12:00') + expected = xr.DataArray(0.5) + + assert_allclose(actual.drop('time'), expected) + + +@requires_cftime +@requires_scipy +def test_cftime(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D') + actual = da.interp(time=times_new) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new], dims=['time']) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_type_error(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D', + calendar='noleap') + with pytest.raises(TypeError): + da.interp(time=times_new) + + +@requires_cftime +@requires_scipy +def test_cftime_list_of_strings(): + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = ['2000-01-01T12:00', '2000-01-02T12:00', '2000-01-03T12:00'] + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new_array], + dims=['time']) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_single_string(): + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = '2000-01-01T12:00' + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian) + expected = xr.DataArray(0.5, coords={'time': times_new_array}) + + assert_allclose(actual, expected) + + +@requires_scipy +def test_datetime_to_non_datetime_error(): + da = xr.DataArray(np.arange(24), dims='time', + coords={'time': pd.date_range('2000-01-01', periods=24)}) + with pytest.raises(TypeError): + da.interp(time=0.5) + + +@requires_cftime +@requires_scipy +def test_cftime_to_non_cftime_error(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + with pytest.raises(TypeError): + da.interp(time=0.5) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 409ad86c1e9..300c490cff6 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -1,18 +1,16 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import numpy as np -import xarray as xr +from __future__ import absolute_import, division, print_function +import numpy as np import pytest -from . import TestCase, raises_regex -from .test_dataset import create_test_data - +import xarray as xr from xarray.core import merge +from . import raises_regex +from .test_dataset import create_test_data + -class TestMergeInternals(TestCase): +class TestMergeInternals(object): def test_broadcast_dimension_size(self): actual = merge.broadcast_dimension_size( [xr.Variable('x', [1]), xr.Variable('y', [2, 1])]) @@ -27,7 +25,7 @@ def test_broadcast_dimension_size(self): [xr.Variable(('x', 'y'), [[1, 2]]), xr.Variable('y', [2])]) -class TestMergeFunction(TestCase): +class TestMergeFunction(object): def test_merge_arrays(self): data = create_test_data() actual = xr.merge([data.var1, data.var2]) @@ -132,7 +130,7 @@ def test_merge_no_conflicts_broadcast(self): assert expected.identical(actual) -class TestMergeMethod(TestCase): +class TestMergeMethod(object): def test_merge(self): data = create_test_data() @@ -197,7 +195,7 @@ def test_merge_compat(self): with pytest.raises(xr.MergeError): ds1.merge(ds2, compat='identical') - with raises_regex(ValueError, 'compat=\S+ invalid'): + with raises_regex(ValueError, 'compat=.* invalid'): ds1.merge(ds2, compat='foobar') def test_merge_auto_align(self): diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index ce735d720d0..47224e55473 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -1,20 +1,18 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +import itertools + import numpy as np import pandas as pd import pytest -import itertools import xarray as xr - -from xarray.core.missing import (NumpyInterpolator, ScipyInterpolator, - SplineInterpolator) +from xarray.core.missing import ( + NumpyInterpolator, ScipyInterpolator, SplineInterpolator) from xarray.core.pycompat import dask_array_type - -from xarray.tests import (assert_equal, assert_array_equal, raises_regex, - requires_scipy, requires_bottleneck, requires_dask, - requires_np112) +from xarray.tests import ( + assert_array_equal, assert_equal, raises_regex, requires_bottleneck, + requires_dask, requires_scipy) @pytest.fixture @@ -69,7 +67,6 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, return da, df -@requires_np112 @requires_scipy def test_interpolate_pd_compat(): shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] @@ -95,19 +92,17 @@ def test_interpolate_pd_compat(): np.testing.assert_allclose(actual.values, expected.values) -@requires_np112 @requires_scipy -def test_scipy_methods_function(): - for method in ['barycentric', 'krog', 'pchip', 'spline', 'akima']: - kwargs = {} - # Note: Pandas does some wacky things with these methods and the full - # integration tests wont work. - da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) - actual = da.interpolate_na(method=method, dim='time', **kwargs) - assert (da.count('time') <= actual.count('time')).all() +@pytest.mark.parametrize('method', ['barycentric', 'krog', + 'pchip', 'spline', 'akima']) +def test_scipy_methods_function(method): + # Note: Pandas does some wacky things with these methods and the full + # integration tests wont work. + da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) + actual = da.interpolate_na(method=method, dim='time') + assert (da.count('time') <= actual.count('time')).all() -@requires_np112 @requires_scipy def test_interpolate_pd_compat_non_uniform_index(): shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] @@ -136,7 +131,6 @@ def test_interpolate_pd_compat_non_uniform_index(): np.testing.assert_allclose(actual.values, expected.values) -@requires_np112 @requires_scipy def test_interpolate_pd_compat_polynomial(): shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] @@ -156,7 +150,6 @@ def test_interpolate_pd_compat_polynomial(): np.testing.assert_allclose(actual.values, expected.values) -@requires_np112 @requires_scipy def test_interpolate_unsorted_index_raises(): vals = np.array([1, 2, 3], dtype=np.float64) @@ -197,7 +190,6 @@ def test_interpolate_2d_coord_raises(): da.interpolate_na(dim='a', use_coordinate='x') -@requires_np112 @requires_scipy def test_interpolate_kwargs(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims='x') @@ -210,7 +202,6 @@ def test_interpolate_kwargs(): assert_equal(actual, expected) -@requires_np112 def test_interpolate(): vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) @@ -224,7 +215,6 @@ def test_interpolate(): assert_equal(actual, expected) -@requires_np112 def test_interpolate_nonans(): vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) @@ -233,7 +223,6 @@ def test_interpolate_nonans(): assert_equal(actual, expected) -@requires_np112 @requires_scipy def test_interpolate_allnans(): vals = np.full(6, np.nan, dtype=np.float64) @@ -243,7 +232,6 @@ def test_interpolate_allnans(): assert_equal(actual, expected) -@requires_np112 @requires_bottleneck def test_interpolate_limits(): da = xr.DataArray(np.array([1, 2, np.nan, np.nan, np.nan, 6], @@ -259,7 +247,6 @@ def test_interpolate_limits(): assert_equal(actual, expected) -@requires_np112 @requires_scipy def test_interpolate_methods(): for method in ['linear', 'nearest', 'zero', 'slinear', 'quadratic', @@ -275,7 +262,6 @@ def test_interpolate_methods(): @requires_scipy -@requires_np112 def test_interpolators(): for method, interpolator in [('linear', NumpyInterpolator), ('linear', ScipyInterpolator), @@ -289,7 +275,6 @@ def test_interpolators(): assert pd.isnull(out).sum() == 0 -@requires_np112 def test_interpolate_use_coordinate(): xc = xr.Variable('x', [100, 200, 300, 400, 500, 600]) da = xr.DataArray(np.array([1, 2, np.nan, np.nan, np.nan, 6], @@ -312,7 +297,6 @@ def test_interpolate_use_coordinate(): assert_equal(actual, expected) -@requires_np112 @requires_dask def test_interpolate_dask(): da, _ = make_interpolate_example_data((40, 40), 0.5) @@ -330,7 +314,6 @@ def test_interpolate_dask(): assert_equal(actual, expected) -@requires_np112 @requires_dask def test_interpolate_dask_raises_for_invalid_chunk_dim(): da, _ = make_interpolate_example_data((40, 40), 0.5) @@ -339,7 +322,6 @@ def test_interpolate_dask_raises_for_invalid_chunk_dim(): da.interpolate_na('time') -@requires_np112 @requires_bottleneck def test_ffill(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims='x') @@ -348,7 +330,6 @@ def test_ffill(): assert_equal(actual, expected) -@requires_np112 @requires_bottleneck @requires_dask def test_ffill_dask(): @@ -386,7 +367,6 @@ def test_bfill_dask(): @requires_bottleneck -@requires_np112 def test_ffill_bfill_nonans(): vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) @@ -400,7 +380,6 @@ def test_ffill_bfill_nonans(): @requires_bottleneck -@requires_np112 def test_ffill_bfill_allnans(): vals = np.full(6, np.nan, dtype=np.float64) @@ -414,14 +393,12 @@ def test_ffill_bfill_allnans(): @requires_bottleneck -@requires_np112 def test_ffill_functions(da): result = da.ffill('time') assert result.isnull().sum() == 0 @requires_bottleneck -@requires_np112 def test_ffill_limit(): da = xr.DataArray( [0, np.nan, np.nan, np.nan, np.nan, 3, 4, 5, np.nan, 6, 7], @@ -435,7 +412,6 @@ def test_ffill_limit(): [0, 0, np.nan, np.nan, np.nan, 3, 4, 5, 5, 6, 7], dims='time') -@requires_np112 def test_interpolate_dataset(ds): actual = ds.interpolate_na(dim='time') # no missing values in var1 @@ -446,12 +422,10 @@ def test_interpolate_dataset(ds): @requires_bottleneck -@requires_np112 def test_ffill_dataset(ds): ds.ffill(dim='time') @requires_bottleneck -@requires_np112 def test_bfill_dataset(ds): ds.ffill(dim='time') diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index 83445e4639f..d3ad87d0d28 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -1,7 +1,8 @@ import numpy as np from numpy.testing import assert_array_equal -from xarray.core.nputils import _is_contiguous, NumpyVIndexAdapter +from xarray.core.nputils import ( + NumpyVIndexAdapter, _is_contiguous, rolling_window) def test_is_contiguous(): @@ -28,3 +29,27 @@ def test_vindex(): vindex[[0, 1], [0, 1], :] = vindex[[0, 1], [0, 1], :] vindex[[0, 1], :, [0, 1]] = vindex[[0, 1], :, [0, 1]] vindex[:, [0, 1], [0, 1]] = vindex[:, [0, 1], [0, 1]] + + +def test_rolling(): + x = np.array([1, 2, 3, 4], dtype=float) + + actual = rolling_window(x, axis=-1, window=3, center=True, + fill_value=np.nan) + expected = np.array([[np.nan, 1, 2], + [1, 2, 3], + [2, 3, 4], + [3, 4, np.nan]], dtype=float) + assert_array_equal(actual, expected) + + actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0) + expected = np.array([[0, 0, 1], + [0, 1, 2], + [1, 2, 3], + [2, 3, 4]], dtype=float) + assert_array_equal(actual, expected) + + x = np.stack([x, x * 1.1]) + actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0) + expected = np.stack([expected, expected * 1.1], axis=0) + assert_array_equal(actual, expected) diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index 498f0354086..d594e1dcd18 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -1,10 +1,12 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import xarray +from __future__ import absolute_import, division, print_function + import pytest -from xarray.core.options import OPTIONS +import xarray +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.backends.file_manager import FILE_CACHE +from xarray.tests.test_dataset import create_test_data +from xarray import concat, merge def test_invalid_option_raises(): @@ -12,6 +14,51 @@ def test_invalid_option_raises(): xarray.set_options(not_a_valid_options=True) +def test_display_width(): + with pytest.raises(ValueError): + xarray.set_options(display_width=0) + with pytest.raises(ValueError): + xarray.set_options(display_width=-10) + with pytest.raises(ValueError): + xarray.set_options(display_width=3.5) + + +def test_arithmetic_join(): + with pytest.raises(ValueError): + xarray.set_options(arithmetic_join='invalid') + with xarray.set_options(arithmetic_join='exact'): + assert OPTIONS['arithmetic_join'] == 'exact' + + +def test_enable_cftimeindex(): + with pytest.raises(ValueError): + xarray.set_options(enable_cftimeindex=None) + with pytest.warns(FutureWarning, match='no-op'): + with xarray.set_options(enable_cftimeindex=True): + assert OPTIONS['enable_cftimeindex'] + + +def test_file_cache_maxsize(): + with pytest.raises(ValueError): + xarray.set_options(file_cache_maxsize=0) + original_size = FILE_CACHE.maxsize + with xarray.set_options(file_cache_maxsize=123): + assert FILE_CACHE.maxsize == 123 + assert FILE_CACHE.maxsize == original_size + + +def test_keep_attrs(): + with pytest.raises(ValueError): + xarray.set_options(keep_attrs='invalid_str') + with xarray.set_options(keep_attrs=True): + assert OPTIONS['keep_attrs'] + with xarray.set_options(keep_attrs=False): + assert not OPTIONS['keep_attrs'] + with xarray.set_options(keep_attrs='default'): + assert _get_keep_attrs(default=True) + assert not _get_keep_attrs(default=False) + + def test_nested_options(): original = OPTIONS['display_width'] with xarray.set_options(display_width=1): @@ -20,3 +67,105 @@ def test_nested_options(): assert OPTIONS['display_width'] == 2 assert OPTIONS['display_width'] == 1 assert OPTIONS['display_width'] == original + + +def create_test_dataset_attrs(seed=0): + ds = create_test_data(seed) + ds.attrs = {'attr1': 5, 'attr2': 'history', + 'attr3': {'nested': 'more_info'}} + return ds + + +def create_test_dataarray_attrs(seed=0, var='var1'): + da = create_test_data(seed)[var] + da.attrs = {'attr1': 5, 'attr2': 'history', + 'attr3': {'nested': 'more_info'}} + return da + + +class TestAttrRetention(object): + def test_dataset_attr_retention(self): + # Use .mean() for all tests: a typical reduction operation + ds = create_test_dataset_attrs() + original_attrs = ds.attrs + + # Test default behaviour + result = ds.mean() + assert result.attrs == {} + with xarray.set_options(keep_attrs='default'): + result = ds.mean() + assert result.attrs == {} + + with xarray.set_options(keep_attrs=True): + result = ds.mean() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = ds.mean() + assert result.attrs == {} + + def test_dataarray_attr_retention(self): + # Use .mean() for all tests: a typical reduction operation + da = create_test_dataarray_attrs() + original_attrs = da.attrs + + # Test default behaviour + result = da.mean() + assert result.attrs == {} + with xarray.set_options(keep_attrs='default'): + result = da.mean() + assert result.attrs == {} + + with xarray.set_options(keep_attrs=True): + result = da.mean() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = da.mean() + assert result.attrs == {} + + def test_groupby_attr_retention(self): + da = xarray.DataArray([1, 2, 3], [('x', [1, 1, 2])]) + da.attrs = {'attr1': 5, 'attr2': 'history', + 'attr3': {'nested': 'more_info'}} + original_attrs = da.attrs + + # Test default behaviour + result = da.groupby('x').sum(keep_attrs=True) + assert result.attrs == original_attrs + with xarray.set_options(keep_attrs='default'): + result = da.groupby('x').sum(keep_attrs=True) + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=True): + result1 = da.groupby('x') + result = result1.sum() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = da.groupby('x').sum() + assert result.attrs == {} + + def test_concat_attr_retention(self): + ds1 = create_test_dataset_attrs() + ds2 = create_test_dataset_attrs() + ds2.attrs = {'wrong': 'attributes'} + original_attrs = ds1.attrs + + # Test default behaviour of keeping the attrs of the first + # dataset in the supplied list + # global keep_attrs option current doesn't affect concat + result = concat([ds1, ds2], dim='dim1') + assert result.attrs == original_attrs + + @pytest.mark.xfail + def test_merge_attr_retention(self): + da1 = create_test_dataarray_attrs(var='var1') + da2 = create_test_dataarray_attrs(var='var2') + da2.attrs = {'wrong': 'attributes'} + original_attrs = da1.attrs + + # merge currently discards attrs, and the global keep_attrs + # option doesn't affect this + result = merge([da1, da2]) + assert result.attrs == original_attrs diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 1573577a092..10c4283032d 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1,30 +1,31 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# import mpl and change the backend before other mpl imports -try: - import matplotlib as mpl - import matplotlib.pyplot as plt -except ImportError: - pass +from __future__ import absolute_import, division, print_function import inspect +from datetime import datetime import numpy as np import pandas as pd -from datetime import datetime import pytest -from xarray import DataArray - +import xarray as xr import xarray.plot as xplt +from xarray import DataArray +from xarray.coding.times import _import_cftime from xarray.plot.plot import _infer_interval_breaks -from xarray.plot.utils import (_determine_cmap_params, _build_discrete_cmap, - _color_palette, import_seaborn) +from xarray.plot.utils import ( + _build_discrete_cmap, _color_palette, _determine_cmap_params, + import_seaborn, label_from_attrs) -from . import (TestCase, requires_matplotlib, requires_seaborn, raises_regex, - assert_equal, assert_array_equal) +from . import ( + assert_array_equal, assert_equal, raises_regex, requires_cftime, + requires_matplotlib, requires_matplotlib2, requires_seaborn) + +# import mpl and change the backend before other mpl imports +try: + import matplotlib as mpl + import matplotlib.pyplot as plt +except ImportError: + pass @pytest.mark.flaky @@ -63,8 +64,10 @@ def easy_array(shape, start=0, stop=1): @requires_matplotlib -class PlotTestCase(TestCase): - def tearDown(self): +class PlotTestCase(object): + @pytest.fixture(autouse=True) + def setup(self): + yield # Remove all matplotlib figures plt.close('all') @@ -86,20 +89,71 @@ def contourf_called(self, plotmethod): class TestPlot(PlotTestCase): - def setUp(self): + @pytest.fixture(autouse=True) + def setup_array(self): self.darray = DataArray(easy_array((2, 3, 4))) + def test_label_from_attrs(self): + da = self.darray.copy() + assert '' == label_from_attrs(da) + + da.name = 'a' + da.attrs['units'] = 'a_units' + da.attrs['long_name'] = 'a_long_name' + da.attrs['standard_name'] = 'a_standard_name' + assert 'a_long_name [a_units]' == label_from_attrs(da) + + da.attrs.pop('long_name') + assert 'a_standard_name [a_units]' == label_from_attrs(da) + da.attrs.pop('units') + assert 'a_standard_name' == label_from_attrs(da) + + da.attrs['units'] = 'a_units' + da.attrs.pop('standard_name') + assert 'a [a_units]' == label_from_attrs(da) + + da.attrs.pop('units') + assert 'a' == label_from_attrs(da) + def test1d(self): self.darray[:, 0, 0].plot() - with raises_regex(ValueError, 'dimension'): + with raises_regex(ValueError, 'None'): self.darray[:, 0, 0].plot(x='dim_1') + def test_1d_x_y_kw(self): + z = np.arange(10) + da = DataArray(np.cos(z), dims=['z'], coords=[z], name='f') + + xy = [[None, None], + [None, 'z'], + ['z', None]] + + f, ax = plt.subplots(3, 1) + for aa, (x, y) in enumerate(xy): + da.plot(x=x, y=y, ax=ax.flat[aa]) + + with raises_regex(ValueError, 'cannot'): + da.plot(x='z', y='z') + + with raises_regex(ValueError, 'None'): + da.plot(x='f', y='z') + + with raises_regex(ValueError, 'None'): + da.plot(x='z', y='f') + def test_2d_line(self): with raises_regex(ValueError, 'hue'): self.darray[:, :, 0].plot.line() self.darray[:, :, 0].plot.line(hue='dim_1') + self.darray[:, :, 0].plot.line(x='dim_1') + self.darray[:, :, 0].plot.line(y='dim_1') + self.darray[:, :, 0].plot.line(x='dim_0', hue='dim_1') + self.darray[:, :, 0].plot.line(y='dim_0', hue='dim_1') + + with raises_regex(ValueError, 'cannot'): + self.darray[:, :, 0].plot.line(x='dim_1', y='dim_0', hue='dim_1') def test_2d_line_accepts_legend_kw(self): self.darray[:, :, 0].plot.line(x='dim_0', add_legend=False) @@ -108,8 +162,8 @@ def test_2d_line_accepts_legend_kw(self): self.darray[:, :, 0].plot.line(x='dim_0', add_legend=True) assert plt.gca().get_legend() # check whether legend title is set - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_1' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_1') def test_2d_line_accepts_x_kw(self): self.darray[:, :, 0].plot.line(x='dim_0') @@ -120,12 +174,31 @@ def test_2d_line_accepts_x_kw(self): def test_2d_line_accepts_hue_kw(self): self.darray[:, :, 0].plot.line(hue='dim_0') - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_0' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_0') plt.cla() self.darray[:, :, 0].plot.line(hue='dim_1') - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_1' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_1') + + def test_2d_coords_line_plot(self): + lon, lat = np.meshgrid(np.linspace(-20, 20, 5), + np.linspace(0, 30, 4)) + lon += lat / 10 + lat += lon / 10 + da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['y', 'x'], + coords={'lat': (('y', 'x'), lat), + 'lon': (('y', 'x'), lon)}) + + hdl = da.plot.line(x='lon', hue='x') + assert len(hdl) == 5 + + plt.clf() + hdl = da.plot.line(x='lon', hue='y') + assert len(hdl) == 4 + + with pytest.raises(ValueError, message='If x or y are 2D '): + da.plot.line(x='lon', hue='lat') def test_2d_before_squeeze(self): a = DataArray(easy_array((1, 5))) @@ -178,6 +251,32 @@ def test__infer_interval_breaks(self): np.testing.assert_allclose(xref, x) np.testing.assert_allclose(yref, y) + # test that ValueError is raised for non-monotonic 1D inputs + with pytest.raises(ValueError): + _infer_interval_breaks(np.array([0, 2, 1]), check_monotonic=True) + + def test_geo_data(self): + # Regression test for gh2250 + # Realistic coordinates taken from the example dataset + lat = np.array([[16.28, 18.48, 19.58, 19.54, 18.35], + [28.07, 30.52, 31.73, 31.68, 30.37], + [39.65, 42.27, 43.56, 43.51, 42.11], + [50.52, 53.22, 54.55, 54.50, 53.06]]) + lon = np.array([[-126.13, -113.69, -100.92, -88.04, -75.29], + [-129.27, -115.62, -101.54, -87.32, -73.26], + [-133.10, -118.00, -102.31, -86.42, -70.76], + [-137.85, -120.99, -103.28, -85.28, -67.62]]) + data = np.sqrt(lon ** 2 + lat ** 2) + da = DataArray(data, dims=('y', 'x'), + coords={'lon': (('y', 'x'), lon), + 'lat': (('y', 'x'), lat)}) + da.plot(x='lon', y='lat') + ax = plt.gca() + assert ax.has_data() + da.plot(x='lat', y='lon') + ax = plt.gca() + assert ax.has_data() + def test_datetime_dimension(self): nrow = 3 ncol = 4 @@ -190,6 +289,7 @@ def test_datetime_dimension(self): assert ax.has_data() @pytest.mark.slow + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -207,6 +307,7 @@ def test_convenient_facetgrid(self): d[0].plot(x='x', y='y', col='z', ax=plt.gca()) @pytest.mark.slow + @requires_matplotlib2 def test_subplot_kws(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -219,12 +320,9 @@ def test_subplot_kws(self): cmap='cool', subplot_kws=dict(facecolor='r')) for ax in g.axes.flat: - try: - # mpl V2 - assert ax.get_facecolor()[0:3] == \ - mpl.colors.to_rgb('r') - except AttributeError: - assert ax.get_axis_bgcolor() == 'r' + # mpl V2 + assert ax.get_facecolor()[0:3] == \ + mpl.colors.to_rgb('r') @pytest.mark.slow def test_plot_size(self): @@ -253,6 +351,7 @@ def test_plot_size(self): self.darray.plot(aspect=1) @pytest.mark.slow + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) @@ -265,16 +364,26 @@ def test_convenient_facetgrid_4d(self): with raises_regex(ValueError, '[Ff]acet'): d.plot(x='x', y='y', col='columns', ax=plt.gca()) + def test_coord_with_interval(self): + bins = [-1, 0, 1, 2] + self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS).plot() + class TestPlot1D(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): d = [0, 1.1, 0, 2] self.darray = DataArray( d, coords={'period': range(len(d))}, dims='period') + self.darray.period.attrs['units'] = 's' def test_xlabel_is_index_name(self): self.darray.plot() - assert 'period' == plt.gca().get_xlabel() + assert 'period [s]' == plt.gca().get_xlabel() + + def test_no_label_name_on_x_axis(self): + self.darray.plot(y='period') + assert '' == plt.gca().get_xlabel() def test_no_label_name_on_y_axis(self): self.darray.plot() @@ -282,8 +391,15 @@ def test_no_label_name_on_y_axis(self): def test_ylabel_is_data_name(self): self.darray.name = 'temperature' + self.darray.attrs['units'] = 'degrees_Celsius' self.darray.plot() - assert self.darray.name == plt.gca().get_ylabel() + assert 'temperature [degrees_Celsius]' == plt.gca().get_ylabel() + + def test_xlabel_is_data_name(self): + self.darray.name = 'temperature' + self.darray.attrs['units'] = 'degrees_Celsius' + self.darray.plot(y='period') + assert 'temperature [degrees_Celsius]' == plt.gca().get_xlabel() def test_format_string(self): self.darray.plot.line('ro') @@ -312,6 +428,13 @@ def test_x_ticks_are_rotated_for_time(self): rotation = plt.gca().get_xticklabels()[0].get_rotation() assert rotation != 0 + def test_xyincrease_false_changes_axes(self): + self.darray.plot.line(xincrease=False, yincrease=False) + xlim = plt.gca().get_xlim() + ylim = plt.gca().get_ylim() + diffs = xlim[1] - xlim[0], ylim[1] - ylim[0] + assert all(x < 0 for x in diffs) + def test_slice_in_title(self): self.darray.coords['d'] = 10 self.darray.plot.line() @@ -319,25 +442,37 @@ def test_slice_in_title(self): assert 'd = 10' == title +class TestPlotStep(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + self.darray = DataArray(easy_array((2, 3, 4))) + + def test_step(self): + self.darray[0, 0].plot.step() + + def test_coord_with_interval_step(self): + bins = [-1, 0, 1, 2] + self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS).plot.step() + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + + class TestPlotHistogram(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): self.darray = DataArray(easy_array((2, 3, 4))) def test_3d_array(self): self.darray.plot.hist() - def test_title_no_name(self): - self.darray.plot.hist() - assert '' == plt.gca().get_title() - - def test_title_uses_name(self): + def test_xlabel_uses_name(self): self.darray.name = 'testpoints' + self.darray.attrs['units'] = 'testunits' self.darray.plot.hist() - assert self.darray.name in plt.gca().get_title() + assert 'testpoints [testunits]' == plt.gca().get_xlabel() - def test_ylabel_is_count(self): + def test_title_is_histogram(self): self.darray.plot.hist() - assert 'Count' == plt.gca().get_ylabel() + assert 'Histogram' == plt.gca().get_title() def test_can_pass_in_kwargs(self): nbins = 5 @@ -356,9 +491,14 @@ def test_plot_nans(self): self.darray[0, 0, 0] = np.nan self.darray.plot.hist() + def test_hist_coord_with_interval(self): + (self.darray.groupby_bins('dim_0', [-1, 0, 1, 2]).mean(xr.ALL_DIMS) + .plot.hist(range=(-1, 2))) + @requires_matplotlib -class TestDetermineCmapParams(TestCase): +class TestDetermineCmapParams(object): + @pytest.fixture(autouse=True) def setUp(self): self.data = np.linspace(0, 1, num=100) @@ -366,7 +506,7 @@ def test_robust(self): cmap_params = _determine_cmap_params(self.data, robust=True) assert cmap_params['vmin'] == np.percentile(self.data, 2) assert cmap_params['vmax'] == np.percentile(self.data, 98) - assert cmap_params['cmap'].name == 'viridis' + assert cmap_params['cmap'] == 'viridis' assert cmap_params['extend'] == 'both' assert cmap_params['levels'] is None assert cmap_params['norm'] is None @@ -379,6 +519,30 @@ def test_center(self): assert cmap_params['levels'] is None assert cmap_params['norm'] is None + def test_cmap_sequential_option(self): + with xr.set_options(cmap_sequential='magma'): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params['cmap'] == 'magma' + + def test_cmap_sequential_explicit_option(self): + with xr.set_options(cmap_sequential=mpl.cm.magma): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params['cmap'] == mpl.cm.magma + + def test_cmap_divergent_option(self): + with xr.set_options(cmap_divergent='magma'): + cmap_params = _determine_cmap_params(self.data, center=0.5) + assert cmap_params['cmap'] == 'magma' + + def test_nan_inf_are_ignored(self): + cmap_params1 = _determine_cmap_params(self.data) + data = self.data + data[50:55] = np.nan + data[56:60] = np.inf + cmap_params2 = _determine_cmap_params(data) + assert cmap_params1['vmin'] == cmap_params2['vmin'] + assert cmap_params1['vmax'] == cmap_params2['vmax'] + @pytest.mark.slow def test_integer_levels(self): data = self.data + 1 @@ -441,7 +605,7 @@ def test_divergentcontrol(self): cmap_params = _determine_cmap_params(pos) assert cmap_params['vmin'] == 0 assert cmap_params['vmax'] == 1 - assert cmap_params['cmap'].name == "viridis" + assert cmap_params['cmap'] == "viridis" # Default with negative data will be a divergent cmap cmap_params = _determine_cmap_params(neg) @@ -453,17 +617,17 @@ def test_divergentcontrol(self): cmap_params = _determine_cmap_params(neg, vmin=-0.1, center=False) assert cmap_params['vmin'] == -0.1 assert cmap_params['vmax'] == 0.9 - assert cmap_params['cmap'].name == "viridis" + assert cmap_params['cmap'] == "viridis" cmap_params = _determine_cmap_params(neg, vmax=0.5, center=False) assert cmap_params['vmin'] == -0.1 assert cmap_params['vmax'] == 0.5 - assert cmap_params['cmap'].name == "viridis" + assert cmap_params['cmap'] == "viridis" # Setting center=False too cmap_params = _determine_cmap_params(neg, center=False) assert cmap_params['vmin'] == -0.1 assert cmap_params['vmax'] == 0.9 - assert cmap_params['cmap'].name == "viridis" + assert cmap_params['cmap'] == "viridis" # However, I should still be able to set center and have a div cmap cmap_params = _determine_cmap_params(neg, center=0) @@ -493,21 +657,42 @@ def test_divergentcontrol(self): cmap_params = _determine_cmap_params(pos, vmin=0.1) assert cmap_params['vmin'] == 0.1 assert cmap_params['vmax'] == 1 - assert cmap_params['cmap'].name == "viridis" + assert cmap_params['cmap'] == "viridis" cmap_params = _determine_cmap_params(pos, vmax=0.5) assert cmap_params['vmin'] == 0 assert cmap_params['vmax'] == 0.5 - assert cmap_params['cmap'].name == "viridis" + assert cmap_params['cmap'] == "viridis" # If both vmin and vmax are provided, output is non-divergent cmap_params = _determine_cmap_params(neg, vmin=-0.2, vmax=0.6) assert cmap_params['vmin'] == -0.2 assert cmap_params['vmax'] == 0.6 - assert cmap_params['cmap'].name == "viridis" + assert cmap_params['cmap'] == "viridis" + + def test_norm_sets_vmin_vmax(self): + vmin = self.data.min() + vmax = self.data.max() + + for norm, extend in zip([mpl.colors.LogNorm(), + mpl.colors.LogNorm(vmin + 1, vmax - 1), + mpl.colors.LogNorm(None, vmax - 1), + mpl.colors.LogNorm(vmin + 1, None)], + ['neither', 'both', 'max', 'min']): + + test_min = vmin if norm.vmin is None else norm.vmin + test_max = vmax if norm.vmax is None else norm.vmax + + cmap_params = _determine_cmap_params(self.data, norm=norm) + + assert cmap_params['vmin'] == test_min + assert cmap_params['vmax'] == test_max + assert cmap_params['extend'] == extend + assert cmap_params['norm'] == norm @requires_matplotlib -class TestDiscreteColorMap(TestCase): +class TestDiscreteColorMap(object): + @pytest.fixture(autouse=True) def setUp(self): x = np.arange(start=0, stop=10, step=2) y = np.arange(start=9, stop=-7, step=-3) @@ -541,10 +726,10 @@ def test_build_discrete_cmap(self): @pytest.mark.slow def test_discrete_colormap_list_of_levels(self): - for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both', - [2, 5, 10, 11]), - ('neither', [0, 5, 10, 15]), ('min', - [2, 5, 10, 15])]: + for extend, levels in [('max', [-1, 2, 4, 8, 10]), + ('both', [2, 5, 10, 11]), + ('neither', [0, 5, 10, 15]), + ('min', [2, 5, 10, 15])]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)(levels=levels) assert_array_equal(levels, primitive.norm.boundaries) @@ -558,10 +743,10 @@ def test_discrete_colormap_list_of_levels(self): @pytest.mark.slow def test_discrete_colormap_int_levels(self): - for extend, levels, vmin, vmax in [('neither', 7, None, - None), ('neither', 7, None, 20), - ('both', 7, 4, 8), ('min', 10, 4, - 15)]: + for extend, levels, vmin, vmax in [('neither', 7, None, None), + ('neither', 7, None, 20), + ('both', 7, 4, 8), + ('min', 10, 4, 15)]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)( levels=levels, vmin=vmin, vmax=vmax) @@ -587,8 +772,13 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self): assert primitive.norm.vmax == max(levels) assert primitive.norm.vmin == min(levels) + def test_discrete_colormap_provided_boundary_norm(self): + norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4) + primitive = self.darray.plot.contourf(norm=norm) + np.testing.assert_allclose(primitive.levels, norm.boundaries) -class Common2dMixin: + +class Common2dMixin(object): """ Common tests for 2d plotting go here. @@ -596,22 +786,35 @@ class Common2dMixin: Should have the same name as the method. """ + @pytest.fixture(autouse=True) def setUp(self): - da = DataArray(easy_array((10, 15), start=-1), dims=['y', 'x']) + da = DataArray(easy_array((10, 15), start=-1), + dims=['y', 'x'], + coords={'y': np.arange(10), + 'x': np.arange(15)}) # add 2d coords ds = da.to_dataset(name='testvar') x, y = np.meshgrid(da.x.values, da.y.values) ds['x2d'] = DataArray(x, dims=['y', 'x']) ds['y2d'] = DataArray(y, dims=['y', 'x']) - ds.set_coords(['x2d', 'y2d'], inplace=True) + ds = ds.set_coords(['x2d', 'y2d']) # set darray and plot method self.darray = ds.testvar + + # Add CF-compliant metadata + self.darray.attrs['long_name'] = 'a_long_name' + self.darray.attrs['units'] = 'a_units' + self.darray.x.attrs['long_name'] = 'x_long_name' + self.darray.x.attrs['units'] = 'x_units' + self.darray.y.attrs['long_name'] = 'y_long_name' + self.darray.y.attrs['units'] = 'y_units' + self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__) def test_label_names(self): self.plotmethod() - assert 'x' == plt.gca().get_xlabel() - assert 'y' == plt.gca().get_ylabel() + assert 'x_long_name [x_units]' == plt.gca().get_xlabel() + assert 'y_long_name [y_units]' == plt.gca().get_ylabel() def test_1d_raises_valueerror(self): with raises_regex(ValueError, r'DataArray must be 2d'): @@ -632,6 +835,24 @@ def test_nonnumeric_index_raises_typeerror(self): def test_can_pass_in_axis(self): self.pass_in_axis(self.plotmethod) + def test_xyincrease_defaults(self): + + # With default settings the axis must be ordered regardless + # of the coords order. + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[1, 2, 3], + [1, 2]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + # Inverted coords + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[3, 2, 1], + [2, 1]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + def test_xyincrease_false_changes_axes(self): self.plotmethod(xincrease=False, yincrease=False) xlim = plt.gca().get_xlim() @@ -663,10 +884,13 @@ def test_plot_nans(self): clim2 = self.plotfunc(x2).get_clim() assert clim1 == clim2 + @pytest.mark.filterwarnings('ignore::UserWarning') + @pytest.mark.filterwarnings('ignore:invalid value encountered') def test_can_plot_all_nans(self): # regression test for issue #1780 self.plotfunc(DataArray(np.full((2, 2), np.nan))) + @pytest.mark.filterwarnings('ignore: Attempting to set') def test_can_plot_axis_size_one(self): if self.plotfunc.__name__ not in ('contour', 'contourf'): self.plotfunc(DataArray(np.ones((1, 1)))) @@ -704,19 +928,19 @@ def test_diverging_color_limits(self): def test_xy_strings(self): self.plotmethod('y', 'x') ax = plt.gca() - assert 'y' == ax.get_xlabel() - assert 'x' == ax.get_ylabel() + assert 'y_long_name [y_units]' == ax.get_xlabel() + assert 'x_long_name [x_units]' == ax.get_ylabel() def test_positional_coord_string(self): self.plotmethod(y='x') ax = plt.gca() - assert 'x' == ax.get_ylabel() - assert 'y' == ax.get_xlabel() + assert 'x_long_name [x_units]' == ax.get_ylabel() + assert 'y_long_name [y_units]' == ax.get_xlabel() self.plotmethod(x='x') ax = plt.gca() - assert 'x' == ax.get_xlabel() - assert 'y' == ax.get_ylabel() + assert 'x_long_name [x_units]' == ax.get_xlabel() + assert 'y_long_name [y_units]' == ax.get_ylabel() def test_bad_x_string_exception(self): with raises_regex(ValueError, 'x and y must be coordinate variables'): @@ -740,7 +964,7 @@ def test_non_linked_coords(self): # Normal case, without transpose self.plotfunc(self.darray, x='x', y='newy') ax = plt.gca() - assert 'x' == ax.get_xlabel() + assert 'x_long_name [x_units]' == ax.get_xlabel() assert 'newy' == ax.get_ylabel() # ax limits might change between plotfuncs # simply ensure that these high coords were passed over @@ -755,7 +979,7 @@ def test_non_linked_coords_transpose(self): self.plotfunc(self.darray, x='newy', y='x') ax = plt.gca() assert 'newy' == ax.get_xlabel() - assert 'x' == ax.get_ylabel() + assert 'x_long_name [x_units]' == ax.get_ylabel() # ax limits might change between plotfuncs # simply ensure that these high coords were passed over assert np.min(ax.get_xlim()) > 100. @@ -769,19 +993,29 @@ def test_default_title(self): assert 'c = 1, d = foo' == title or 'd = foo, c = 1' == title def test_colorbar_default_label(self): - self.darray.name = 'testvar' self.plotmethod(add_colorbar=True) - assert self.darray.name in text_in_fig() + assert ('a_long_name [a_units]' in text_in_fig()) def test_no_labels(self): self.darray.name = 'testvar' + self.darray.attrs['units'] = 'test_units' self.plotmethod(add_labels=False) alltxt = text_in_fig() - for string in ['x', 'y', 'testvar']: + for string in ['x_long_name [x_units]', + 'y_long_name [y_units]', + 'testvar [test_units]']: assert string not in alltxt def test_colorbar_kwargs(self): # replace label + self.darray.attrs.pop('long_name') + self.darray.attrs['units'] = 'test_units' + # check default colorbar label + self.plotmethod(add_colorbar=True) + alltxt = text_in_fig() + assert 'testvar [test_units]' in alltxt + self.darray.attrs.pop('units') + self.darray.name = 'testvar' self.plotmethod(add_colorbar=True, cbar_kwargs={'label': 'MyLabel'}) alltxt = text_in_fig() @@ -848,6 +1082,7 @@ def test_2d_function_and_method_signature_same(self): del func_sig['darray'] assert func_sig == method_sig + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -879,6 +1114,7 @@ def test_convenient_facetgrid(self): else: assert '' == ax.get_xlabel() + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) @@ -888,6 +1124,19 @@ def test_convenient_facetgrid_4d(self): for ax in g.axes.flat: assert ax.has_data() + @pytest.mark.filterwarnings('ignore:This figure includes') + def test_facetgrid_map_only_appends_mappables(self): + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + g = self.plotfunc(d, x='x', y='y', col='columns', row='rows') + + expected = g._mappables + + g.map(lambda: plt.plot(1, 1)) + actual = g._mappables + + assert expected == actual + def test_facetgrid_cmap(self): # Regression test for GH592 data = (np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12)) @@ -898,6 +1147,42 @@ def test_facetgrid_cmap(self): # check that all colormaps are the same assert len(set(m.get_cmap().name for m in fg._mappables)) == 1 + def test_facetgrid_cbar_kwargs(self): + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + g = self.plotfunc(d, x='x', y='y', col='columns', row='rows', + cbar_kwargs={'label': 'test_label'}) + + # catch contour case + if hasattr(g, 'cbar'): + assert g.cbar._label == 'test_label' + + def test_facetgrid_no_cbar_ax(self): + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + with pytest.raises(ValueError): + g = self.plotfunc(d, x='x', y='y', col='columns', row='rows', + cbar_ax=1) + + def test_cmap_and_color_both(self): + with pytest.raises(ValueError): + self.plotmethod(colors='k', cmap='RdBu') + + def test_2d_coord_with_interval(self): + for dim in self.darray.dims: + gp = self.darray.groupby_bins(dim, range(15)).mean(dim) + for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: + getattr(gp.plot, kind)() + + def test_colormap_error_norm_and_vmin_vmax(self): + norm = mpl.colors.LogNorm(0.1, 1e1) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmin=2) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmax=2) + @pytest.mark.slow class TestContourf(Common2dMixin, PlotTestCase): @@ -960,23 +1245,23 @@ def test_colors(self): def _color_as_tuple(c): return tuple(c[:3]) + # with single color, we don't want rgb array artist = self.plotmethod(colors='k') - assert _color_as_tuple(artist.cmap.colors[0]) == \ - (0.0, 0.0, 0.0) + assert artist.cmap.colors[0] == 'k' artist = self.plotmethod(colors=['k', 'b']) - assert _color_as_tuple(artist.cmap.colors[1]) == \ - (0.0, 0.0, 1.0) + assert (_color_as_tuple(artist.cmap.colors[1]) == + (0.0, 0.0, 1.0)) artist = self.darray.plot.contour( levels=[-0.5, 0., 0.5, 1.], colors=['k', 'r', 'w', 'b']) - assert _color_as_tuple(artist.cmap.colors[1]) == \ - (1.0, 0.0, 0.0) - assert _color_as_tuple(artist.cmap.colors[2]) == \ - (1.0, 1.0, 1.0) + assert (_color_as_tuple(artist.cmap.colors[1]) == + (1.0, 0.0, 0.0)) + assert (_color_as_tuple(artist.cmap.colors[2]) == + (1.0, 1.0, 1.0)) # the last color is now under "over" - assert _color_as_tuple(artist.cmap._rgba_over) == \ - (0.0, 0.0, 1.0) + assert (_color_as_tuple(artist.cmap._rgba_over) == + (0.0, 0.0, 1.0)) def test_cmap_and_color_both(self): with pytest.raises(ValueError): @@ -1146,8 +1431,33 @@ def test_normalize_rgb_one_arg_error(self): for kwds in [dict(vmax=-1, vmin=-1.2), dict(vmin=2, vmax=2.1)]: da.plot.imshow(**kwds) + def test_imshow_rgb_values_in_valid_range(self): + da = DataArray(np.arange(75, dtype='uint8').reshape((5, 5, 3))) + _, ax = plt.subplots() + out = da.plot.imshow(ax=ax).get_array() + assert out.dtype == np.uint8 + assert (out[..., :3] == da.values).all() # Compare without added alpha + + @pytest.mark.filterwarnings('ignore:Several dimensions of this array') + def test_regression_rgb_imshow_dim_size_one(self): + # Regression: https://github.com/pydata/xarray/issues/1966 + da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) + da.plot.imshow() + + def test_origin_overrides_xyincrease(self): + da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]]) + da.plot.imshow(origin='upper') + assert plt.xlim()[0] < 0 + assert plt.ylim()[1] < 0 + + plt.clf() + da.plot.imshow(origin='lower') + assert plt.xlim()[0] < 0 + assert plt.ylim()[0] < 0 + class TestFacetGrid(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): d = easy_array((10, 15, 3)) self.darray = DataArray( @@ -1319,7 +1629,9 @@ def test_num_ticks(self): @pytest.mark.slow def test_map(self): + assert self.g._finalized is False self.g.map(plt.contourf, 'x', 'y', Ellipsis) + assert self.g._finalized is True self.g.map(lambda: None) @pytest.mark.slow @@ -1373,7 +1685,9 @@ def test_facetgrid_polar(self): sharey=False) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetGrid4d(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): a = easy_array((10, 15, 3, 2)) darray = DataArray(a, dims=['y', 'x', 'col', 'row']) @@ -1400,7 +1714,90 @@ def test_default_labels(self): assert substring_in_axes(label, ax) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') +class TestFacetedLinePlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + self.darray = DataArray(np.random.randn(10, 6, 3, 4), + dims=['hue', 'x', 'col', 'row'], + coords=[range(10), range(6), + range(3), ['A', 'B', 'C', 'C++']], + name='Cornelius Ortega the 1st') + + self.darray.hue.name = 'huename' + self.darray.hue.attrs['units'] = 'hunits' + self.darray.x.attrs['units'] = 'xunits' + self.darray.col.attrs['units'] = 'colunits' + self.darray.row.attrs['units'] = 'rowunits' + + def test_facetgrid_shape(self): + g = self.darray.plot(row='row', col='col', hue='hue') + assert g.axes.shape == (len(self.darray.row), len(self.darray.col)) + + g = self.darray.plot(row='col', col='row', hue='hue') + assert g.axes.shape == (len(self.darray.col), len(self.darray.row)) + + def test_unnamed_args(self): + g = self.darray.plot.line('o--', row='row', col='col', hue='hue') + lines = [q for q in g.axes.flat[0].get_children() + if isinstance(q, mpl.lines.Line2D)] + # passing 'o--' as argument should set marker and linestyle + assert lines[0].get_marker() == 'o' + assert lines[0].get_linestyle() == '--' + + def test_default_labels(self): + g = self.darray.plot(row='row', col='col', hue='hue') + # Rightmost column should be labeled + for label, ax in zip(self.darray.coords['row'].values, g.axes[:, -1]): + assert substring_in_axes(label, ax) + + # Top row should be labeled + for label, ax in zip(self.darray.coords['col'].values, g.axes[0, :]): + assert substring_in_axes(str(label), ax) + + # Leftmost column should have array name + for ax in g.axes[:, 0]: + assert substring_in_axes(self.darray.name, ax) + + def test_test_empty_cell(self): + g = self.darray.isel(row=1).drop('row').plot(col='col', + hue='hue', + col_wrap=2) + bottomright = g.axes[-1, -1] + assert not bottomright.has_data() + assert not bottomright.get_visible() + + def test_set_axis_labels(self): + g = self.darray.plot(row='row', col='col', hue='hue') + g.set_axis_labels('longitude', 'latitude') + alltxt = text_in_fig() + + assert 'longitude' in alltxt + assert 'latitude' in alltxt + + def test_both_x_and_y(self): + with pytest.raises(ValueError): + self.darray.plot.line(row='row', col='col', + x='x', y='hue') + + def test_axes_in_faceted_plot(self): + with pytest.raises(ValueError): + self.darray.plot.line(row='row', col='col', + x='x', ax=plt.axes()) + + def test_figsize_and_size(self): + with pytest.raises(ValueError): + self.darray.plot.line(row='row', col='col', + x='x', size=3, figsize=4) + + def test_wrong_num_of_dimensions(self): + with pytest.raises(ValueError): + self.darray.plot(row='row', hue='hue') + self.darray.plot.line(row='row', hue='hue') + + class TestDatetimePlot(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): ''' Create a DataArray with a time-axis that contains datetime objects. @@ -1432,3 +1829,89 @@ def test_plot_seaborn_no_import_warning(): with pytest.warns(None) as record: _color_palette('Blues', 4) assert len(record) == 0 + + +@requires_cftime +def test_plot_cftime_coordinate_error(): + cftime = _import_cftime() + time = cftime.num2date(np.arange(5), units='days since 0001-01-01', + calendar='noleap') + data = DataArray(np.arange(5), coords=[time], dims=['time']) + with raises_regex(TypeError, + 'requires coordinates to be numeric or dates'): + data.plot() + + +@requires_cftime +def test_plot_cftime_data_error(): + cftime = _import_cftime() + data = cftime.num2date(np.arange(5), units='days since 0001-01-01', + calendar='noleap') + data = DataArray(data, coords=[np.arange(5)], dims=['x']) + with raises_regex(NotImplementedError, 'cftime.datetime'): + data.plot() + + +test_da_list = [DataArray(easy_array((10, ))), + DataArray(easy_array((10, 3))), + DataArray(easy_array((10, 3, 2)))] + + +@requires_matplotlib +class TestAxesKwargs(object): + @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize('xincrease', [True, False]) + def test_xincrease_kwarg(self, da, xincrease): + plt.clf() + da.plot(xincrease=xincrease) + assert plt.gca().xaxis_inverted() == (not xincrease) + + @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize('yincrease', [True, False]) + def test_yincrease_kwarg(self, da, yincrease): + plt.clf() + da.plot(yincrease=yincrease) + assert plt.gca().yaxis_inverted() == (not yincrease) + + @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize('xscale', ['linear', 'log', 'logit', 'symlog']) + def test_xscale_kwarg(self, da, xscale): + plt.clf() + da.plot(xscale=xscale) + assert plt.gca().get_xscale() == xscale + + @pytest.mark.parametrize('da', [DataArray(easy_array((10, ))), + DataArray(easy_array((10, 3)))]) + @pytest.mark.parametrize('yscale', ['linear', 'log', 'logit', 'symlog']) + def test_yscale_kwarg(self, da, yscale): + plt.clf() + da.plot(yscale=yscale) + assert plt.gca().get_yscale() == yscale + + @pytest.mark.parametrize('da', test_da_list) + def test_xlim_kwarg(self, da): + plt.clf() + expected = (0.0, 1000.0) + da.plot(xlim=[0, 1000]) + assert plt.gca().get_xlim() == expected + + @pytest.mark.parametrize('da', test_da_list) + def test_ylim_kwarg(self, da): + plt.clf() + da.plot(ylim=[0, 1000]) + expected = (0.0, 1000.0) + assert plt.gca().get_ylim() == expected + + @pytest.mark.parametrize('da', test_da_list) + def test_xticks_kwarg(self, da): + plt.clf() + da.plot(xticks=np.arange(5)) + expected = np.arange(5).tolist() + assert np.all(plt.gca().get_xticks() == expected) + + @pytest.mark.parametrize('da', test_da_list) + def test_yticks_kwarg(self, da): + plt.clf() + da.plot(yticks=np.arange(5)) + expected = np.arange(5) + assert np.all(plt.gca().get_yticks() == expected) diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 02390ac277a..8a0fa5f6e48 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -1,6 +1,4 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import xarray as xr diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py index 9ad797a9ac9..6547311aa2f 100644 --- a/xarray/tests/test_tutorial.py +++ b/xarray/tests/test_tutorial.py @@ -1,18 +1,18 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import os -from xarray import tutorial, DataArray +import pytest + +from xarray import DataArray, tutorial from xarray.core.pycompat import suppress -from . import TestCase, network, assert_identical +from . import assert_identical, network @network -class TestLoadDataset(TestCase): - +class TestLoadDataset(object): + @pytest.fixture(autouse=True) def setUp(self): self.testfile = 'tiny' self.testfilepath = os.path.expanduser(os.sep.join( @@ -23,6 +23,11 @@ def setUp(self): os.remove('{}.md5'.format(self.testfilepath)) def test_download_from_github(self): - ds = tutorial.load_dataset(self.testfile) + ds = tutorial.open_dataset(self.testfile).load() tiny = DataArray(range(5), name='tiny').to_dataset() assert_identical(ds, tiny) + + def test_download_from_github_load_without_cache(self): + ds_nocache = tutorial.open_dataset(self.testfile, cache=False).load() + ds_cache = tutorial.open_dataset(self.testfile).load() + assert_identical(ds_cache, ds_nocache) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index a42819605fa..6941efb1c6e 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -1,69 +1,179 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import pickle import numpy as np +import pytest -import xarray.ufuncs as xu import xarray as xr +import xarray.ufuncs as xu + +from . import assert_array_equal +from . import assert_identical as assert_identical_ +from . import mock, raises_regex, requires_np113 + + +def assert_identical(a, b): + assert type(a) is type(b) or (float(a) == float(b)) # noqa + if isinstance(a, (xr.DataArray, xr.Dataset, xr.Variable)): + assert_identical_(a, b) + else: + assert_array_equal(a, b) + + +@requires_np113 +def test_unary(): + args = [0, + np.zeros(2), + xr.Variable(['x'], [0, 0]), + xr.DataArray([0, 0], dims='x'), + xr.Dataset({'y': ('x', [0, 0])})] + for a in args: + assert_identical(a + 1, np.cos(a)) + + +@requires_np113 +def test_binary(): + args = [0, + np.zeros(2), + xr.Variable(['x'], [0, 0]), + xr.DataArray([0, 0], dims='x'), + xr.Dataset({'y': ('x', [0, 0])})] + for n, t1 in enumerate(args): + for t2 in args[n:]: + assert_identical(t2 + 1, np.maximum(t1, t2 + 1)) + assert_identical(t2 + 1, np.maximum(t2, t1 + 1)) + assert_identical(t2 + 1, np.maximum(t1 + 1, t2)) + assert_identical(t2 + 1, np.maximum(t2 + 1, t1)) + + +@requires_np113 +def test_binary_out(): + args = [1, + np.ones(2), + xr.Variable(['x'], [1, 1]), + xr.DataArray([1, 1], dims='x'), + xr.Dataset({'y': ('x', [1, 1])})] + for arg in args: + actual_mantissa, actual_exponent = np.frexp(arg) + assert_identical(actual_mantissa, 0.5 * arg) + assert_identical(actual_exponent, arg) + + +@requires_np113 +def test_groupby(): + ds = xr.Dataset({'a': ('x', [0, 0, 0])}, {'c': ('x', [0, 0, 1])}) + ds_grouped = ds.groupby('c') + group_mean = ds_grouped.mean('x') + arr_grouped = ds['a'].groupby('c') + + assert_identical(ds, np.maximum(ds_grouped, group_mean)) + assert_identical(ds, np.maximum(group_mean, ds_grouped)) + + assert_identical(ds, np.maximum(arr_grouped, group_mean)) + assert_identical(ds, np.maximum(group_mean, arr_grouped)) + + assert_identical(ds, np.maximum(ds_grouped, group_mean['a'])) + assert_identical(ds, np.maximum(group_mean['a'], ds_grouped)) + + assert_identical(ds.a, np.maximum(arr_grouped, group_mean.a)) + assert_identical(ds.a, np.maximum(group_mean.a, arr_grouped)) + + with raises_regex(ValueError, 'mismatched lengths for dimension'): + np.maximum(ds.a.variable, ds_grouped) + + +@requires_np113 +def test_alignment(): + ds1 = xr.Dataset({'a': ('x', [1, 2])}, {'x': [0, 1]}) + ds2 = xr.Dataset({'a': ('x', [2, 3]), 'b': 4}, {'x': [1, 2]}) + + actual = np.add(ds1, ds2) + expected = xr.Dataset({'a': ('x', [4])}, {'x': [1]}) + assert_identical_(actual, expected) + + with xr.set_options(arithmetic_join='outer'): + actual = np.add(ds1, ds2) + expected = xr.Dataset({'a': ('x', [np.nan, 4, np.nan]), 'b': np.nan}, + coords={'x': [0, 1, 2]}) + assert_identical_(actual, expected) + + +@requires_np113 +def test_kwargs(): + x = xr.DataArray(0) + result = np.add(x, 1, dtype=np.float64) + assert result.dtype == np.float64 + + +@requires_np113 +def test_xarray_defers_to_unrecognized_type(): + + class Other(object): + def __array_ufunc__(self, *args, **kwargs): + return 'other' + + xarray_obj = xr.DataArray([1, 2, 3]) + other = Other() + assert np.maximum(xarray_obj, other) == 'other' + assert np.sin(xarray_obj, out=other) == 'other' + + +@requires_np113 +def test_xarray_handles_dask(): + da = pytest.importorskip('dask.array') + x = xr.DataArray(np.ones((2, 2)), dims=['x', 'y']) + y = da.ones((2, 2), chunks=(2, 2)) + result = np.add(x, y) + assert result.chunks == ((2,), (2,)) + assert isinstance(result, xr.DataArray) + + +@requires_np113 +def test_dask_defers_to_xarray(): + da = pytest.importorskip('dask.array') + x = xr.DataArray(np.ones((2, 2)), dims=['x', 'y']) + y = da.ones((2, 2), chunks=(2, 2)) + result = np.add(y, x) + assert result.chunks == ((2,), (2,)) + assert isinstance(result, xr.DataArray) + + +@requires_np113 +def test_gufunc_methods(): + xarray_obj = xr.DataArray([1, 2, 3]) + with raises_regex(NotImplementedError, 'reduce method'): + np.add.reduce(xarray_obj, 1) + + +@requires_np113 +def test_out(): + xarray_obj = xr.DataArray([1, 2, 3]) + + # xarray out arguments should raise + with raises_regex(NotImplementedError, '`out` argument'): + np.add(xarray_obj, 1, out=xarray_obj) + + # but non-xarray should be OK + other = np.zeros((3,)) + np.add(other, xarray_obj, out=other) + assert_identical(other, np.array([1, 2, 3])) + + +@requires_np113 +def test_gufuncs(): + xarray_obj = xr.DataArray([1, 2, 3]) + fake_gufunc = mock.Mock(signature='(n)->()', autospec=np.sin) + with raises_regex(NotImplementedError, 'generalized ufuncs'): + xarray_obj.__array_ufunc__(fake_gufunc, '__call__', xarray_obj) + + +def test_xarray_ufuncs_deprecation(): + with pytest.warns(PendingDeprecationWarning, match='xarray.ufuncs'): + xu.cos(xr.DataArray([0, 1])) + -from . import ( - TestCase, raises_regex, assert_identical, assert_array_equal) - - -class TestOps(TestCase): - def assert_identical(self, a, b): - assert type(a) is type(b) or (float(a) == float(b)) - try: - assert a.identical(b), (a, b) - except AttributeError: - assert_array_equal(a, b) - - def test_unary(self): - args = [0, - np.zeros(2), - xr.Variable(['x'], [0, 0]), - xr.DataArray([0, 0], dims='x'), - xr.Dataset({'y': ('x', [0, 0])})] - for a in args: - self.assert_identical(a + 1, xu.cos(a)) - - def test_binary(self): - args = [0, - np.zeros(2), - xr.Variable(['x'], [0, 0]), - xr.DataArray([0, 0], dims='x'), - xr.Dataset({'y': ('x', [0, 0])})] - for n, t1 in enumerate(args): - for t2 in args[n:]: - self.assert_identical(t2 + 1, xu.maximum(t1, t2 + 1)) - self.assert_identical(t2 + 1, xu.maximum(t2, t1 + 1)) - self.assert_identical(t2 + 1, xu.maximum(t1 + 1, t2)) - self.assert_identical(t2 + 1, xu.maximum(t2 + 1, t1)) - - def test_groupby(self): - ds = xr.Dataset({'a': ('x', [0, 0, 0])}, {'c': ('x', [0, 0, 1])}) - ds_grouped = ds.groupby('c') - group_mean = ds_grouped.mean('x') - arr_grouped = ds['a'].groupby('c') - - assert_identical(ds, xu.maximum(ds_grouped, group_mean)) - assert_identical(ds, xu.maximum(group_mean, ds_grouped)) - - assert_identical(ds, xu.maximum(arr_grouped, group_mean)) - assert_identical(ds, xu.maximum(group_mean, arr_grouped)) - - assert_identical(ds, xu.maximum(ds_grouped, group_mean['a'])) - assert_identical(ds, xu.maximum(group_mean['a'], ds_grouped)) - - assert_identical(ds.a, xu.maximum(arr_grouped, group_mean.a)) - assert_identical(ds.a, xu.maximum(group_mean.a, arr_grouped)) - - with raises_regex(TypeError, 'only support binary ops'): - xu.maximum(ds.a.variable, ds_grouped) - - def test_pickle(self): - a = 1.0 - cos_pickled = pickle.loads(pickle.dumps(xu.cos)) - self.assert_identical(cos_pickled(a), xu.cos(a)) +def test_xarray_ufuncs_pickle(): + a = 1.0 + cos_pickled = pickle.loads(pickle.dumps(xu.cos)) + assert_identical(cos_pickled(a), xu.cos(a)) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 1a008eff180..ed07af0d7bb 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -1,17 +1,26 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import pytest +from __future__ import absolute_import, division, print_function + +from datetime import datetime import numpy as np import pandas as pd +import pytest +import xarray as xr +from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils +from xarray.core.options import set_options from xarray.core.pycompat import OrderedDict -from . import TestCase, requires_dask, assert_array_equal +from xarray.core.utils import either_dict_or_kwargs +from xarray.testing import assert_identical +from . import ( + assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_cftime, + requires_dask) +from .test_coding_times import _all_cftime_date_types -class TestAlias(TestCase): + +class TestAlias(object): def test(self): def new_method(): pass @@ -21,24 +30,52 @@ def new_method(): old_method() -class TestSafeCastToIndex(TestCase): - def test(self): - dates = pd.date_range('2000-01-01', periods=10) - x = np.arange(5) - td = x * np.timedelta64(1, 'D') - for expected, array in [ - (dates, dates.values), - (pd.Index(x, dtype=object), x.astype(object)), - (pd.Index(td), td), - (pd.Index(td, dtype=object), td.astype(object)), - ]: - actual = utils.safe_cast_to_index(array) - assert_array_equal(expected, actual) - assert expected.dtype == actual.dtype +def test_safe_cast_to_index(): + dates = pd.date_range('2000-01-01', periods=10) + x = np.arange(5) + td = x * np.timedelta64(1, 'D') + for expected, array in [ + (dates, dates.values), + (pd.Index(x, dtype=object), x.astype(object)), + (pd.Index(td), td), + (pd.Index(td, dtype=object), td.astype(object)), + ]: + actual = utils.safe_cast_to_index(array) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +def test_safe_cast_to_index_cftimeindex(): + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + dates = [date_type(1, 1, day) for day in range(1, 20)] + + if has_cftime: + expected = CFTimeIndex(dates) + else: + expected = pd.Index(dates) + + actual = utils.safe_cast_to_index(np.array(dates)) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + assert isinstance(actual, type(expected)) + + +# Test that datetime.datetime objects are never used in a CFTimeIndex +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +def test_safe_cast_to_index_datetime_datetime(): + dates = [datetime(1, 1, day) for day in range(1, 20)] + + expected = pd.Index(dates) + actual = utils.safe_cast_to_index(np.array(dates)) + assert_array_equal(expected, actual) + assert isinstance(actual, pd.Index) def test_multiindex_from_product_levels(): - result = utils.multiindex_from_product_levels([['b', 'a'], [1, 3, 2]]) + result = utils.multiindex_from_product_levels( + [pd.Index(['b', 'a']), pd.Index([1, 3, 2])]) np.testing.assert_array_equal( result.labels, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]) np.testing.assert_array_equal(result.levels[0], ['b', 'a']) @@ -48,7 +85,16 @@ def test_multiindex_from_product_levels(): np.testing.assert_array_equal(result.values, other.values) -class TestArrayEquiv(TestCase): +def test_multiindex_from_product_levels_non_unique(): + result = utils.multiindex_from_product_levels( + [pd.Index(['b', 'a']), pd.Index([1, 1, 2])]) + np.testing.assert_array_equal( + result.labels, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]]) + np.testing.assert_array_equal(result.levels[0], ['b', 'a']) + np.testing.assert_array_equal(result.levels[1], [1, 2]) + + +class TestArrayEquiv(object): def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional # object arrays @@ -58,8 +104,9 @@ def test_0d(self): assert not duck_array_ops.array_equiv(0, np.array(1, dtype=object)) -class TestDictionaries(TestCase): - def setUp(self): +class TestDictionaries(object): + @pytest.fixture(autouse=True) + def setup(self): self.x = {'a': 'A', 'b': 'B'} self.y = {'c': 'C', 'b': 'B'} self.z = {'a': 'Z'} @@ -126,7 +173,7 @@ def test_frozen(self): def test_sorted_keys_dict(self): x = {'a': 1, 'b': 2, 'c': 3} y = utils.SortedKeysDict(x) - self.assertItemsEqual(y, ['a', 'b', 'c']) + assert list(y) == ['a', 'b', 'c'] assert repr(utils.SortedKeysDict()) == \ "SortedKeysDict({})" @@ -141,7 +188,7 @@ def test_chain_map(self): m['x'] = 100 assert m['x'] == 100 assert m.maps[0]['x'] == 100 - self.assertItemsEqual(['x', 'y', 'z'], m) + assert set(m) == {'x', 'y', 'z'} def test_repr_object(): @@ -149,7 +196,23 @@ def test_repr_object(): assert repr(obj) == 'foo' -class Test_is_uniform_and_sorted(TestCase): +def test_is_remote_uri(): + assert utils.is_remote_uri('http://example.com') + assert utils.is_remote_uri('https://example.com') + assert not utils.is_remote_uri(' http://example.com') + assert not utils.is_remote_uri('example.nc') + + +def test_is_grib_path(): + assert not utils.is_grib_path('example.nc') + assert not utils.is_grib_path('example.grib ') + assert utils.is_grib_path('example.grib') + assert utils.is_grib_path('example.grib2') + assert utils.is_grib_path('example.grb') + assert utils.is_grib_path('example.grb2') + + +class Test_is_uniform_and_sorted(object): def test_sorted_uniform(self): assert utils.is_uniform_spaced(np.arange(5)) @@ -170,7 +233,7 @@ def test_relative_tolerance(self): assert utils.is_uniform_spaced([0, 0.97, 2], rtol=0.1) -class Test_hashable(TestCase): +class Test_hashable(object): def test_hashable(self): for v in [False, 1, (2, ), (3, 4), 'four']: @@ -201,3 +264,56 @@ def test_hidden_key_dict(): hkd[hidden_key] with pytest.raises(KeyError): del hkd[hidden_key] + + +def test_either_dict_or_kwargs(): + + result = either_dict_or_kwargs(dict(a=1), None, 'foo') + expected = dict(a=1) + assert result == expected + + result = either_dict_or_kwargs(None, dict(a=1), 'foo') + expected = dict(a=1) + assert result == expected + + with pytest.raises(ValueError, match=r'foo'): + result = either_dict_or_kwargs(dict(a=1), dict(a=1), 'foo') + + +def test_datetime_to_numeric_datetime64(): + times = pd.date_range('2000', periods=5, freq='7D') + da = xr.DataArray(times, coords=[times], dims=['time']) + result = utils.datetime_to_numeric(da, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords) + assert_identical(result, expected) + + offset = da.isel(time=1) + result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords) + assert_identical(result, expected) + + dtype = np.float32 + result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype) + expected = 24 * xr.DataArray( + np.arange(0, 35, 7), coords=da.coords).astype(dtype) + assert_identical(result, expected) + + +@requires_cftime +def test_datetime_to_numeric_cftime(): + times = xr.cftime_range('2000', periods=5, freq='7D') + da = xr.DataArray(times, coords=[times], dims=['time']) + result = utils.datetime_to_numeric(da, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords) + assert_identical(result, expected) + + offset = da.isel(time=1) + result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords) + assert_identical(result, expected) + + dtype = np.float32 + result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype) + expected = 24 * xr.DataArray( + np.arange(0, 35, 7), coords=da.coords).astype(dtype) + assert_identical(result, expected) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 5a89627a0f9..0bd440781ac 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1,38 +1,35 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from collections import namedtuple +from __future__ import absolute_import, division, print_function + +import warnings from copy import copy, deepcopy from datetime import datetime, timedelta +from distutils.version import LooseVersion from textwrap import dedent -import pytest -from distutils.version import LooseVersion import numpy as np -import pytz import pandas as pd +import pytest +import pytz -from xarray import Variable, IndexVariable, Coordinate, Dataset +from xarray import Coordinate, Dataset, IndexVariable, Variable from xarray.core import indexing -from xarray.core.variable import as_variable, as_compatible_data -from xarray.core.indexing import (PandasIndexAdapter, LazilyIndexedArray, - BasicIndexer, OuterIndexer, - VectorizedIndexer, NumpyIndexingAdapter, - CopyOnWriteArray, MemoryCachedArray, - DaskIndexingAdapter) +from xarray.core.common import full_like, ones_like, zeros_like +from xarray.core.indexing import ( + BasicIndexer, CopyOnWriteArray, DaskIndexingAdapter, + LazilyOuterIndexedArray, MemoryCachedArray, NumpyIndexingAdapter, + OuterIndexer, PandasIndexAdapter, VectorizedIndexer) from xarray.core.pycompat import PY3, OrderedDict -from xarray.core.common import full_like, zeros_like, ones_like from xarray.core.utils import NDArrayMixin +from xarray.core.variable import as_compatible_data, as_variable +from xarray.tests import requires_bottleneck from . import ( - TestCase, source_ndarray, requires_dask, raises_regex, assert_identical, - assert_array_equal, assert_equal, assert_allclose) - -from xarray.tests import requires_bottleneck + assert_allclose, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask, source_ndarray) -class VariableSubclassTestCases(object): +class VariableSubclassobjects(object): def test_properties(self): data = 0.5 * np.arange(10) v = self.cls(['time'], data, {'foo': 'bar'}) @@ -141,8 +138,10 @@ def _assertIndexedLikeNDArray(self, variable, expected_value0, assert variable.equals(variable.copy()) assert variable.identical(variable.copy()) # check value is equal for both ndarray and Variable - assert variable.values[0] == expected_value0 - assert variable[0].values == expected_value0 + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', "In the future, 'NAT == x'") + assert variable.values[0] == expected_value0 + assert variable[0].values == expected_value0 # check type or dtype is consistent for both ndarray and Variable if expected_dtype is None: # check output type instead of array dtype @@ -468,25 +467,32 @@ def test_concat_number_strings(self): a = self.cls('x', ['0', '1', '2']) b = self.cls('x', ['3', '4']) actual = Variable.concat([a, b], dim='x') - expected = Variable('x', np.arange(5).astype(str).astype(object)) + expected = Variable('x', np.arange(5).astype(str)) + assert_identical(expected, actual) + assert actual.dtype.kind == expected.dtype.kind + + def test_concat_mixed_dtypes(self): + a = self.cls('x', [0, 1]) + b = self.cls('x', ['two']) + actual = Variable.concat([a, b], dim='x') + expected = Variable('x', np.array([0, 1, 'two'], dtype=object)) assert_identical(expected, actual) - assert expected.dtype == object - assert type(expected.values[0]) == str + assert actual.dtype == object - def test_copy(self): + @pytest.mark.parametrize('deep', [True, False]) + def test_copy(self, deep): v = self.cls('x', 0.5 * np.arange(10), {'foo': 'bar'}) - for deep in [True, False]: - w = v.copy(deep=deep) - assert type(v) is type(w) - assert_identical(v, w) - assert v.dtype == w.dtype - if self.cls is Variable: - if deep: - assert source_ndarray(v.values) is not \ - source_ndarray(w.values) - else: - assert source_ndarray(v.values) is \ - source_ndarray(w.values) + w = v.copy(deep=deep) + assert type(v) is type(w) + assert_identical(v, w) + assert v.dtype == w.dtype + if self.cls is Variable: + if deep: + assert (source_ndarray(v.values) is not + source_ndarray(w.values)) + else: + assert (source_ndarray(v.values) is + source_ndarray(w.values)) assert_identical(v, copy(v)) def test_copy_index(self): @@ -499,6 +505,34 @@ def test_copy_index(self): assert isinstance(w.to_index(), pd.MultiIndex) assert_array_equal(v._data.array, w._data.array) + def test_copy_with_data(self): + orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + new_data = np.array([[2.5, 5.0], [7.1, 43]]) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_with_data_errors(self): + orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + new_data = [2.5, 5.0] + with raises_regex(ValueError, 'must match shape of object'): + orig.copy(data=new_data) + + def test_copy_index_with_data(self): + orig = IndexVariable('x', np.arange(5)) + new_data = np.arange(5, 10) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_index_with_data_errors(self): + orig = IndexVariable('x', np.arange(5)) + new_data = np.arange(5, 20) + with raises_regex(ValueError, 'must match shape of object'): + orig.copy(data=new_data) + def test_real_and_imag(self): v = self.cls('x', np.arange(3) - 1j * np.arange(3), {'foo': 'bar'}) expected_re = self.cls('x', np.arange(3), {'foo': 'bar'}) @@ -620,6 +654,12 @@ def test_getitem_0d_array(self): v_new = v[np.array([0])[0]] assert_array_equal(v_new, v_data[0]) + v_new = v[np.array(0)] + assert_array_equal(v_new, v_data[0]) + + v_new = v[Variable((), np.array(0))] + assert_array_equal(v_new, v_data[0]) + def test_getitem_fancy(self): v = self.cls(['x', 'y'], [[0, 1, 2], [3, 4, 5]]) v_data = v.compute().data @@ -727,11 +767,58 @@ def test_getitem_error(self): with raises_regex(IndexError, 'Dimensions of indexers mis'): v[:, ind] - -class TestVariable(TestCase, VariableSubclassTestCases): + def test_pad(self): + data = np.arange(4 * 3 * 2).reshape(4, 3, 2) + v = self.cls(['x', 'y', 'z'], data) + + xr_args = [{'x': (2, 1)}, {'y': (0, 3)}, {'x': (3, 1), 'z': (2, 0)}] + np_args = [((2, 1), (0, 0), (0, 0)), ((0, 0), (0, 3), (0, 0)), + ((3, 1), (0, 0), (2, 0))] + for xr_arg, np_arg in zip(xr_args, np_args): + actual = v.pad_with_fill_value(**xr_arg) + expected = np.pad(np.array(v.data.astype(float)), np_arg, + mode='constant', constant_values=np.nan) + assert_array_equal(actual, expected) + assert isinstance(actual._data, type(v._data)) + + # for the boolean array, we pad False + data = np.full_like(data, False, dtype=bool).reshape(4, 3, 2) + v = self.cls(['x', 'y', 'z'], data) + for xr_arg, np_arg in zip(xr_args, np_args): + actual = v.pad_with_fill_value(fill_value=False, **xr_arg) + expected = np.pad(np.array(v.data), np_arg, + mode='constant', constant_values=False) + assert_array_equal(actual, expected) + + def test_rolling_window(self): + # Just a working test. See test_nputils for the algorithm validation + v = self.cls(['x', 'y', 'z'], + np.arange(40 * 30 * 2).reshape(40, 30, 2)) + for (d, w) in [('x', 3), ('y', 5)]: + v_rolling = v.rolling_window(d, w, d + '_window') + assert v_rolling.dims == ('x', 'y', 'z', d + '_window') + assert v_rolling.shape == v.shape + (w, ) + + v_rolling = v.rolling_window(d, w, d + '_window', center=True) + assert v_rolling.dims == ('x', 'y', 'z', d + '_window') + assert v_rolling.shape == v.shape + (w, ) + + # dask and numpy result should be the same + v_loaded = v.load().rolling_window(d, w, d + '_window', + center=True) + assert_array_equal(v_rolling, v_loaded) + + # numpy backend should not be over-written + if isinstance(v._data, np.ndarray): + with pytest.raises(ValueError): + v_loaded[0] = 1.0 + + +class TestVariable(VariableSubclassobjects): cls = staticmethod(Variable) - def setUp(self): + @pytest.fixture(autouse=True) + def setup(self): self.d = np.random.random((10, 3)).astype(np.float64) def test_data_and_values(self): @@ -879,27 +966,14 @@ def test_as_variable(self): assert not isinstance(ds['x'], Variable) assert isinstance(as_variable(ds['x']), Variable) - FakeVariable = namedtuple('FakeVariable', 'values dims') - fake_xarray = FakeVariable(expected.values, expected.dims) - assert_identical(expected, as_variable(fake_xarray)) - - FakeVariable = namedtuple('FakeVariable', 'data dims') - fake_xarray = FakeVariable(expected.data, expected.dims) - assert_identical(expected, as_variable(fake_xarray)) - - FakeVariable = namedtuple('FakeVariable', - 'data values dims attrs encoding') - fake_xarray = FakeVariable(expected_extra.data, expected_extra.values, - expected_extra.dims, expected_extra.attrs, - expected_extra.encoding) - assert_identical(expected_extra, as_variable(fake_xarray)) - xarray_tuple = (expected_extra.dims, expected_extra.values, expected_extra.attrs, expected_extra.encoding) assert_identical(expected_extra, as_variable(xarray_tuple)) - with raises_regex(TypeError, 'tuples to convert'): + with raises_regex(TypeError, 'tuple of form'): as_variable(tuple(data)) + with raises_regex(ValueError, 'tuple of form'): # GH1016 + as_variable(('five', 'six', 'seven')) with raises_regex( TypeError, 'without an explicit list of dimensions'): as_variable(data) @@ -920,6 +994,13 @@ def test_as_variable(self): ValueError, 'has more than 1-dimension'): as_variable(expected, name='x') + # test datetime, timedelta conversion + dt = np.array([datetime(1999, 1, 1) + timedelta(days=x) + for x in range(10)]) + assert as_variable(dt, 'time').dtype.kind == 'M' + td = np.array([timedelta(days=x) for x in range(10)]) + assert as_variable(td, 'time').dtype.kind == 'm' + def test_repr(self): v = Variable(['time', 'x'], [[1, 2, 3], [4, 5, 6]], {'foo': 'bar'}) expected = dedent(""" @@ -932,9 +1013,9 @@ def test_repr(self): assert expected == repr(v) def test_repr_lazy_data(self): - v = Variable('x', LazilyIndexedArray(np.arange(2e5))) + v = Variable('x', LazilyOuterIndexedArray(np.arange(2e5))) assert '200000 values with dtype' in repr(v) - assert isinstance(v._data, LazilyIndexedArray) + assert isinstance(v._data, LazilyOuterIndexedArray) def test_detect_indexer_type(self): """ Tests indexer type was correctly detected. """ @@ -1366,8 +1447,6 @@ def test_reduce(self): with raises_regex(ValueError, 'cannot supply both'): v.mean(dim='x', axis=0) - @pytest.mark.skipif(LooseVersion(np.__version__) < LooseVersion('1.10.0'), - reason='requires numpy version 1.10.0 or later') def test_quantile(self): v = Variable(['x', 'y'], self.d) for q in [0.25, [0.50], [0.25, 0.75]]: @@ -1439,20 +1518,15 @@ def test_reduce_funcs(self): assert_identical(v.cumprod(axis=0), Variable('x', np.array([1, 1, 2, 6]))) assert_identical(v.var(), Variable([], 2.0 / 3)) - - if LooseVersion(np.__version__) < '1.9': - with pytest.raises(NotImplementedError): - v.median() - else: - assert_identical(v.median(), Variable([], 2)) + assert_identical(v.median(), Variable([], 2)) v = Variable('x', [True, False, False]) assert_identical(v.any(), Variable([], True)) assert_identical(v.all(dim='x'), Variable([], False)) v = Variable('t', pd.date_range('2000-01-01', periods=3)) - with pytest.raises(NotImplementedError): - v.max(skipna=True) + assert v.argmax(skipna=True) == 2 + assert_identical( v.max(), Variable([], pd.Timestamp('2000-01-03'))) @@ -1587,7 +1661,7 @@ def assert_assigned_2d(array, key_x, key_y, values): @requires_dask -class TestVariableWithDask(TestCase, VariableSubclassTestCases): +class TestVariableWithDask(VariableSubclassobjects): cls = staticmethod(lambda *args: Variable(*args).chunk()) @pytest.mark.xfail @@ -1608,17 +1682,17 @@ def test_eq_all_dtypes(self): super(TestVariableWithDask, self).test_eq_all_dtypes() def test_getitem_fancy(self): - import dask - if LooseVersion(dask.__version__) <= LooseVersion('0.15.1'): - pytest.xfail("vindex from latest dask is required") super(TestVariableWithDask, self).test_getitem_fancy() def test_getitem_1d_fancy(self): - import dask - if LooseVersion(dask.__version__) <= LooseVersion('0.15.1'): - pytest.xfail("vindex from latest dask is required") super(TestVariableWithDask, self).test_getitem_1d_fancy() + def test_equals_all_dtypes(self): + import dask + if '0.18.2' <= LooseVersion(dask.__version__) < '0.19.1': + pytest.xfail('https://github.com/pydata/xarray/issues/2318') + super(TestVariableWithDask, self).test_equals_all_dtypes() + def test_getitem_with_mask_nd_indexer(self): import dask.array as da v = Variable(['x'], da.arange(3, chunks=3)) @@ -1627,7 +1701,7 @@ def test_getitem_with_mask_nd_indexer(self): self.cls(('x', 'y'), [[0, -1], [-1, 2]])) -class TestIndexVariable(TestCase, VariableSubclassTestCases): +class TestIndexVariable(VariableSubclassobjects): cls = staticmethod(IndexVariable) def test_init(self): @@ -1706,6 +1780,13 @@ def test_coordinate_alias(self): x = Coordinate('x', [1, 2, 3]) assert isinstance(x, IndexVariable) + def test_datetime64(self): + # GH:1932 Make sure indexing keeps precision + t = np.array([1518418799999986560, 1518418799999996560], + dtype='datetime64[ns]') + v = IndexVariable('t', t) + assert v[0].data == t[0] + # These tests make use of multi-dimensional variables, which are not valid # IndexVariable objects: @pytest.mark.xfail @@ -1724,10 +1805,18 @@ def test_getitem_fancy(self): def test_getitem_uint(self): super(TestIndexVariable, self).test_getitem_fancy() + @pytest.mark.xfail + def test_pad(self): + super(TestIndexVariable, self).test_rolling_window() + + @pytest.mark.xfail + def test_rolling_window(self): + super(TestIndexVariable, self).test_rolling_window() + -class TestAsCompatibleData(TestCase): +class TestAsCompatibleData(object): def test_unchanged_types(self): - types = (np.asarray, PandasIndexAdapter, indexing.LazilyIndexedArray) + types = (np.asarray, PandasIndexAdapter, LazilyOuterIndexedArray) for t in types: for data in [np.arange(3), pd.date_range('2000-01-01', periods=3), @@ -1866,9 +1955,10 @@ def test_raise_no_warning_for_nan_in_binary_ops(): assert len(record) == 0 -class TestBackendIndexing(TestCase): +class TestBackendIndexing(object): """ Make sure all the array wrappers can be indexed. """ + @pytest.fixture(autouse=True) def setUp(self): self.d = np.random.random((10, 3)).astype(np.float64) @@ -1890,18 +1980,19 @@ def test_NumpyIndexingAdapter(self): v = Variable(dims=('x', 'y'), data=NumpyIndexingAdapter( NumpyIndexingAdapter(self.d))) - def test_LazilyIndexedArray(self): - v = Variable(dims=('x', 'y'), data=LazilyIndexedArray(self.d)) + def test_LazilyOuterIndexedArray(self): + v = Variable(dims=('x', 'y'), data=LazilyOuterIndexedArray(self.d)) self.check_orthogonal_indexing(v) - with raises_regex(NotImplementedError, 'Vectorized indexing for '): - self.check_vectorized_indexing(v) + self.check_vectorized_indexing(v) # doubly wrapping - v = Variable(dims=('x', 'y'), - data=LazilyIndexedArray(LazilyIndexedArray(self.d))) + v = Variable( + dims=('x', 'y'), + data=LazilyOuterIndexedArray(LazilyOuterIndexedArray(self.d))) self.check_orthogonal_indexing(v) # hierarchical wrapping - v = Variable(dims=('x', 'y'), - data=LazilyIndexedArray(NumpyIndexingAdapter(self.d))) + v = Variable( + dims=('x', 'y'), + data=LazilyOuterIndexedArray(NumpyIndexingAdapter(self.d))) self.check_orthogonal_indexing(v) def test_CopyOnWriteArray(self): @@ -1909,11 +2000,11 @@ def test_CopyOnWriteArray(self): self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # doubly wrapping - v = Variable(dims=('x', 'y'), - data=CopyOnWriteArray(LazilyIndexedArray(self.d))) + v = Variable( + dims=('x', 'y'), + data=CopyOnWriteArray(LazilyOuterIndexedArray(self.d))) self.check_orthogonal_indexing(v) - with raises_regex(NotImplementedError, 'Vectorized indexing for '): - self.check_vectorized_indexing(v) + self.check_vectorized_indexing(v) def test_MemoryCachedArray(self): v = Variable(dims=('x', 'y'), data=MemoryCachedArray(self.d)) diff --git a/xarray/tutorial.py b/xarray/tutorial.py index d7da63a328e..064eed330cc 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -5,18 +5,15 @@ * building tutorials in the documentation. ''' -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import hashlib - import os as _os +import warnings from .backends.api import open_dataset as _open_dataset from .core.pycompat import urlretrieve as _urlretrieve - _default_cache_dir = _os.sep.join(('~', '.xarray_tutorial_data')) @@ -28,7 +25,7 @@ def file_md5_checksum(fname): # idea borrowed from Seaborn -def load_dataset(name, cache=True, cache_dir=_default_cache_dir, +def open_dataset(name, cache=True, cache_dir=_default_cache_dir, github_url='https://github.com/pydata/xarray-data', branch='master', **kws): """ @@ -52,6 +49,10 @@ def load_dataset(name, cache=True, cache_dir=_default_cache_dir, kws : dict, optional Passed to xarray.open_dataset + See Also + -------- + xarray.open_dataset + """ longdir = _os.path.expanduser(cache_dir) fullname = name + '.nc' @@ -81,9 +82,27 @@ def load_dataset(name, cache=True, cache_dir=_default_cache_dir, """ raise IOError(msg) - ds = _open_dataset(localfile, **kws).load() + ds = _open_dataset(localfile, **kws) if not cache: + ds = ds.load() _os.remove(localfile) return ds + + +def load_dataset(*args, **kwargs): + """ + `load_dataset` will be removed in version 0.12. The current behavior of + this function can be achived by using `tutorial.open_dataset(...).load()`. + + See Also + -------- + open_dataset + """ + warnings.warn( + "load_dataset` will be removed in xarray version 0.12. The current " + "behavior of this function can be achived by using " + "`tutorial.open_dataset(...).load()`.", + DeprecationWarning, stacklevel=2) + return open_dataset(*args, **kwargs).load() diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index 1990ac5b765..628f8568a6d 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -13,20 +13,18 @@ Once NumPy 1.10 comes out with support for overriding ufuncs, this module will hopefully no longer be necessary. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + +import warnings as _warnings import numpy as _np -from .core.variable import Variable as _Variable -from .core.dataset import Dataset as _Dataset from .core.dataarray import DataArray as _DataArray +from .core.dataset import Dataset as _Dataset +from .core.duck_array_ops import _dask_or_eager_func from .core.groupby import GroupBy as _GroupBy - from .core.pycompat import dask_array_type as _dask_array_type -from .core.duck_array_ops import _dask_or_eager_func - +from .core.variable import Variable as _Variable _xarray_types = (_Variable, _DataArray, _Dataset, _GroupBy) _dispatch_order = (_np.ndarray, _dask_array_type) + _xarray_types @@ -46,8 +44,13 @@ def __init__(self, name): self._name = name def __call__(self, *args, **kwargs): + _warnings.warn( + 'xarray.ufuncs will be deprecated when xarray no longer supports ' + 'versions of numpy older than v1.13. Instead, use numpy ufuncs ' + 'directly.', PendingDeprecationWarning, stacklevel=2) + new_args = args - f = _dask_or_eager_func(self._name, n_array_args=len(args)) + f = _dask_or_eager_func(self._name, array_args=slice(len(args))) if len(args) > 2 or len(args) == 0: raise TypeError('cannot handle %s arguments for %r' % (len(args), self._name)) diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index b9bd6e88547..5459e67e603 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -2,14 +2,16 @@ see pandas/pandas/util/_print_versions.py''' +from __future__ import absolute_import + +import codecs +import importlib +import locale import os import platform -import sys import struct import subprocess -import codecs -import locale -import importlib +import sys def get_sys_info(): @@ -42,7 +44,7 @@ def get_sys_info(): (sysname, nodename, release, version, machine, processor) = platform.uname() blob.extend([ - ("python", "%d.%d.%d.%s.%s" % sys.version_info[:]), + ("python", sys.version), ("python-bits", struct.calcsize("P") * 8), ("OS", "%s" % (sysname)), ("OS-release", "%s" % (release)), @@ -61,9 +63,27 @@ def get_sys_info(): return blob +def netcdf_and_hdf5_versions(): + libhdf5_version = None + libnetcdf_version = None + try: + import netCDF4 + libhdf5_version = netCDF4.__hdf5libversion__ + libnetcdf_version = netCDF4.__netcdf4libversion__ + except ImportError: + try: + import h5py + libhdf5_version = h5py.__hdf5libversion__ + except ImportError: + pass + return [('libhdf5', libhdf5_version), ('libnetcdf', libnetcdf_version)] + + def show_versions(as_json=False): sys_info = get_sys_info() + sys_info.extend(netcdf_and_hdf5_versions()) + deps = [ # (MODULE_NAME, f(mod) -> mod version) ("xarray", lambda mod: mod.__version__), @@ -72,11 +92,16 @@ def show_versions(as_json=False): ("scipy", lambda mod: mod.__version__), # xarray optionals ("netCDF4", lambda mod: mod.__version__), - # ("pydap", lambda mod: mod.version.version), + ("pydap", lambda mod: mod.__version__), ("h5netcdf", lambda mod: mod.__version__), ("h5py", lambda mod: mod.__version__), ("Nio", lambda mod: mod.__version__), ("zarr", lambda mod: mod.__version__), + ("cftime", lambda mod: mod.__version__), + ("PseudonetCDF", lambda mod: mod.__version__), + ("rasterio", lambda mod: mod.__version__), + ("cfgrib", lambda mod: mod.__version__), + ("iris", lambda mod: mod.__version__), ("bottleneck", lambda mod: mod.__version__), ("cyordereddict", lambda mod: mod.__version__), ("dask", lambda mod: mod.__version__), @@ -101,10 +126,14 @@ def show_versions(as_json=False): mod = sys.modules[modname] else: mod = importlib.import_module(modname) - ver = ver_f(mod) - deps_blob.append((modname, ver)) except Exception: deps_blob.append((modname, None)) + else: + try: + ver = ver_f(mod) + deps_blob.append((modname, ver)) + except Exception: + deps_blob.append((modname, 'installed')) if (as_json): try: