-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add asynchronous load method #10327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add asynchronous load method #10327
Conversation
for more information, see https://pre-commit.ci
xarray/backends/common.py
Outdated
@@ -267,13 +268,23 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 | |||
time.sleep(1e-3 * next_delay) | |||
|
|||
|
|||
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): | |||
class BackendArray(ABC, NdimSizeLenMixin, indexing.ExplicitlyIndexed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As __getitem__
is required, I feel like BackendArray
should always have been an ABC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is public API and this is a backwards incompatible change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is technically, but only if someone is using this class in a way counter to what the docs explicitly tell you to do (i.e. subclass it).
Regardless this is orthogonal to the rest of the PR, I can remove it, I was just trying to clean up bad things I found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted in 6c47e3f
async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: | ||
raise NotImplementedError("Backend does not not support asynchronous loading") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented this for the ZarrArray
class but in theory it could be supported by other backends too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might not be the desired behaviour though - this currently means if you opened a dataset from netCDF and called ds.load_async
you would get a NotImplementedError
. Would it be better to quietly just block instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes absolutely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I can do that. But can you explain why you feel that this would be better behaviour? Asking for something to be done async and it quietly blocking also seems not great...
for more information, see https://pre-commit.ci
# load everything else concurrently | ||
coros = [ | ||
v.load_async() for k, v in self.variables.items() if k not in chunked_data | ||
] | ||
await asyncio.gather(*coros) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could actually do this same thing inside of the synchronous ds.load()
too, but it would require:
- Xarray to decide how to call the async code, e.g. with a
ThreadPool
or similar (see Support concurrent loading of variables #8965) - The backend to support
async_getitem
(it could fall back to synchronous loading if it's not supported)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rate-limite all gather
calls with a Semaphore using something like this:
async def async_gather(*coros, concurrency: Optional[int] = None, return_exceptions: bool = False) -> list[Any]:
"""Execute a gather while limiting the number of concurrent tasks.
Args:
coros: coroutines
list of coroutines to execute
concurrency: int
concurrency limit
if None, defaults to config_obj.get('async.concurrency', 4)
if <= 0, no concurrency limit
"""
if concurrency is None:
concurrency = int(config_obj.get("async.concurrency", 4))
if concurrency > 0:
# if concurrency > 0, we use a semaphore to limit the number of concurrent coroutines
semaphore = asyncio.Semaphore(concurrency)
async def sem_coro(coro):
async with semaphore:
return await coro
results = await asyncio.gather(*(sem_coro(c) for c in coros), return_exceptions=return_exceptions)
else:
results = await asyncio.gather(*coros, return_exceptions=return_exceptions)
return results
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arguably that should be left to the underlying storage layer. Zarr already has its own rate limiting. Why introduce this additional complexity and configuration parameter in Xarray?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does zarr rate-limit per call or globally though? If it's rate-limited per call, and we make lots of concurrent calls from the xarray API, it will exceed the intended rate set in zarr...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% on what Zarr will do but this will rate limit across Xarray variables. We will undoubtedly want to offer control here, even if the default is None for a start.
There is something funky going on when using # /// script
# requires-python = ">=3.12"
# dependencies = [
# "arraylake",
# "yappi",
# "zarr==3.0.8",
# "xarray",
# "icechunk"
# ]
#
# [tool.uv.sources]
# xarray = { git = "https://github.com/TomNicholas/xarray", rev = "async.load" }
# ///
import asyncio
from collections.abc import Iterable
from typing import TypeVar
import numpy as np
import xarray as xr
import zarr
from zarr.abc.store import ByteRequest, Store
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.storage._wrapper import WrapperStore
T_Store = TypeVar("T_Store", bound=Store)
class LatencyStore(WrapperStore[T_Store]):
"""Works the same way as the zarr LoggingStore"""
latency: float
def __init__(
self,
store: T_Store,
latency: float = 0.0,
) -> None:
"""
Store wrapper that adds artificial latency to each get call.
Parameters
----------
store : Store
Store to wrap
latency : float
Amount of artificial latency to add to each get call, in seconds.
"""
super().__init__(store)
self.latency = latency
def __str__(self) -> str:
return f"latency-{self._store}"
def __repr__(self) -> str:
return f"LatencyStore({self._store.__class__.__name__}, '{self._store}', latency={self.latency})"
async def get(
self,
key: str,
prototype: BufferPrototype,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
await asyncio.sleep(self.latency)
return await self._store.get(
key=key, prototype=prototype, byte_range=byte_range
)
async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
await asyncio.sleep(self.latency)
return await self._store.get_partial_values(
prototype=prototype, key_ranges=key_ranges
)
memorystore = zarr.storage.MemoryStore({})
shape = 5
X = np.arange(5) * 10
ds = xr.Dataset(
{
"data": xr.DataArray(
np.zeros(shape),
coords={"x": X},
)
}
)
ds.to_zarr(memorystore)
latencystore = LatencyStore(memorystore, latency=0.1)
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
# no problem for any of these
asyncio.run(ds["data"][0].load_async())
asyncio.run(ds["data"].sel(x=10).load_async())
asyncio.run(ds["data"].sel(x=11, method="nearest").load_async())
# also fine
ds["data"].sel(x=[30, 40]).load()
# broken!
asyncio.run(ds["data"].sel(x=[30, 40]).load_async())
|
xarray/backends/zarr.py
Outdated
elif isinstance(key, indexing.VectorizedIndexer): | ||
# TODO | ||
method = self._vindex | ||
elif isinstance(key, indexing.OuterIndexer): | ||
# TODO | ||
method = self._oindex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ianhi almost certainly these need to become async to fix your bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outer (also known as "Orthogonal") indexing support added in 5eacdb0, but requires changes to zarr-python: zarr-developers/zarr-python#3083
for more information, see https://pre-commit.ci
xarray/tests/test_async.py
Outdated
# test vectorized indexing | ||
# TODO this shouldn't pass! I haven't implemented async vectorized indexing yet... | ||
indexer = xr.DataArray([2, 3], dims=["x"]) | ||
result = await ds.foo[indexer].load_async() | ||
xrt.assert_identical(result, ds.foo[indexer].load()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This currently passes, even though it shouldn't, because I haven't added support for async vectorized indexing yet!
I think this means that my test is wrong, and what I'm doing here is apparently not vectorized indexing. I'm unsure what my test would have to look like though 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is an outer indexer. Try xr.DataArray([[2, 3]], dims=["y", "x"])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I since worked this out, but apparently haven't pushed those changes. Note that it requires changes in Zarr too to make async lazy vectorized indexing work
async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: | ||
raise NotImplementedError("Backend does not not support asynchronous loading") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes absolutely.
xarray/backends/common.py
Outdated
@@ -267,13 +268,23 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 | |||
time.sleep(1e-3 * next_delay) | |||
|
|||
|
|||
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): | |||
class BackendArray(ABC, NdimSizeLenMixin, indexing.ExplicitlyIndexed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is public API and this is a backwards incompatible change.
# load everything else concurrently | ||
coros = [ | ||
v.load_async() for k, v in self.variables.items() if k not in chunked_data | ||
] | ||
await asyncio.gather(*coros) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rate-limite all gather
calls with a Semaphore using something like this:
async def async_gather(*coros, concurrency: Optional[int] = None, return_exceptions: bool = False) -> list[Any]:
"""Execute a gather while limiting the number of concurrent tasks.
Args:
coros: coroutines
list of coroutines to execute
concurrency: int
concurrency limit
if None, defaults to config_obj.get('async.concurrency', 4)
if <= 0, no concurrency limit
"""
if concurrency is None:
concurrency = int(config_obj.get("async.concurrency", 4))
if concurrency > 0:
# if concurrency > 0, we use a semaphore to limit the number of concurrent coroutines
semaphore = asyncio.Semaphore(concurrency)
async def sem_coro(coro):
async with semaphore:
return await coro
results = await asyncio.gather(*(sem_coro(c) for c in coros), return_exceptions=return_exceptions)
else:
results = await asyncio.gather(*coros, return_exceptions=return_exceptions)
return results
case "ds": | ||
return ds | ||
|
||
def assert_time_as_expected( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's instead use mocks to assert the async methods were called. Xarray's job is to do that only
xarray/tests/test_async.py
Outdated
# test vectorized indexing | ||
# TODO this shouldn't pass! I haven't implemented async vectorized indexing yet... | ||
indexer = xr.DataArray([2, 3], dims=["x"]) | ||
result = await ds.foo[indexer].load_async() | ||
xrt.assert_identical(result, ds.foo[indexer].load()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is an outer indexer. Try xr.DataArray([[2, 3]], dims=["y", "x"])
xarray/core/indexing.py
Outdated
async def _async_ensure_cached(self): | ||
duck_array = await self.array.async_get_duck_array() | ||
self.array = as_indexable(duck_array) | ||
|
||
def get_duck_array(self): | ||
self._ensure_cached() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_ensure_cached
seems like pointless indirection, it is only used once. let's consolidate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in 884ce13, but I still feel like it could be simplified further. Does it really need to have the side-effect of re-assigning to self.array
?
return self | ||
|
||
async def load_async(self, **kwargs) -> Self: | ||
# TODO refactor this to pull out the common chunked_data codepath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's instead just have the sync methods issue a blocking call to the async versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that would solve the use case in xpublish though? You need to be able to asynchronously trigger loading for a bunch of separate dataset objects, which requires an async load api to be exposed, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I understand what you mean now, you're not talking about the API, you're just talking about my comment about internal refactoring. You're proposing we do what zarr does internally, which makes sense.
Adds an
.async_load()
method toVariable
, which works by plumbing asyncget_duck_array
all the way down until it finally gets to the async methods zarr v3 exposes.Needs a lot of refactoring before it could be merged, but it works.
whats-new.rst
api.rst
API:
Variable.load_async
DataArray.load_async
Dataset.load_async
DataTree.load_async
load_dataset
?load_dataarray
?