Skip to content

[v3] Sync with futures #1804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 24, 2024
Merged
16 changes: 10 additions & 6 deletions src/zarr/v3/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,15 +527,19 @@ def __setitem__(self, selection: Selection, value: np.ndarray) -> None:
)

def resize(self, new_shape: ChunkCoords) -> Array:
return sync(
self._async_array.resize(new_shape),
self._async_array.runtime_configuration.asyncio_loop,
return type(self)(
sync(
self._async_array.resize(new_shape),
self._async_array.runtime_configuration.asyncio_loop,
)
)

def update_attributes(self, new_attributes: Dict[str, Any]) -> Array:
return sync(
self._async_array.update_attributes(new_attributes),
self._async_array.runtime_configuration.asyncio_loop,
return type(self)(
sync(
self._async_array.update_attributes(new_attributes),
self._async_array.runtime_configuration.asyncio_loop,
)
)

def __repr__(self):
Expand Down
1 change: 1 addition & 0 deletions src/zarr/v3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class SyncConfiguration:
concurrency: Optional[int] = None
asyncio_loop: Optional[AbstractEventLoop] = None
timeout: float | None = None


def parse_indexing_order(data: Any) -> Literal["C", "F"]:
Expand Down
24 changes: 18 additions & 6 deletions src/zarr/v3/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,24 +415,36 @@ def nchildren(self) -> int:

@property
def children(self) -> List[Union[Array, Group]]:
_children = self._sync_iter(self._async_group.children())
return [Array(obj) if isinstance(obj, AsyncArray) else Group(obj) for obj in _children]
raise NotImplementedError
# Uncomment with AsyncGroup implements this method
# _children: List[Union[AsyncArray, AsyncGroup]] = self._sync_iter(
# self._async_group.children()
# )
# return [Array(obj) if isinstance(obj, AsyncArray) else Group(obj) for obj in _children]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of these changes were to make mypy happy.


def __contains__(self, child) -> bool:
return self._sync(self._async_group.contains(child))

def group_keys(self) -> List[str]:
return self._sync_iter(self._async_group.group_keys())
raise NotImplementedError
# uncomment with AsyncGroup implements this method
# return self._sync_iter(self._async_group.group_keys())

def groups(self) -> List[Group]:
# TODO: in v2 this was a generator that return key: Group
return [Group(obj) for obj in self._sync_iter(self._async_group.groups())]
raise NotImplementedError
# uncomment with AsyncGroup implements this method
# return [Group(obj) for obj in self._sync_iter(self._async_group.groups())]

def array_keys(self) -> List[str]:
return self._sync_iter(self._async_group.array_keys())
# uncomment with AsyncGroup implements this method
# return self._sync_iter(self._async_group.array_keys())
raise NotImplementedError

def arrays(self) -> List[Array]:
return [Array(obj) for obj in self._sync_iter(self._async_group.arrays())]
raise NotImplementedError
# uncomment with AsyncGroup implements this method
# return [Array(obj) for obj in self._sync_iter(self._async_group.arrays())]

def tree(self, expand=False, level=None) -> Any:
return self._sync(self._async_group.tree(expand=expand, level=level))
Expand Down
89 changes: 50 additions & 39 deletions src/zarr/v3/sync.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
from typing import Any, AsyncIterator, Coroutine

import asyncio
from concurrent.futures import wait
import threading
from typing import (
Any,
AsyncIterator,
Coroutine,
List,
Optional,
TypeVar,
)

from typing_extensions import ParamSpec

from zarr.v3.config import SyncConfiguration

P = ParamSpec("P")
T = TypeVar("T")

# From https://github.com/fsspec/filesystem_spec/blob/master/fsspec/asyn.py

iothread: List[Optional[threading.Thread]] = [None] # dedicated IO thread
loop: List[Optional[asyncio.AbstractEventLoop]] = [
iothread: list[threading.Thread | None] = [None] # dedicated IO thread
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are targeting py >=3.9?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.10+ actually!

loop: list[asyncio.AbstractEventLoop | None] = [
None
] # global event loop for any non-async instance
_lock: Optional[threading.Lock] = None # global lock placeholder
_lock: threading.Lock | None = None # global lock placeholder
get_running_loop = asyncio.get_running_loop


class SyncError(Exception):
pass


def _get_lock() -> threading.Lock:
"""Allocate or return a threading lock.

Expand All @@ -36,16 +40,22 @@ def _get_lock() -> threading.Lock:
return _lock


async def _runner(event: threading.Event, coro: Coroutine, result_box: List[Optional[Any]]):
async def _runner(coro: Coroutine[Any, Any, T]) -> T | BaseException:
"""
Await a coroutine and return the result of running it. If awaiting the coroutine raises an
exception, the exception will be returned.
"""
try:
result_box[0] = await coro
return await coro
except Exception as ex:
result_box[0] = ex
finally:
event.set()
return ex


def sync(coro: Coroutine, loop: Optional[asyncio.AbstractEventLoop] = None):
def sync(
coro: Coroutine[Any, Any, T],
loop: asyncio.AbstractEventLoop | None = None,
timeout: float | None = None,
) -> T:
"""
Make loop run coroutine until it returns. Runs in other thread

Expand All @@ -57,30 +67,32 @@ def sync(coro: Coroutine, loop: Optional[asyncio.AbstractEventLoop] = None):
# NB: if the loop is not running *yet*, it is OK to submit work
# and we will wait for it
loop = _get_loop()
if loop is None or loop.is_closed():
if not isinstance(loop, asyncio.AbstractEventLoop):
raise TypeError(f"loop cannot be of type {type(loop)}")
if loop.is_closed():
raise RuntimeError("Loop is not running")
try:
loop0 = asyncio.events.get_running_loop()
if loop0 is loop:
raise NotImplementedError("Calling sync() from within a running loop")
raise SyncError("Calling sync() from within a running loop")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this was a bug before as NotImplementedError inherits from RuntimeError so this exception never raised.

except RuntimeError:
pass
result_box: List[Optional[Any]] = [None]
event = threading.Event()
asyncio.run_coroutine_threadsafe(_runner(event, coro, result_box), loop)
while True:
# this loops allows thread to get interrupted
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment here explains the way to loop was written. Does the changed code allow for an exception in the main thread (timeout, interrupt and other signals)? Does the GC run while waiting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the changed code allow for an exception in the main thread (timeout, interrupt and other signals)? Does the GC run while waiting?

We don't test for these things at present, so I have no idea! Tests are needed, regardless of the efforts in this PR, and I see that fsspec basically has this covered.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we open a ticket for finding a test that exercises this concern. I don't think we should hold back this PR though.

if event.wait(1):
break

return_result = result_box[0]

future = asyncio.run_coroutine_threadsafe(_runner(coro), loop)

finished, unfinished = wait([future], return_when=asyncio.ALL_COMPLETED, timeout=timeout)
if len(unfinished) > 0:
raise asyncio.TimeoutError(f"Coroutine {coro} failed to finish in within {timeout}s")
assert len(finished) == 1
return_result = list(finished)[0].result()

if isinstance(return_result, BaseException):
raise return_result
else:
return return_result


def _get_loop():
def _get_loop() -> asyncio.AbstractEventLoop:
"""Create or return the default fsspec IO loop

The loop will be running on a separate thread.
Expand All @@ -96,25 +108,24 @@ def _get_loop():
th.daemon = True
th.start()
iothread[0] = th
assert loop[0] is not None
return loop[0]


P = ParamSpec("P")
T = TypeVar("T")


class SyncMixin:
_sync_configuration: SyncConfiguration

def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T:
# TODO: refactor this to to take *args and **kwargs and pass those to the method
# this should allow us to better type the sync wrapper
return sync(coroutine, loop=self._sync_configuration.asyncio_loop)

def _sync_iter(self, coroutine: Coroutine[Any, Any, AsyncIterator[T]]) -> List[T]:
async def iter_to_list() -> List[T]:
# TODO: replace with generators so we don't materialize the entire iterator at once
async_iterator = await coroutine
return sync(
coroutine,
loop=self._sync_configuration.asyncio_loop,
timeout=self._sync_configuration.timeout,
)

def _sync_iter(self, async_iterator: AsyncIterator[T]) -> list[T]:
async def iter_to_list() -> list[T]:
return [item async for item in async_iterator]

return self._sync(iter_to_list())
129 changes: 129 additions & 0 deletions tests/v3/test_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from collections.abc import AsyncGenerator
import asyncio
import time
from unittest.mock import patch

from zarr.v3.sync import sync, _get_loop, _get_lock, SyncError, SyncMixin
from zarr.v3.config import SyncConfiguration

import pytest


@pytest.fixture(params=[True, False])
def sync_loop(request) -> asyncio.AbstractEventLoop | None:
if request.param is True:
return _get_loop()

if request.param is False:
return None


def test_get_loop() -> None:
# test that calling _get_loop() twice returns the same loop
loop = _get_loop()
loop2 = _get_loop()
assert loop is loop2


def test_get_lock() -> None:
# test that calling _get_lock() twice returns the same lock
lock = _get_lock()
lock2 = _get_lock()
assert lock is lock2


def test_sync(sync_loop: asyncio.AbstractEventLoop | None) -> None:
async def foo() -> str:
return "foo"

assert sync(foo(), loop=sync_loop) == "foo"


def test_sync_raises(sync_loop: asyncio.AbstractEventLoop | None) -> None:
async def foo() -> str:
raise ValueError("foo")

with pytest.raises(ValueError):
sync(foo(), loop=sync_loop)


def test_sync_timeout() -> None:
duration = 0.002

async def foo() -> None:
time.sleep(duration)

with pytest.raises(asyncio.TimeoutError):
sync(foo(), timeout=duration / 2)


def test_sync_raises_if_no_coroutine(sync_loop: asyncio.AbstractEventLoop | None) -> None:
def foo() -> str:
return "foo"

with pytest.raises(TypeError):
sync(foo(), loop=sync_loop)


@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited")
def test_sync_raises_if_loop_is_closed() -> None:
loop = _get_loop()

async def foo() -> str:
return "foo"

with patch.object(loop, "is_closed", return_value=True):
with pytest.raises(RuntimeError):
sync(foo(), loop=loop)


@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

despite this, we're still seeing the below warning. Why?

sys:1: RuntimeWarning: coroutine 'test_sync_raises_if_calling_sync_from_within_a_running_loop.<locals>.foo' was never awaited
RuntimeWarning: Enable tracemalloc to get the object allocation traceback

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning comes after the test has finished, during garbage collection. Maybe doing del on the objects (and closure) would fix. Or run the coroutine after the check.

def test_sync_raises_if_calling_sync_from_within_a_running_loop(
sync_loop: asyncio.AbstractEventLoop | None,
) -> None:
async def foo() -> str:
return "foo"

async def bar() -> str:
return sync(foo())

with pytest.raises(SyncError):
sync(bar(), loop=sync_loop)


@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited")
def test_sync_raises_if_loop_is_invalid_type() -> None:
async def foo() -> str:
return "foo"

with pytest.raises(TypeError):
sync(foo(), loop=1)


def test_sync_mixin(sync_loop) -> None:
class AsyncFoo:
def __init__(self) -> None:
pass

async def foo(self) -> str:
return "foo"

async def bar(self) -> AsyncGenerator:
for i in range(10):
yield i

class SyncFoo(SyncMixin):
def __init__(self, async_foo: AsyncFoo) -> None:
self._async_foo = async_foo
self._sync_configuration = SyncConfiguration(asyncio_loop=sync_loop)

def foo(self) -> str:
return self._sync(self._async_foo.foo())

def bar(self) -> list[int]:
return self._sync_iter(self._async_foo.bar())

async_foo = AsyncFoo()
foo = SyncFoo(async_foo)
assert foo.foo() == "foo"
assert foo.bar() == list(range(10))