diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c1c46add..0548fbff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ default_stages: minimum_pre_commit_version: 2.16.0 repos: - repo: https://github.com/psf/black - rev: 22.12.0 + rev: 23.1.0 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-prettier @@ -19,11 +19,11 @@ repos: hooks: - id: blacken-docs - repo: https://github.com/PyCQA/isort - rev: 5.11.4 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.991 + rev: v1.0.1 hooks: - id: mypy additional_dependencies: [numpy==1.24.0] @@ -50,7 +50,7 @@ repos: - id: trailing-whitespace - id: check-case-conflict - repo: https://github.com/PyCQA/autoflake - rev: v2.0.0 + rev: v2.0.1 hooks: - id: autoflake args: diff --git a/pyproject.toml b/pyproject.toml index 6aa35ef7..aad9280f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "joblib", "imagecodecs", "dask-image", + "pyarrow", ] [project.optional-dependencies] diff --git a/src/spatialdata_io/__init__.py b/src/spatialdata_io/__init__.py index decc2e82..297c4435 100644 --- a/src/spatialdata_io/__init__.py +++ b/src/spatialdata_io/__init__.py @@ -1,6 +1,8 @@ from importlib.metadata import version from spatialdata_io.readers.cosmx import cosmx +from spatialdata_io.readers.mcmicro import mcmicro +from spatialdata_io.readers.steinbock import steinbock from spatialdata_io.readers.visium import visium from spatialdata_io.readers.xenium import xenium @@ -8,6 +10,8 @@ "visium", "xenium", "cosmx", + "mcmicro", + "steinbock", ] __version__ = version("spatialdata-io") diff --git a/src/spatialdata_io/_constants/_constants.py b/src/spatialdata_io/_constants/_constants.py index 7992c2bf..aee085d5 100644 --- a/src/spatialdata_io/_constants/_constants.py +++ b/src/spatialdata_io/_constants/_constants.py @@ -57,7 +57,8 @@ class XeniumKeys(ModeEnum): CELL_METADATA_FILE = "cells.parquet" CELL_X = "x_centroid" CELL_Y = "y_centroid" - CELL_AREA = 'cell_area' + CELL_AREA = "cell_area" + CELL_NUCLEUS_AREA = "nucleus_area" # morphology iamges MORPHOLOGY_MIP_FILE = "morphology_mip.ome.tif" @@ -85,3 +86,37 @@ class VisiumKeys(ModeEnum): SPOTS_FILE = "spatial/tissue_positions.csv" SPOTS_X = "pxl_row_in_fullres" SPOTS_Y = "pxl_col_in_fullres" + + +@unique +class SteinbockKeys(ModeEnum): + """Keys for *Steinbock* formatted dataset.""" + + # files and directories + CELLS_FILE = "cells.h5ad" + DEEPCELL_MASKS_DIR = "masks_deepcell" + ILASTIK_MASKS_DIR = "masks_ilastik" + IMAGES_DIR = "ome" + + # suffixes for images and labels + IMAGE_SUFFIX = ".ome.tiff" + LABEL_SUFFIX = ".tiff" + + +@unique +class McmicroKeys(ModeEnum): + """Keys for *Mcmicro* formatted dataset.""" + + # files and directories + CELL_FEATURES_SUFFIX = "--unmicst_cell.csv" + QUANTIFICATION_DIR = "quantification" + MARKERS_FILE = "markers.csv" + IMAGES_DIR = "registration" + IMAGE_SUFFIX = ".ome.tif" + LABELS_DIR = "segmentation" + LABELS_PREFIX = "unmicst-" + + # metadata + COORDS_X = "X_centroid" + COORDS_Y = "Y_centroid" + INSTANCE_KEY = "CellID" diff --git a/src/spatialdata_io/readers/cosmx.py b/src/spatialdata_io/readers/cosmx.py index 7083631f..3a7d6ca3 100644 --- a/src/spatialdata_io/readers/cosmx.py +++ b/src/spatialdata_io/readers/cosmx.py @@ -2,26 +2,29 @@ import os import re -import tempfile from collections.abc import Mapping from pathlib import Path from types import MappingProxyType from typing import Any, Optional -import pyarrow as pa -import pyarrow.parquet as pq +import dask.array as da import numpy as np import pandas as pd +import pyarrow as pa from anndata import AnnData -from dask_image.imread import imread -import dask.array as da from dask.dataframe.core import DataFrame as DaskDataFrame +from dask_image.imread import imread from scipy.sparse import csr_matrix # from spatialdata._core.core_utils import xy_cs from skimage.transform import estimate_transform from spatialdata import SpatialData -from spatialdata._core.models import Image2DModel, Labels2DModel, TableModel, PointsModel +from spatialdata._core.models import ( + Image2DModel, + Labels2DModel, + PointsModel, + TableModel, +) # from spatialdata._core.ngff.ngff_coordinate_system import NgffAxis # , CoordinateSystem from spatialdata._core.transformations import Affine, Identity @@ -41,7 +44,6 @@ def cosmx( path: str | Path, dataset_id: Optional[str] = None, - # shape_size: float | int = 1, transcripts: bool = True, imread_kwargs: Mapping[str, Any] = MappingProxyType({}), image_models_kwargs: Mapping[str, Any] = MappingProxyType({}), @@ -67,8 +69,8 @@ def cosmx( Path to the root directory containing *Nanostring* files. dataset_id Name of the dataset. - shape_size - Size of the shape to be used for the centroids of the labels. + transcripts + Whether to also read in transcripts information. imread_kwargs Keyword arguments passed to :func:`dask_image.imread.imread`. image_models_kwargs @@ -118,7 +120,7 @@ def cosmx( obs = pd.read_csv(path / meta_file, header=0, index_col=CosmxKeys.INSTANCE_KEY) obs[CosmxKeys.FOV] = pd.Categorical(obs[CosmxKeys.FOV].astype(str)) - obs[CosmxKeys.REGION_KEY] = pd.Categorical(obs[CosmxKeys.FOV].astype(str).apply(lambda s: "/labels/" + s)) + obs[CosmxKeys.REGION_KEY] = pd.Categorical(obs[CosmxKeys.FOV].astype(str).apply(lambda s: s + "_labels")) obs[CosmxKeys.INSTANCE_KEY] = obs.index.astype(np.int64) obs.rename_axis(None, inplace=True) obs.index = obs.index.astype(str).str.cat(obs[CosmxKeys.FOV].values, sep="_") @@ -141,12 +143,6 @@ def cosmx( fovs_counts = list(map(str, adata.obs.fov.astype(int).unique())) - # TODO(giovp): uncomment once transform is ready - # input_cs = CoordinateSystem("cxy", axes=[c_axis, y_axis, x_axis]) - # input_cs_labels = CoordinateSystem("cxy", axes=[y_axis, x_axis]) - # output_cs = CoordinateSystem("global", axes=[c_axis, y_axis, x_axis]) - # output_cs_labels = CoordinateSystem("global", axes=[y_axis, x_axis]) - affine_transforms_to_global = {} for fov in fovs_counts: @@ -163,7 +159,10 @@ def cosmx( table.obsm["global"] = table.obs[[CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL]].to_numpy() table.obsm["spatial"] = table.obs[[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL]].to_numpy() - table.obs.drop(columns=[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL, CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL], inplace=True) + table.obs.drop( + columns=[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL, CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL], + inplace=True, + ) # prepare to read images and labels file_extensions = (".jpg", ".png", ".jpeg", ".tif", ".tiff") @@ -200,7 +199,6 @@ def cosmx( flipped_im = da.flip(im, axis=0) parsed_im = Image2DModel.parse( flipped_im, - name=fov, transformations={ fov: Identity(), "global": aff, @@ -209,7 +207,7 @@ def cosmx( dims=("y", "x", "c"), **image_models_kwargs, ) - images[fov] = parsed_im + images[f"{fov}_image"] = parsed_im else: logger.warning(f"FOV {fov} not found in counts file. Skipping image {fname}.") @@ -224,7 +222,6 @@ def cosmx( flipped_la = da.flip(la, axis=0) parsed_la = Labels2DModel.parse( flipped_la, - name=fov, transformations={ fov: Identity(), "global": aff, @@ -233,15 +230,40 @@ def cosmx( dims=("y", "x"), **image_models_kwargs, ) - labels[fov] = parsed_la + labels[f"{fov}_labels"] = parsed_la else: logger.warning(f"FOV {fov} not found in counts file. Skipping labels {fname}.") points: dict[str, DaskDataFrame] = {} if transcripts: + # assert transcripts_file is not None + # from pyarrow.csv import read_csv + # + # ptable = read_csv(path / transcripts_file) # , header=0) + # for fov in fovs_counts: + # aff = affine_transforms_to_global[fov] + # sub_table = ptable.filter(pa.compute.equal(ptable.column(CosmxKeys.FOV), int(fov))).to_pandas() + # sub_table[CosmxKeys.INSTANCE_KEY] = sub_table[CosmxKeys.INSTANCE_KEY].astype("category") + # # we rename z because we want to treat the data as 2d + # sub_table.rename(columns={"z": "z_raw"}, inplace=True) + # points[fov] = PointsModel.parse( + # sub_table, + # coordinates={"x": CosmxKeys.X_LOCAL_TRANSCRIPT, "y": CosmxKeys.Y_LOCAL_TRANSCRIPT}, + # feature_key=CosmxKeys.TARGET_OF_TRANSCRIPT, + # instance_key=CosmxKeys.INSTANCE_KEY, + # transformations={ + # fov: Identity(), + # "global": aff, + # "global_only_labels": aff, + # }, + # ) # let's convert the .csv to .parquet and let's read it with pyarrow.parquet for faster subsetting + import tempfile + + import pyarrow.parquet as pq + with tempfile.TemporaryDirectory() as tmpdir: - print("converting .csv to .parquet... ", end="") + print("converting .csv to .parquet to improve the speed of the slicing operations... ", end="") assert transcripts_file is not None transcripts_data = pd.read_csv(path / transcripts_file, header=0) transcripts_data.to_parquet(Path(tmpdir) / "transcripts.parquet") @@ -251,10 +273,10 @@ def cosmx( for fov in fovs_counts: aff = affine_transforms_to_global[fov] sub_table = ptable.filter(pa.compute.equal(ptable.column(CosmxKeys.FOV), int(fov))).to_pandas() - sub_table[CosmxKeys.INSTANCE_KEY] = sub_table[CosmxKeys.INSTANCE_KEY].astype('category') + sub_table[CosmxKeys.INSTANCE_KEY] = sub_table[CosmxKeys.INSTANCE_KEY].astype("category") # we rename z because we want to treat the data as 2d - sub_table.rename(columns={'z': 'z_raw'}, inplace=True) - points[fov] = PointsModel.parse( + sub_table.rename(columns={"z": "z_raw"}, inplace=True) + points[f"{fov}_points"] = PointsModel.parse( sub_table, coordinates={"x": CosmxKeys.X_LOCAL_TRANSCRIPT, "y": CosmxKeys.Y_LOCAL_TRANSCRIPT}, feature_key=CosmxKeys.TARGET_OF_TRANSCRIPT, @@ -266,7 +288,6 @@ def cosmx( }, ) - # TODO: what to do with fov file? # if fov_file is not None: # fov_positions = pd.read_csv(path / fov_file, header=0, index_col=CosmxKeys.FOV) diff --git a/src/spatialdata_io/readers/mcmicro.py b/src/spatialdata_io/readers/mcmicro.py new file mode 100644 index 00000000..98f7647b --- /dev/null +++ b/src/spatialdata_io/readers/mcmicro.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import os +from collections.abc import Mapping +from pathlib import Path +from types import MappingProxyType +from typing import Any, Union + +import numpy as np +import pandas as pd +from anndata import AnnData +from dask_image.imread import imread +from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage +from spatialdata import Image2DModel, Labels2DModel, SpatialData, TableModel + +from spatialdata_io._constants._constants import McmicroKeys + +__all__ = ["mcmicro"] + + +def mcmicro( + path: str | Path, + dataset_id: str, + imread_kwargs: Mapping[str, Any] = MappingProxyType({}), + image_models_kwargs: Mapping[str, Any] = MappingProxyType({}), +) -> SpatialData: + """ + Read a *Mcmicro* output into a SpatialData object. + + .. seealso:: + + - `Mcmicro pipeline `_. + + Parameters + ---------- + path + Path to the dataset. + dataset_id + Dataset identifier. + imread_kwargs + Keyword arguments to pass to the image reader. + image_models_kwargs + Keyword arguments to pass to the image models. + + Returns + ------- + :class:`spatialdata.SpatialData` + """ + path = Path(path) + + samples = os.listdir(path / McmicroKeys.IMAGES_DIR) + if len(samples) > 1: + raise ValueError("Only one sample per dataset is supported.") + if (dataset_id + McmicroKeys.IMAGE_SUFFIX) not in samples: + raise ValueError("Dataset id is not consistent with sample name.") + + images = {} + images[f"{dataset_id}_image"] = _get_images( + path, + dataset_id, + imread_kwargs, + image_models_kwargs, + ) + labels = {} + labels[f"{dataset_id}_cells"] = _get_labels( + path, + dataset_id, + "cell", + imread_kwargs, + image_models_kwargs, + ) + labels[f"{dataset_id}_nuclei"] = _get_labels( + path, + dataset_id, + "nuclei", + imread_kwargs, + image_models_kwargs, + ) + + table = _get_table(path, dataset_id) + + return SpatialData(images=images, labels=labels, table=table) + + +def _get_images( + path: Path, + sample: str, + imread_kwargs: Mapping[str, Any] = MappingProxyType({}), + image_models_kwargs: Mapping[str, Any] = MappingProxyType({}), +) -> Union[SpatialImage, MultiscaleSpatialImage]: + image = imread(path / McmicroKeys.IMAGES_DIR / f"{sample}{McmicroKeys.IMAGE_SUFFIX}", **imread_kwargs) + return Image2DModel.parse(image, **image_models_kwargs) + + +def _get_labels( + path: Path, + sample: str, + labels_kind: str, + imread_kwargs: Mapping[str, Any] = MappingProxyType({}), + image_models_kwargs: Mapping[str, Any] = MappingProxyType({}), +) -> Union[SpatialImage, MultiscaleSpatialImage]: + image = imread( + path + / McmicroKeys.LABELS_DIR + / f"{McmicroKeys.LABELS_PREFIX}{sample}" + / f"{labels_kind}{McmicroKeys.IMAGE_SUFFIX}", + **imread_kwargs, + ).squeeze() + return Labels2DModel.parse(image, **image_models_kwargs) + + +def _get_table( + path: Path, + sample: str, +) -> AnnData: + table = pd.read_csv(path / McmicroKeys.QUANTIFICATION_DIR / f"{sample}{McmicroKeys.CELL_FEATURES_SUFFIX}") + markers = pd.read_csv(path / McmicroKeys.MARKERS_FILE) + markers.index = markers.marker_name + var = markers.marker_name.tolist() + coords = [McmicroKeys.COORDS_X.value, McmicroKeys.COORDS_Y.value] + adata = AnnData( + table[var].to_numpy(), + obs=table.drop(columns=var + coords), + var=markers, + obsm={"spatial": table[coords].to_numpy()}, + dtype=np.float_, + ) + adata.obs["region"] = f"{sample}_cells" + + return TableModel.parse( + adata, region=f"{sample}_cells", region_key="region", instance_key=McmicroKeys.INSTANCE_KEY.value + ) diff --git a/src/spatialdata_io/readers/steinbock.py b/src/spatialdata_io/readers/steinbock.py new file mode 100644 index 00000000..70074736 --- /dev/null +++ b/src/spatialdata_io/readers/steinbock.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import os +from collections.abc import Mapping +from pathlib import Path +from types import MappingProxyType +from typing import Any, Literal, Union + +import anndata as ad +from dask_image.imread import imread +from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage +from spatialdata import Image2DModel, Labels2DModel, SpatialData, TableModel +from spatialdata._core.transformations import Identity +from spatialdata._logging import logger + +from spatialdata_io._constants._constants import SteinbockKeys + +__all__ = ["steinbock"] + + +def steinbock( + path: str | Path, + labels_kind: Literal["deepcell", "ilastik"] = "deepcell", + imread_kwargs: Mapping[str, Any] = MappingProxyType({}), + image_models_kwargs: Mapping[str, Any] = MappingProxyType({}), +) -> SpatialData: + """ + Read a *Steinbock* output into a SpatialData object. + + .. seealso:: + + - `Steinbock pipeline `_. + + Parameters + ---------- + path + Path to the dataset. + labels_kind + Kind of labels to use. Either ``deepcell`` or ``ilastik``. + imread_kwargs + Keyword arguments to pass to the image reader. + image_models_kwargs + Keyword arguments to pass to the image models. + + Returns + ------- + :class:`spatialdata.SpatialData` + """ + path = Path(path) + + labels_kind = SteinbockKeys(f"masks_{labels_kind}") # type: ignore[assignment] + + samples = [i.replace(SteinbockKeys.IMAGE_SUFFIX, "") for i in os.listdir(path / SteinbockKeys.IMAGES_DIR)] + samples_labels = [i.replace(SteinbockKeys.LABEL_SUFFIX, "") for i in os.listdir(path / labels_kind)] + images = {} + labels = {} + if len(set(samples).difference(set(samples_labels))): + logger.warning( + f"Samples {set(samples).difference(set(samples_labels))} have images but no labels. " + "They will be ignored." + ) + for sample in samples: + images[f"{sample}_image"] = _get_images( + path, + sample, + imread_kwargs, + image_models_kwargs, + ) + labels[f"{sample}_labels"] = _get_labels( + path, + sample, + labels_kind, + imread_kwargs, + image_models_kwargs, + ) + + adata = ad.read(path / SteinbockKeys.CELLS_FILE) + idx = adata.obs.index.str.split(" ").map(lambda x: int(x[1])) + regions = adata.obs.image.str.replace(".tiff", "", regex=False) + regions = regions.apply(lambda x: f"{x}_labels") + adata.obs["cell_id"] = idx + adata.obs["region"] = regions + adata.obsm["spatial"] = adata.obs[["centroid-0", "centroid-1"]].to_numpy() + if len({f"{s}_labels" for s in samples}.difference(set(regions.unique()))): + raise ValueError("Samples in table and images are inconsistent, please check.") + table = TableModel.parse(adata, region=regions.unique().tolist(), region_key="region", instance_key="cell_id") + + return SpatialData(images=images, labels=labels, table=table) + + +def _get_images( + path: Path, + sample: str, + imread_kwargs: Mapping[str, Any] = MappingProxyType({}), + image_models_kwargs: Mapping[str, Any] = MappingProxyType({}), +) -> Union[SpatialImage, MultiscaleSpatialImage]: + image = imread(path / SteinbockKeys.IMAGES_DIR / f"{sample}{SteinbockKeys.IMAGE_SUFFIX}", **imread_kwargs) + return Image2DModel.parse(data=image, transformations={sample: Identity()}, **image_models_kwargs) + + +def _get_labels( + path: Path, + sample: str, + labels_kind: str, + imread_kwargs: Mapping[str, Any] = MappingProxyType({}), + image_models_kwargs: Mapping[str, Any] = MappingProxyType({}), +) -> Union[SpatialImage, MultiscaleSpatialImage]: + image = imread(path / labels_kind / f"{sample}{SteinbockKeys.LABEL_SUFFIX}", **imread_kwargs).squeeze() + return Labels2DModel.parse(data=image, transformations={sample: Identity()}, **image_models_kwargs) diff --git a/src/spatialdata_io/readers/visium.py b/src/spatialdata_io/readers/visium.py index 1d891005..64aac258 100644 --- a/src/spatialdata_io/readers/visium.py +++ b/src/spatialdata_io/readers/visium.py @@ -95,20 +95,12 @@ def visium( adata.obs = pd.merge(adata.obs, coords, how="left", left_index=True, right_index=True) coords = adata.obs[[VisiumKeys.SPOTS_X, VisiumKeys.SPOTS_Y]].values + adata.obsm["spatial"] = coords adata.obs.drop(columns=[VisiumKeys.SPOTS_X, VisiumKeys.SPOTS_Y], inplace=True) - adata.obs["visium_spot_id"] = adata.obs_names + adata.obs["spot_id"] = np.arange(len(adata)) + adata.var_names_make_unique() scalefactors = json.loads((path / VisiumKeys.SCALEFACTORS_FILE).read_bytes()) - shapes = {} - circles = ShapesModel.parse( - coords, - shape_type="Circle", - shape_size=scalefactors["spot_diameter_fullres"], - index=adata.obs_names, - transformations={"global": Identity()}, - ) - shapes[dataset_id] = circles - table = TableModel.parse(adata, region=f"/shapes/{dataset_id}", region_key=None, instance_key="visium_spot_id") transform_original = Identity() transform_lowres = Scale( @@ -120,6 +112,22 @@ def visium( axes=("y", "x"), ) + shapes = {} + circles = ShapesModel.parse( + coords, + geometry=0, + radius=scalefactors["spot_diameter_fullres"] / 2.0, + index=adata.obs["spot_id"].copy(), + transformations={ + "global": Identity(), + "downscaled_hires": transform_hires, + "downscaled_lowres": transform_lowres, + }, + ) + shapes[dataset_id] = circles + adata.obs["region"] = dataset_id + table = TableModel.parse(adata, region=dataset_id, region_key="region", instance_key="spot_id") + full_image = ( imread(path / f"{dataset_id}{VisiumKeys.IMAGE_TIF_SUFFIX}", **imread_kwargs).squeeze().transpose(2, 0, 1) ) @@ -133,16 +141,12 @@ def visium( full_image_parsed = Image2DModel.parse( full_image, - multiscale_factors=[2, 2, 2, 2], + scale_factors=[2, 2, 2, 2], transformations={"global": transform_original}, **image_models_kwargs, ) - image_hires_parsed = Image2DModel.parse( - image_hires, transformations={"downscaled": transform_hires} - ) - image_lowres_parsed = Image2DModel.parse( - image_lowres, transformations={"downscaled": transform_lowres} - ) + image_hires_parsed = Image2DModel.parse(image_hires, transformations={"downscaled_hires": transform_hires}) + image_lowres_parsed = Image2DModel.parse(image_lowres, transformations={"downscaled_lowres": transform_lowres}) images = { dataset_id + "_full_image": full_image_parsed, @@ -150,4 +154,4 @@ def visium( dataset_id + "_lowres_image": image_lowres_parsed, } - return SpatialData(table=table, shapes=shapes, images=images) + return SpatialData(images=images, shapes=shapes, table=table) diff --git a/src/spatialdata_io/readers/xenium.py b/src/spatialdata_io/readers/xenium.py index ec919bdb..18dc70dd 100644 --- a/src/spatialdata_io/readers/xenium.py +++ b/src/spatialdata_io/readers/xenium.py @@ -4,29 +4,21 @@ from collections.abc import Mapping from pathlib import Path from types import MappingProxyType -from typing import Any +from typing import Any, Optional import numpy as np import pandas as pd import pyarrow.parquet as pq from anndata import AnnData -from dask_image.imread import imread -from dask.dataframe.core import DataFrame as DaskDataFrame from dask.dataframe import read_parquet +from dask_image.imread import imread from geopandas import GeoDataFrame from joblib import Parallel, delayed from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from pyarrow import Table from shapely import Polygon from spatial_image import SpatialImage -from spatialdata import ( - Image2DModel, - PointsModel, - PolygonsModel, - ShapesModel, - SpatialData, - TableModel, -) +from spatialdata import Image2DModel, PointsModel, ShapesModel, SpatialData, TableModel from spatialdata._core.transformations import Identity, Scale from spatialdata._types import ArrayLike @@ -40,10 +32,9 @@ @inject_docs(xx=XeniumKeys) def xenium( path: str | Path, - # dataset_id: str, n_jobs: int = 1, + cells_as_shapes: bool = False, nucleus_boundaries: bool = True, - cell_boundaries: bool = True, transcripts: bool = True, morphology_mip: bool = True, morphology_focus: bool = True, @@ -74,10 +65,10 @@ def xenium( Path to the dataset. n_jobs Number of jobs to use for parallel processing. + cells_as_shapes + Whether to read cells also as shapes. Useful for visualization. nucleus_boundaries Whether to read nucleus boundaries. - cell_boundaries - Whether to read cell boundaries. transcripts Whether to read transcripts. morphology_mip @@ -98,12 +89,20 @@ def xenium( image_models_kwargs = {} assert isinstance(image_models_kwargs, dict) image_models_kwargs["chunks"] = (1, 4096, 4096) - image_models_kwargs["multiscale_factors"] = [2, 2, 2, 2] + if "scale_factors" not in image_models_kwargs: + if isinstance(image_models_kwargs, MappingProxyType): + image_models_kwargs = {} + assert isinstance(image_models_kwargs, dict) + image_models_kwargs["scale_factors"] = [2, 2, 2, 2] + path = Path(path) with open(path / XeniumKeys.XENIUM_SPECS) as f: specs = json.load(f) + specs["region"] = "cell_circles" if cells_as_shapes else "cell_boundaries" + table, circles = _get_tables_and_circles(path, cells_as_shapes, specs) polygons = {} + if nucleus_boundaries: polygons["nucleus_boundaries"] = _get_polygons( path, @@ -111,13 +110,11 @@ def xenium( specs, n_jobs, ) - if cell_boundaries: - polygons["cell_boundaries"] = _get_polygons( - path, - XeniumKeys.CELL_BOUNDARIES_FILE, - specs, - n_jobs, - ) + + polygons["cell_boundaries"] = _get_polygons( + path, XeniumKeys.CELL_BOUNDARIES_FILE, specs, n_jobs, idx=table.obs[str(XeniumKeys.CELL_ID)].copy() + ) + points = {} if transcripts: points["transcripts"] = _get_points(path, specs) @@ -139,14 +136,14 @@ def xenium( imread_kwargs, image_models_kwargs, ) - - circles = {} - table, circles["circles"] = _get_tables(path, specs) - - return SpatialData(images=images, polygons=polygons, points=points, shapes=circles, table=table) + if cells_as_shapes: + return SpatialData(images=images, shapes=polygons | {specs["region"]: circles}, points=points, table=table) + return SpatialData(images=images, shapes=polygons, points=points, table=table) -def _get_polygons(path: Path, file: str, specs: dict[str, Any], n_jobs: int) -> GeoDataFrame: +def _get_polygons( + path: Path, file: str, specs: dict[str, Any], n_jobs: int, idx: Optional[ArrayLike] = None +) -> GeoDataFrame: def _poly(arr: ArrayLike) -> Polygon: return Polygon(arr[:-1]) @@ -158,28 +155,19 @@ def _poly(arr: ArrayLike) -> Polygon: for _, i in df.groupby(XeniumKeys.CELL_ID)[[XeniumKeys.BOUNDARIES_VERTEX_X, XeniumKeys.BOUNDARIES_VERTEX_Y]] ) geo_df = GeoDataFrame({"geometry": out}) + if idx is not None: + geo_df.index = idx scale = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y")) - return PolygonsModel.parse(geo_df, transformations={"global": scale}) + return ShapesModel.parse(geo_df, transformations={"global": scale}) def _get_points(path: Path, specs: dict[str, Any]) -> Table: table = read_parquet(path / XeniumKeys.TRANSCRIPTS_FILE) - # table = pq.read_table(path / XeniumKeys.TRANSCRIPTS_FILE) - # arr = ( - # table.select([XeniumKeys.TRANSCRIPTS_X, XeniumKeys.TRANSCRIPTS_Y, XeniumKeys.TRANSCRIPTS_Z]) - # .to_pandas() - # .to_numpy() - # ) - # annotations = table.select((XeniumKeys.OVERLAPS_NUCLEUS, XeniumKeys.QUALITY_VALUE, XeniumKeys.CELL_ID)) - # annotations = annotations.add_column( - # 3, XeniumKeys.FEATURE_NAME, table.column(XeniumKeys.FEATURE_NAME).cast("string").dictionary_encode() - # ) transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y")) - # points = PointsModel.parse(coords=arr, annotations=annotations, transformations={"global": transform}) points = PointsModel.parse( table, - coordinates={"x": XeniumKeys.TRANSCRIPTS_X, "y": XeniumKeys.TRANSCRIPTS_Y}, + coordinates={"x": XeniumKeys.TRANSCRIPTS_X, "y": XeniumKeys.TRANSCRIPTS_Y, "z": XeniumKeys.TRANSCRIPTS_Y}, feature_key=XeniumKeys.FEATURE_NAME, instance_key=XeniumKeys.CELL_ID, transformations={"global": transform}, @@ -187,26 +175,30 @@ def _get_points(path: Path, specs: dict[str, Any]) -> Table: return points -def _get_tables(path: Path, specs: dict[str, Any]) -> tuple[AnnData, AnnData]: +def _get_tables_and_circles( + path: Path, cells_as_shapes: bool, specs: dict[str, Any] +) -> AnnData | tuple[AnnData, AnnData]: adata = _read_10x_h5(path / XeniumKeys.CELL_FEATURE_MATRIX_FILE) metadata = pd.read_parquet(path / XeniumKeys.CELL_METADATA_FILE) np.testing.assert_array_equal(metadata.cell_id.astype(str).values, adata.obs_names.values) - circ = metadata[[XeniumKeys.CELL_X, XeniumKeys.CELL_Y]].to_numpy() + adata.obsm["spatial"] = circ metadata.drop([XeniumKeys.CELL_X, XeniumKeys.CELL_Y], axis=1, inplace=True) - metadata[XeniumKeys.CELL_ID] = metadata[XeniumKeys.CELL_ID].astype(str) adata.obs = metadata - transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y")) - diameters = 2 * np.sqrt(adata.obs[XeniumKeys.CELL_AREA].to_numpy() / np.pi) / specs["pixel_size"] - circles = ShapesModel.parse( - circ, - shape_type="Circle", - shape_size=diameters, - transformations={"global": transform}, - index=adata.obs[XeniumKeys.CELL_ID], - ) - table = TableModel.parse(adata, region="/shapes/circles", instance_key=str(XeniumKeys.CELL_ID)) - return table, circles + adata.obs["region"] = specs["region"] + table = TableModel.parse(adata, region=specs["region"], region_key="region", instance_key=str(XeniumKeys.CELL_ID)) + if cells_as_shapes: + transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y")) + radii = np.sqrt(adata.obs[XeniumKeys.CELL_NUCLEUS_AREA].to_numpy() / np.pi) + circles = ShapesModel.parse( + circ, + geometry=0, + radius=radii, + transformations={"global": transform}, + index=adata.obs[XeniumKeys.CELL_ID].copy(), + ) + return table, circles + return table def _get_images(