Skip to content

Add asynchronous load method #10327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 57 commits into
base: main
Choose a base branch
from
Draft

Conversation

TomNicholas
Copy link
Member

@TomNicholas TomNicholas commented May 16, 2025

Adds an .async_load() method to Variable, which works by plumbing async get_duck_array all the way down until it finally gets to the async methods zarr v3 exposes.

Needs a lot of refactoring before it could be merged, but it works.

API:

  • Variable.load_async
  • DataArray.load_async
  • Dataset.load_async
  • DataTree.load_async
  • load_dataset?
  • load_dataarray?

TomNicholas and others added 21 commits October 24, 2024 17:48
@@ -267,13 +268,23 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
time.sleep(1e-3 * next_delay)


class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
class BackendArray(ABC, NdimSizeLenMixin, indexing.ExplicitlyIndexed):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As __getitem__ is required, I feel like BackendArray should always have been an ABC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is public API and this is a backwards incompatible change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is technically, but only if someone is using this class in a way counter to what the docs explicitly tell you to do (i.e. subclass it).

Regardless this is orthogonal to the rest of the PR, I can remove it, I was just trying to clean up bad things I found.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted in 6c47e3f

Comment on lines +277 to +278
async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike:
raise NotImplementedError("Backend does not not support asynchronous loading")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented this for the ZarrArray class but in theory it could be supported by other backends too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be the desired behaviour though - this currently means if you opened a dataset from netCDF and called ds.load_async you would get a NotImplementedError. Would it be better to quietly just block instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes absolutely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I can do that. But can you explain why you feel that this would be better behaviour? Asking for something to be done async and it quietly blocking also seems not great...

