diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c1c46add..0548fbff 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -7,7 +7,7 @@ default_stages:
minimum_pre_commit_version: 2.16.0
repos:
- repo: https://github.com/psf/black
- rev: 22.12.0
+ rev: 23.1.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-prettier
@@ -19,11 +19,11 @@ repos:
hooks:
- id: blacken-docs
- repo: https://github.com/PyCQA/isort
- rev: 5.11.4
+ rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v0.991
+ rev: v1.0.1
hooks:
- id: mypy
additional_dependencies: [numpy==1.24.0]
@@ -50,7 +50,7 @@ repos:
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/PyCQA/autoflake
- rev: v2.0.0
+ rev: v2.0.1
hooks:
- id: autoflake
args:
diff --git a/pyproject.toml b/pyproject.toml
index 6aa35ef7..aad9280f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,7 @@ dependencies = [
"joblib",
"imagecodecs",
"dask-image",
+ "pyarrow",
]
[project.optional-dependencies]
diff --git a/src/spatialdata_io/__init__.py b/src/spatialdata_io/__init__.py
index decc2e82..297c4435 100644
--- a/src/spatialdata_io/__init__.py
+++ b/src/spatialdata_io/__init__.py
@@ -1,6 +1,8 @@
from importlib.metadata import version
from spatialdata_io.readers.cosmx import cosmx
+from spatialdata_io.readers.mcmicro import mcmicro
+from spatialdata_io.readers.steinbock import steinbock
from spatialdata_io.readers.visium import visium
from spatialdata_io.readers.xenium import xenium
@@ -8,6 +10,8 @@
"visium",
"xenium",
"cosmx",
+ "mcmicro",
+ "steinbock",
]
__version__ = version("spatialdata-io")
diff --git a/src/spatialdata_io/_constants/_constants.py b/src/spatialdata_io/_constants/_constants.py
index 7992c2bf..aee085d5 100644
--- a/src/spatialdata_io/_constants/_constants.py
+++ b/src/spatialdata_io/_constants/_constants.py
@@ -57,7 +57,8 @@ class XeniumKeys(ModeEnum):
CELL_METADATA_FILE = "cells.parquet"
CELL_X = "x_centroid"
CELL_Y = "y_centroid"
- CELL_AREA = 'cell_area'
+ CELL_AREA = "cell_area"
+ CELL_NUCLEUS_AREA = "nucleus_area"
# morphology iamges
MORPHOLOGY_MIP_FILE = "morphology_mip.ome.tif"
@@ -85,3 +86,37 @@ class VisiumKeys(ModeEnum):
SPOTS_FILE = "spatial/tissue_positions.csv"
SPOTS_X = "pxl_row_in_fullres"
SPOTS_Y = "pxl_col_in_fullres"
+
+
+@unique
+class SteinbockKeys(ModeEnum):
+ """Keys for *Steinbock* formatted dataset."""
+
+ # files and directories
+ CELLS_FILE = "cells.h5ad"
+ DEEPCELL_MASKS_DIR = "masks_deepcell"
+ ILASTIK_MASKS_DIR = "masks_ilastik"
+ IMAGES_DIR = "ome"
+
+ # suffixes for images and labels
+ IMAGE_SUFFIX = ".ome.tiff"
+ LABEL_SUFFIX = ".tiff"
+
+
+@unique
+class McmicroKeys(ModeEnum):
+ """Keys for *Mcmicro* formatted dataset."""
+
+ # files and directories
+ CELL_FEATURES_SUFFIX = "--unmicst_cell.csv"
+ QUANTIFICATION_DIR = "quantification"
+ MARKERS_FILE = "markers.csv"
+ IMAGES_DIR = "registration"
+ IMAGE_SUFFIX = ".ome.tif"
+ LABELS_DIR = "segmentation"
+ LABELS_PREFIX = "unmicst-"
+
+ # metadata
+ COORDS_X = "X_centroid"
+ COORDS_Y = "Y_centroid"
+ INSTANCE_KEY = "CellID"
diff --git a/src/spatialdata_io/readers/cosmx.py b/src/spatialdata_io/readers/cosmx.py
index 7083631f..3a7d6ca3 100644
--- a/src/spatialdata_io/readers/cosmx.py
+++ b/src/spatialdata_io/readers/cosmx.py
@@ -2,26 +2,29 @@
import os
import re
-import tempfile
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Any, Optional
-import pyarrow as pa
-import pyarrow.parquet as pq
+import dask.array as da
import numpy as np
import pandas as pd
+import pyarrow as pa
from anndata import AnnData
-from dask_image.imread import imread
-import dask.array as da
from dask.dataframe.core import DataFrame as DaskDataFrame
+from dask_image.imread import imread
from scipy.sparse import csr_matrix
# from spatialdata._core.core_utils import xy_cs
from skimage.transform import estimate_transform
from spatialdata import SpatialData
-from spatialdata._core.models import Image2DModel, Labels2DModel, TableModel, PointsModel
+from spatialdata._core.models import (
+ Image2DModel,
+ Labels2DModel,
+ PointsModel,
+ TableModel,
+)
# from spatialdata._core.ngff.ngff_coordinate_system import NgffAxis # , CoordinateSystem
from spatialdata._core.transformations import Affine, Identity
@@ -41,7 +44,6 @@
def cosmx(
path: str | Path,
dataset_id: Optional[str] = None,
- # shape_size: float | int = 1,
transcripts: bool = True,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
@@ -67,8 +69,8 @@ def cosmx(
Path to the root directory containing *Nanostring* files.
dataset_id
Name of the dataset.
- shape_size
- Size of the shape to be used for the centroids of the labels.
+ transcripts
+ Whether to also read in transcripts information.
imread_kwargs
Keyword arguments passed to :func:`dask_image.imread.imread`.
image_models_kwargs
@@ -118,7 +120,7 @@ def cosmx(
obs = pd.read_csv(path / meta_file, header=0, index_col=CosmxKeys.INSTANCE_KEY)
obs[CosmxKeys.FOV] = pd.Categorical(obs[CosmxKeys.FOV].astype(str))
- obs[CosmxKeys.REGION_KEY] = pd.Categorical(obs[CosmxKeys.FOV].astype(str).apply(lambda s: "/labels/" + s))
+ obs[CosmxKeys.REGION_KEY] = pd.Categorical(obs[CosmxKeys.FOV].astype(str).apply(lambda s: s + "_labels"))
obs[CosmxKeys.INSTANCE_KEY] = obs.index.astype(np.int64)
obs.rename_axis(None, inplace=True)
obs.index = obs.index.astype(str).str.cat(obs[CosmxKeys.FOV].values, sep="_")
@@ -141,12 +143,6 @@ def cosmx(
fovs_counts = list(map(str, adata.obs.fov.astype(int).unique()))
- # TODO(giovp): uncomment once transform is ready
- # input_cs = CoordinateSystem("cxy", axes=[c_axis, y_axis, x_axis])
- # input_cs_labels = CoordinateSystem("cxy", axes=[y_axis, x_axis])
- # output_cs = CoordinateSystem("global", axes=[c_axis, y_axis, x_axis])
- # output_cs_labels = CoordinateSystem("global", axes=[y_axis, x_axis])
-
affine_transforms_to_global = {}
for fov in fovs_counts:
@@ -163,7 +159,10 @@ def cosmx(
table.obsm["global"] = table.obs[[CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL]].to_numpy()
table.obsm["spatial"] = table.obs[[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL]].to_numpy()
- table.obs.drop(columns=[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL, CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL], inplace=True)
+ table.obs.drop(
+ columns=[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL, CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL],
+ inplace=True,
+ )
# prepare to read images and labels
file_extensions = (".jpg", ".png", ".jpeg", ".tif", ".tiff")
@@ -200,7 +199,6 @@ def cosmx(
flipped_im = da.flip(im, axis=0)
parsed_im = Image2DModel.parse(
flipped_im,
- name=fov,
transformations={
fov: Identity(),
"global": aff,
@@ -209,7 +207,7 @@ def cosmx(
dims=("y", "x", "c"),
**image_models_kwargs,
)
- images[fov] = parsed_im
+ images[f"{fov}_image"] = parsed_im
else:
logger.warning(f"FOV {fov} not found in counts file. Skipping image {fname}.")
@@ -224,7 +222,6 @@ def cosmx(
flipped_la = da.flip(la, axis=0)
parsed_la = Labels2DModel.parse(
flipped_la,
- name=fov,
transformations={
fov: Identity(),
"global": aff,
@@ -233,15 +230,40 @@ def cosmx(
dims=("y", "x"),
**image_models_kwargs,
)
- labels[fov] = parsed_la
+ labels[f"{fov}_labels"] = parsed_la
else:
logger.warning(f"FOV {fov} not found in counts file. Skipping labels {fname}.")
points: dict[str, DaskDataFrame] = {}
if transcripts:
+ # assert transcripts_file is not None
+ # from pyarrow.csv import read_csv
+ #
+ # ptable = read_csv(path / transcripts_file) # , header=0)
+ # for fov in fovs_counts:
+ # aff = affine_transforms_to_global[fov]
+ # sub_table = ptable.filter(pa.compute.equal(ptable.column(CosmxKeys.FOV), int(fov))).to_pandas()
+ # sub_table[CosmxKeys.INSTANCE_KEY] = sub_table[CosmxKeys.INSTANCE_KEY].astype("category")
+ # # we rename z because we want to treat the data as 2d
+ # sub_table.rename(columns={"z": "z_raw"}, inplace=True)
+ # points[fov] = PointsModel.parse(
+ # sub_table,
+ # coordinates={"x": CosmxKeys.X_LOCAL_TRANSCRIPT, "y": CosmxKeys.Y_LOCAL_TRANSCRIPT},
+ # feature_key=CosmxKeys.TARGET_OF_TRANSCRIPT,
+ # instance_key=CosmxKeys.INSTANCE_KEY,
+ # transformations={
+ # fov: Identity(),
+ # "global": aff,
+ # "global_only_labels": aff,
+ # },
+ # )
# let's convert the .csv to .parquet and let's read it with pyarrow.parquet for faster subsetting
+ import tempfile
+
+ import pyarrow.parquet as pq
+
with tempfile.TemporaryDirectory() as tmpdir:
- print("converting .csv to .parquet... ", end="")
+ print("converting .csv to .parquet to improve the speed of the slicing operations... ", end="")
assert transcripts_file is not None
transcripts_data = pd.read_csv(path / transcripts_file, header=0)
transcripts_data.to_parquet(Path(tmpdir) / "transcripts.parquet")
@@ -251,10 +273,10 @@ def cosmx(
for fov in fovs_counts:
aff = affine_transforms_to_global[fov]
sub_table = ptable.filter(pa.compute.equal(ptable.column(CosmxKeys.FOV), int(fov))).to_pandas()
- sub_table[CosmxKeys.INSTANCE_KEY] = sub_table[CosmxKeys.INSTANCE_KEY].astype('category')
+ sub_table[CosmxKeys.INSTANCE_KEY] = sub_table[CosmxKeys.INSTANCE_KEY].astype("category")
# we rename z because we want to treat the data as 2d
- sub_table.rename(columns={'z': 'z_raw'}, inplace=True)
- points[fov] = PointsModel.parse(
+ sub_table.rename(columns={"z": "z_raw"}, inplace=True)
+ points[f"{fov}_points"] = PointsModel.parse(
sub_table,
coordinates={"x": CosmxKeys.X_LOCAL_TRANSCRIPT, "y": CosmxKeys.Y_LOCAL_TRANSCRIPT},
feature_key=CosmxKeys.TARGET_OF_TRANSCRIPT,
@@ -266,7 +288,6 @@ def cosmx(
},
)
-
# TODO: what to do with fov file?
# if fov_file is not None:
# fov_positions = pd.read_csv(path / fov_file, header=0, index_col=CosmxKeys.FOV)
diff --git a/src/spatialdata_io/readers/mcmicro.py b/src/spatialdata_io/readers/mcmicro.py
new file mode 100644
index 00000000..98f7647b
--- /dev/null
+++ b/src/spatialdata_io/readers/mcmicro.py
@@ -0,0 +1,133 @@
+from __future__ import annotations
+
+import os
+from collections.abc import Mapping
+from pathlib import Path
+from types import MappingProxyType
+from typing import Any, Union
+
+import numpy as np
+import pandas as pd
+from anndata import AnnData
+from dask_image.imread import imread
+from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
+from spatial_image import SpatialImage
+from spatialdata import Image2DModel, Labels2DModel, SpatialData, TableModel
+
+from spatialdata_io._constants._constants import McmicroKeys
+
+__all__ = ["mcmicro"]
+
+
+def mcmicro(
+ path: str | Path,
+ dataset_id: str,
+ imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
+ image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
+) -> SpatialData:
+ """
+ Read a *Mcmicro* output into a SpatialData object.
+
+ .. seealso::
+
+ - `Mcmicro pipeline `_.
+
+ Parameters
+ ----------
+ path
+ Path to the dataset.
+ dataset_id
+ Dataset identifier.
+ imread_kwargs
+ Keyword arguments to pass to the image reader.
+ image_models_kwargs
+ Keyword arguments to pass to the image models.
+
+ Returns
+ -------
+ :class:`spatialdata.SpatialData`
+ """
+ path = Path(path)
+
+ samples = os.listdir(path / McmicroKeys.IMAGES_DIR)
+ if len(samples) > 1:
+ raise ValueError("Only one sample per dataset is supported.")
+ if (dataset_id + McmicroKeys.IMAGE_SUFFIX) not in samples:
+ raise ValueError("Dataset id is not consistent with sample name.")
+
+ images = {}
+ images[f"{dataset_id}_image"] = _get_images(
+ path,
+ dataset_id,
+ imread_kwargs,
+ image_models_kwargs,
+ )
+ labels = {}
+ labels[f"{dataset_id}_cells"] = _get_labels(
+ path,
+ dataset_id,
+ "cell",
+ imread_kwargs,
+ image_models_kwargs,
+ )
+ labels[f"{dataset_id}_nuclei"] = _get_labels(
+ path,
+ dataset_id,
+ "nuclei",
+ imread_kwargs,
+ image_models_kwargs,
+ )
+
+ table = _get_table(path, dataset_id)
+
+ return SpatialData(images=images, labels=labels, table=table)
+
+
+def _get_images(
+ path: Path,
+ sample: str,
+ imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
+ image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
+) -> Union[SpatialImage, MultiscaleSpatialImage]:
+ image = imread(path / McmicroKeys.IMAGES_DIR / f"{sample}{McmicroKeys.IMAGE_SUFFIX}", **imread_kwargs)
+ return Image2DModel.parse(image, **image_models_kwargs)
+
+
+def _get_labels(
+ path: Path,
+ sample: str,
+ labels_kind: str,
+ imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
+ image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
+) -> Union[SpatialImage, MultiscaleSpatialImage]:
+ image = imread(
+ path
+ / McmicroKeys.LABELS_DIR
+ / f"{McmicroKeys.LABELS_PREFIX}{sample}"
+ / f"{labels_kind}{McmicroKeys.IMAGE_SUFFIX}",
+ **imread_kwargs,
+ ).squeeze()
+ return Labels2DModel.parse(image, **image_models_kwargs)
+
+
+def _get_table(
+ path: Path,
+ sample: str,
+) -> AnnData:
+ table = pd.read_csv(path / McmicroKeys.QUANTIFICATION_DIR / f"{sample}{McmicroKeys.CELL_FEATURES_SUFFIX}")
+ markers = pd.read_csv(path / McmicroKeys.MARKERS_FILE)
+ markers.index = markers.marker_name
+ var = markers.marker_name.tolist()
+ coords = [McmicroKeys.COORDS_X.value, McmicroKeys.COORDS_Y.value]
+ adata = AnnData(
+ table[var].to_numpy(),
+ obs=table.drop(columns=var + coords),
+ var=markers,
+ obsm={"spatial": table[coords].to_numpy()},
+ dtype=np.float_,
+ )
+ adata.obs["region"] = f"{sample}_cells"
+
+ return TableModel.parse(
+ adata, region=f"{sample}_cells", region_key="region", instance_key=McmicroKeys.INSTANCE_KEY.value
+ )
diff --git a/src/spatialdata_io/readers/steinbock.py b/src/spatialdata_io/readers/steinbock.py
new file mode 100644
index 00000000..70074736
--- /dev/null
+++ b/src/spatialdata_io/readers/steinbock.py
@@ -0,0 +1,110 @@
+from __future__ import annotations
+
+import os
+from collections.abc import Mapping
+from pathlib import Path
+from types import MappingProxyType
+from typing import Any, Literal, Union
+
+import anndata as ad
+from dask_image.imread import imread
+from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
+from spatial_image import SpatialImage
+from spatialdata import Image2DModel, Labels2DModel, SpatialData, TableModel
+from spatialdata._core.transformations import Identity
+from spatialdata._logging import logger
+
+from spatialdata_io._constants._constants import SteinbockKeys
+
+__all__ = ["steinbock"]
+
+
+def steinbock(
+ path: str | Path,
+ labels_kind: Literal["deepcell", "ilastik"] = "deepcell",
+ imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
+ image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
+) -> SpatialData:
+ """
+ Read a *Steinbock* output into a SpatialData object.
+
+ .. seealso::
+
+ - `Steinbock pipeline `_.
+
+ Parameters
+ ----------
+ path
+ Path to the dataset.
+ labels_kind
+ Kind of labels to use. Either ``deepcell`` or ``ilastik``.
+ imread_kwargs
+ Keyword arguments to pass to the image reader.
+ image_models_kwargs
+ Keyword arguments to pass to the image models.
+
+ Returns
+ -------
+ :class:`spatialdata.SpatialData`
+ """
+ path = Path(path)
+
+ labels_kind = SteinbockKeys(f"masks_{labels_kind}") # type: ignore[assignment]
+
+ samples = [i.replace(SteinbockKeys.IMAGE_SUFFIX, "") for i in os.listdir(path / SteinbockKeys.IMAGES_DIR)]
+ samples_labels = [i.replace(SteinbockKeys.LABEL_SUFFIX, "") for i in os.listdir(path / labels_kind)]
+ images = {}
+ labels = {}
+ if len(set(samples).difference(set(samples_labels))):
+ logger.warning(
+ f"Samples {set(samples).difference(set(samples_labels))} have images but no labels. "
+ "They will be ignored."
+ )
+ for sample in samples:
+ images[f"{sample}_image"] = _get_images(
+ path,
+ sample,
+ imread_kwargs,
+ image_models_kwargs,
+ )
+ labels[f"{sample}_labels"] = _get_labels(
+ path,
+ sample,
+ labels_kind,
+ imread_kwargs,
+ image_models_kwargs,
+ )
+
+ adata = ad.read(path / SteinbockKeys.CELLS_FILE)
+ idx = adata.obs.index.str.split(" ").map(lambda x: int(x[1]))
+ regions = adata.obs.image.str.replace(".tiff", "", regex=False)
+ regions = regions.apply(lambda x: f"{x}_labels")
+ adata.obs["cell_id"] = idx
+ adata.obs["region"] = regions
+ adata.obsm["spatial"] = adata.obs[["centroid-0", "centroid-1"]].to_numpy()
+ if len({f"{s}_labels" for s in samples}.difference(set(regions.unique()))):
+ raise ValueError("Samples in table and images are inconsistent, please check.")
+ table = TableModel.parse(adata, region=regions.unique().tolist(), region_key="region", instance_key="cell_id")
+
+ return SpatialData(images=images, labels=labels, table=table)
+
+
+def _get_images(
+ path: Path,
+ sample: str,
+ imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
+ image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
+) -> Union[SpatialImage, MultiscaleSpatialImage]:
+ image = imread(path / SteinbockKeys.IMAGES_DIR / f"{sample}{SteinbockKeys.IMAGE_SUFFIX}", **imread_kwargs)
+ return Image2DModel.parse(data=image, transformations={sample: Identity()}, **image_models_kwargs)
+
+
+def _get_labels(
+ path: Path,
+ sample: str,
+ labels_kind: str,
+ imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
+ image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
+) -> Union[SpatialImage, MultiscaleSpatialImage]:
+ image = imread(path / labels_kind / f"{sample}{SteinbockKeys.LABEL_SUFFIX}", **imread_kwargs).squeeze()
+ return Labels2DModel.parse(data=image, transformations={sample: Identity()}, **image_models_kwargs)
diff --git a/src/spatialdata_io/readers/visium.py b/src/spatialdata_io/readers/visium.py
index 1d891005..64aac258 100644
--- a/src/spatialdata_io/readers/visium.py
+++ b/src/spatialdata_io/readers/visium.py
@@ -95,20 +95,12 @@ def visium(
adata.obs = pd.merge(adata.obs, coords, how="left", left_index=True, right_index=True)
coords = adata.obs[[VisiumKeys.SPOTS_X, VisiumKeys.SPOTS_Y]].values
+ adata.obsm["spatial"] = coords
adata.obs.drop(columns=[VisiumKeys.SPOTS_X, VisiumKeys.SPOTS_Y], inplace=True)
- adata.obs["visium_spot_id"] = adata.obs_names
+ adata.obs["spot_id"] = np.arange(len(adata))
+ adata.var_names_make_unique()
scalefactors = json.loads((path / VisiumKeys.SCALEFACTORS_FILE).read_bytes())
- shapes = {}
- circles = ShapesModel.parse(
- coords,
- shape_type="Circle",
- shape_size=scalefactors["spot_diameter_fullres"],
- index=adata.obs_names,
- transformations={"global": Identity()},
- )
- shapes[dataset_id] = circles
- table = TableModel.parse(adata, region=f"/shapes/{dataset_id}", region_key=None, instance_key="visium_spot_id")
transform_original = Identity()
transform_lowres = Scale(
@@ -120,6 +112,22 @@ def visium(
axes=("y", "x"),
)
+ shapes = {}
+ circles = ShapesModel.parse(
+ coords,
+ geometry=0,
+ radius=scalefactors["spot_diameter_fullres"] / 2.0,
+ index=adata.obs["spot_id"].copy(),
+ transformations={
+ "global": Identity(),
+ "downscaled_hires": transform_hires,
+ "downscaled_lowres": transform_lowres,
+ },
+ )
+ shapes[dataset_id] = circles
+ adata.obs["region"] = dataset_id
+ table = TableModel.parse(adata, region=dataset_id, region_key="region", instance_key="spot_id")
+
full_image = (
imread(path / f"{dataset_id}{VisiumKeys.IMAGE_TIF_SUFFIX}", **imread_kwargs).squeeze().transpose(2, 0, 1)
)
@@ -133,16 +141,12 @@ def visium(
full_image_parsed = Image2DModel.parse(
full_image,
- multiscale_factors=[2, 2, 2, 2],
+ scale_factors=[2, 2, 2, 2],
transformations={"global": transform_original},
**image_models_kwargs,
)
- image_hires_parsed = Image2DModel.parse(
- image_hires, transformations={"downscaled": transform_hires}
- )
- image_lowres_parsed = Image2DModel.parse(
- image_lowres, transformations={"downscaled": transform_lowres}
- )
+ image_hires_parsed = Image2DModel.parse(image_hires, transformations={"downscaled_hires": transform_hires})
+ image_lowres_parsed = Image2DModel.parse(image_lowres, transformations={"downscaled_lowres": transform_lowres})
images = {
dataset_id + "_full_image": full_image_parsed,
@@ -150,4 +154,4 @@ def visium(
dataset_id + "_lowres_image": image_lowres_parsed,
}
- return SpatialData(table=table, shapes=shapes, images=images)
+ return SpatialData(images=images, shapes=shapes, table=table)
diff --git a/src/spatialdata_io/readers/xenium.py b/src/spatialdata_io/readers/xenium.py
index ec919bdb..18dc70dd 100644
--- a/src/spatialdata_io/readers/xenium.py
+++ b/src/spatialdata_io/readers/xenium.py
@@ -4,29 +4,21 @@
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
-from typing import Any
+from typing import Any, Optional
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
from anndata import AnnData
-from dask_image.imread import imread
-from dask.dataframe.core import DataFrame as DaskDataFrame
from dask.dataframe import read_parquet
+from dask_image.imread import imread
from geopandas import GeoDataFrame
from joblib import Parallel, delayed
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
from pyarrow import Table
from shapely import Polygon
from spatial_image import SpatialImage
-from spatialdata import (
- Image2DModel,
- PointsModel,
- PolygonsModel,
- ShapesModel,
- SpatialData,
- TableModel,
-)
+from spatialdata import Image2DModel, PointsModel, ShapesModel, SpatialData, TableModel
from spatialdata._core.transformations import Identity, Scale
from spatialdata._types import ArrayLike
@@ -40,10 +32,9 @@
@inject_docs(xx=XeniumKeys)
def xenium(
path: str | Path,
- # dataset_id: str,
n_jobs: int = 1,
+ cells_as_shapes: bool = False,
nucleus_boundaries: bool = True,
- cell_boundaries: bool = True,
transcripts: bool = True,
morphology_mip: bool = True,
morphology_focus: bool = True,
@@ -74,10 +65,10 @@ def xenium(
Path to the dataset.
n_jobs
Number of jobs to use for parallel processing.
+ cells_as_shapes
+ Whether to read cells also as shapes. Useful for visualization.
nucleus_boundaries
Whether to read nucleus boundaries.
- cell_boundaries
- Whether to read cell boundaries.
transcripts
Whether to read transcripts.
morphology_mip
@@ -98,12 +89,20 @@ def xenium(
image_models_kwargs = {}
assert isinstance(image_models_kwargs, dict)
image_models_kwargs["chunks"] = (1, 4096, 4096)
- image_models_kwargs["multiscale_factors"] = [2, 2, 2, 2]
+ if "scale_factors" not in image_models_kwargs:
+ if isinstance(image_models_kwargs, MappingProxyType):
+ image_models_kwargs = {}
+ assert isinstance(image_models_kwargs, dict)
+ image_models_kwargs["scale_factors"] = [2, 2, 2, 2]
+
path = Path(path)
with open(path / XeniumKeys.XENIUM_SPECS) as f:
specs = json.load(f)
+ specs["region"] = "cell_circles" if cells_as_shapes else "cell_boundaries"
+ table, circles = _get_tables_and_circles(path, cells_as_shapes, specs)
polygons = {}
+
if nucleus_boundaries:
polygons["nucleus_boundaries"] = _get_polygons(
path,
@@ -111,13 +110,11 @@ def xenium(
specs,
n_jobs,
)
- if cell_boundaries:
- polygons["cell_boundaries"] = _get_polygons(
- path,
- XeniumKeys.CELL_BOUNDARIES_FILE,
- specs,
- n_jobs,
- )
+
+ polygons["cell_boundaries"] = _get_polygons(
+ path, XeniumKeys.CELL_BOUNDARIES_FILE, specs, n_jobs, idx=table.obs[str(XeniumKeys.CELL_ID)].copy()
+ )
+
points = {}
if transcripts:
points["transcripts"] = _get_points(path, specs)
@@ -139,14 +136,14 @@ def xenium(
imread_kwargs,
image_models_kwargs,
)
-
- circles = {}
- table, circles["circles"] = _get_tables(path, specs)
-
- return SpatialData(images=images, polygons=polygons, points=points, shapes=circles, table=table)
+ if cells_as_shapes:
+ return SpatialData(images=images, shapes=polygons | {specs["region"]: circles}, points=points, table=table)
+ return SpatialData(images=images, shapes=polygons, points=points, table=table)
-def _get_polygons(path: Path, file: str, specs: dict[str, Any], n_jobs: int) -> GeoDataFrame:
+def _get_polygons(
+ path: Path, file: str, specs: dict[str, Any], n_jobs: int, idx: Optional[ArrayLike] = None
+) -> GeoDataFrame:
def _poly(arr: ArrayLike) -> Polygon:
return Polygon(arr[:-1])
@@ -158,28 +155,19 @@ def _poly(arr: ArrayLike) -> Polygon:
for _, i in df.groupby(XeniumKeys.CELL_ID)[[XeniumKeys.BOUNDARIES_VERTEX_X, XeniumKeys.BOUNDARIES_VERTEX_Y]]
)
geo_df = GeoDataFrame({"geometry": out})
+ if idx is not None:
+ geo_df.index = idx
scale = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
- return PolygonsModel.parse(geo_df, transformations={"global": scale})
+ return ShapesModel.parse(geo_df, transformations={"global": scale})
def _get_points(path: Path, specs: dict[str, Any]) -> Table:
table = read_parquet(path / XeniumKeys.TRANSCRIPTS_FILE)
- # table = pq.read_table(path / XeniumKeys.TRANSCRIPTS_FILE)
- # arr = (
- # table.select([XeniumKeys.TRANSCRIPTS_X, XeniumKeys.TRANSCRIPTS_Y, XeniumKeys.TRANSCRIPTS_Z])
- # .to_pandas()
- # .to_numpy()
- # )
- # annotations = table.select((XeniumKeys.OVERLAPS_NUCLEUS, XeniumKeys.QUALITY_VALUE, XeniumKeys.CELL_ID))
- # annotations = annotations.add_column(
- # 3, XeniumKeys.FEATURE_NAME, table.column(XeniumKeys.FEATURE_NAME).cast("string").dictionary_encode()
- # )
transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
- # points = PointsModel.parse(coords=arr, annotations=annotations, transformations={"global": transform})
points = PointsModel.parse(
table,
- coordinates={"x": XeniumKeys.TRANSCRIPTS_X, "y": XeniumKeys.TRANSCRIPTS_Y},
+ coordinates={"x": XeniumKeys.TRANSCRIPTS_X, "y": XeniumKeys.TRANSCRIPTS_Y, "z": XeniumKeys.TRANSCRIPTS_Y},
feature_key=XeniumKeys.FEATURE_NAME,
instance_key=XeniumKeys.CELL_ID,
transformations={"global": transform},
@@ -187,26 +175,30 @@ def _get_points(path: Path, specs: dict[str, Any]) -> Table:
return points
-def _get_tables(path: Path, specs: dict[str, Any]) -> tuple[AnnData, AnnData]:
+def _get_tables_and_circles(
+ path: Path, cells_as_shapes: bool, specs: dict[str, Any]
+) -> AnnData | tuple[AnnData, AnnData]:
adata = _read_10x_h5(path / XeniumKeys.CELL_FEATURE_MATRIX_FILE)
metadata = pd.read_parquet(path / XeniumKeys.CELL_METADATA_FILE)
np.testing.assert_array_equal(metadata.cell_id.astype(str).values, adata.obs_names.values)
-
circ = metadata[[XeniumKeys.CELL_X, XeniumKeys.CELL_Y]].to_numpy()
+ adata.obsm["spatial"] = circ
metadata.drop([XeniumKeys.CELL_X, XeniumKeys.CELL_Y], axis=1, inplace=True)
- metadata[XeniumKeys.CELL_ID] = metadata[XeniumKeys.CELL_ID].astype(str)
adata.obs = metadata
- transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
- diameters = 2 * np.sqrt(adata.obs[XeniumKeys.CELL_AREA].to_numpy() / np.pi) / specs["pixel_size"]
- circles = ShapesModel.parse(
- circ,
- shape_type="Circle",
- shape_size=diameters,
- transformations={"global": transform},
- index=adata.obs[XeniumKeys.CELL_ID],
- )
- table = TableModel.parse(adata, region="/shapes/circles", instance_key=str(XeniumKeys.CELL_ID))
- return table, circles
+ adata.obs["region"] = specs["region"]
+ table = TableModel.parse(adata, region=specs["region"], region_key="region", instance_key=str(XeniumKeys.CELL_ID))
+ if cells_as_shapes:
+ transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
+ radii = np.sqrt(adata.obs[XeniumKeys.CELL_NUCLEUS_AREA].to_numpy() / np.pi)
+ circles = ShapesModel.parse(
+ circ,
+ geometry=0,
+ radius=radii,
+ transformations={"global": transform},
+ index=adata.obs[XeniumKeys.CELL_ID].copy(),
+ )
+ return table, circles
+ return table
def _get_images(