diff --git a/docs/api/dataset.md b/docs/api/dataset.md index ec1087e6..2b3cb7c4 100644 --- a/docs/api/dataset.md +++ b/docs/api/dataset.md @@ -2,6 +2,12 @@ options: filters: ["!^_"] +--- + +::: polaris.dataset._base.BaseDataset + options: + filters: ["!^_"] + --- ::: polaris.dataset.ColumnAnnotation diff --git a/docs/quickstart.md b/docs/quickstart.md index f2faa180..90c6a9e8 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -82,8 +82,8 @@ dataset.get_data( # Or, similarly: dataset[dataset.rows[0], dataset.columns[0]] -# Get the first 10 rows in memory -dataset[:10] +# Get an entire row +dataset[0] ``` ## Core concepts diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 45f9cc15..e73d7f35 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -1,6 +1,6 @@ -from itertools import chain import json from hashlib import md5 +from itertools import chain from typing import Any, Callable, Optional, Union import fsspec @@ -18,11 +18,11 @@ from sklearn.utils.multiclass import type_of_target from polaris._artifact import BaseArtifactModel -from polaris.mixins import ChecksumMixin -from polaris.dataset import Dataset, Subset, CompetitionDataset +from polaris.dataset import CompetitionDataset, DatasetV1, Subset from polaris.evaluate import BenchmarkResults, Metric from polaris.evaluate.utils import evaluate_benchmark from polaris.hub.settings import PolarisHubSettings +from polaris.mixins import ChecksumMixin from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidBenchmarkError from polaris.utils.misc import listit @@ -96,7 +96,7 @@ class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin): # Public attributes # Data - dataset: Union[Dataset, CompetitionDataset, str, dict[str, Any]] + dataset: Union[DatasetV1, CompetitionDataset, str, dict[str, Any]] target_cols: ColumnsType input_cols: ColumnsType split: SplitType @@ -111,12 +111,11 @@ class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin): def _validate_dataset(cls, v): """ Allows either passing a Dataset object or the kwargs to create one - TODO (cwognum): Allow multiple datasets to be used as part of a benchmark """ if isinstance(v, dict): - v = Dataset(**v) + v = DatasetV1(**v) elif isinstance(v, str): - v = Dataset.from_json(v) + v = DatasetV1.from_json(v) return v @field_validator("target_cols", "input_cols") @@ -162,7 +161,7 @@ def _validate_main_metric(cls, v): return v @model_validator(mode="after") - def _validate_split(cls, m: "BenchmarkSpecification"): + def _validate_split(self): """ Verifies that: 1) There are no empty test partitions @@ -171,7 +170,7 @@ def _validate_split(cls, m: "BenchmarkSpecification"): 4) There is no overlap between the train and test set 5) No row exists in the test set where all labels are missing/empty """ - split = m.split + split = self.split # Train partition can be empty (zero-shot) # Test partitions cannot be empty @@ -214,13 +213,13 @@ def _validate_split(cls, m: "BenchmarkSpecification"): raise InvalidBenchmarkError("The test set contains duplicate indices") # All indices are valid given the dataset - dataset = m.dataset + dataset = self.dataset if dataset is not None: max_i = len(dataset) if any(i < 0 or i >= max_i for i in chain(train_idx_list, full_test_idx_set)): raise InvalidBenchmarkError("The predefined split contains invalid indices") - return m + return self @field_validator("target_types") def _validate_target_types(cls, v, info: ValidationInfo): @@ -234,11 +233,20 @@ def _validate_target_types(cls, v, info: ValidationInfo): for target in target_cols: if target not in v: - val = dataset[:, target] + # Skip inferring the target type for pointer columns. + # This would be complex to implement properly. + # For these columns, dataset creators can still manually specify the target type. + anno = dataset.annotations.get(target) + if anno is not None and anno.is_pointer: + v[target] = None + continue + + val = dataset.table.loc[:, target] # Non numeric columns can be targets (e.g. prediction molecular reactions), # but in that case we currently don't infer the target type. if not np.issubdtype(val.dtype, np.number): + v[target] = None continue # remove the nans for mutiple task dataset when the table is sparse @@ -254,15 +262,14 @@ def _validate_target_types(cls, v, info: ValidationInfo): return v @model_validator(mode="after") - @classmethod - def _validate_model(cls, m: "BenchmarkSpecification"): + def _validate_model(self): """ Sets a default metric if missing. """ # Set a default main metric if not set yet - if m.main_metric is None: - m.main_metric = m.metrics[0] - return m + if self.main_metric is None: + self.main_metric = self.metrics[0] + return self @field_serializer("metrics", "main_metric") def _serialize_metrics(self, v): @@ -342,9 +349,10 @@ def n_classes(self) -> dict[str, int]: """The number of classes for each of the target columns.""" n_classes = {} for target in self.target_cols: - target_type = self.target_types[target] + target_type = self.target_types.get(target) if target_type is None or target_type == TargetType.REGRESSION: continue + # TODO: Don't use table attribute n_classes[target] = self.dataset.table.loc[:, target].nunique() return n_classes diff --git a/polaris/dataset/__init__.py b/polaris/dataset/__init__.py index 3f861d54..e82249cc 100644 --- a/polaris/dataset/__init__.py +++ b/polaris/dataset/__init__.py @@ -1,8 +1,9 @@ -from polaris.dataset._column import ColumnAnnotation, Modality, KnownContentType -from polaris.dataset._dataset import Dataset +from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality +from polaris.dataset._competition_dataset import CompetitionDataset +from polaris.dataset._dataset import DatasetV1 +from polaris.dataset._dataset import DatasetV1 as Dataset from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files from polaris.dataset._subset import Subset -from polaris.dataset._competition_dataset import CompetitionDataset __all__ = [ "ColumnAnnotation", @@ -14,4 +15,5 @@ "DatasetFactory", "create_dataset_from_file", "create_dataset_from_files", + "DatasetV1", ] diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py new file mode 100644 index 00000000..7e5e6133 --- /dev/null +++ b/polaris/dataset/_base.py @@ -0,0 +1,356 @@ +import abc +import json +from pathlib import Path +from typing import Any, Dict, MutableMapping, Optional, Union + +import numpy as np +import zarr +from loguru import logger +from pydantic import ( + Field, + PrivateAttr, + computed_field, + field_serializer, + field_validator, + model_validator, +) + +from polaris._artifact import BaseArtifactModel +from polaris.dataset._adapters import Adapter +from polaris.dataset._column import ColumnAnnotation +from polaris.dataset.zarr import MemoryMappedDirectoryStore +from polaris.dataset.zarr._utils import load_zarr_group_to_memory +from polaris.hub.polarisfs import PolarisFileSystem +from polaris.utils.dict2html import dict2html +from polaris.utils.errors import InvalidDatasetError +from polaris.utils.types import ( + AccessType, + DatasetIndex, + HttpUrlString, + HubOwner, + SupportedLicenseType, + ZarrConflictResolution, +) + +# Constants +_CACHE_SUBDIR = "datasets" + + +class BaseDataset(BaseArtifactModel, abc.ABC): + """Base data-model for a Polaris dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. + + At its core, a dataset in Polaris can _conceptually_ be thought of as tabular data structure that stores data-points + in a row-wise manner, where each column correspond to a variable associated with that datapoint. + + A Dataset can have multiple modalities or targets, can be sparse and can be part of one or multiple + [`BenchmarkSpecification`][polaris.benchmark.BenchmarkSpecification] objects. + + Attributes: + default_adapters: The adapters that the Dataset recommends to use by default to change the format of the data + for specific columns. + zarr_root_path: The data for any pointer column should be saved in the Zarr archive this path points to. + readme: Markdown text that can be used to provide a formatted description of the dataset. + If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI + as it provides a rich text editor for writing markdown. + annotations: Each column _can be_ annotated with a [`ColumnAnnotation`][polaris.dataset.ColumnAnnotation] object. + Importantly, this is used to annotate whether a column is a pointer column. + source: The data source, e.g. a DOI, Github repo or URI. + license: The dataset license. Polaris only supports some Creative Commons licenses. See [`SupportedLicenseType`][polaris.utils.types.SupportedLicenseType] for accepted ID values. + curation_reference: A reference to the curation process, e.g. a DOI, Github repo or URI. + cache_dir: Where the dataset would be cached if you call the `cache()` method. + For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. + + Raises: + InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. + """ + + # Public attributes + # Data + default_adapters: Dict[str, Adapter] = Field(default_factory=dict) + zarr_root_path: Optional[str] = None + + # Additional meta-data + readme: str = "" + annotations: Dict[str, ColumnAnnotation] = Field(default_factory=dict) + source: Optional[HttpUrlString] = None + license: Optional[SupportedLicenseType] = None + curation_reference: Optional[HttpUrlString] = None + + # Config + cache_dir: Optional[Path] = None + + # Private attributes + _zarr_root: Optional[zarr.Group] = PrivateAttr(None) + _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) + _client = PrivateAttr(None) # Optional[PolarisHubClient] + _warn_about_remote_zarr: bool = PrivateAttr(True) + + @field_validator("default_adapters", mode="before") + def _validate_adapters(cls, value): + """Validate the adapters""" + return {k: Adapter[v] if isinstance(v, str) else v for k, v in value.items()} + + @field_serializer("default_adapters") + def _serialize_adapters(self, value: dict[str, Adapter]): + """Serializes the adapters""" + return {k: v.name for k, v in value.items()} + + @field_serializer("cache_dir", "zarr_root_path") + def _serialize_paths(value): + """Serialize the paths""" + if value is not None: + value = str(value) + return value + + @model_validator(mode="after") + def _validate_base_dataset_model(self): + # Verify that all annotations are for columns that exist + if any(k not in self.columns for k in self.annotations): + raise InvalidDatasetError( + f"There are annotations for columns that do not exist. Columns: {self.columns}. Annotations: {list(self.annotations.keys())}" + ) + + # Verify that all adapters are for columns that exist + if any(k not in self.columns for k in self.default_adapters.keys()): + raise InvalidDatasetError( + f"There are default adapters for columns that do not exist. Columns: {self.columns}. Adapters: {list(self.annotations.keys())}" + ) + + # Set a default for missing annotations and convert strings to Modality + for c in self.columns: + if c not in self.annotations: + self.annotations[c] = ColumnAnnotation() + self.annotations[c].dtype = self.dtypes[c] + + return self + + @property + def client(self): + """The Polaris Hub client used to interact with the Polaris Hub.""" + + # Import it here to prevent circular imports + from polaris.hub.client import PolarisHubClient + + if self._client is None: + self._client = PolarisHubClient() + return self._client + + @property + def uses_zarr(self) -> bool: + """Whether any of the data in this dataset is stored in a Zarr Archive.""" + return self.zarr_root_path is not None + + @property + def zarr_data(self): + """Get the Zarr data. + + This is different from the Zarr Root, because to optimize the efficiency of + data loading, a user can choose to load the data into memory as a numpy array + + Note: General purpose dataloader. + The goal with Polaris is to provide general purpose datasets that serve as good + options for a _wide variety of use cases_. This also implies you should be able to + optimize things further for a specific use case if needed. + """ + if self._zarr_data is not None: + return self._zarr_data + return self.zarr_root + + @property + def zarr_root(self) -> zarr.Group | None: + """Get the zarr Group object corresponding to the root. + + Opens the zarr archive in read-write mode if it is not already open. + + Note: Different to `zarr_data` + The `zarr_data` attribute references either to the Zarr archive or to a in-memory copy of the data. + See also [`Dataset.load_to_memory`][polaris.dataset.Dataset.load_to_memory]. + """ + + if self._zarr_root is not None: + return self._zarr_root + + if self.zarr_root_path is None: + return None + + # We open the archive in read-only mode if it is saved on the Hub + saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) + + if self._warn_about_remote_zarr: + saved_remote = saved_on_hub or not Path(self.zarr_root_path).exists() + + if saved_remote: + logger.warning( + f"You're loading data from a remote location. " + f"To speed up this process, consider caching the dataset first " + f"using {self.__class__.__name__}.cache()" + ) + self._warn_about_remote_zarr = False + + try: + if saved_on_hub: + self._zarr_root = self.client.open_zarr_file(self.owner, self.name, self.zarr_root_path, "r+") + else: + # We use memory mapping by default because our experiments show that it's consistently faster + store = MemoryMappedDirectoryStore(self.zarr_root_path) + self._zarr_root = zarr.open_consolidated(store, mode="r+") + except KeyError as error: + raise InvalidDatasetError( + "A Zarr archive associated with a Polaris dataset has to be consolidated." + ) from error + return self._zarr_root + + @computed_field + @property + def n_rows(self) -> int: + """The number of rows in the dataset.""" + return len(self.rows) + + @computed_field + @property + def n_columns(self) -> int: + """The number of columns in the dataset.""" + return len(self.columns) + + @property + @abc.abstractmethod + def rows(self) -> list[str | int]: + """Return all row indices for the dataset""" + raise NotImplementedError + + @property + @abc.abstractmethod + def columns(self) -> list[str]: + """Return all columns for the dataset""" + raise NotImplementedError + + @property + @abc.abstractmethod + def dtypes(self) -> dict[str, np.dtype]: + """Return the dtype for each of the columns for the dataset""" + raise NotImplementedError + + def load_to_memory(self): + """ + Load data from zarr files to memeory + + Warning: Make sure the uncompressed dataset fits in-memory. + This method will load the **uncompressed** dataset into memory. Make + sure you actually have enough memory to store the dataset. + """ + data = self.zarr_data + + if not isinstance(data, zarr.Group): + raise TypeError( + "The dataset zarr_root is not a valid Zarr archive. " + "Did you call Dataset.load_to_memory() twice?" + ) + + # NOTE (cwognum): If the dataset fits in memory, the most performant is to use plain NumPy arrays. + # Even if we disable chunking and compression in Zarr. + # For more information, see https://github.com/zarr-developers/zarr-python/issues/1395 + self._zarr_data = load_zarr_group_to_memory(data) + + @abc.abstractmethod + def get_data( + self, row: str | int, col: str, adapters: dict[str, Adapter] | None = None + ) -> np.ndarray | Any: + """Since the dataset might contain pointers to external files, data retrieval is more complicated + than just indexing the `table` attribute. This method provides an end-point for seamlessly + accessing the underlying data. + + Args: + row: The row index in the `Dataset.table` attribute + col: The column index in the `Dataset.table` attribute + adapters: The adapters to apply to the data before returning it. + If None, will use the default adapters specified for the dataset. + + Returns: + A numpy array with the data at the specified indices. If the column is a pointer column, + the content of the referenced file is loaded to memory. + """ + raise NotImplementedError + + @abc.abstractmethod + def upload_to_hub(self, access: AccessType = "private", owner: Union[HubOwner, str, None] = None): + """Uploads the dataset to the Polaris Hub.""" + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def from_json(cls, path: str): + """ + Loads a dataset from a JSON file. + + Args: + path: The path to the JSON file to load the dataset from. + """ + raise NotImplementedError + + @abc.abstractmethod + def to_json( + self, + destination: str, + if_exists: ZarrConflictResolution = "replace", + load_zarr_from_new_location: bool = False, + ) -> str: + """ + Save the dataset to a destination directory as a JSON file. + + Args: + destination: The _directory_ to save the associated data to. + if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw + an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. + load_zarr_from_new_location: Whether to update the current instance to load data from the location + the data is saved to. Only relevant for Zarr-datasets. + + Returns: + The path to the JSON file. + """ + raise NotImplementedError + + def cache(self) -> str: + """Caches the dataset by downloading all additional data for pointer columns to a local directory. + + Returns: + The path to the cache directory. + """ + self.to_json(self.cache_dir, load_zarr_from_new_location=True) + return self.cache_dir + + def size(self) -> tuple[int, int]: + return self.n_rows, self.n_columns + + def __getitem__(self, item: DatasetIndex) -> Any | np.ndarray | dict[str, np.ndarray]: + """Allows for indexing the dataset directly""" + + # If a tuple, we assume it's the row and column index pair + if isinstance(item, tuple): + row, col = item + return self.get_data(row, col) + + # Otherwise, we assume you're indexing the row + return {col: self.get_data(item, col) for col in self.columns} + + @abc.abstractmethod + def _repr_dict_(self) -> dict: + """Utility function for pretty-printing to the command line and jupyter notebooks""" + raise NotImplementedError + + def _repr_html_(self): + """For pretty-printing in Jupyter Notebooks""" + return dict2html(self._repr_dict_()) + + def __len__(self): + return self.n_rows + + def __repr__(self): + return json.dumps(self._repr_dict_(), indent=2) + + def __str__(self): + return self.__repr__() + + def __del__(self): + """Close the connection of the client""" + if self._client is not None: + self._client.close() diff --git a/polaris/dataset/_column.py b/polaris/dataset/_column.py index 71d34e45..3ece7f96 100644 --- a/polaris/dataset/_column.py +++ b/polaris/dataset/_column.py @@ -1,5 +1,5 @@ import enum -from typing import Dict, Optional, Union +from typing import Dict, Literal, Optional, TypeAlias import numpy as np from numpy.typing import DTypeLike @@ -18,11 +18,7 @@ class Modality(enum.Enum): IMAGE = "image" -class KnownContentType(enum.Enum): - """Used to specify column's IANA content type in a dataset.""" - - SMILES = "chemical/x-smiles" - PDB = "chemical/x-pdb" +KnownContentType: TypeAlias = Literal["chemical/x-smiles", "chemical/x-pdb"] class ColumnAnnotation(BaseModel): @@ -42,29 +38,22 @@ class ColumnAnnotation(BaseModel): """ is_pointer: bool = False - modality: Union[str, Modality] = Modality.UNKNOWN + modality: Modality = Modality.UNKNOWN description: Optional[str] = None user_attributes: Dict[str, str] = Field(default_factory=dict) - dtype: Union[np.dtype, str, None] = None - content_type: Union[KnownContentType, str, None] = None + dtype: np.dtype | None = None + content_type: KnownContentType | str | None = None model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel, populate_by_name=True) - @field_validator("modality") - def _validate_modality(cls, v, values): + @field_validator("modality", mode="before") + def _validate_modality(cls, v): """Tries to convert a string to the Enum""" if isinstance(v, str): v = Modality[v.upper()] return v - @field_validator("content_type") - def _validate_content_type(cls, v, values): - """Tries to convert a string to the Enum""" - if isinstance(v, str): - v = KnownContentType[v.upper()] - return v - - @field_validator("dtype") + @field_validator("dtype", mode="before") def _validate_dtype(cls, v): """Tries to convert a string to the Enum""" if isinstance(v, str): @@ -76,13 +65,6 @@ def _serialize_modality(self, v: Modality): """Return the modality as a string, keeping it serializable""" return v.name - @field_serializer("content_type") - def _serialize_content_type(self, v: KnownContentType): - """Return the content_type as a string, keeping it serializable""" - if v is not None: - v = v.name - return v - @field_serializer("dtype") def _serialize_dtype(self, v: Optional[DTypeLike]): """Return the dtype as a string, keeping it serializable""" diff --git a/polaris/dataset/_competition_dataset.py b/polaris/dataset/_competition_dataset.py index 2f224c22..a7a9a532 100644 --- a/polaris/dataset/_competition_dataset.py +++ b/polaris/dataset/_competition_dataset.py @@ -1,11 +1,10 @@ from pydantic import model_validator -from polaris.dataset import Dataset -from polaris.utils.errors import InvalidCompetitionError -_CACHE_SUBDIR = "datasets" +from polaris.dataset._dataset import DatasetV1 +from polaris.utils.errors import InvalidCompetitionError -class CompetitionDataset(Dataset): +class CompetitionDataset(DatasetV1): """Dataset subclass for Polaris competitions. In addition to the data model and logic of the base Dataset class, @@ -14,8 +13,8 @@ class CompetitionDataset(Dataset): """ @model_validator(mode="after") - def _validate_model(cls, m: "CompetitionDataset"): + def _validate_model(self): """We reject the instantiation of competition datasets which leverage Zarr for the time being""" - if m.uses_zarr: + if self.uses_zarr: raise InvalidCompetitionError("Pointer columns are not currently supported in competitions.") diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 7887c84c..299edba5 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -2,7 +2,7 @@ import uuid from hashlib import md5 from pathlib import Path -from typing import Dict, List, MutableMapping, Optional, Tuple, Union +from typing import Any, ClassVar, List, Literal, Optional, Union import fsspec import numpy as np @@ -10,67 +10,41 @@ import zarr from datamol.utils import fs as dmfs from loguru import logger -from pydantic import ( - Field, - PrivateAttr, - computed_field, - field_serializer, - field_validator, - model_validator, -) +from pydantic import PrivateAttr, computed_field, field_validator, model_validator -from polaris._artifact import BaseArtifactModel from polaris.dataset._adapters import Adapter -from polaris.dataset._column import ColumnAnnotation -from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum, compute_zarr_checksum -from polaris.dataset.zarr._utils import load_zarr_group_to_memory -from polaris.hub.polarisfs import PolarisFileSystem -from polaris.mixins import ChecksumMixin +from polaris.dataset._base import _CACHE_SUBDIR, BaseDataset +from polaris.dataset.zarr import ZarrFileChecksum, compute_zarr_checksum +from polaris.mixins._checksum import ChecksumMixin from polaris.utils.constants import DEFAULT_CACHE_DIR -from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import ( AccessType, - HttpUrlString, HubOwner, - SupportedLicenseType, - ZarrConflictResolution, TimeoutTypes, + ZarrConflictResolution, ) # Constants _SUPPORTED_TABLE_EXTENSIONS = ["parquet"] -_CACHE_SUBDIR = "datasets" _INDEX_SEP = "#" -class Dataset(BaseArtifactModel, ChecksumMixin): - """Basic data-model for a Polaris dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. +class DatasetV1(BaseDataset, ChecksumMixin): + """First version of a Polaris Dataset. - At its core, a dataset in Polaris is a tabular data structure that stores data-points in a row-wise manner. - A Dataset can have multiple modalities or targets, can be sparse and can be part of one or multiple - [`BenchmarkSpecification`][polaris.benchmark.BenchmarkSpecification] objects. + Stores datapoints in a Pandas DataFrame and implements _pointer columns_ to support the storage of XXL data + outside of the DataFrame in a Zarr archive. Info: Pointer columns - Whereas a `Dataset` contains all information required to construct a dataset, it is not ready yet. For complex data, such as images, we support storing the content in external blobs of data. In that case, the table contains _pointers_ to these blobs that are dynamically loaded when needed. Attributes: table: The core data-structure, storing data-points in a row-wise manner. Can be specified as either a path to a `.parquet` file or a `pandas.DataFrame`. - default_adapters: The adapters that the Dataset recommends to use by default to change the format of the data - for specific columns. - zarr_root_path: The data for any pointer column should be saved in the Zarr archive this path points to. - readme: Markdown text that can be used to provide a formatted description of the dataset. - If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI - as it provides a rich text editor for writing markdown. - annotations: Each column _can be_ annotated with a [`ColumnAnnotation`][polaris.dataset.ColumnAnnotation] object. - Importantly, this is used to annotate whether a column is a pointer column. - source: The data source, e.g. a DOI, Github repo or URI. - license: The dataset license. Polaris only supports some Creative Commons licenses. See [`SupportedLicenseType`][polaris.utils.types.SupportedLicenseType] for accepted ID values. - curation_reference: A reference to the curation process, e.g. a DOI, Github repo or URI. - For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. + + For additional meta-data attributes, see the [`BaseDataset`][polaris.dataset._base.BaseDataset] class. Raises: InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. @@ -78,29 +52,12 @@ class Dataset(BaseArtifactModel, ChecksumMixin): # Public attributes # Data - table: Union[pd.DataFrame, str] - default_adapters: Dict[str, Adapter] = Field(default_factory=dict) - zarr_root_path: Optional[str] = None - - # Additional meta-data - readme: str = "" - annotations: Dict[str, ColumnAnnotation] = Field(default_factory=dict) - source: Optional[HttpUrlString] = None - license: Optional[SupportedLicenseType] = None - curation_reference: Optional[HttpUrlString] = None - - # Config - cache_dir: Optional[Path] = None # Where to cache the data to if cache() is called. - - # Private attributes - _zarr_root: Optional[zarr.Group] = PrivateAttr(None) - _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) - _md5sum: Optional[str] = PrivateAttr(None) + table: pd.DataFrame + + version: ClassVar[Literal[1]] = 1 _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) - _client = PrivateAttr(None) # Optional[PolarisHubClient] - _warn_about_remote_zarr: bool = PrivateAttr(True) - @field_validator("table") + @field_validator("table", mode="before") def _validate_table(cls, v): """ If the table is not a dataframe yet, assume it's a path and try load it. @@ -121,56 +78,26 @@ def _validate_table(cls, v): return v @model_validator(mode="after") - @classmethod - def _validate_model(cls, m: "Dataset"): + def _validate_v1_dataset_model(self): """Verifies some dependencies between properties""" - # Verify that all annotations are for columns that exist - if any(k not in m.table.columns for k in m.annotations): - raise InvalidDatasetError("There are annotations for columns that do not exist") - - # Verify that all adapters are for columns that exist - if any(k not in m.table.columns for k in m.default_adapters.keys()): - raise InvalidDatasetError("There are default adapters for columns that do not exist") - - has_pointers = any(anno.is_pointer for anno in m.annotations.values()) - if has_pointers and m.zarr_root_path is None: + has_pointers = any(anno.is_pointer for anno in self.annotations.values()) + if has_pointers and self.zarr_root_path is None: raise InvalidDatasetError("A zarr_root_path needs to be specified when there are pointer columns") - if not has_pointers and m.zarr_root_path is not None: + if not has_pointers and self.zarr_root_path is not None: raise InvalidDatasetError( "The zarr_root_path should only be specified when there are pointer columns" ) - # Set a default for missing annotations and convert strings to Modality - for c in m.table.columns: - if c not in m.annotations: - m.annotations[c] = ColumnAnnotation() - m.annotations[c].dtype = m.table[c].dtype - # Set the default cache dir if none and make sure it exists - if m.cache_dir is None: - dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) - m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + if self.cache_dir is None: + dataset_id = self._md5sum if self.has_md5sum else str(uuid.uuid4()) + self.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + self.cache_dir.mkdir(parents=True, exist_ok=True) - m.cache_dir.mkdir(parents=True, exist_ok=True) - return m + return self - @field_validator("default_adapters", mode="before") - def _validate_adapters(cls, value): - """Validate the adapters""" - return {k: Adapter[v] if isinstance(v, str) else v for k, v in value.items()} - - @field_serializer("default_adapters") - def _serialize_adapters(self, value: List[Adapter]): - """Serializes the adapters""" - return {k: v.name for k, v in value.items()} - - @field_serializer("cache_dir") - def _serialize_cache_dir(value): - """Serialize the cacha_dir""" - return str(value) - - def _compute_checksum(self): + def _compute_checksum(self) -> str: """Computes a hash of the dataset. This is meant to uniquely identify the dataset and can be used to verify the version. @@ -192,7 +119,7 @@ def _compute_checksum(self): # If the Zarr archive exists, we hash its contents too. if self.uses_zarr: zarr_hash, self._zarr_md5sum_manifest = compute_zarr_checksum(self.zarr_root_path) - hash_fn.update(zarr_hash.encode()) + hash_fn.update(zarr_hash.digest.encode()) checksum = hash_fn.hexdigest() return checksum @@ -211,125 +138,23 @@ def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: return self._zarr_md5sum_manifest @property - def client(self): - """The Polaris Hub client used to interact with the Polaris Hub.""" - - # Import it here to prevent circular imports - from polaris.hub.client import PolarisHubClient - - if self._client is None: - self._client = PolarisHubClient() - return self._client - - @property - def uses_zarr(self) -> bool: - """Whether any of the data in this dataset is stored in a Zarr Archive.""" - return self.zarr_root_path is not None - - @property - def zarr_data(self): - """Get the Zarr data. - - This is different from the Zarr Root, because to optimize the efficiency of - data loading, a user can choose to load the data into memory as a numpy array - - Note: General purpose dataloader. - The goal with Polaris is to provide general purpose datasets that serve as good - options for a _wide variety of use cases_. This also implies you should be able to - optimize things further for a specific use case if needed. - """ - if self._zarr_data is not None: - return self._zarr_data - return self.zarr_root - - @property - def zarr_root(self): - """Get the zarr Group object corresponding to the root. - - Opens the zarr archive in read-write mode if it is not already open. - - Note: Different to `zarr_data` - The `zarr_data` attribute references either to the Zarr archive or to a in-memory copy of the data. - See also [`Dataset.load_to_memory`][polaris.dataset.Dataset.load_to_memory]. - """ - - if self._zarr_root is not None: - return self._zarr_root - - if self.zarr_root_path is None: - return None - - # We open the archive in read-only mode if it is saved on the Hub - saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) - - if self._warn_about_remote_zarr: - saved_remote = saved_on_hub or not Path(self.zarr_root_path).exists() - - if saved_remote: - logger.warning( - f"You're loading data from a remote location. " - f"To speed up this process, consider caching the dataset first " - f"using {self.__class__.__name__}.cache()" - ) - self._warn_about_remote_zarr = False - - try: - if saved_on_hub: - self._zarr_root = self.client.open_zarr_file(self.owner, self.name, self.zarr_root_path, "r+") - else: - # We use memory mapping by default because our experiments show that it's consistently faster - store = MemoryMappedDirectoryStore(self.zarr_root_path) - self._zarr_root = zarr.open_consolidated(store, mode="r+") - except KeyError as error: - raise InvalidDatasetError( - "A Zarr archive associated with a Polaris dataset has to be consolidated." - ) from error - return self._zarr_root - - @computed_field - @property - def n_rows(self) -> int: - """The number of rows in the dataset.""" - return len(self.rows) - - @computed_field - @property - def n_columns(self) -> int: - """The number of columns in the dataset.""" - return len(self.columns) - - @property - def rows(self) -> list: + def rows(self) -> list[str | int]: """Return all row indices for the dataset""" return self.table.index.tolist() @property - def columns(self) -> list: + def columns(self) -> list[str]: """Return all columns for the dataset""" return self.table.columns.tolist() - def load_to_memory(self): - """ - Load data from zarr files to memeory - - Warning: Make sure the uncompressed dataset fits in-memory. - This method will load the **uncompressed** dataset into memory. Make - sure you actually have enough memory to store the dataset. - """ - data = self.zarr_data - - if not isinstance(data, zarr.Group): - raise TypeError( - "The dataset zarr_root is not a valid Zarr archive. " - "Did you call Dataset.load_to_memory() twice?" - ) - - # NOTE (cwognum): If the dataset fits in memory, the most performant is to use plain NumPy arrays. - # Even if we disable chunking and compression in Zarr. - # For more information, see https://github.com/zarr-developers/zarr-python/issues/1395 - self._zarr_data = load_zarr_group_to_memory(data) + @property + def dtypes(self) -> dict[str, np.dtype]: + """Return the dtype for each of the columns for the dataset""" + return {col: self.table[col].dtype for col in self.columns} - def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = None) -> np.ndarray: + def get_data( + self, row: str | int, col: str, adapters: dict[str, Adapter] | None = None + ) -> np.ndarray | Any: """Since the dataset might contain pointers to external files, data retrieval is more complicated than just indexing the `table` attribute. This method provides an end-point for seamlessly accessing the underlying data. @@ -346,7 +171,8 @@ def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = """ # Fetch adapters for dataset and a given column - adapters = adapters or self.default_adapters + # Partially override if the adapters parameter is specified. + adapters = {**self.default_adapters, **(adapters or {})} adapter = adapters.get(col) # If not a pointer, return it here. Apply adapter if specified. @@ -384,12 +210,11 @@ def upload_to_hub( @classmethod def from_json(cls, path: str): - """Loads a benchmark from a JSON file. - Overrides the method from the base class to remove the caching dir from the file to load from, - as that should be user dependent. + """ + Loads a dataset from a JSON file. Args: - path: Loads a benchmark specification from a JSON file. + path: The path to the JSON file to load the dataset from. """ with fsspec.open(path, "r") as f: data = json.load(f) @@ -400,6 +225,7 @@ def to_json( self, destination: str, if_exists: ZarrConflictResolution = "replace", + load_zarr_from_new_location: bool = False, ) -> str: """ Save the dataset to a destination directory as a JSON file. @@ -407,8 +233,7 @@ def to_json( Warning: Multiple files Perhaps unintuitive, this method creates multiple files. - 1. `/path/to/destination/dataset.json`: This file can be loaded with - [`Dataset.from_json`][polaris.dataset.Dataset.from_json]. + 1. `/path/to/destination/dataset.json`: This file can be loaded with `Dataset.from_json`. 2. `/path/to/destination/table.parquet`: The `Dataset.table` attribute is saved here. 3. _(Optional)_ `/path/to/destination/data/*`: Any additional blobs of data referenced by the pointer columns will be stored here. @@ -417,14 +242,18 @@ def to_json( destination: The _directory_ to save the associated data to. if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. + load_zarr_from_new_location: Whether to update the current instance to load data from the location + the data is saved to. Only relevant for Zarr-datasets. Returns: The path to the JSON file. """ - dmfs.mkdir(destination, exist_ok=True) - table_path = dmfs.join(destination, "table.parquet") - dataset_path = dmfs.join(destination, "dataset.json") - new_zarr_root_path = dmfs.join(destination, "data.zarr") + destination = Path(destination) + destination.mkdir(exist_ok=True, parents=True) + + table_path = str(destination / "table.parquet") + dataset_path = str(destination / "dataset.json") + new_zarr_root_path = str(destination / "data.zarr") # Lu: Avoid serilizing and sending None to hub app. serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) @@ -432,6 +261,8 @@ def to_json( # Copy over Zarr data to the destination if self.uses_zarr: + serialized["zarrRootPath"] = new_zarr_root_path + self._warn_about_remote_zarr = False logger.info(f"Copying Zarr archive to {new_zarr_root_path}. This may take a while.") @@ -445,42 +276,29 @@ def to_json( if_exists=if_exists, ) + if load_zarr_from_new_location: + self.zarr_root_path = new_zarr_root_path + self._zarr_root = None + self._zarr_data = None + self.table.to_parquet(table_path) with fsspec.open(dataset_path, "w") as f: json.dump(serialized, f) return dataset_path - def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) -> str: - """Caches the dataset by downloading all additional data for pointer columns to a local directory. + def cache(self, verify_checksum: bool = False): + """Cache the dataset to the cache directory. Args: - cache_dir: The directory to cache the data to. If not provided, - this will fall back to the `Dataset.cache_dir` attribute verify_checksum: Whether to verify the checksum of the dataset after caching. - - Returns: - The path to the cache directory. """ - - if cache_dir is not None: - self.cache_dir = cache_dir - - self.to_json(self.cache_dir) - - if self.uses_zarr: - self.zarr_root_path = dmfs.join(self.cache_dir, "data.zarr") - self._zarr_root = None - + dst = super().cache() if verify_checksum: self.verify_checksum() + return dst - return self.cache_dir - - def size(self): - return self.rows, self.n_columns - - def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: + def _split_index_from_path(self, path: str) -> tuple[str, int | None]: """ Paths can have an additional index appended to them. This extracts that index from the path. @@ -498,58 +316,13 @@ def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: raise ValueError(f"Invalid index format: {index}") return path, index - def __getitem__(self, item): - """Allows for indexing the dataset directly""" - ret = self.table.loc[item] - if isinstance(ret, pd.Series): - # Load the data from the pointer columns - - if ret.name in self.table.columns: - # Returning a column, the indices are rows - if self.annotations[ret.name].is_pointer: - ret = np.array([self.get_data(k, ret.name) for k in ret.index]) - - elif len(ret) == self.n_rows: - # Returning a row, the indices are columns - ret = { - k: self.get_data(k, ret.name) if self.annotations[ret.name].is_pointer else ret[k] - for k in ret.index - } - - # Returning a dataframe - if isinstance(ret, pd.DataFrame): - for c in ret.columns: - if self.annotations[c].is_pointer: - ret[c] = [self.get_data(item, c) for item in ret.index] - return ret - - return ret - def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump(exclude={"table", "zarr_md5sum_manifest"}) return repr_dict - def _repr_html_(self): - """For pretty-printing in Jupyter Notebooks""" - return dict2html(self._repr_dict_()) - - def __len__(self): - return self.n_rows - - def __repr__(self): - return json.dumps(self._repr_dict_(), indent=2) - - def __str__(self): - return self.__repr__() - def __eq__(self, other): """Whether two datasets are equal is solely determined by the checksum""" - if not isinstance(other, Dataset): + if not isinstance(other, DatasetV1): return False return self.md5sum == other.md5sum - - def __del__(self): - """Close the connection of the client""" - if self._client is not None: - self._client.close() diff --git a/polaris/dataset/_factory.py b/polaris/dataset/_factory.py index c69cdba8..dfd550b0 100644 --- a/polaris/dataset/_factory.py +++ b/polaris/dataset/_factory.py @@ -6,12 +6,12 @@ import zarr from loguru import logger -from polaris.dataset import ColumnAnnotation, Dataset +from polaris.dataset import ColumnAnnotation, DatasetV1 from polaris.dataset._adapters import Adapter from polaris.dataset.converters import Converter, PDBConverter, SDFConverter, ZarrConverter -def create_dataset_from_file(path: str, zarr_root_path: Optional[str] = None) -> Dataset: +def create_dataset_from_file(path: str, zarr_root_path: Optional[str] = None) -> DatasetV1: """ This function is a convenience function to create a dataset from a file. @@ -29,7 +29,7 @@ def create_dataset_from_file(path: str, zarr_root_path: Optional[str] = None) -> def create_dataset_from_files( paths: List[str], zarr_root_path: Optional[str] = None, axis: Literal[0, 1, "index", "columns"] = 0 -) -> Dataset: +) -> DatasetV1: """ This function is a convenience function to create a dataset from multiple files. @@ -265,10 +265,10 @@ def add_from_files(self, paths: List[str], axis: Literal[0, 1, "index", "columns for path in paths: self.add_from_file(path) - def build(self) -> Dataset: + def build(self) -> DatasetV1: """Returns a Dataset based on the current state of the factory.""" zarr.consolidate_metadata(self.zarr_root.store) - return Dataset( + return DatasetV1( table=self._table, annotations=self._annotations, default_adapters=self._adapters, diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 448abbff..dfb95869 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -2,7 +2,7 @@ import numpy as np -from polaris.dataset import Dataset +from polaris.dataset import DatasetV1 from polaris.dataset._adapters import Adapter from polaris.utils.errors import TestAccessError from polaris.utils.types import DatapointType @@ -61,12 +61,12 @@ class Subset: def __init__( self, - dataset: Dataset, - indices: List[Union[int, Sequence[int]]], - input_cols: Union[List[str], str], - target_cols: Union[List[str], str], - adapters: Optional[List[Adapter]] = None, - featurization_fn: Optional[Callable] = None, + dataset: DatasetV1, + indices: List[int | Sequence[int]], + input_cols: List[str] | str, + target_cols: List[str] | str, + adapters: dict[str, Adapter] | None = None, + featurization_fn: Callable | None = None, hide_targets: bool = False, ): self.dataset = dataset @@ -77,10 +77,13 @@ def __init__( self._adapters = adapters self._featurization_fn = featurization_fn - # NOTE (cwognum): Note to future self. As we're starting to think about competition-style benchmarks, - # we will likely split up datasets. In that case, this default iloc_to_loc mapping won't work. - # By that time, we should probably be able to overwrite this mapping. - self._iloc_to_loc = self.dataset.table.index + # Storing all indices in memory can be memory consuming for XXL datasets. + # This is why we constrain the iloc to loc mapping to be the identity function for Dataset V2. + match self.dataset: + case DatasetV1(): + self._iloc_to_loc = lambda idx: self.dataset.rows[idx] + case _: + self._iloc_to_loc = lambda idx: idx # For the iterator implementation self._pointer = 0 @@ -167,9 +170,9 @@ def as_array(self, data_type: Union[Literal["x"], Literal["y"], Literal["xy"]]): # We reset the index of the Pandas Table during Dataset class validation. # We can thus always assume that .iloc[idx] is the same as .loc[idx]. if data_type == "x": - ret = [self._get_single_input(self._iloc_to_loc[idx]) for idx in self.indices] + ret = [self._get_single_input(self._iloc_to_loc(idx)) for idx in self.indices] else: - ret = [self._get_single_output(self._iloc_to_loc[idx]) for idx in self.indices] + ret = [self._get_single_output(self._iloc_to_loc(idx)) for idx in self.indices] if not ((self.is_multi_input and data_type == "x") or (self.is_multi_task and data_type == "y")): # If the target format is not a dict, we can just create the array directly. @@ -202,9 +205,10 @@ def __getitem__(self, item) -> DatapointType: """ idx = self.indices[item] + idx = self._iloc_to_loc(idx) # Load the input modalities - ins = self._get_single_input(self._iloc_to_loc[idx]) + ins = self._get_single_input(idx) if self._hide_targets: # If we are not allowed to access the targets, we return the inputs only. @@ -212,7 +216,7 @@ def __getitem__(self, item) -> DatapointType: return ins # Retrieve the targets - outs = self._get_single_output(self._iloc_to_loc[idx]) + outs = self._get_single_output(idx) return ins, outs def __iter__(self): diff --git a/polaris/dataset/converters/_pdb.py b/polaris/dataset/converters/_pdb.py index ac0a0e15..5694bd52 100644 --- a/polaris/dataset/converters/_pdb.py +++ b/polaris/dataset/converters/_pdb.py @@ -7,7 +7,7 @@ import zarr from fastpdb import struc -from polaris.dataset import ColumnAnnotation, Modality, KnownContentType +from polaris.dataset import ColumnAnnotation, Modality from polaris.dataset._adapters import Adapter from polaris.dataset.converters._base import Converter, FactoryProduct @@ -190,7 +190,7 @@ def convert(self, path, factory: "DatasetFactory", append: bool = False) -> Fact # Set the annotations annotations = { self.pdb_column: ColumnAnnotation( - is_pointer=True, modality=Modality.PROTEIN_3D, content_type=KnownContentType.PDB + is_pointer=True, modality=Modality.PROTEIN_3D, content_type="chemical/x-pdb" ) } diff --git a/polaris/dataset/converters/_sdf.py b/polaris/dataset/converters/_sdf.py index 5a993fb7..2cde7acb 100644 --- a/polaris/dataset/converters/_sdf.py +++ b/polaris/dataset/converters/_sdf.py @@ -5,7 +5,7 @@ import pandas as pd from rdkit import Chem -from polaris.dataset import ColumnAnnotation, Modality, KnownContentType +from polaris.dataset import ColumnAnnotation, Modality from polaris.dataset._adapters import Adapter from polaris.dataset.converters._base import Converter, FactoryProduct @@ -149,7 +149,7 @@ def _get_name(mol: dm.Mol): annotations = {self.mol_column: ColumnAnnotation(is_pointer=True, modality=Modality.MOLECULE_3D)} if self.smiles_column is not None: annotations[self.smiles_column] = ColumnAnnotation( - modality=Modality.MOLECULE, content_type=KnownContentType.SMILES + modality=Modality.MOLECULE, content_type="chemical/x-smiles" ) # Return the dataframe and the annotations diff --git a/polaris/dataset/zarr/__init__.py b/polaris/dataset/zarr/__init__.py index 57f500ed..b936b607 100644 --- a/polaris/dataset/zarr/__init__.py +++ b/polaris/dataset/zarr/__init__.py @@ -1,4 +1,10 @@ from ._checksum import ZarrFileChecksum, compute_zarr_checksum +from ._manifest import generate_zarr_manifest from ._memmap import MemoryMappedDirectoryStore -__all__ = ["MemoryMappedDirectoryStore", "compute_zarr_checksum", "ZarrFileChecksum"] +__all__ = [ + "MemoryMappedDirectoryStore", + "compute_zarr_checksum", + "ZarrFileChecksum", + "generate_zarr_manifest", +] diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index a06f9491..8a1a8348 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -52,7 +52,7 @@ ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)" -def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, List["ZarrFileChecksum"]]: +def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest", List["ZarrFileChecksum"]]: r""" Implements an algorithm to compute the Zarr checksum. @@ -145,7 +145,7 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, List["ZarrFileCheck zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size)) # Compute digest - return tree.process().digest, zarr_md5sum_manifest + return tree.process(), zarr_md5sum_manifest class ZarrFileChecksum(BaseModel): diff --git a/polaris/dataset/zarr/_manifest.py b/polaris/dataset/zarr/_manifest.py new file mode 100644 index 00000000..b7578ce4 --- /dev/null +++ b/polaris/dataset/zarr/_manifest.py @@ -0,0 +1,72 @@ +import os +from hashlib import md5 + +import pyarrow as pa +import pyarrow.parquet as pq + +# PyArrow table schema for the V2 Zarr manifest file +ZARR_MANIFEST_SCHEMA = pa.schema([("path", pa.string()), ("checksum", pa.string())]) + + +def generate_zarr_manifest(zarr_root_path: str, output_dir: str): + """ + Entry point function which triggers the creation of a Zarr manifest for a V2 dataset. + + Parameters: + zarr_root_path: The path to the root of a Zarr archive + output_dir: The path to the directory which will hold the generated manifest + """ + + zarr_manifest_path = f"{output_dir}/zarr_manifest.parquet" + + with pq.ParquetWriter(zarr_manifest_path, ZARR_MANIFEST_SCHEMA) as writer: + recursively_build_manifest(zarr_root_path, writer, zarr_root_path) + + return zarr_manifest_path + + +def recursively_build_manifest(dir_path: str, writer: pq.ParquetWriter, zarr_root_path: str) -> str: + """ + Recursive function that traverses a Zarr archive to build a V2 manifest file. + + Parameters: + dir_path: The path to the current directory being processed in the archive + writer: Writer object for incrementally adding rows to the manifest Parquet file + zarr_root_path: The root path which triggered the first recursive call + """ + + # Get iterator of items located in the directory at `dir_path` + with os.scandir(dir_path) as it: + # + # Loop through directory items in iterator + for entry in it: + if entry.is_dir(): + # + # If item is a directory, recurse into that directory + recursively_build_manifest(entry.path, writer, zarr_root_path) + elif entry.is_file(): + # + # If item is a file, calculate its relative path and chunk checksum. Then, append that + # to the Zarr manifest parquet. + table = pa.Table.from_pydict( + { + "path": [os.path.relpath(entry.path, zarr_root_path)], + "checksum": [calculate_file_md5(entry.path)], + }, + schema=ZARR_MANIFEST_SCHEMA, + ) + writer.write_table(table) + + +def calculate_file_md5(file_path: str): + """Calculates the md5 hash for a file at a given path""" + + md5_hash = md5() + with open(file_path, "rb") as file: + # + # Read the file in chunks to avoid using too much memory for large files + for chunk in iter(lambda: file.read(4096), b""): + md5_hash.update(chunk) + + # Return the hex representation of the digest + return md5_hash.hexdigest() diff --git a/polaris/experimental/__init__.py b/polaris/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py new file mode 100644 index 00000000..dc34aff5 --- /dev/null +++ b/polaris/experimental/_dataset_v2.py @@ -0,0 +1,261 @@ +import json +import re +import uuid +from pathlib import Path +from typing import Any, ClassVar, Literal + +import fsspec +import numpy as np +import zarr +from loguru import logger +from pydantic import PrivateAttr, computed_field, model_validator + +from polaris.dataset._adapters import Adapter +from polaris.dataset._base import _CACHE_SUBDIR, BaseDataset +from polaris.dataset.zarr._manifest import calculate_file_md5, generate_zarr_manifest +from polaris.utils.constants import DEFAULT_CACHE_DIR +from polaris.utils.errors import InvalidDatasetError +from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution + +_INDEX_ARRAY_KEY = "__index__" + + +class DatasetV2(BaseDataset): + """Second version of a Polaris Dataset. + + This version gets rid of the DataFrame and stores all data in a Zarr archive. + + V1 stored all datapoints in a Pandas DataFrame. Because a DataFrame is always loaded to memory, + this was a bottleneck when the number of data points grew large. Even with the pointer columns, you still + need to load all pointers into memory. V2 therefore switches to a Zarr-only format. + + Info: This feature is still experimental + The DatasetV2 is in active development and will likely undergo breaking changes before release. + + Attributes: + zarr_root_path: The path to the Zarr archive. Different from V1, this is now required. + + For additional meta-data attributes, see the [`BaseDataset`][polaris._dataset.BaseDataset] class. + + Raises: + InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. + """ + + version: ClassVar[Literal[2]] = 2 + _zarr_manifest_path: str | None = PrivateAttr(None) + _zarr_manifest_md5sum: str | None = PrivateAttr(None) + + # Redefine this to make it a required field + zarr_root_path: str + + @model_validator(mode="after") + def _validate_v2_dataset_model(self): + """Verifies some dependencies between properties""" + + # Since the keys for subgroups are not ordered, we have no easy way to index these groups. + # Any subgroup should therefore have a special array that defines the index for that group. + for group in self.zarr_root.group_keys(): + if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys(): + raise InvalidDatasetError(f"Group {group} does not have an index array.") + + index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY] + if len(index_arr) != len(self.zarr_root[group]) - 1: + raise InvalidDatasetError( + f"Length of index array for group {group} does not match the size of the group." + ) + if any(x not in self.zarr_root[group] for x in index_arr): + raise InvalidDatasetError( + f"Keys of index array for group {group} does not match the group members." + ) + + # Check the structure of the Zarr archive + # All arrays or groups in the root should have the same length. + lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()} + lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()}) + if len(lengths) > 1: + raise InvalidDatasetError( + f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" + ) + + # Set the default cache dir if none and make sure it exists + if self.cache_dir is None: + dataset_id = self._zarr_manifest_md5sum if self.has_zarr_manifest_md5sum else str(uuid.uuid4()) + self.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + self.cache_dir.mkdir(parents=True, exist_ok=True) + + return self + + @property + def n_rows(self) -> int: + """Return all row indices for the dataset""" + example = self.zarr_root[self.columns[0]] + if isinstance(example, zarr.Group): + return len(example[_INDEX_ARRAY_KEY]) + return len(example) + + @property + def rows(self) -> np.ndarray[int]: + """Return all row indices for the dataset + + Warning: Memory consumption + This feature is added for completeness sake, but when datasets get large could consume a lot of memory. + E.g. storing a billion indices with np.in64 would consume 8GB of memory. Use with caution. + """ + return np.arange(len(self), dtype=int) + + @property + def columns(self) -> list[str]: + """Return all columns for the dataset""" + return list(self.zarr_root.keys()) + + @property + def dtypes(self) -> dict[str, np.dtype]: + """Return the dtype for each of the columns for the dataset""" + dtypes = {} + for arr in self.zarr_root.array_keys(): + dtypes[arr] = self.zarr_root[arr].dtype + for group in self.zarr_root.group_keys(): + dtypes[group] = np.dtype(object) + return dtypes + + @property + def zarr_manifest_path(self) -> str: + if self._zarr_manifest_path is None: + zarr_manifest_path = generate_zarr_manifest(self.zarr_root_path, self.cache_dir) + self._zarr_manifest_path = zarr_manifest_path + + return self._zarr_manifest_path + + @computed_field + @property + def zarr_manifest_md5sum(self) -> str: + """ + Lazily compute the checksum once needed. + + The checksum of the DatasetV2 is computed from the Zarr Manifest and is _not_ deterministic. + """ + if not self.has_zarr_manifest_md5sum: + logger.info("Computing the checksum. This can be slow for large datasets.") + self.zarr_manifest_md5sum = calculate_file_md5(self.zarr_manifest_path) + return self._zarr_manifest_md5sum + + @zarr_manifest_md5sum.setter + def zarr_manifest_md5sum(self, value: str): + """Set the checksum.""" + if not re.fullmatch(r"^[a-f0-9]{32}$", value): + raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") + self._zarr_manifest_md5sum = value + + @property + def has_zarr_manifest_md5sum(self) -> bool: + """Whether the md5sum for this dataset's zarr manifest has been computed and stored.""" + return self._zarr_manifest_md5sum is not None + + def get_data(self, row: int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray | Any: + """Indexes the Zarr archive. + + Args: + row: The index of the data to fetch. + col: The label of a direct child of the Zarr root. + adapters: The adapters to apply to the data before returning it. + If None, will use the default adapters specified for the dataset. + + Returns: + A numpy array with the data at the specified indices. If the column is a pointer column, + the content of the referenced file is loaded to memory. + """ + # Fetch adapters for dataset and a given column + # Partially override if the adapters parameter is specified. + adapters = {**self.default_adapters, **(adapters or {})} + adapter = adapters.get(col) + + # Get the data + group_or_array = self.zarr_data[col] + + # If it is a group, there is no deterministic order for the child keys. + # We therefore use a special array that defines the index. + if isinstance(group_or_array, zarr.Group): + row = group_or_array[_INDEX_ARRAY_KEY][row] + arr = group_or_array[row] + + # Adapt the input to the specified format + if adapter is not None: + arr = adapter(arr) + + return arr + + def upload_to_hub(self, access: AccessType = "private", owner: HubOwner | str | None = None): + """Uploads the dataset to the Polaris Hub.""" + + # NOTE (cwognum): Leaving this for a later PR, because I want + # to do it simultaneously with a PR on the Hub side. + raise NotImplementedError + + @classmethod + def from_json(cls, path: str): + """ + Loads a dataset from a JSON file. + + Args: + path: The path to the JSON file to load the dataset from. + """ + with fsspec.open(path, "r") as f: + data = json.load(f) + data.pop("cache_dir", None) + return cls.model_validate(data) + + def to_json( + self, + destination: str, + if_exists: ZarrConflictResolution = "replace", + load_zarr_from_new_location: bool = False, + ) -> str: + """ + Save the dataset to a destination directory as a JSON file. + + Args: + destination: The _directory_ to save the associated data to. + if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw + an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. + load_zarr_from_new_location: Whether to update the current instance to load data from the location + the data is saved to. Only relevant for Zarr-datasets. + + Returns: + The path to the JSON file. + """ + destination = Path(destination) + destination.mkdir(exist_ok=True, parents=True) + + dataset_path = str(destination / "dataset.json") + new_zarr_root_path = str(destination / "data.zarr") + + # Lu: Avoid serilizing and sending None to hub app. + serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) + serialized["zarrRootPath"] = new_zarr_root_path + + # Copy over Zarr data to the destination + self._warn_about_remote_zarr = False + + logger.info(f"Copying Zarr archive to {new_zarr_root_path}. This may take a while.") + dest = zarr.open(new_zarr_root_path, "w") + + zarr.copy_store( + source=self.zarr_root.store.store, + dest=dest.store, + log=logger.debug, + if_exists=if_exists, + ) + + if load_zarr_from_new_location: + self.zarr_root_path = new_zarr_root_path + self._zarr_root = None + self._zarr_data = None + + with fsspec.open(dataset_path, "w") as f: + json.dump(serialized, f) + return dataset_path + + def _repr_dict_(self) -> dict: + """Utility function for pretty-printing to the command line and jupyter notebooks""" + repr_dict = self.model_dump(exclude={"zarr_md5sum_manifest"}) + return repr_dict diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 7d2839c5..f2f9e796 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -2,8 +2,7 @@ import ssl from hashlib import md5 from io import BytesIO -from typing import Callable, get_args -from typing import Union +from typing import Callable, Union, get_args from urllib.parse import urljoin import certifi @@ -23,14 +22,12 @@ MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification, ) -from polaris.dataset import Dataset -from polaris.evaluate import BenchmarkResults +from polaris.competition import CompetitionSpecification +from polaris.dataset import CompetitionDataset, DatasetV1 +from polaris.evaluate import BenchmarkResults, CompetitionResults from polaris.evaluate._results import CompetitionPredictions from polaris.hub.external_auth_client import ExternalAuthClient from polaris.hub.oauth import CachedTokenAuth -from polaris.dataset import CompetitionDataset -from polaris.evaluate import CompetitionResults -from polaris.competition import CompetitionSpecification from polaris.hub.polarisfs import PolarisFileSystem from polaris.hub.settings import PolarisHubSettings from polaris.utils.context import ProgressIndicator, tmp_attribute_change @@ -296,7 +293,7 @@ def get_dataset( owner: str | HubOwner, name: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr", - ) -> Dataset: + ) -> DatasetV1: """Load a standard dataset from the Polaris Hub. Args: @@ -316,7 +313,7 @@ def _get_dataset( name: str, artifact_type: ArtifactSubtype, verify_checksum: bool = True, - ) -> Dataset: + ) -> DatasetV1: """Loads either a standard or competition dataset from Polaris Hub Args: @@ -360,7 +357,7 @@ def _get_dataset( dataset = CompetitionDataset(**response) md5Sum = response["maskedMd5Sum"] else: - dataset = Dataset(**response) + dataset = DatasetV1(**response) md5Sum = response["md5Sum"] if should_verify_checksum(verify_checksum, dataset): @@ -536,7 +533,7 @@ def upload_results( def upload_dataset( self, - dataset: Dataset, + dataset: DatasetV1, access: AccessType = "private", timeout: TimeoutTypes = (10, 200), owner: HubOwner | str | None = None, @@ -549,7 +546,7 @@ def upload_dataset( def _upload_dataset( self, - dataset: Dataset, + dataset: DatasetV1, artifact_type: ArtifactSubtype, access: AccessType = "private", timeout: TimeoutTypes = (10, 200), diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 6e152f68..797f7b78 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -7,13 +7,13 @@ MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification, ) -from polaris.dataset import Dataset, create_dataset_from_file +from polaris.dataset import DatasetV1, create_dataset_from_file from polaris.hub.client import PolarisHubClient from polaris.utils.misc import should_verify_checksum from polaris.utils.types import ChecksumStrategy -def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr") -> Dataset: +def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr") -> DatasetV1: """ Loads a Polaris dataset. @@ -41,7 +41,7 @@ def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_z # Load from local file if extension == "json": - dataset = Dataset.from_json(path) + dataset = DatasetV1.from_json(path) else: dataset = create_dataset_from_file(path) diff --git a/polaris/utils/errors.py b/polaris/utils/errors.py index 3e800847..ad7726bf 100644 --- a/polaris/utils/errors.py +++ b/polaris/utils/errors.py @@ -1,7 +1,5 @@ from httpx import Response -from polaris.mixins._format_text import FormattingMixin # Imported with full path to avoid circular import - class InvalidDatasetError(ValueError): pass @@ -34,7 +32,7 @@ class InvalidZarrChecksum(Exception): pass -class PolarisHubError(Exception, FormattingMixin): +class PolarisHubError(Exception): def __init__(self, message: str = "", response: Response | None = None): prefix = "The request to the Polaris Hub failed." @@ -50,7 +48,6 @@ def __init__(self, response: Response | None = None): "You are not logged in to Polaris or your login has expired. " "You can use the Polaris CLI to easily authenticate yourself again with `polaris login --overwrite`." ) - message = self.format(message, [self.BOLD, self.YELLOW]) super().__init__(message, response) @@ -60,7 +57,6 @@ def __init__(self, response: Response | None = None): "Note: If you can confirm that you are authorized to perform this action, " "please call 'polaris login --overwrite' and try again. If the issue persists, please reach out to the Polaris team for support." ) - message = self.format(message, [self.BOLD, self.YELLOW]) super().__init__(message, response) @@ -70,5 +66,4 @@ def __init__(self, response: Response | None = None): "Note: If this artifact exists and you can confirm that you are authorized to retrieve it, " "please call 'polaris login --overwrite' and try again. If the issue persists, please reach out to the Polaris team for support." ) - message = self.format(message, [self.BOLD, self.YELLOW]) super().__init__(message, response) diff --git a/polaris/utils/misc.py b/polaris/utils/misc.py index 9a8199eb..2622fde4 100644 --- a/polaris/utils/misc.py +++ b/polaris/utils/misc.py @@ -3,7 +3,7 @@ from polaris.utils.types import ChecksumStrategy, SlugCompatibleStringType if TYPE_CHECKING: - from polaris.dataset import Dataset + from polaris.dataset import DatasetV1 def listit(t: Any): @@ -21,7 +21,7 @@ def sluggify(sluggable: SlugCompatibleStringType): return sluggable.lower().replace("_", "-") -def should_verify_checksum(strategy: ChecksumStrategy, dataset: "Dataset") -> bool: +def should_verify_checksum(strategy: ChecksumStrategy, dataset: "DatasetV1") -> bool: """ Determines whether a checksum should be verified. """ diff --git a/polaris/utils/types.py b/polaris/utils/types.py index e1cac444..b34b6e55 100644 --- a/polaris/utils/types.py +++ b/polaris/utils/types.py @@ -118,6 +118,19 @@ Type to specify which action to take to verify the data integrity of an artifact through a checksum. """ +RowIndex: TypeAlias = int | str +ColumnIndex: TypeAlias = str +DatasetIndex: TypeAlias = RowIndex | tuple[RowIndex, ColumnIndex] +""" +To index a dataset using square brackets, we have a few options: + +- A single row, e.g. dataset[0] +- Specify a specific value, e.g. dataset[0, "col1"] + +There are more exciting options we could implement, such as slicing, +but this gets complex. +""" + class HubOwner(BaseModel): """An owner of an artifact on the Polaris Hub diff --git a/tests/conftest.py b/tests/conftest.py index d9190f58..a0e19c23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,8 @@ SingleTaskBenchmarkSpecification, ) from polaris.competition import CompetitionSpecification -from polaris.dataset import ColumnAnnotation, Dataset, CompetitionDataset +from polaris.dataset import ColumnAnnotation, CompetitionDataset, DatasetV1 +from polaris.experimental._dataset_v2 import DatasetV2 from polaris.utils.types import HubOwner @@ -108,8 +109,8 @@ def test_user_owner(): @pytest.fixture(scope="function") -def test_dataset(test_data, test_org_owner): - dataset = Dataset( +def test_dataset(test_data, test_org_owner) -> DatasetV1: + dataset = DatasetV1( table=test_data, name="test-dataset", source="https://www.example.com", @@ -124,6 +125,23 @@ def test_dataset(test_data, test_org_owner): return dataset +@pytest.fixture(scope="function") +def test_dataset_v2(zarr_archive, test_org_owner) -> DatasetV2: + dataset = DatasetV2( + name="test-dataset-v2", + source="https://www.example.com", + annotations={"A": ColumnAnnotation(user_attributes={"unit": "kcal/mol"})}, + tags=["tagA", "tagB"], + user_attributes={"attributeA": "valueA", "attributeB": "valueB"}, + owner=test_org_owner, + license="CC-BY-4.0", + curation_reference="https://www.example.com", + zarr_root_path=zarr_archive, + ) + check_version(dataset) + return dataset + + @pytest.fixture(scope="function") def test_competition_dataset(test_data, test_org_owner): dataset = CompetitionDataset( diff --git a/tests/test_dataset.py b/tests/test_dataset.py index db4336e1..5319bb7e 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -6,7 +6,7 @@ import zarr from datamol.utils import fs -from polaris.dataset import Dataset, Subset, create_dataset_from_file +from polaris.dataset import DatasetV1, Subset, create_dataset_from_file from polaris.loader import load_dataset @@ -27,10 +27,11 @@ def test_load_data(tmp_path, with_slice, with_caching): path = "A#0:5" if with_slice else "A#0" table = pd.DataFrame({"A": [path]}, index=[0]) - dataset = Dataset(table=table, annotations={"A": {"is_pointer": True}}, zarr_root_path=zarr_path) + dataset = DatasetV1(table=table, annotations={"A": {"is_pointer": True}}, zarr_root_path=zarr_path) if with_caching: - dataset.cache(fs.join(tmpdir, "cache")) + dataset.cache_dir = fs.join(tmpdir, "cache") + dataset.cache() data = dataset.get_data(row=0, col="A") @@ -56,28 +57,28 @@ def test_dataset_checksum(test_dataset): # Without any changes, same hash kwargs = test_dataset.model_dump() - assert Dataset(**kwargs) == test_dataset + assert DatasetV1(**kwargs) == test_dataset # With unimportant changes, same hash kwargs["name"] = "changed" kwargs["description"] = "changed" kwargs["source"] = "https://changed.com" - assert Dataset(**kwargs) == test_dataset + assert DatasetV1(**kwargs) == test_dataset # Check sensitivity to the row and column ordering kwargs["table"] = kwargs["table"].iloc[::-1] kwargs["table"] = kwargs["table"][kwargs["table"].columns[::-1]] - assert Dataset(**kwargs) == test_dataset + assert DatasetV1(**kwargs) == test_dataset # Without any changes, but different hash - dataset = Dataset(**kwargs) + dataset = DatasetV1(**kwargs) dataset._md5sum = "invalid" assert dataset != test_dataset # With changes, but same hash kwargs["md5sum"] = test_dataset.md5sum kwargs["table"] = kwargs["table"].iloc[:-1] - assert Dataset(**kwargs) != test_dataset + assert DatasetV1(**kwargs) != test_dataset def test_dataset_from_zarr(zarr_archive, tmpdir): @@ -97,7 +98,7 @@ def test_dataset_from_json(test_dataset, tmpdir): path = fs.join(str(tmpdir), "dataset.json") - new_dataset = Dataset.from_json(path) + new_dataset = DatasetV1.from_json(path) assert test_dataset == new_dataset new_dataset = load_dataset(path) @@ -117,7 +118,7 @@ def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmpdir): dataset = create_dataset_from_file(archive, zarr_dir) path = dataset.to_json(json_dir) - new_dataset = Dataset.from_json(path) + new_dataset = DatasetV1.from_json(path) assert dataset == new_dataset new_dataset = load_dataset(path) @@ -131,7 +132,8 @@ def test_dataset_caching(zarr_archive, tmpdir): cached_dataset = create_dataset_from_file(zarr_archive, tmpdir.join("original2")) assert original_dataset == cached_dataset - cache_dir = cached_dataset.cache(tmpdir.join("cached").strpath) + cached_dataset.cache_dir = tmpdir.join("cached").strpath + cache_dir = cached_dataset.cache(verify_checksum=True) assert cached_dataset.zarr_root_path.startswith(cache_dir) assert cached_dataset == original_dataset @@ -140,7 +142,7 @@ def test_dataset_caching(zarr_archive, tmpdir): def test_dataset_index(): """Small test to check whether the dataset resets its index.""" df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["X", "Y", "Z"]) - dataset = Dataset(table=df) + dataset = DatasetV1(table=df) subset = Subset(dataset=dataset, indices=[1], input_cols=["A"], target_cols=["B"]) assert next(iter(subset)) == (np.array([2]), np.array([5])) @@ -177,3 +179,42 @@ def test_checksum_verification(test_dataset): test_dataset.md5sum = "0" * 32 with pytest.raises(ValueError): test_dataset.verify_checksum() + + +def test_dataset__get_item__(): + """Test the __getitem__() interface for the dataset.""" + + table = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}, index=["X", "Y", "Z"]) + dataset = DatasetV1(table=table) + + # Get a specific cell + assert dataset["X", "A"] == 1 + assert dataset["X", "B"] == 4 + assert dataset["Y", "A"] == 2 + assert dataset["Y", "B"] == 5 + assert dataset["Z", "A"] == 3 + assert dataset["Z", "B"] == 6 + + # Get a row + assert dataset["X"] == {"A": 1, "B": 4, "C": 7} + assert dataset["Y"] == {"A": 2, "B": 5, "C": 8} + assert dataset["Z"] == {"A": 3, "B": 6, "C": 9} + + +def test_dataset__get_item__with_pointer_columns(zarr_archive, tmpdir): + """Test the __getitem__() interface for a dataset with pointer columns (i.e. part of the data stored in Zarr).""" + + dataset = create_dataset_from_file(zarr_archive, tmpdir.join("data")) + root = zarr.open(zarr_archive) + + # Get a specific cell + assert np.array_equal(dataset[0, "A"], root["A"][0, :]) + + # Get a specific row + def _check_row_equality(d1, d2): + assert len(d1) == len(d2) + for k in d1: + assert np.array_equal(d1[k], d2[k]) + + _check_row_equality(dataset[0], {"A": root["A"][0, :], "B": root["B"][0, :]}) + _check_row_equality(dataset[10], {"A": root["A"][10, :], "B": root["B"][10, :]}) diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py new file mode 100644 index 00000000..d0141152 --- /dev/null +++ b/tests/test_dataset_v2.py @@ -0,0 +1,271 @@ +import os +from copy import deepcopy +from time import perf_counter + +import numcodecs +import numpy as np +import pandas as pd +import pytest +import zarr +from pydantic import ValidationError + +from polaris.dataset import Subset +from polaris.dataset._factory import DatasetFactory +from polaris.dataset.converters._pdb import PDBConverter +from polaris.dataset.zarr._manifest import generate_zarr_manifest +from polaris.experimental._dataset_v2 import _INDEX_ARRAY_KEY, DatasetV2 + + +def test_dataset_v2_get_columns(test_dataset_v2): + assert set(test_dataset_v2.columns) == {"A", "B"} + + +def test_dataset_v2_get_rows(test_dataset_v2): + assert set(test_dataset_v2.rows) == set(range(100)) + + +def test_dataset_v2_get_data(test_dataset_v2, zarr_archive): + root = zarr.open(zarr_archive, "r") + indices = np.random.randint(0, len(test_dataset_v2), 5) + for idx in indices: + assert np.array_equal(test_dataset_v2.get_data(row=idx, col="A"), root["A"][idx]) + assert np.array_equal(test_dataset_v2.get_data(row=idx, col="B"), root["B"][idx]) + + +def test_dataset_v2_with_subset(test_dataset_v2, zarr_archive): + root = zarr.open(zarr_archive, "r") + indices = np.random.randint(0, len(test_dataset_v2), 5) + subset = Subset(test_dataset_v2, indices, "A", "B") + for i, (x, y) in enumerate(subset): + idx = indices[i] + assert np.array_equal(x, root["A"][idx]) + assert np.array_equal(y, root["B"][idx]) + + +def test_dataset_v2_load_to_memory(test_dataset_v2): + subset = Subset( + dataset=test_dataset_v2, + indices=range(100), + input_cols=["A"], + target_cols=["B"], + ) + + t1 = perf_counter() + for x in subset: + pass + d1 = perf_counter() - t1 + + test_dataset_v2.load_to_memory() + + t2 = perf_counter() + for x in subset: + pass + d2 = perf_counter() - t2 + + assert d2 < d1 + + +def test_dataset_v2_serialization(test_dataset_v2, tmpdir): + save_dir = tmpdir.join("save_dir") + path = test_dataset_v2.to_json(save_dir) + new_dataset = DatasetV2.from_json(path) + for i in range(5): + assert np.array_equal(new_dataset.get_data(i, "A"), test_dataset_v2.get_data(i, "A")) + assert np.array_equal(new_dataset.get_data(i, "B"), test_dataset_v2.get_data(i, "B")) + + +def test_dataset_v2_caching(test_dataset_v2, tmpdir): + cache_dir = tmpdir.join("cache").strpath + test_dataset_v2.cache_dir = cache_dir + test_dataset_v2.cache() + assert str(test_dataset_v2.zarr_root_path).startswith(cache_dir) + + +def test_dataset_v1_v2_compatibility(test_dataset, tmpdir): + # A DataFrame is ultimately a collection of labeled numpy arrays + # We can thus also saved these same arrays to a Zarr archive + df = test_dataset.table + + path = tmpdir.join("data/v1v2.zarr") + + root = zarr.open(path, "w") + root.array("smiles", data=df["smiles"].values, dtype=object, object_codec=numcodecs.VLenUTF8()) + root.array("iupac", data=df["iupac"].values, dtype=object, object_codec=numcodecs.VLenUTF8()) + for col in set(df.columns) - {"smiles", "iupac"}: + root.array(col, data=df[col].values) + zarr.consolidate_metadata(path) + + kwargs = test_dataset.model_dump(exclude=["table", "zarr_root_path"]) + dataset = DatasetV2(**kwargs, zarr_root_path=str(path)) + + subset_1 = Subset(dataset=test_dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"]) + subset_2 = Subset(dataset=dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"]) + + for idx in range(5): + x1, y1 = subset_1[idx] + x2, y2 = subset_2[idx] + assert x1 == x2 + assert y1 == y2 + + +def test_dataset_v2_with_pdbs(pdb_paths, tmpdir): + # The PDB example is interesting because it creates a more complex Zarr archive + # that includes subgroups + zarr_root_path = str(tmpdir.join("pdbs.zarr")) + factory = DatasetFactory(zarr_root_path) + + # Build a V1 dataset + converter = PDBConverter() + factory.register_converter("pdb", converter) + factory.add_from_files(pdb_paths, axis=0) + dataset_v1 = factory.build() + + # Build a V2 dataset based on the V1 dataset + + # Add the magic index column to the Zarr subgroup + root = zarr.open(zarr_root_path, "a") + ordered_keys = [v.split("/")[-1] for v in dataset_v1.table["pdb"].values] + root["pdb"].array(_INDEX_ARRAY_KEY, data=ordered_keys, dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_root_path) + + # Update annotations to no longer have pointer columns + annotations = deepcopy(dataset_v1.annotations) + for anno in annotations.values(): + anno.is_pointer = False + + # Create the dataset + dataset_v2 = DatasetV2( + zarr_root_path=zarr_root_path, + annotations=annotations, + default_adapters=dataset_v1.default_adapters, + ) + + assert len(dataset_v1) == len(dataset_v2) + for idx in range(len(dataset_v1)): + pdb_1 = dataset_v1.get_data(idx, "pdb") + pdb_2 = dataset_v2.get_data(idx, "pdb") + assert pdb_1 == pdb_2 + + +def test_dataset_v2_indexing(zarr_archive): + # Create a subgroup with 100 arrays + root = zarr.open(zarr_archive, "a") + subgroup = root.create_group("X") + for i in range(100): + subgroup.array(f"{i}", data=np.array([i] * 100)) + + # Index it in reverse (element 0 is the last element in the array) + indices = [f"{idx}" for idx in range(100)][::-1] + subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_archive) + + # Create the dataset + dataset = DatasetV2(zarr_root_path=zarr_archive) + + # Check that the dataset is indexed correctly + for idx in range(100): + assert np.array_equal(dataset.get_data(row=idx, col="X"), np.array([99 - idx] * 100)) + + +def test_dataset_v2_validation_index_array(zarr_archive): + root = zarr.open(zarr_archive, "a") + + # Create subgroup that lacks the index array + subgroup = root.create_group("X") + zarr.consolidate_metadata(zarr_archive) + + with pytest.raises(ValidationError, match="does not have an index array"): + DatasetV2(zarr_root_path=zarr_archive) + + indices = [f"{idx}" for idx in range(100)] + indices[-1] = "invalid" + + # Create index array that doesn't match group length (zero arrays in group, but 100 indices) + subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_archive) + + with pytest.raises(ValidationError, match="Length of index array"): + DatasetV2(zarr_root_path=zarr_archive) + + for i in range(100): + subgroup.array(f"{i}", data=np.random.random(100)) + zarr.consolidate_metadata(zarr_archive) + + # Create index array that has invalid keys (last keys = 'invalid' rather than '99') + with pytest.raises(ValidationError, match="Keys of index array"): + DatasetV2(zarr_root_path=zarr_archive) + + +def test_dataset_v2_validation_consistent_lengths(zarr_archive): + root = zarr.open(zarr_archive, "a") + + # Change the length of one of the arrays + root["A"].append(np.random.random((1, 2048))) + zarr.consolidate_metadata(zarr_archive) + + # Subgroup has a false number of indices + with pytest.raises(ValidationError, match="should have the same length"): + DatasetV2(zarr_root_path=zarr_archive) + + # Make the length of the two arrays equal again + # shouldn't crash + root["B"].append(np.random.random((1, 2048))) + zarr.consolidate_metadata(zarr_archive) + DatasetV2(zarr_root_path=zarr_archive) + + # Create subgroup with inconsistent length + subgroup = root.create_group("X") + for i in range(100): + subgroup.array(f"{i}", data=np.random.random(100)) + indices = [f"{idx}" for idx in range(100)] + subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_archive) + + # Subgroup has a false number of indices + with pytest.raises(ValidationError, match="should have the same length"): + DatasetV2(zarr_root_path=zarr_archive) + + +def test_zarr_manifest(test_dataset_v2): + # Assert the manifest Parquet is created + assert test_dataset_v2.zarr_manifest_path is not None + assert os.path.isfile(test_dataset_v2.zarr_manifest_path) + + # Assert the manifest contains 204 rows (the number "204" is chosen because + # the Zarr archive defined in `conftest.py` contains 204 unique files) + df = pd.read_parquet(test_dataset_v2.zarr_manifest_path) + assert len(df) == 204 + + # Assert the manifest hash is calculated + assert test_dataset_v2.zarr_manifest_md5sum is not None + + # Add array to Zarr archive to change the number of chunks in the dataset + root = zarr.open(test_dataset_v2.zarr_root_path, "a") + root.array("C", data=np.random.random((100, 2048)), chunks=(1, None)) + + generate_zarr_manifest(test_dataset_v2.zarr_root_path, test_dataset_v2.cache_dir) + + # Get the length of the updated manifest file + post_change_manifest_length = len(pd.read_parquet(test_dataset_v2.zarr_manifest_path)) + + # Ensure Zarr manifest has an additional 100 chunks + 1 array metadata file + assert post_change_manifest_length == 305 + + +def test_dataset_v2__get_item__(test_dataset_v2, zarr_archive): + """Test the __getitem__() interface for the dataset V2.""" + + # Ground truth + root = zarr.open(zarr_archive) + + # Get a specific cell + assert np.array_equal(test_dataset_v2[0, "A"], root["A"][0, :]) + + # Get a specific row + def _check_row_equality(d1, d2): + assert len(d1) == len(d2) + for k in d1: + assert np.array_equal(d1[k], d2[k]) + + _check_row_equality(test_dataset_v2[0], {"A": root["A"][0, :], "B": root["B"][0, :]}) + _check_row_equality(test_dataset_v2[10], {"A": root["A"][10, :], "B": root["B"][10, :]}) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 6bb8be46..f1e76eee 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -1,17 +1,18 @@ import os -import pytest + import numpy as np import pandas as pd +import pytest import polaris as po from polaris.benchmark import ( MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification, ) +from polaris.dataset import DatasetV1 from polaris.evaluate._metric import Metric from polaris.evaluate._results import BenchmarkResults from polaris.utils.types import HubOwner -from polaris.dataset import Dataset def test_result_to_json(tmpdir: str, test_user_owner: HubOwner): @@ -44,7 +45,7 @@ def test_result_to_json(tmpdir: str, test_user_owner: HubOwner): assert po.__version__ == result.polaris_version -def test_metrics_singletask_reg(tmpdir: str, test_single_task_benchmark: SingleTaskBenchmarkSpecification): +def test_metrics_singletask_reg(test_single_task_benchmark: SingleTaskBenchmarkSpecification): _, test = test_single_task_benchmark.get_train_test_split() predictions = np.random.random(size=test.inputs.shape[0]) result = test_single_task_benchmark.evaluate(predictions) @@ -59,7 +60,7 @@ def test_metrics_singletask_reg(tmpdir: str, test_single_task_benchmark: SingleT assert metric in result.results.Metric.tolist() -def test_metrics_multitask_reg(tmpdir: str, test_multi_task_benchmark: MultiTaskBenchmarkSpecification): +def test_metrics_multitask_reg(test_multi_task_benchmark: MultiTaskBenchmarkSpecification): train, test = test_multi_task_benchmark.get_train_test_split() predictions = { target_col: np.random.random(size=test.inputs.shape[0]) for target_col in train.target_cols @@ -69,9 +70,7 @@ def test_metrics_multitask_reg(tmpdir: str, test_multi_task_benchmark: MultiTask assert metric in result.results.Metric.tolist() -def test_metrics_singletask_clf( - tmpdir: str, test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification -): +def test_metrics_singletask_clf(test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification): _, test = test_single_task_benchmark_clf.get_train_test_split() predictions = np.random.randint(2, size=test.inputs.shape[0]) probabilities = np.random.uniform(size=test.inputs.shape[0]) @@ -81,7 +80,7 @@ def test_metrics_singletask_clf( def test_metrics_singletask_multicls_clf( - tmpdir: str, test_single_task_benchmark_multi_clf: SingleTaskBenchmarkSpecification + test_single_task_benchmark_multi_clf: SingleTaskBenchmarkSpecification, ): _, test = test_single_task_benchmark_multi_clf.get_train_test_split() predictions = np.random.randint(3, size=test.inputs.shape[0]) @@ -92,7 +91,7 @@ def test_metrics_singletask_multicls_clf( assert metric in result.results.Metric.tolist() -def test_metrics_multitask_clf(tmpdir: str, test_multi_task_benchmark_clf: MultiTaskBenchmarkSpecification): +def test_metrics_multitask_clf(test_multi_task_benchmark_clf: MultiTaskBenchmarkSpecification): train, test = test_multi_task_benchmark_clf.get_train_test_split() predictions = { target_col: np.random.randint(2, size=test.inputs.shape[0]) for target_col in train.target_cols @@ -150,7 +149,7 @@ def test_absolute_average_fold_error(): def test_metric_y_types( - tmpdir: str, test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification, test_data: Dataset + test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification, test_data: DatasetV1 ): # here we use train split for testing purpose. _, test = test_single_task_benchmark_clf.get_train_test_split()