Skip to content

Commit 8724f8a

Browse files
authored
cleanup format (#154)
1 parent 38c661c commit 8724f8a

File tree

6 files changed

+110
-53
lines changed

6 files changed

+110
-53
lines changed

.editorconfig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ end_of_line = lf
77
charset = utf-8
88
trim_trailing_whitespace = true
99
insert_final_newline = true
10+
line_length = 120
1011

1112
[Makefile]
1213
indent_style = tab

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Can't yet be moved to the pyproject.toml due to https://github.com/PyCQA/flake8/issues/234
22
[flake8]
3-
max-line-length = 88
3+
max-line-length = 120
44
ignore =
55
# line break before a binary operator -> black does not adhere to PEP8
66
W503

spatialdata/_io/format.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,11 @@
1717

1818

1919
class SpatialDataFormatV01(CurrentFormat):
20-
"""
21-
SpatialDataFormat defines the format of the spatialdata
22-
package.
23-
"""
20+
"""SpatialDataFormat defines the format of the spatialdata package."""
2421

25-
@property
26-
def spatialdata_version(self) -> str:
27-
return "0.1"
2822

29-
def validate_table(
30-
self,
31-
table: AnnData,
32-
region_key: Optional[str] = None,
33-
instance_key: Optional[str] = None,
34-
) -> None:
35-
if not isinstance(table, AnnData):
36-
raise TypeError(f"`tables` must be `anndata.AnnData`, was {type(table)}.")
37-
if region_key is not None:
38-
if not is_categorical_dtype(table.obs[region_key]):
39-
raise ValueError(
40-
f"`tables.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`."
41-
)
42-
if instance_key is not None:
43-
if table.obs[instance_key].isnull().values.any():
44-
raise ValueError("`tables.obs[instance_key]` must not contain null values, but it does.")
23+
class RasterFormatV01(SpatialDataFormatV01):
24+
"""Formatter for raster data."""
4525

4626
def generate_coordinate_transformations(self, shapes: list[tuple[Any]]) -> Optional[list[list[dict[str, Any]]]]:
4727
data_shape = shapes[0]
@@ -114,9 +94,13 @@ def channels_from_metadata(self, omero_metadata: dict[str, Any]) -> list[Any]:
11494
return [d["labels"] for d in omero_metadata["channels"]]
11595

11696

117-
class ShapesFormat(SpatialDataFormatV01):
97+
class ShapesFormatV01(SpatialDataFormatV01):
11898
"""Formatter for shapes."""
11999

100+
@property
101+
def version(self) -> str:
102+
return "0.1"
103+
120104
def attrs_from_dict(self, metadata: dict[str, Any]) -> GeometryType:
121105
if Shapes_s.ATTRS_KEY not in metadata:
122106
raise KeyError(f"Missing key {Shapes_s.ATTRS_KEY} in shapes metadata.")
@@ -129,21 +113,25 @@ def attrs_from_dict(self, metadata: dict[str, Any]) -> GeometryType:
129113

130114
typ = GeometryType(metadata_[Shapes_s.GEOS_KEY][Shapes_s.TYPE_KEY])
131115
assert typ.name == metadata_[Shapes_s.GEOS_KEY][Shapes_s.NAME_KEY]
132-
assert self.spatialdata_version == metadata_["version"]
116+
assert self.version == metadata_["version"]
133117
return typ
134118

135119
def attrs_to_dict(self, geometry: GeometryType) -> dict[str, Union[str, dict[str, Any]]]:
136120
return {Shapes_s.GEOS_KEY: {Shapes_s.NAME_KEY: geometry.name, Shapes_s.TYPE_KEY: geometry.value}}
137121

138122

139-
class PointsFormat(SpatialDataFormatV01):
123+
class PointsFormatV01(SpatialDataFormatV01):
140124
"""Formatter for points."""
141125

126+
@property
127+
def version(self) -> str:
128+
return "0.1"
129+
142130
def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, dict[str, Any]]:
143131
if Points_s.ATTRS_KEY not in metadata:
144132
raise KeyError(f"Missing key {Points_s.ATTRS_KEY} in points metadata.")
145133
metadata_ = metadata[Points_s.ATTRS_KEY]
146-
assert self.spatialdata_version == metadata_["version"]
134+
assert self.version == metadata_["version"]
147135
d = {}
148136
if Points_s.FEATURE_KEY in metadata_:
149137
d[Points_s.FEATURE_KEY] = metadata_[Points_s.FEATURE_KEY]
@@ -159,3 +147,34 @@ def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]:
159147
if Points_s.FEATURE_KEY in data[Points_s.ATTRS_KEY]:
160148
d[Points_s.FEATURE_KEY] = data[Points_s.ATTRS_KEY][Points_s.FEATURE_KEY]
161149
return d
150+
151+
152+
class TablesFormatV01(SpatialDataFormatV01):
153+
"""Formatter for tables."""
154+
155+
@property
156+
def version(self) -> str:
157+
return "0.1"
158+
159+
def validate_table(
160+
self,
161+
table: AnnData,
162+
region_key: Optional[str] = None,
163+
instance_key: Optional[str] = None,
164+
) -> None:
165+
if not isinstance(table, AnnData):
166+
raise TypeError(f"`tables` must be `anndata.AnnData`, was {type(table)}.")
167+
if region_key is not None:
168+
if not is_categorical_dtype(table.obs[region_key]):
169+
raise ValueError(
170+
f"`tables.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`."
171+
)
172+
if instance_key is not None:
173+
if table.obs[instance_key].isnull().values.any():
174+
raise ValueError("`tables.obs[instance_key]` must not contain null values, but it does.")
175+
176+
177+
CurrentRasterFormat = RasterFormatV01
178+
CurrentShapesFormat = ShapesFormatV01
179+
CurrentPointsFormat = PointsFormatV01
180+
CurrentTablesFormat = TablesFormatV01

