Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
line_length = 120

[Makefile]
indent_style = tab
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Can't yet be moved to the pyproject.toml due to https://github.com/PyCQA/flake8/issues/234
[flake8]
max-line-length = 88
max-line-length = 120
ignore =
# line break before a binary operator -> black does not adhere to PEP8
W503
Expand Down
30 changes: 19 additions & 11 deletions spatialdata/_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def parse(cls, data: Any, **kwargs: Any) -> GeoDataFrame:
In the case of (Multi)`Polygons` shapes, the offsets of the polygons must be provided.
radius
Array of size of the `Circles`. It must be provided if the shapes are `Circles`.
index
Index of the shapes, must be of type `str`. If None, it's generated automatically.
transform
Transform of points.
kwargs
Expand All @@ -376,6 +378,7 @@ def _(
geometry: Literal[0, 3, 6], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON]
offsets: Optional[tuple[ArrayLike, ...]] = None,
radius: Optional[ArrayLike] = None,
index: Optional[ArrayLike] = None,
transformations: Optional[MappingToCoordinateSystem_t] = None,
) -> GeoDataFrame:
geometry = GeometryType(geometry)
Expand All @@ -385,6 +388,8 @@ def _(
if radius is None:
raise ValueError("If `geometry` is `Circles`, `radius` must be provided.")
geo_df[cls.RADIUS_KEY] = radius
if index is not None:
geo_df.index = index
_parse_transformations(geo_df, transformations)
cls.validate(geo_df)
return geo_df
Expand All @@ -396,6 +401,7 @@ def _(
cls,
data: Union[str, Path],
radius: Optional[ArrayLike] = None,
index: Optional[ArrayLike] = None,
transformations: Optional[Any] = None,
**kwargs: Any,
) -> GeoDataFrame:
Expand All @@ -411,6 +417,8 @@ def _(
if radius is None:
raise ValueError("If `geometry` is `Circles`, `radius` must be provided.")
geo_df[cls.RADIUS_KEY] = radius
if index is not None:
geo_df.index = index
_parse_transformations(geo_df, transformations)
cls.validate(geo_df)
return geo_df
Expand Down Expand Up @@ -457,17 +465,6 @@ def validate(cls, data: DaskDataFrame) -> None:
logger.info(
f"Instance key `{instance_key}` could be of type `pd.Categorical`. Consider casting it."
)
# commented out to address this issue: https://github.com/scverse/spatialdata/issues/140
# for c in data.columns:
# # this is not strictly a validation since we are explicitly importing the categories
# # but it is a convenient way to ensure that the categories are known. It also just changes the state of the
# # series, so it is not a big deal.
# if is_categorical_dtype(data[c]):
# if not data[c].cat.known:
# try:
# data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
# except ValueError:
# logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")

@singledispatchmethod
@classmethod
Expand Down Expand Up @@ -593,6 +590,17 @@ def _add_metadata_and_validate(
assert instance_key in data.columns
data.attrs[cls.ATTRS_KEY][cls.INSTANCE_KEY] = instance_key

for c in data.columns:
# Here we are explicitly importing the categories
# but it is a convenient way to ensure that the categories are known.
# It also just changes the state of the series, so it is not a big deal.
if is_categorical_dtype(data[c]):
if not data[c].cat.known:
try:
data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
except ValueError:
logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")

_parse_transformations(data, transformations)
cls.validate(data)
# false positive with the PyCharm mypy plugin
Expand Down
73 changes: 46 additions & 27 deletions spatialdata/_io/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,11 @@


class SpatialDataFormatV01(CurrentFormat):
"""
SpatialDataFormat defines the format of the spatialdata
package.
"""
"""SpatialDataFormat defines the format of the spatialdata package."""

@property
def spatialdata_version(self) -> str:
return "0.1"

def validate_table(
self,
table: AnnData,
region_key: Optional[str] = None,
instance_key: Optional[str] = None,
) -> None:
if not isinstance(table, AnnData):
raise TypeError(f"`tables` must be `anndata.AnnData`, was {type(table)}.")
if region_key is not None:
if not is_categorical_dtype(table.obs[region_key]):
raise ValueError(
f"`tables.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`."
)
if instance_key is not None:
if table.obs[instance_key].isnull().values.any():
raise ValueError("`tables.obs[instance_key]` must not contain null values, but it does.")
class RasterFormatV01(SpatialDataFormatV01):
"""Formatter for raster data."""

def generate_coordinate_transformations(self, shapes: list[tuple[Any]]) -> Optional[list[list[dict[str, Any]]]]:
data_shape = shapes[0]
Expand Down Expand Up @@ -114,9 +94,13 @@ def channels_from_metadata(self, omero_metadata: dict[str, Any]) -> list[Any]:
return [d["labels"] for d in omero_metadata["channels"]]


class ShapesFormat(SpatialDataFormatV01):
class ShapesFormatV01(SpatialDataFormatV01):
"""Formatter for shapes."""

@property
def version(self) -> str:
return "0.1"

def attrs_from_dict(self, metadata: dict[str, Any]) -> GeometryType:
if Shapes_s.ATTRS_KEY not in metadata:
raise KeyError(f"Missing key {Shapes_s.ATTRS_KEY} in shapes metadata.")
Expand All @@ -129,21 +113,25 @@ def attrs_from_dict(self, metadata: dict[str, Any]) -> GeometryType:

typ = GeometryType(metadata_[Shapes_s.GEOS_KEY][Shapes_s.TYPE_KEY])
assert typ.name == metadata_[Shapes_s.GEOS_KEY][Shapes_s.NAME_KEY]
assert self.spatialdata_version == metadata_["version"]
assert self.version == metadata_["version"]
return typ

def attrs_to_dict(self, geometry: GeometryType) -> dict[str, Union[str, dict[str, Any]]]:
return {Shapes_s.GEOS_KEY: {Shapes_s.NAME_KEY: geometry.name, Shapes_s.TYPE_KEY: geometry.value}}


class PointsFormat(SpatialDataFormatV01):
class PointsFormatV01(SpatialDataFormatV01):
"""Formatter for points."""

@property
def version(self) -> str:
return "0.1"

def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, dict[str, Any]]:
if Points_s.ATTRS_KEY not in metadata:
raise KeyError(f"Missing key {Points_s.ATTRS_KEY} in points metadata.")
metadata_ = metadata[Points_s.ATTRS_KEY]
assert self.spatialdata_version == metadata_["version"]
assert self.version == metadata_["version"]
d = {}
if Points_s.FEATURE_KEY in metadata_:
d[Points_s.FEATURE_KEY] = metadata_[Points_s.FEATURE_KEY]
Expand All @@ -159,3 +147,34 @@ def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]:
if Points_s.FEATURE_KEY in data[Points_s.ATTRS_KEY]:
d[Points_s.FEATURE_KEY] = data[Points_s.ATTRS_KEY][Points_s.FEATURE_KEY]
return d


class TablesFormatV01(SpatialDataFormatV01):
"""Formatter for tables."""

@property
def version(self) -> str:
return "0.1"

def validate_table(
self,
table: AnnData,
region_key: Optional[str] = None,
instance_key: Optional[str] = None,
) -> None:
if not isinstance(table, AnnData):
raise TypeError(f"`tables` must be `anndata.AnnData`, was {type(table)}.")
if region_key is not None:
if not is_categorical_dtype(table.obs[region_key]):
raise ValueError(
f"`tables.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`."
)
if instance_key is not None:
if table.obs[instance_key].isnull().values.any():
raise ValueError("`tables.obs[instance_key]` must not contain null values, but it does.")


CurrentRasterFormat = RasterFormatV01
CurrentShapesFormat = ShapesFormatV01
CurrentPointsFormat = PointsFormatV01
CurrentTablesFormat = TablesFormatV01
15 changes: 10 additions & 5 deletions spatialdata/_io/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
from spatialdata._core.ngff.ngff_transformations import NgffBaseTransformation
from spatialdata._core.transformations import BaseTransformation
from spatialdata._io._utils import ome_zarr_logger
from spatialdata._io.format import PointsFormat, ShapesFormat, SpatialDataFormatV01
from spatialdata._io.format import (
CurrentPointsFormat,
CurrentRasterFormat,
CurrentShapesFormat,
SpatialDataFormatV01,
)


def read_zarr(store: Union[str, Path, zarr.Group]) -> SpatialData:
Expand Down Expand Up @@ -122,7 +127,7 @@ def _get_transformations_from_ngff_dict(


def _read_multiscale(
store: str, raster_type: Literal["image", "labels"], fmt: SpatialDataFormatV01 = SpatialDataFormatV01()
store: str, raster_type: Literal["image", "labels"], fmt: SpatialDataFormatV01 = CurrentRasterFormat()
) -> Union[SpatialImage, MultiscaleSpatialImage]:
assert isinstance(store, str)
assert raster_type in ["image", "labels"]
Expand Down Expand Up @@ -159,7 +164,7 @@ def _read_multiscale(
# if image, read channels metadata
if raster_type == "image":
omero = multiscales[0]["omero"]
channels = fmt.channels_from_metadata(omero)
channels: list[Any] = fmt.channels_from_metadata(omero)
axes = [i["name"] for i in node.metadata["axes"]]
if len(datasets) > 1:
multiscale_image = {}
Expand Down Expand Up @@ -188,7 +193,7 @@ def _read_multiscale(
return compute_coordinates(si)


def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = ShapesFormat()) -> GeoDataFrame: # type: ignore[type-arg]
def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = CurrentShapesFormat()) -> GeoDataFrame: # type: ignore[type-arg]
"""Read shapes from a zarr store."""
assert isinstance(store, str)
f = zarr.open(store, mode="r")
Expand All @@ -212,7 +217,7 @@ def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: Spati


def _read_points(
store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = PointsFormat() # type: ignore[type-arg]
store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = CurrentPointsFormat() # type: ignore[type-arg]
) -> DaskDataFrame:
"""Read points from a zarr store."""
assert isinstance(store, str)
Expand Down
36 changes: 20 additions & 16 deletions spatialdata/_io/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
)
from spatialdata._core.models import ShapesModel
from spatialdata._core.transformations import _get_current_output_axes
from spatialdata._io.format import PointsFormat, ShapesFormat, SpatialDataFormatV01
from spatialdata._io.format import (
CurrentPointsFormat,
CurrentRasterFormat,
CurrentShapesFormat,
CurrentTablesFormat,
)

__all__ = [
"write_image",
Expand Down Expand Up @@ -89,15 +94,14 @@ def overwrite_coordinate_transformations_raster(
def _write_metadata(
group: zarr.Group,
group_type: str,
# coordinate_transformations: list[dict[str, Any]],
fmt: Format,
axes: Optional[Union[str, list[str], list[dict[str, str]]]] = None,
attrs: Optional[Mapping[str, Any]] = None,
fmt: Format = SpatialDataFormatV01(),
) -> None:
"""Write metdata to a group."""
axes = _get_valid_axes(axes=axes, fmt=fmt)

group.attrs["@type"] = group_type
group.attrs["encoding-type"] = group_type
group.attrs["axes"] = axes
# we write empty coordinateTransformations and then overwrite them with overwrite_coordinate_transformations_non_raster()
group.attrs["coordinateTransformations"] = []
Expand All @@ -110,7 +114,7 @@ def _write_raster(
raster_data: Union[SpatialImage, MultiscaleSpatialImage],
group: zarr.Group,
name: str,
fmt: Format = SpatialDataFormatV01(),
fmt: Format = CurrentRasterFormat(),
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
label_metadata: Optional[JSONDict] = None,
channels_metadata: Optional[JSONDict] = None,
Expand Down Expand Up @@ -212,7 +216,7 @@ def write_image(
image: Union[SpatialImage, MultiscaleSpatialImage],
group: zarr.Group,
name: str,
fmt: Format = SpatialDataFormatV01(),
fmt: Format = CurrentRasterFormat(),
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
**metadata: Union[str, JSONDict, list[JSONDict]],
) -> None:
Expand All @@ -231,7 +235,7 @@ def write_labels(
labels: Union[SpatialImage, MultiscaleSpatialImage],
group: zarr.Group,
name: str,
fmt: Format = SpatialDataFormatV01(),
fmt: Format = CurrentRasterFormat(),
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
label_metadata: Optional[JSONDict] = None,
**metadata: JSONDict,
Expand All @@ -253,7 +257,7 @@ def write_shapes(
group: zarr.Group,
name: str,
group_type: str = "ngff:shapes",
fmt: Format = ShapesFormat(),
fmt: Format = CurrentShapesFormat(),
) -> None:
axes = get_dims(shapes)
t = _get_transformations(shapes)
Expand All @@ -263,17 +267,18 @@ def write_shapes(
shapes_group.create_dataset(name="coords", data=coords)
for i, o in enumerate(offsets):
shapes_group.create_dataset(name=f"offset{i}", data=o)
# index cannot be string
# https://github.com/zarr-developers/zarr-python/issues/1090
shapes_group.create_dataset(name="Index", data=shapes.index.values)
if geometry.name == "POINT":
shapes_group.create_dataset(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values)

attrs = fmt.attrs_to_dict(geometry)
attrs["version"] = fmt.spatialdata_version
attrs["version"] = fmt.version

_write_metadata(
shapes_group,
group_type=group_type,
# coordinate_transformations=coordinate_transformations,
axes=list(axes),
attrs=attrs,
fmt=fmt,
Expand All @@ -287,7 +292,7 @@ def write_points(
group: zarr.Group,
name: str,
group_type: str = "ngff:points",
fmt: Format = PointsFormat(),
fmt: Format = CurrentPointsFormat(),
) -> None:
axes = get_dims(points)
t = _get_transformations(points)
Expand All @@ -297,12 +302,11 @@ def write_points(
points.to_parquet(path)

attrs = fmt.attrs_to_dict(points.attrs)
attrs["version"] = fmt.spatialdata_version
attrs["version"] = fmt.version

_write_metadata(
points_groups,
group_type=group_type,
# coordinate_transformations=coordinate_transformations,
axes=list(axes),
attrs=attrs,
fmt=fmt,
Expand All @@ -316,19 +320,19 @@ def write_table(
group: zarr.Group,
name: str,
group_type: str = "ngff:regions_table",
fmt: Format = SpatialDataFormatV01(),
fmt: Format = CurrentTablesFormat(),
) -> None:
region = table.uns["spatialdata_attrs"]["region"]
region_key = table.uns["spatialdata_attrs"].get("region_key", None)
instance_key = table.uns["spatialdata_attrs"].get("instance_key", None)
fmt.validate_table(table, region_key, instance_key)
write_adata(group, name, table) # creates group[name]
tables_group = group[name]
tables_group.attrs["@type"] = group_type
tables_group.attrs["spatialdata-encoding-type"] = group_type
tables_group.attrs["region"] = region
tables_group.attrs["region_key"] = region_key
tables_group.attrs["instance_key"] = instance_key
tables_group.attrs["version"] = fmt.spatialdata_version
tables_group.attrs["version"] = fmt.version


def _iter_multiscale(
Expand Down
Loading