Comment on lines +574 to +578
# load everything else concurrently
coros = [
v.load_async() for k, v in self.variables.items() if k not in chunked_data
]
await asyncio.gather(*coros)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could actually do this same thing inside of the synchronous ds.load() too, but it would require:

  1. Xarray to decide how to call the async code, e.g. with a ThreadPool or similar (see Support concurrent loading of variables #8965)
  2. The backend to support async_getitem (it could fall back to synchronous loading if it's not supported)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rate-limite all gather calls with a Semaphore using something like this:

async def async_gather(*coros, concurrency: Optional[int] = None, return_exceptions: bool = False) -> list[Any]:
    """Execute a gather while limiting the number of concurrent tasks.

    Args:
        coros: coroutines
            list of coroutines to execute
        concurrency: int
            concurrency limit
            if None, defaults to config_obj.get('async.concurrency', 4)
            if <= 0, no concurrency limit

    """
    if concurrency is None:
        concurrency = int(config_obj.get("async.concurrency", 4))

    if concurrency > 0:
        # if concurrency > 0, we use a semaphore to limit the number of concurrent coroutines
        semaphore = asyncio.Semaphore(concurrency)

        async def sem_coro(coro):
            async with semaphore:
                return await coro

        results = await asyncio.gather(*(sem_coro(c) for c in coros), return_exceptions=return_exceptions)
    else:
        results = await asyncio.gather(*coros, return_exceptions=return_exceptions)

    return results

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably that should be left to the underlying storage layer. Zarr already has its own rate limiting. Why introduce this additional complexity and configuration parameter in Xarray?

Copy link
Member Author

@TomNicholas TomNicholas May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does zarr rate-limit per call or globally though? If it's rate-limited per call, and we make lots of concurrent calls from the xarray API, it will exceed the intended rate set in zarr...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% on what Zarr will do but this will rate limit across Xarray variables. We will undoubtedly want to offer control here, even if the default is None for a start.

@ianhi
Copy link
Contributor

ianhi commented May 22, 2025

There is something funky going on when using .sel

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "arraylake",
#     "yappi",
#     "zarr==3.0.8",
#     "xarray",
#     "icechunk"
# ]
#
# [tool.uv.sources]
# xarray = { git = "https://github.com/TomNicholas/xarray", rev = "async.load" }
# ///

import asyncio
from collections.abc import Iterable
from typing import TypeVar

import numpy as np

import xarray as xr

import zarr
from zarr.abc.store import ByteRequest, Store
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.storage._wrapper import WrapperStore

T_Store = TypeVar("T_Store", bound=Store)


class LatencyStore(WrapperStore[T_Store]):
    """Works the same way as the zarr LoggingStore"""

    latency: float

    def __init__(
        self,
        store: T_Store,
        latency: float = 0.0,
    ) -> None:
        """
        Store wrapper that adds artificial latency to each get call.

        Parameters
        ----------
        store : Store
            Store to wrap
        latency : float
            Amount of artificial latency to add to each get call, in seconds.
        """
        super().__init__(store)
        self.latency = latency

    def __str__(self) -> str:
        return f"latency-{self._store}"

    def __repr__(self) -> str:
        return f"LatencyStore({self._store.__class__.__name__}, '{self._store}', latency={self.latency})"

    async def get(
        self,
        key: str,
        prototype: BufferPrototype,
        byte_range: ByteRequest | None = None,
    ) -> Buffer | None:
        await asyncio.sleep(self.latency)
        return await self._store.get(
            key=key, prototype=prototype, byte_range=byte_range
        )

    async def get_partial_values(
        self,
        prototype: BufferPrototype,
        key_ranges: Iterable[tuple[str, ByteRequest | None]],
    ) -> list[Buffer | None]:
        await asyncio.sleep(self.latency)
        return await self._store.get_partial_values(
            prototype=prototype, key_ranges=key_ranges
        )


memorystore = zarr.storage.MemoryStore({})

shape = 5
X = np.arange(5) * 10
ds = xr.Dataset(
    {
        "data": xr.DataArray(
            np.zeros(shape),
            coords={"x": X},
        )
    }
)

ds.to_zarr(memorystore)


latencystore = LatencyStore(memorystore, latency=0.1)
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)

# no problem for any of these
asyncio.run(ds["data"][0].load_async())
asyncio.run(ds["data"].sel(x=10).load_async())
asyncio.run(ds["data"].sel(x=11, method="nearest").load_async())

# also fine
ds["data"].sel(x=[30, 40]).load()

# broken!
asyncio.run(ds["data"].sel(x=[30, 40]).load_async())

uv run that script gives:

Traceback (most recent call last):
  File "/Users/ian/tmp/async_error.py", line 109, in <module>
    asyncio.run(ds["data"].sel(x=[30, 40]).load_async())
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/dataarray.py", line 1165, in load_async
    ds = await temp_ds.load_async(**kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/dataset.py", line 578, in load_async
    await asyncio.gather(*coros)
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/variable.py", line 963, in load_async
    self._data = await async_to_duck_array(self._data, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/namedarray/pycompat.py", line 168, in async_to_duck_array
    return await data.async_get_duck_array()  # type: ignore[no-untyped-call, no-any-return]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 875, in async_get_duck_array
    await self._async_ensure_cached()
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 867, in _async_ensure_cached
    duck_array = await self.array.async_get_duck_array()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 821, in async_get_duck_array
    return await self.array.async_get_duck_array()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 674, in async_get_duck_array
    array = await self.array.async_getitem(self.key)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/backends/zarr.py", line 248, in async_getitem
    return await indexing.async_explicit_indexing_adapter(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 1068, in async_explicit_indexing_adapter
    result = await raw_indexing_method(raw_key.tuple)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: object numpy.ndarray can't be used in 'await' expression

Comment on lines 240 to 245
elif isinstance(key, indexing.VectorizedIndexer):
# TODO
method = self._vindex
elif isinstance(key, indexing.OuterIndexer):
# TODO
method = self._oindex
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ianhi almost certainly these need to become async to fix your bug

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outer (also known as "Orthogonal") indexing support added in 5eacdb0, but requires changes to zarr-python: zarr-developers/zarr-python#3083

Comment on lines 192 to 196
# test vectorized indexing
# TODO this shouldn't pass! I haven't implemented async vectorized indexing yet...
indexer = xr.DataArray([2, 3], dims=["x"])
result = await ds.foo[indexer].load_async()
xrt.assert_identical(result, ds.foo[indexer].load())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently passes, even though it shouldn't, because I haven't added support for async vectorized indexing yet!

I think this means that my test is wrong, and what I'm doing here is apparently not vectorized indexing. I'm unsure what my test would have to look like though 😕

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an outer indexer. Try xr.DataArray([[2, 3]], dims=["y", "x"])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I since worked this out, but apparently haven't pushed those changes. Note that it requires changes in Zarr too to make async lazy vectorized indexing work

zarr-developers/zarr-python#3083

Comment on lines +277 to +278
async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike:
raise NotImplementedError("Backend does not not support asynchronous loading")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes absolutely.

@@ -267,13 +268,23 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
time.sleep(1e-3 * next_delay)


class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
class BackendArray(ABC, NdimSizeLenMixin, indexing.ExplicitlyIndexed):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is public API and this is a backwards incompatible change.

Comment on lines +574 to +578
# load everything else concurrently
coros = [
v.load_async() for k, v in self.variables.items() if k not in chunked_data
]
await asyncio.gather(*coros)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rate-limite all gather calls with a Semaphore using something like this:

async def async_gather(*coros, concurrency: Optional[int] = None, return_exceptions: bool = False) -> list[Any]:
    """Execute a gather while limiting the number of concurrent tasks.

    Args:
        coros: coroutines
            list of coroutines to execute
        concurrency: int
            concurrency limit
            if None, defaults to config_obj.get('async.concurrency', 4)
            if <= 0, no concurrency limit

    """
    if concurrency is None:
        concurrency = int(config_obj.get("async.concurrency", 4))

    if concurrency > 0:
        # if concurrency > 0, we use a semaphore to limit the number of concurrent coroutines
        semaphore = asyncio.Semaphore(concurrency)

        async def sem_coro(coro):
            async with semaphore:
                return await coro

        results = await asyncio.gather(*(sem_coro(c) for c in coros), return_exceptions=return_exceptions)
    else:
        results = await asyncio.gather(*coros, return_exceptions=return_exceptions)

    return results

case "ds":
return ds

def assert_time_as_expected(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's instead use mocks to assert the async methods were called. Xarray's job is to do that only

Comment on lines 192 to 196
# test vectorized indexing
# TODO this shouldn't pass! I haven't implemented async vectorized indexing yet...
indexer = xr.DataArray([2, 3], dims=["x"])
result = await ds.foo[indexer].load_async()
xrt.assert_identical(result, ds.foo[indexer].load())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an outer indexer. Try xr.DataArray([[2, 3]], dims=["y", "x"])

async def _async_ensure_cached(self):
duck_array = await self.array.async_get_duck_array()
self.array = as_indexable(duck_array)

def get_duck_array(self):
self._ensure_cached()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ensure_cached seems like pointless indirection, it is only used once. let's consolidate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in 884ce13, but I still feel like it could be simplified further. Does it really need to have the side-effect of re-assigning to self.array?

return self

async def load_async(self, **kwargs) -> Self:
# TODO refactor this to pull out the common chunked_data codepath
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's instead just have the sync methods issue a blocking call to the async versions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that would solve the use case in xpublish though? You need to be able to asynchronously trigger loading for a bunch of separate dataset objects, which requires an async load api to be exposed, no?

Copy link
Member Author

@TomNicholas TomNicholas May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I understand what you mean now, you're not talking about the API, you're just talking about my comment about internal refactoring. You're proposing we do what zarr does internally, which makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI Continuous Integration tools dependencies Pull requests that update a dependency file enhancement io topic-backends topic-documentation topic-indexing topic-NamedArray Lightweight version of Variable topic-zarr Related to zarr storage library
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add an asynchronous load method?
4 participants