spatialdata/_io/read.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from spatialdata._core.ngff.ngff_transformations import NgffBaseTransformation
2929
from spatialdata._core.transformations import BaseTransformation
3030
from spatialdata._io._utils import ome_zarr_logger
31-
from spatialdata._io.format import PointsFormat, ShapesFormat, SpatialDataFormatV01
31+
from spatialdata._io.format import (
32+
CurrentPointsFormat,
33+
CurrentRasterFormat,
34+
CurrentShapesFormat,
35+
SpatialDataFormatV01,
36+
)
3237

3338

3439
def read_zarr(store: Union[str, Path, zarr.Group]) -> SpatialData:
@@ -122,7 +127,7 @@ def _get_transformations_from_ngff_dict(
122127

123128

124129
def _read_multiscale(
125-
store: str, raster_type: Literal["image", "labels"], fmt: SpatialDataFormatV01 = SpatialDataFormatV01()
130+
store: str, raster_type: Literal["image", "labels"], fmt: SpatialDataFormatV01 = CurrentRasterFormat()
126131
) -> Union[SpatialImage, MultiscaleSpatialImage]:
127132
assert isinstance(store, str)
128133
assert raster_type in ["image", "labels"]
@@ -159,7 +164,7 @@ def _read_multiscale(
159164
# if image, read channels metadata
160165
if raster_type == "image":
161166
omero = multiscales[0]["omero"]
162-
channels = fmt.channels_from_metadata(omero)
167+
channels: list[Any] = fmt.channels_from_metadata(omero)
163168
axes = [i["name"] for i in node.metadata["axes"]]
164169
if len(datasets) > 1:
165170
multiscale_image = {}
@@ -188,7 +193,7 @@ def _read_multiscale(
188193
return compute_coordinates(si)
189194

190195

191-
def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = ShapesFormat()) -> GeoDataFrame: # type: ignore[type-arg]
196+
def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = CurrentShapesFormat()) -> GeoDataFrame: # type: ignore[type-arg]
192197
"""Read shapes from a zarr store."""
193198
assert isinstance(store, str)
194199
f = zarr.open(store, mode="r")
@@ -212,7 +217,7 @@ def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: Spati
212217

213218

214219
def _read_points(
215-
store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = PointsFormat() # type: ignore[type-arg]
220+
store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = CurrentPointsFormat() # type: ignore[type-arg]
216221
) -> DaskDataFrame:
217222
"""Read points from a zarr store."""
218223
assert isinstance(store, str)

spatialdata/_io/write.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
)
2929
from spatialdata._core.models import ShapesModel
3030
from spatialdata._core.transformations import _get_current_output_axes
31-
from spatialdata._io.format import PointsFormat, ShapesFormat, SpatialDataFormatV01
31+
from spatialdata._io.format import (
32+
CurrentPointsFormat,
33+
CurrentRasterFormat,
34+
CurrentShapesFormat,
35+
CurrentTablesFormat,
36+
)
3237

