Skip to content

fix sync group class methods #1652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zarr/v3/codecs/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BytesCodecConfigurationMetadata:
@frozen
class BytesCodecMetadata:
configuration: BytesCodecConfigurationMetadata
name: Literal["bytes"] = field(default="bytes", init=False)
name: Literal["bytes"] = field(default="bytes", init=True)


@frozen
Expand Down
17 changes: 10 additions & 7 deletions src/zarr/v3/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
from zarr.v3.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON, ZGROUP_JSON, make_cattr
from zarr.v3.config import RuntimeConfiguration, SyncConfiguration
from zarr.v3.store import StoreLike, StorePath, make_store_path
from zarr.v3.sync import SyncMixin
from zarr.v3.sync import SyncMixin, sync

logger = logging.getLogger("zarr.group")


@frozen
class GroupMetadata:
attributes: Dict[str, Any] = field(factory=dict)
zarr_format: Literal[2, 3] = 3 # field(default=3, validator=validators.in_([2, 3]))
node_type: Literal["group"] = field(default="group", init=False)
zarr_format: Literal[2, 3] = 3
node_type: Literal["group"] = field(default="group", init=True)

def to_bytes(self) -> Dict[str, bytes]:
if self.zarr_format == 3:
Expand Down Expand Up @@ -52,7 +52,7 @@ async def create(
*,
attributes: Optional[Dict[str, Any]] = None,
exists_ok: bool = False,
zarr_format: Literal[2, 3] = 3, # field(default=3, validator=validators.in_([2, 3])),
zarr_format: Literal[2, 3] = 3,
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
) -> AsyncGroup:
store_path = make_store_path(store)
Expand Down Expand Up @@ -305,13 +305,14 @@ def create(
exists_ok: bool = False,
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
) -> Group:
obj = cls._sync(
obj = sync(
AsyncGroup.create(
store,
attributes=attributes,
exists_ok=exists_ok,
runtime_configuration=runtime_configuration,
)
),
loop=runtime_configuration.asyncio_loop,
)

return cls(obj)
Expand All @@ -322,7 +323,9 @@ def open(
store: StoreLike,
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
) -> Group:
obj = cls._sync(AsyncGroup.open(store, runtime_configuration))
obj = sync(
AsyncGroup.open(store, runtime_configuration), loop=runtime_configuration.asyncio_loop
)
return cls(obj)

def __getitem__(self, path: str) -> Union[Array, Group]:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_group_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,14 @@ def test_group(store_path) -> None:
# and the attrs were modified in the store
bar3 = foo["bar"]
assert dict(bar3.attrs) == {"baz": "qux", "name": "bar"}


def test_group_sync_constructor(store_path) -> None:

group = Group.create(
store=store_path,
attributes={"title": "test 123"},
runtime_configuration=RuntimeConfiguration(),
)

assert group._async_group.metadata.attributes["title"] == "test 123"