diff --git a/docs/api.rst b/docs/api.rst index e902c069..488ffe55 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,5 +1,3 @@ -.. currentmodule:: xarray - API reference ============= This page contains a auto-generated summary of ``pint-xarray``'s API. @@ -11,11 +9,11 @@ Dataset :toctree: generated/ :template: autosummary/accessor_method.rst - Dataset.pint.quantify - Dataset.pint.dequantify - Dataset.pint.to - Dataset.pint.to_base_units - Dataset.pint.to_system + xarray.Dataset.pint.quantify + xarray.Dataset.pint.dequantify + xarray.Dataset.pint.to + xarray.Dataset.pint.to_base_units + xarray.Dataset.pint.to_system DataArray --------- @@ -23,16 +21,24 @@ DataArray :toctree: generated/ :template: autosummary/accessor_attribute.rst - DataArray.pint.magnitude - DataArray.pint.units - DataArray.pint.dimensionality - DataArray.pint.registry + xarray.DataArray.pint.magnitude + xarray.DataArray.pint.units + xarray.DataArray.pint.dimensionality + xarray.DataArray.pint.registry .. autosummary:: :toctree: generated/ :template: autosummary/accessor_method.rst - DataArray.pint.quantify - DataArray.pint.dequantify - DataArray.pint.to - DataArray.pint.to_base_units + xarray.DataArray.pint.quantify + xarray.DataArray.pint.dequantify + xarray.DataArray.pint.to + xarray.DataArray.pint.to_base_units + +Testing +------- + +.. autosummary:: + :toctree: generated/ + + pint_xarray.testing.assert_units_equal diff --git a/pint_xarray/__init__.py b/pint_xarray/__init__.py index 5d5791ac..204ba4dc 100644 --- a/pint_xarray/__init__.py +++ b/pint_xarray/__init__.py @@ -3,6 +3,7 @@ except ImportError: from importlib_metadata import version +from . import testing # noqa from .accessors import PintDataArrayAccessor, PintDatasetAccessor # noqa try: diff --git a/pint_xarray/testing.py b/pint_xarray/testing.py new file mode 100644 index 00000000..2ad19020 --- /dev/null +++ b/pint_xarray/testing.py @@ -0,0 +1,22 @@ +from . import conversion + + +def assert_units_equal(a, b): + """ assert that the units of two xarray objects are equal + + Raises an :py:exc:`AssertionError` if the units of both objects are not + equal. ``units`` attributes and attached unit objects are compared + separately. + + Parameters + ---------- + a, b : DataArray or Dataset + The objects to compare + """ + + __tracebackhide__ = True + + assert conversion.extract_units(a) == conversion.extract_units(b) + assert conversion.extract_unit_attributes(a) == conversion.extract_unit_attributes( + b + ) diff --git a/pint_xarray/tests/test_testing.py b/pint_xarray/tests/test_testing.py new file mode 100644 index 00000000..bc2e9a3e --- /dev/null +++ b/pint_xarray/tests/test_testing.py @@ -0,0 +1,58 @@ +import pint +import pytest +import xarray as xr + +from pint_xarray import testing + +unit_registry = pint.UnitRegistry(force_ndarray_like=True) + + +@pytest.mark.parametrize( + ("a", "b", "error"), + ( + pytest.param( + xr.DataArray(attrs={"units": "K"}), + xr.DataArray(attrs={"units": "K"}), + None, + id="equal attrs", + ), + pytest.param( + xr.DataArray(attrs={"units": "m"}), + xr.DataArray(attrs={"units": "K"}), + AssertionError, + id="different attrs", + ), + pytest.param( + xr.DataArray([10, 20] * unit_registry.K), + xr.DataArray([50, 80] * unit_registry.K), + None, + id="equal units", + ), + pytest.param( + xr.DataArray([10, 20] * unit_registry.K), + xr.DataArray([50, 80] * unit_registry.dimensionless), + AssertionError, + id="different units", + ), + pytest.param( + xr.Dataset({"a": ("x", [0, 10], {"units": "K"})}), + xr.Dataset({"a": ("x", [20, 40], {"units": "K"})}), + None, + id="matching variables", + ), + pytest.param( + xr.Dataset({"a": ("x", [0, 10], {"units": "K"})}), + xr.Dataset({"b": ("x", [20, 40], {"units": "K"})}), + AssertionError, + id="mismatching variables", + ), + ), +) +def test_assert_units_equal(a, b, error): + if error is not None: + with pytest.raises(error): + testing.assert_units_equal(a, b) + + return + + testing.assert_units_equal(a, b) diff --git a/pyproject.toml b/pyproject.toml index 0e472c79..83576b59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,4 +2,5 @@ requires = ["setuptools >= 42", "wheel", "setuptools_scm[toml] >= 3.4"] build-backend = "setuptools.build_meta" -[tool.setuptools_scm] \ No newline at end of file +[tool.setuptools_scm] +fallback_version = "999"