Skip to content

Commit 4fb3433

Browse files
Merge pull request #157 from scverse/io/fixes
minor fix for points model parser
2 parents 8724f8a + b30886e commit 4fb3433

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

spatialdata/_core/models.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ def parse(cls, data: Any, **kwargs: Any) -> GeoDataFrame:
357357
In the case of (Multi)`Polygons` shapes, the offsets of the polygons must be provided.
358358
radius
359359
Array of size of the `Circles`. It must be provided if the shapes are `Circles`.
360+
index
361+
Index of the shapes, must be of type `str`. If None, it's generated automatically.
360362
transform
361363
Transform of points.
362364
kwargs
@@ -376,6 +378,7 @@ def _(
376378
geometry: Literal[0, 3, 6], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON]
377379
offsets: Optional[tuple[ArrayLike, ...]] = None,
378380
radius: Optional[ArrayLike] = None,
381+
index: Optional[ArrayLike] = None,
379382
transformations: Optional[MappingToCoordinateSystem_t] = None,
380383
) -> GeoDataFrame:
381384
geometry = GeometryType(geometry)
@@ -385,6 +388,8 @@ def _(
385388
if radius is None:
386389
raise ValueError("If `geometry` is `Circles`, `radius` must be provided.")
387390
geo_df[cls.RADIUS_KEY] = radius
391+
if index is not None:
392+
geo_df.index = index
388393
_parse_transformations(geo_df, transformations)
389394
cls.validate(geo_df)
390395
return geo_df
@@ -396,6 +401,7 @@ def _(
396401
cls,
397402
data: Union[str, Path],
398403
radius: Optional[ArrayLike] = None,
404+
index: Optional[ArrayLike] = None,
399405
transformations: Optional[Any] = None,
400406
**kwargs: Any,
401407
) -> GeoDataFrame:
@@ -411,6 +417,8 @@ def _(
411417
if radius is None:
412418
raise ValueError("If `geometry` is `Circles`, `radius` must be provided.")
413419
geo_df[cls.RADIUS_KEY] = radius
420+
if index is not None:
421+
geo_df.index = index
414422
_parse_transformations(geo_df, transformations)
415423
cls.validate(geo_df)
416424
return geo_df
@@ -457,17 +465,6 @@ def validate(cls, data: DaskDataFrame) -> None:
457465
logger.info(
458466
f"Instance key `{instance_key}` could be of type `pd.Categorical`. Consider casting it."
459467
)
460-
# commented out to address this issue: https://github.com/scverse/spatialdata/issues/140
461-
# for c in data.columns:
462-
# # this is not strictly a validation since we are explicitly importing the categories
463-
# # but it is a convenient way to ensure that the categories are known. It also just changes the state of the
464-
# # series, so it is not a big deal.
465-
# if is_categorical_dtype(data[c]):
466-
# if not data[c].cat.known:
467-
# try:
468-
# data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
469-
# except ValueError:
470-
# logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")
471468

472469
@singledispatchmethod
473470
@classmethod
@@ -593,6 +590,17 @@ def _add_metadata_and_validate(
593590
assert instance_key in data.columns
594591
data.attrs[cls.ATTRS_KEY][cls.INSTANCE_KEY] = instance_key
595592

593+
for c in data.columns:
594+
# Here we are explicitly importing the categories
595+
# but it is a convenient way to ensure that the categories are known.
596+
# It also just changes the state of the series, so it is not a big deal.
597+
if is_categorical_dtype(data[c]):
598+
if not data[c].cat.known:
599+
try:
600+
data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
601+
except ValueError:
602+
logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")
603+
596604
_parse_transformations(data, transformations)
597605
cls.validate(data)
598606
# false positive with the PyCharm mypy plugin

spatialdata/_io/write.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def write_shapes(
267267
shapes_group.create_dataset(name="coords", data=coords)
268268
for i, o in enumerate(offsets):
269269
shapes_group.create_dataset(name=f"offset{i}", data=o)
270+
# index cannot be string
271+
# https://github.com/zarr-developers/zarr-python/issues/1090
270272
shapes_group.create_dataset(name="Index", data=shapes.index.values)
271273
if geometry.name == "POINT":
272274
shapes_group.create_dataset(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values)

0 commit comments

Comments
 (0)