3338
__all__ = [
3439
"write_image",
@@ -89,15 +94,14 @@ def overwrite_coordinate_transformations_raster(
8994
def _write_metadata(
9095
group: zarr.Group,
9196
group_type: str,
92-
# coordinate_transformations: list[dict[str, Any]],
97+
fmt: Format,
9398
axes: Optional[Union[str, list[str], list[dict[str, str]]]] = None,
9499
attrs: Optional[Mapping[str, Any]] = None,
95-
fmt: Format = SpatialDataFormatV01(),
96100
) -> None:
97101
"""Write metdata to a group."""
98102
axes = _get_valid_axes(axes=axes, fmt=fmt)
99103

100-
group.attrs["@type"] = group_type
104+
group.attrs["encoding-type"] = group_type
101105
group.attrs["axes"] = axes
102106
# we write empty coordinateTransformations and then overwrite them with overwrite_coordinate_transformations_non_raster()
103107
group.attrs["coordinateTransformations"] = []
@@ -110,7 +114,7 @@ def _write_raster(
110114
raster_data: Union[SpatialImage, MultiscaleSpatialImage],
111115
group: zarr.Group,
112116
name: str,
113-
fmt: Format = SpatialDataFormatV01(),
117+
fmt: Format = CurrentRasterFormat(),
114118
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
115119
label_metadata: Optional[JSONDict] = None,
116120
channels_metadata: Optional[JSONDict] = None,
@@ -212,7 +216,7 @@ def write_image(
212216
image: Union[SpatialImage, MultiscaleSpatialImage],
213217
group: zarr.Group,
214218
name: str,
215-
fmt: Format = SpatialDataFormatV01(),
219+
fmt: Format = CurrentRasterFormat(),
216220
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
217221
**metadata: Union[str, JSONDict, list[JSONDict]],
218222
) -> None:
@@ -231,7 +235,7 @@ def write_labels(
231235
labels: Union[SpatialImage, MultiscaleSpatialImage],
232236
group: zarr.Group,
233237
name: str,
234-
fmt: Format = SpatialDataFormatV01(),
238+
fmt: Format = CurrentRasterFormat(),
235239
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
236240
label_metadata: Optional[JSONDict] = None,
237241
**metadata: JSONDict,
@@ -253,7 +257,7 @@ def write_shapes(
253257
group: zarr.Group,
254258
name: str,
255259
group_type: str = "ngff:shapes",
256-
fmt: Format = ShapesFormat(),
260+
fmt: Format = CurrentShapesFormat(),
257261
) -> None:
258262
axes = get_dims(shapes)
259263
t = _get_transformations(shapes)
@@ -268,12 +272,11 @@ def write_shapes(
268272
shapes_group.create_dataset(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values)
269273

270274
attrs = fmt.attrs_to_dict(geometry)
271-
attrs["version"] = fmt.spatialdata_version
275+
attrs["version"] = fmt.version
272276

273277
_write_metadata(
274278
shapes_group,
275279
group_type=group_type,
276-
# coordinate_transformations=coordinate_transformations,
277280
axes=list(axes),
278281
attrs=attrs,
279282
fmt=fmt,
@@ -287,7 +290,7 @@ def write_points(
287290
group: zarr.Group,
288291
name: str,
289292
group_type: str = "ngff:points",
290-
fmt: Format = PointsFormat(),
293+
fmt: Format = CurrentPointsFormat(),
291294
) -> None:
292295
axes = get_dims(points)
293296
t = _get_transformations(points)
@@ -297,12 +300,11 @@ def write_points(
297300
points.to_parquet(path)
298301

299302
attrs = fmt.attrs_to_dict(points.attrs)
300-
attrs["version"] = fmt.spatialdata_version
303+
attrs["version"] = fmt.version
301304

302305
_write_metadata(
303306
points_groups,
304307
group_type=group_type,
305-
# coordinate_transformations=coordinate_transformations,
306308
axes=list(axes),
307309
attrs=attrs,
308310
fmt=fmt,
@@ -316,19 +318,19 @@ def write_table(
316318
group: zarr.Group,
317319
name: str,
318320
group_type: str = "ngff:regions_table",
319-
fmt: Format = SpatialDataFormatV01(),
321+
fmt: Format = CurrentTablesFormat(),
320322
) -> None:
321323
region = table.uns["spatialdata_attrs"]["region"]
322324
region_key = table.uns["spatialdata_attrs"].get("region_key", None)
323325
instance_key = table.uns["spatialdata_attrs"].get("instance_key", None)
324326
fmt.validate_table(table, region_key, instance_key)
325327
write_adata(group, name, table) # creates group[name]
326328
tables_group = group[name]
327-
tables_group.attrs["@type"] = group_type
329+
tables_group.attrs["spatialdata-encoding-type"] = group_type
328330
tables_group.attrs["region"] = region
329331
tables_group.attrs["region_key"] = region_key
330332
tables_group.attrs["instance_key"] = instance_key
331-
tables_group.attrs["version"] = fmt.spatialdata_version
333+
tables_group.attrs["version"] = fmt.version
332334

333335

334336
def _iter_multiscale(

tests/_io/test_format.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from typing import Any, Optional
22

33
import pytest
4+
from shapely import GeometryType
45

5-
from spatialdata._core.models import PointsModel
6-
from spatialdata._io.format import PointsFormat
6+
from spatialdata._core.models import PointsModel, ShapesModel
7+
from spatialdata._io.format import CurrentPointsFormat, CurrentShapesFormat
78

8-
Points_f = PointsFormat()
9+
Points_f = CurrentPointsFormat()
10+
Shapes_f = CurrentShapesFormat()
911

1012

1113
class TestFormat:
@@ -20,7 +22,7 @@ def test_format_points(
2022
feature_key: Optional[str],
2123
instance_key: Optional[str],
2224
) -> None:
23-
metadata: dict[str, Any] = {attrs_key: {"version": Points_f.spatialdata_version}}
25+
metadata: dict[str, Any] = {attrs_key: {"version": Points_f.version}}
2426
format_metadata: dict[str, Any] = {attrs_key: {}}
2527
if feature_key is not None:
2628
metadata[attrs_key][feature_key] = "target"
@@ -31,3 +33,31 @@ def test_format_points(
3133
assert metadata[attrs_key] == Points_f.attrs_to_dict(format_metadata)
3234
if feature_key is None and instance_key is None:
3335
assert len(format_metadata[attrs_key]) == len(metadata[attrs_key]) == 0
36+
37+
@pytest.mark.parametrize("attrs_key", [ShapesModel.ATTRS_KEY])
38+
@pytest.mark.parametrize("geos_key", [ShapesModel.GEOS_KEY])
39+
@pytest.mark.parametrize("type_key", [ShapesModel.TYPE_KEY])
40+
@pytest.mark.parametrize("name_key", [ShapesModel.NAME_KEY])
41+
@pytest.mark.parametrize("shapes_type", [0, 3, 6])
42+
def test_format_shapes(
43+
self,
44+
attrs_key: str,
45+
geos_key: str,
46+
type_key: str,
47+
name_key: str,
48+
shapes_type: int,
49+
) -> None:
50+
shapes_dict = {
51+
0: "POINT",
52+
3: "POLYGON",
53+
6: "MULTIPOLYGON",
54+
}
55+
metadata: dict[str, Any] = {attrs_key: {"version": Shapes_f.version}}
56+
format_metadata: dict[str, Any] = {attrs_key: {}}
57+
metadata[attrs_key][geos_key] = {}
58+
metadata[attrs_key][geos_key][type_key] = shapes_type
59+
metadata[attrs_key][geos_key][name_key] = shapes_dict[shapes_type]
60+
format_metadata[attrs_key] = Shapes_f.attrs_from_dict(metadata)
61+
metadata[attrs_key].pop("version")
62+
geometry = GeometryType(metadata[attrs_key][geos_key][type_key])
63+
assert metadata[attrs_key] == Shapes_f.attrs_to_dict(geometry)

0 commit comments

Comments
 (0)