diff --git a/LICENSE b/LICENSE index f048b6a9..7dd1e2be 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2021 Valence + Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/NOTICE b/NOTICE new file mode 100644 index 00000000..564a6a7c --- /dev/null +++ b/NOTICE @@ -0,0 +1,13 @@ +Copyright 2023 Valence Labs + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/docs/api/dataset.md b/docs/api/dataset.md index ae107961..ec1087e6 100644 --- a/docs/api/dataset.md +++ b/docs/api/dataset.md @@ -8,4 +8,10 @@ options: filters: ["!^_"] +--- + +::: polaris.dataset.zarr + options: + filters: ["!^_"] + --- \ No newline at end of file diff --git a/polaris/__init__.py b/polaris/__init__.py index a9266859..ddb0f44a 100644 --- a/polaris/__init__.py +++ b/polaris/__init__.py @@ -1,4 +1,14 @@ +import os +import sys + +from loguru import logger + from ._version import __version__ from .loader import load_benchmark, load_dataset __all__ = ["load_dataset", "load_benchmark", "__version__"] + +# Configure the default logging level +os.environ["LOGURU_LEVEL"] = os.environ.get("LOGURU_LEVEL", "INFO") +logger.remove() +logger.add(sys.stderr, level=os.environ["LOGURU_LEVEL"]) diff --git a/polaris/_mixins.py b/polaris/_mixins.py new file mode 100644 index 00000000..8fccac35 --- /dev/null +++ b/polaris/_mixins.py @@ -0,0 +1,71 @@ +import abc +import re + +from loguru import logger +from pydantic import BaseModel, PrivateAttr, computed_field + +from polaris.utils.errors import PolarisChecksumError + + +class ChecksumMixin(BaseModel, abc.ABC): + """ + Mixin class to add checksum functionality to a class. + """ + + _md5sum: str | None = PrivateAttr(None) + + @abc.abstractmethod + def _compute_checksum(self) -> str: + """Compute the checksum of the dataset.""" + raise NotImplementedError + + @computed_field + @property + def md5sum(self) -> str: + """Lazily compute the checksum once needed.""" + if not self.has_md5sum: + logger.info("Computing the checksum. This can be slow for large datasets.") + self.md5sum = self._compute_checksum() + return self._md5sum + + @md5sum.setter + def md5sum(self, value: str): + """Set the checksum.""" + if not re.fullmatch(r"^[a-f0-9]{32}$", value): + raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") + self._md5sum = value + + @property + def has_md5sum(self) -> bool: + """Whether the md5sum for this class has been computed and stored.""" + return self._md5sum is not None + + def verify_checksum(self, md5sum: str | None = None): + """ + Recomputes the checksum and verifies whether it matches the stored checksum. + + Warning: Slow operation + This operation can be slow for large datasets. + + Info: Only works for locally stored datasets + The checksum verification only works for datasets that are stored locally in its entirety. + We don't have to verify the checksum for datasets stored on the Hub, as the Hub will do this on upload. + And if you're streaming the data from the Hub, we will check the checksum of each chunk on download. + """ + if md5sum is None: + md5sum = self._md5sum + if md5sum is None: + logger.warning( + "No checksum to verify against. Specify either the md5sum parameter or " + "store the checksum in the dataset.md5sum attribute." + ) + return + + # Recompute the checksum + logger.info("To verify the checksum, we need to recompute it. This can be slow for large datasets.") + self.md5sum = self._compute_checksum() + + if self.md5sum != md5sum: + raise PolarisChecksumError( + f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" + ) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 90f65284..7af669fb 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -17,12 +17,13 @@ from sklearn.utils.multiclass import type_of_target from polaris._artifact import BaseArtifactModel +from polaris._mixins import ChecksumMixin from polaris.dataset import Dataset, Subset from polaris.evaluate import BenchmarkResults, Metric, ResultsType from polaris.hub.settings import PolarisHubSettings from polaris.utils.context import tmp_attribute_change from polaris.utils.dict2html import dict2html -from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError +from polaris.utils.errors import InvalidBenchmarkError from polaris.utils.misc import listit from polaris.utils.types import ( AccessType, @@ -36,7 +37,7 @@ ColumnsType = Union[str, list[str]] -class BenchmarkSpecification(BaseArtifactModel): +class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin): """This class wraps a [`Dataset`][polaris.dataset.Dataset] with additional data to specify the evaluation logic. @@ -85,8 +86,6 @@ class BenchmarkSpecification(BaseArtifactModel): split: The predefined train-test split to use for evaluation. metrics: The metrics to use for evaluating performance main_metric: The main metric used to rank methods. If `None`, the first of the `metrics` field. - md5sum: The checksum is used to verify the version of the dataset specification. If specified, it will - raise an error if the specified checksum doesn't match the computed checksum. readme: Markdown text that can be used to provide a formatted description of the benchmark. If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI as it provides a rich text editor for writing markdown. @@ -102,7 +101,6 @@ class BenchmarkSpecification(BaseArtifactModel): split: SplitType metrics: Union[str, Metric, list[Union[str, Metric]]] main_metric: Optional[Union[str, Metric]] = None - md5sum: Optional[str] = None # Additional meta-data readme: str = "" @@ -214,6 +212,12 @@ def _validate_target_types(cls, v, info: ValidationInfo): for target in target_cols: if target not in v: val = dataset[:, target] + + # Non numeric columns can be targets (e.g. prediction molecular reactions), + # but in that case we currently don't infer the target type. + if not np.issubdtype(val.dtype, np.number): + continue + # remove the nans for mutiple task dataset when the table is sparse target_type = type_of_target(val[~np.isnan(val)]) if target_type == "continuous": @@ -230,34 +234,11 @@ def _validate_target_types(cls, v, info: ValidationInfo): @classmethod def _validate_model(cls, m: "BenchmarkSpecification"): """ - If a checksum is provided, verify it matches what the checksum should be. - If no checksum is provided, make sure it is set. - Also sets a default metric if missing. + Sets a default metric if missing. """ - - # Validate checksum - checksum = m.md5sum - - expected = cls._compute_checksum( - dataset=m.dataset, - target_cols=m.target_cols, - input_cols=m.input_cols, - split=m.split, - metrics=m.metrics, - ) - - if checksum is None: - m.md5sum = expected - elif checksum != expected: - raise PolarisChecksumError( - "The dataset checksum does not match what was specified in the meta-data. " - f"{checksum} != {expected}" - ) - # Set a default main metric if not set yet if m.main_metric is None: m.main_metric = m.metrics[0] - return m @field_serializer("metrics", "main_metric") @@ -277,8 +258,7 @@ def _serialize_target_types(self, v): """Convert from enum to string to make sure it's serializable""" return {k: v.value for k, v in self.target_types.items()} - @staticmethod - def _compute_checksum(dataset, target_cols, input_cols, split, metrics): + def _compute_checksum(self): """ Computes a hash of the benchmark. @@ -286,16 +266,16 @@ def _compute_checksum(dataset, target_cols, input_cols, split, metrics): """ hash_fn = md5() - hash_fn.update(dataset.md5sum.encode("utf-8")) - for c in sorted(target_cols): + hash_fn.update(self.dataset.md5sum.encode("utf-8")) + for c in sorted(self.target_cols): hash_fn.update(c.encode("utf-8")) - for c in sorted(input_cols): + for c in sorted(self.input_cols): hash_fn.update(c.encode("utf-8")) - for m in sorted(metrics, key=lambda k: k.name): + for m in sorted(self.metrics, key=lambda k: k.name): hash_fn.update(m.name.encode("utf-8")) - if not isinstance(split[1], dict): - split = split[0], {"test": split[1]} + if not isinstance(self.split[1], dict): + split = self.split[0], {"test": self.split[1]} # Train set s = json.dumps(sorted(split[0])) diff --git a/polaris/dataset/_adapters.py b/polaris/dataset/_adapters.py index b97b55af..89fcee14 100644 --- a/polaris/dataset/_adapters.py +++ b/polaris/dataset/_adapters.py @@ -1,4 +1,5 @@ from enum import Enum, auto, unique + import datamol as dm # Map of conversion operations which can be applied to dataset columns diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 596db4ec..745485b3 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -1,12 +1,14 @@ import json +import uuid from hashlib import md5 +from pathlib import Path from typing import Dict, List, MutableMapping, Optional, Tuple, Union import fsspec import numpy as np import pandas as pd import zarr -from datamol.utils import fs +from datamol.utils import fs as dmfs from loguru import logger from pydantic import ( Field, @@ -18,14 +20,21 @@ ) from polaris._artifact import BaseArtifactModel +from polaris._mixins import ChecksumMixin from polaris.dataset._adapters import Adapter from polaris.dataset._column import ColumnAnnotation -from polaris.dataset.zarr import MemoryMappedDirectoryStore +from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum, compute_zarr_checksum from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.dict2html import dict2html -from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError -from polaris.utils.types import AccessType, HttpUrlString, HubOwner, SupportedLicenseType +from polaris.utils.errors import InvalidDatasetError +from polaris.utils.types import ( + AccessType, + HttpUrlString, + HubOwner, + SupportedLicenseType, + ZarrConflictResolution, +) # Constants _SUPPORTED_TABLE_EXTENSIONS = ["parquet"] @@ -33,7 +42,7 @@ _INDEX_SEP = "#" -class Dataset(BaseArtifactModel): +class Dataset(BaseArtifactModel, ChecksumMixin): """Basic data-model for a Polaris dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. At its core, a dataset in Polaris is a tabular data structure that stores data-points in a row-wise manner. @@ -51,8 +60,6 @@ class Dataset(BaseArtifactModel): default_adapters: The adapters that the Dataset recommends to use by default to change the format of the data for specific columns. zarr_root_path: The data for any pointer column should be saved in the Zarr archive this path points to. - md5sum: The checksum is used to verify the version of the dataset specification. If specified, it will - raise an error if the specified checksum doesn't match the computed checksum. readme: Markdown text that can be used to provide a formatted description of the dataset. If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI as it provides a rich text editor for writing markdown. @@ -65,7 +72,6 @@ class Dataset(BaseArtifactModel): Raises: InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. - PolarisChecksumError: If the specified checksum does not match the computed checksum. """ # Public attributes @@ -73,7 +79,6 @@ class Dataset(BaseArtifactModel): table: Union[pd.DataFrame, str] default_adapters: Dict[str, Adapter] = Field(default_factory=dict) zarr_root_path: Optional[str] = None - md5sum: Optional[str] = None # Additional meta-data readme: str = "" @@ -83,12 +88,15 @@ class Dataset(BaseArtifactModel): curation_reference: Optional[HttpUrlString] = None # Config - cache_dir: Optional[str] = None # Where to cache the data to if cache() is called. + cache_dir: Optional[Path] = None # Where to cache the data to if cache() is called. # Private attributes _zarr_root: Optional[zarr.Group] = PrivateAttr(None) _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) + _md5sum: Optional[str] = PrivateAttr(None) + _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) _client = PrivateAttr(None) # Optional[PolarisHubClient] + _warn_about_remote_zarr: bool = PrivateAttr(True) @field_validator("table") def _validate_table(cls, v): @@ -99,7 +107,7 @@ def _validate_table(cls, v): """ # Load from path if not a dataframe if not isinstance(v, pd.DataFrame): - if not fs.is_file(v) or fs.get_extension(v) not in _SUPPORTED_TABLE_EXTENSIONS: + if not dmfs.is_file(v) or dmfs.get_extension(v) not in _SUPPORTED_TABLE_EXTENSIONS: raise InvalidDatasetError(f"{v} is not a valid DataFrame or .parquet path.") v = pd.read_parquet(v) # Check if there are any duplicate columns @@ -113,10 +121,7 @@ def _validate_table(cls, v): @model_validator(mode="after") @classmethod def _validate_model(cls, m: "Dataset"): - """If a checksum is provided, verify it matches what the checksum should be. - If no checksum is provided, make sure it is set. - If no cache_dir is provided, set it to the default cache dir and make sure it exists - """ + """Verifies some dependencies between properties""" # Verify that all annotations are for columns that exist if any(k not in m.table.columns for k in m.annotations): @@ -140,24 +145,12 @@ def _validate_model(cls, m: "Dataset"): m.annotations[c] = ColumnAnnotation() m.annotations[c].dtype = m.table[c].dtype - # Verify the checksum - # NOTE (cwognum): Is it still reasonable to always verify this as the dataset size grows? - actual = m.md5sum - expected = cls._compute_checksum(m.table) - - if actual is None: - m.md5sum = expected - elif actual != expected: - raise PolarisChecksumError( - "The dataset md5sum does not match what was specified in the meta-data. " - f"{actual} != {expected}" - ) - # Set the default cache dir if none and make sure it exists if m.cache_dir is None: - m.cache_dir = fs.join(DEFAULT_CACHE_DIR, _CACHE_SUBDIR, m.name, m.md5sum) - fs.mkdir(m.cache_dir, exist_ok=True) + dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) + m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + m.cache_dir.mkdir(parents=True, exist_ok=True) return m @field_validator("default_adapters", mode="before") @@ -170,31 +163,46 @@ def _serialize_adapters(self, value: List[Adapter]): """Serializes the adapters""" return {k: v.name for k, v in value.items()} - @staticmethod - def _compute_checksum(table): + def _compute_checksum(self): """Computes a hash of the dataset. This is meant to uniquely identify the dataset and can be used to verify the version. 1. Is not sensitive to the ordering of the columns or rows in the table. 2. Purposefully does not include the meta-data (source, description, name, annotations). - 3. For any pointer column, it uses a hash of the path instead of the file contents. - This is a limitation, but probably a reasonable assumption that helps practicality. - A big downside is that as the dataset is saved elsewhere, the hash changes. + 3. Includes a hash for the Zarr archive. """ hash_fn = md5() # Sort the columns s.t. the checksum is not sensitive to the column-ordering - df = table.copy(deep=True) + df = self.table.copy(deep=True) df = df[sorted(df.columns.tolist())] # Use the sum of the row-wise hashes s.t. the hash is insensitive to the row-ordering table_hash = pd.util.hash_pandas_object(df, index=False).sum() hash_fn.update(table_hash) + # If the Zarr archive exists, we hash its contents too. + if self.uses_zarr: + zarr_hash, self._zarr_md5sum_manifest = compute_zarr_checksum(self.zarr_root_path) + hash_fn.update(zarr_hash.encode()) + checksum = hash_fn.hexdigest() return checksum + @computed_field + @property + def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: + """ + The Zarr Checksum manifest stores the checksums of all files in a Zarr archive. + If the dataset doesn't use Zarr, this will simply return an empty list. + """ + if len(self._zarr_md5sum_manifest) == 0 and not self.has_md5sum: + # The manifest is set as an instance variable + # as a side-effect of the compute_checksum method + self.md5sum = self._compute_checksum() + return self._zarr_md5sum_manifest + @property def client(self): """The Polaris Hub client used to interact with the Polaris Hub.""" @@ -206,6 +214,11 @@ def client(self): self._client = PolarisHubClient() return self._client + @property + def uses_zarr(self) -> bool: + """Whether any of the data in this dataset is stored in a Zarr Archive.""" + return self.zarr_root_path is not None + @property def zarr_data(self): """Get the Zarr data. @@ -241,14 +254,17 @@ def zarr_root(self): # We open the archive in read-only mode if it is saved on the Hub saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) - saved_remote = saved_on_hub or not fs.is_local_path(self.zarr_root_path) - if saved_remote: - logger.warning( - f"You're loading data from a remote location. " - f"To speed up this process, consider caching the dataset first " - f"using {self.__class__.__name__}.cache()" - ) + if self._warn_about_remote_zarr: + saved_remote = saved_on_hub or not Path(self.zarr_root_path).exists() + + if saved_remote: + logger.warning( + f"You're loading data from a remote location. " + f"To speed up this process, consider caching the dataset first " + f"using {self.__class__.__name__}.cache()" + ) + self._warn_about_remote_zarr = False try: if saved_on_hub: @@ -369,7 +385,11 @@ def from_json(cls, path: str): data.pop("cache_dir", None) return cls.model_validate(data) - def to_json(self, destination: str) -> str: + def to_json( + self, + destination: str, + if_exists: ZarrConflictResolution = "replace", + ) -> str: """ Save the dataset to a destination directory as a JSON file. @@ -384,31 +404,35 @@ def to_json(self, destination: str) -> str: Args: destination: The _directory_ to save the associated data to. + if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw + an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. Returns: The path to the JSON file. """ - fs.mkdir(destination, exist_ok=True) - table_path = fs.join(destination, "table.parquet") - dataset_path = fs.join(destination, "dataset.json") - zarr_archive = fs.join(destination, "data.zarr") + dmfs.mkdir(destination, exist_ok=True) + table_path = dmfs.join(destination, "table.parquet") + dataset_path = dmfs.join(destination, "dataset.json") + new_zarr_root_path = dmfs.join(destination, "data.zarr") # Lu: Avoid serilizing and sending None to hub app. serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) serialized["table"] = table_path # Copy over Zarr data to the destination - if self.zarr_root is not None: - dest = zarr.open(zarr_archive, "w") - zarr.copy_all(source=self.zarr_root, dest=dest) + if self.uses_zarr: + self._warn_about_remote_zarr = False + + logger.info(f"Copying Zarr archive to {new_zarr_root_path}. This may take a while.") - # Copy the .zmetadata file - # To track discussions on whether this should be done by copy_all() - # see https://github.com/zarr-developers/zarr-python/issues/1731 - zmetadata_content = self.zarr_root.store.store[".zmetadata"] - dest.store[".zmetadata"] = zmetadata_content + dest = zarr.open(new_zarr_root_path, "w") - serialized["zarr_root_path"] = zarr_archive + zarr.copy_store( + source=self.zarr_root.store.store, + dest=dest.store, + log=logger.debug, + if_exists=if_exists, + ) self.table.to_parquet(table_path) with fsspec.open(dataset_path, "w") as f: @@ -416,12 +440,13 @@ def to_json(self, destination: str) -> str: return dataset_path - def cache(self, cache_dir: Optional[str] = None) -> str: + def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) -> str: """Caches the dataset by downloading all additional data for pointer columns to a local directory. Args: cache_dir: The directory to cache the data to. If not provided, this will fall back to the `Dataset.cache_dir` attribute + verify_checksum: Whether to verify the checksum of the dataset after caching. Returns: The path to the cache directory. @@ -432,10 +457,13 @@ def cache(self, cache_dir: Optional[str] = None) -> str: self.to_json(self.cache_dir) - if self.zarr_root_path is not None: - self.zarr_root_path = fs.join(self.cache_dir, "data.zarr") + if self.uses_zarr: + self.zarr_root_path = dmfs.join(self.cache_dir, "data.zarr") self._zarr_root = None + if verify_checksum: + self.verify_checksum() + return self.cache_dir def size(self): diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 0b0e02db..448abbff 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -76,6 +76,10 @@ def __init__( self._adapters = adapters self._featurization_fn = featurization_fn + + # NOTE (cwognum): Note to future self. As we're starting to think about competition-style benchmarks, + # we will likely split up datasets. In that case, this default iloc_to_loc mapping won't work. + # By that time, we should probably be able to overwrite this mapping. self._iloc_to_loc = self.dataset.table.index # For the iterator implementation diff --git a/polaris/dataset/converters/_zarr.py b/polaris/dataset/converters/_zarr.py index 5ed706d0..4380325c 100644 --- a/polaris/dataset/converters/_zarr.py +++ b/polaris/dataset/converters/_zarr.py @@ -35,7 +35,7 @@ def convert(self, path: str, factory: "DatasetFactory") -> FactoryProduct: raise ValueError("The root of the zarr hierarchy should only contain arrays.") # Copy to the source zarr, so everything is in one place - zarr.copy_all(source=src, dest=factory.zarr_root) + zarr.copy_store(source=src.store, dest=factory.zarr_root.store, if_exists="skip") # Construct the table # Parse any group into a column diff --git a/polaris/dataset/zarr/__init__.py b/polaris/dataset/zarr/__init__.py index cb984e02..57f500ed 100644 --- a/polaris/dataset/zarr/__init__.py +++ b/polaris/dataset/zarr/__init__.py @@ -1,3 +1,4 @@ +from ._checksum import ZarrFileChecksum, compute_zarr_checksum from ._memmap import MemoryMappedDirectoryStore -__all__ = ["MemoryMappedDirectoryStore"] +__all__ = ["MemoryMappedDirectoryStore", "compute_zarr_checksum", "ZarrFileChecksum"] diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py new file mode 100644 index 00000000..a06f9491 --- /dev/null +++ b/polaris/dataset/zarr/_checksum.py @@ -0,0 +1,410 @@ +""" +The code in this file is based on the zarr-checksum package + +Mainted by Jacob Nesbitt, released under the DANDI org on Github +and with Kitware, Inc. credited as the author. This code is released +with the Apache 2.0 license. + +See also: https://github.com/dandi/zarr_checksum + +Instead of adding the package as a dependency, we opted to copy over the code +because it is a small and self-contained module that we will want to alter to +support our Polaris code base. + +NOTE: We have made some modifications to the original code. + +---- + +Copyright 2023 Kitware, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import hashlib +import heapq +import os +import re +from dataclasses import asdict, dataclass, field +from functools import total_ordering +from json import dumps +from pathlib import Path +from typing import List, Tuple + +import fsspec +import zarr +import zarr.errors +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel +from tqdm import tqdm + +from polaris.utils.errors import InvalidZarrChecksum + +ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)" + + +def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, List["ZarrFileChecksum"]]: + r""" + Implements an algorithm to compute the Zarr checksum. + + Warning: This checksum is sensitive to Zarr configuration. + This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size, + the checksum will also change. + + To understand how this works, consider the following directory structure: + + . (root) + / \ + a c + / + b + + Within zarr, this would for example be: + + - `root`: A Zarr Group with a single Array. + - `a`: A Zarr Array + - `b`: A single chunk of the Zarr Array + - `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup) + + To compute the checksum, we first find all the trees in the node, in this case b and c. + We compute the hash of the content (the raw bytes) for each of these files. + + We then work our way up the tree. For any node (directory), we find all children of that node. + In an sorted order, we then serialize a list with - for each of the children - the checksum, size, and number of children. + The hash of the directory is then equal to the hash of the serialized JSON. + + The Polaris implementation is heavily based on the [`zarr-checksum` package](https://github.com/dandi/zarr_checksum). + This method is the biggest deviation of the original code. + """ + + # Get the protocol of the path + protocol = fsspec.utils.get_protocol(zarr_root_path) + + # We only support computing checksum for local datasets. + # NOTE (cwognum): We don't have to verify the checksum for datasets stored on the Hub, + # as the Hub will do this on upload. And if you're streaming the data from the Hub, + # we will check the checksum of each chunk on download. + if protocol != "file": + raise RuntimeError( + "You can only compute the checksum for a local Zarr archive. " + "You can cache a dataset to your local machine with `dataset.cache()`." + ) + + # Normalize the path + zarr_root_path = os.path.expandvars(zarr_root_path) + zarr_root_path = os.path.expanduser(zarr_root_path) + zarr_root_path = os.path.abspath(zarr_root_path) + + fs, zarr_root_path = fsspec.url_to_fs(zarr_root_path) + + # Make sure the path exists and is a Zarr archive + zarr.open_group(zarr_root_path, mode="r") + + # Generate the checksum + tree = _ZarrChecksumTree() + + # Find all files below the root + leaves = fs.find(zarr_root_path, detail=True) + zarr_md5sum_manifest = [] + + for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"): + path = file["name"] + + relpath = path.removeprefix(zarr_root_path) + relpath = relpath.lstrip("/") + relpath = Path(relpath) + + size = file["size"] + + # Compute md5sum of file + md5sum = hashlib.md5() + with fs.open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + md5sum.update(chunk) + digest = md5sum.hexdigest() + + # Add a leaf to the tree + # (This actually adds the file's checksum to the parent directory's manifest) + tree.add_leaf( + path=relpath, + size=size, + digest=digest, + ) + + # We persist the checksums for leaf nodes separately, + # because this is what the Hub needs to verify data integrity. + zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size)) + + # Compute digest + return tree.process().digest, zarr_md5sum_manifest + + +class ZarrFileChecksum(BaseModel): + """ + This data is sent to the Hub to verify the integrity of the Zarr archive on upload. + + Attributes: + path: The path of the file relative to the Zarr root. + md5sum: The md5sum of the file. + size: The size of the file in bytes. + """ + + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, arbitrary_types_allowed=True) + + path: str + md5sum: str + size: int + + +# ================================ +# Overview of the data structures +# ================================ + +# NOTE (cwognum): I kept forgetting how this works, so I'm writing it down +# - The ZarrChecksumTree is a binary tree (heap queue). It determines the order in which to process the nodes. +# - The ZarrChecksumNode is a node in the ZarrChecksumTree queue. It represents a directory in the Zarr archive and +# stores a manifest with all the data needed to compute the checksum for that node. +# - The ZarrChecksumManifest is a collection of checksums for all direct (non-recursive) children of a directory. +# - The ZarrChecksum is the data used to compute the checksum for a file or directory in a Zarr Archive. +# This is the object that the ZarrChecksumManifest stores a collection of. +# - A ZarrDirectoryDigest is the result of processing a directory. Once completed, +# it is added to the ZarrChecksumManifest of its parent as part of a ZarrChecksum. + +# NOTE (cwognum): As a first impression, it seems there is some redundancy in the data structures. +# My feeling is that we could reduce the redundancy to simplify things and improve maintainability. +# However, for the time being, let's stick close to the original code. + +# ================================ + + +# Pydantic models aren't used for performance reasons +class _ZarrChecksumTree: + """ + The ZarrChecksumTree is a tree structure that maintains the state of the checksum algorithm. + + Initialized with a set of leafs (i.e. files), the nodes in this tree correspond to all directories + that are above those leafs and below the Zarr Root. + + The tree then implements the logic for retrieving the next node (i.e. directory) to process, + and for computing the checksum for that node based on its children. + Once it reaches the root, it has computed the checksum for the entire Zarr archive. + """ + + def __init__(self) -> None: + # Queue to prioritize the next node to process + self._heap: list[tuple[int, _ZarrChecksumNode]] = [] + + # Map of (relative) paths to nodes. + self._path_map: dict[Path, _ZarrChecksumNode] = {} + + @property + def empty(self) -> bool: + """Check if the tree is empty.""" + # This is used as an exit condition in the process() method + return len(self._heap) == 0 + + def _add_path(self, key: Path) -> None: + """Adds a new entry to the heap queue for which we need to compute the checksum.""" + + # Create a new node + # A node represents a file or directory. + # A node refers to a node in the heap queue (i.e. binary tree) + # The structure of the heap is thus _not_ the same as the structure of the file system! + node = _ZarrChecksumNode(path=key, checksums=_ZarrChecksumManifest()) + self._path_map[key] = node + + # Add node to heap with length (negated to represent a max heap) + # We use the length of the parents (relative to the Zarr root) to structure the heap. + # The node with the longest path is the deepest node in the tree. + # This node will be prioritized for processing next. + length = len(key.parents) + heapq.heappush(self._heap, (-1 * length, node)) + + def _get_path(self, key: Path) -> "_ZarrChecksumNode": + """ + If an entry for this path already exists, return it. + Otherwise create a new one and return that. + """ + if key not in self._path_map: + self._add_path(key) + return self._path_map[key] + + def add_leaf(self, path: Path, size: int, digest: str) -> None: + """Add a leaf file to the tree.""" + parent_node = self._get_path(path.parent) + parent_node.checksums.files.append(_ZarrChecksum(name=path.name, size=size, digest=digest)) + + def add_node(self, path: Path, size: int, digest: str, count: int) -> None: + """Add an internal node to the tree.""" + parent_node = self._get_path(path.parent) + parent_node.checksums.directories.append( + _ZarrChecksum( + name=path.name, + size=size, + digest=digest, + count=count, + ) + ) + + def pop_deepest(self) -> "_ZarrChecksumNode": + """ + Returns the node with the highest priority for processing next. + + Returns (one of the) node(s) with the most parent directories + (i.e. the deepest directory in the file system) + """ + _, node = heapq.heappop(self._heap) + del self._path_map[node.path] + return node + + def process(self) -> "_ZarrDirectoryDigest": + """Process the tree, returning the resulting top level digest.""" + + # Begin with empty root node, so that if no files are present, the empty checksum is returned + node = _ZarrChecksumNode(path=Path("."), checksums=_ZarrChecksumManifest()) + + while not self.empty: + # Get the next directory to process + # Priority is based on the number of parents a directory has + # In other word, the depth of the directory in the file system. + node = self.pop_deepest() + + # If we have reached the root node, then we're done. + if node.path == Path(".") or node.path == Path("/"): + break + + # Add the parent of this node to the tree + directory_digest = node.checksums.generate_digest() + self.add_node( + path=node.path, + size=directory_digest.size, + digest=directory_digest.digest, + count=directory_digest.count, + ) + + # Return digest + return node.checksums.generate_digest() + + +@dataclass +class _ZarrChecksumNode: + """ + A node in the ZarrChecksumTree. + + This node represents a file or directory in the Zarr archive, + but "node" here refers to a node in the heap queue (i.e. binary tree). + The structure of the heap is thus _not_ the same as the structure of the file system! + + The node stores a manifest of checksums for all files and directories below it. + """ + + path: Path + checksums: "_ZarrChecksumManifest" + + def __lt__(self, other: "_ZarrChecksumNode") -> bool: + return str(self.path) < str(other.path) + + +@dataclass +class _ZarrChecksumManifest: + """ + For a directory in the Zarr archive (i.e. a node in the heap queue), + we maintain a manifest of the checksums for all files and directories + below that directory. + + This data is then used to calculate the checksum of a directory. + """ + + directories: list["_ZarrChecksum"] = field(default_factory=list) + files: list["_ZarrChecksum"] = field(default_factory=list) + + @property + def is_empty(self) -> bool: + return not (self.files or self.directories) + + def generate_digest(self) -> "_ZarrDirectoryDigest": + """Generate an aggregated digest for the provided files/directories.""" + + # Sort everything to ensure the checksum is deterministic + self.files.sort() + self.directories.sort() + + # Aggregate total file count + count = len(self.files) + sum(checksum.count for checksum in self.directories) + + # Aggregate total size + size = sum(file.size for file in self.files) + sum(directory.size for directory in self.directories) + + # Serialize json without any spacing + json = dumps(asdict(self), separators=(",", ":")) + + # Generate digest + md5 = hashlib.md5(json.encode("utf-8")).hexdigest() + + # Construct and return + return _ZarrDirectoryDigest(md5=md5, count=count, size=size) + + +@total_ordering +@dataclass +class _ZarrChecksum: + """ + The data used to compute the checksum for a file or directory in a Zarr Archive. + + This class is serialized to JSON, and as such, key order should not be modified. + """ + + digest: str + name: str + size: int + count: int = 0 + + # To make this class sortable + def __lt__(self, other: "_ZarrChecksum") -> bool: + return self.name < other.name + + +@dataclass +class _ZarrDirectoryDigest: + """ + The digest for a directory in a Zarr Archive. + + The digest is a string representation that serves as a checksum for the directory. + This is a utility class to (de)serialize that string. + """ + + md5: str + count: int + size: int + + @classmethod + def parse(cls, checksum: str | None) -> "_ZarrDirectoryDigest": + if checksum is None: + return cls.parse(EMPTY_CHECKSUM) + + match = re.match(ZARR_DIGEST_PATTERN, checksum) + if match is None: + raise InvalidZarrChecksum() + + md5, count, size = match.groups() + return cls(md5=md5, count=int(count), size=int(size)) + + def __str__(self) -> str: + return self.digest + + @property + def digest(self) -> str: + return f"{self.md5}-{self.count}-{self.size}" + + +# The "null" zarr checksum +EMPTY_CHECKSUM = _ZarrChecksumManifest().generate_digest().digest diff --git a/polaris/dataset/zarr/_memmap.py b/polaris/dataset/zarr/_memmap.py index 55d13d94..6a3f2da3 100644 --- a/polaris/dataset/zarr/_memmap.py +++ b/polaris/dataset/zarr/_memmap.py @@ -6,7 +6,7 @@ class MemoryMappedDirectoryStore(zarr.DirectoryStore): """ A Zarr Store to open chunks as memory-mapped files. - See https://github.com/zarr-developers/zarr-python/issues/1245 + See also [this Github issue](https://github.com/zarr-developers/zarr-python/issues/1245). Memory mapping leverages low-level OS functionality to reduce the time it takes to read the content of a file by directly mapping to memory. diff --git a/polaris/hub/client.py b/polaris/hub/client.py index cbb7231d..62fad481 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -13,7 +13,7 @@ from authlib.integrations.httpx_client import OAuth2Client, OAuthError from authlib.oauth2 import TokenAuth from authlib.oauth2.rfc6749 import OAuth2Token -from httpx import HTTPStatusError +from httpx import HTTPStatusError, Response from httpx._types import HeaderTypes, URLTypes from loguru import logger @@ -29,9 +29,15 @@ from polaris.hub.polarisfs import PolarisFileSystem from polaris.hub.settings import PolarisHubSettings from polaris.utils.context import tmp_attribute_change -from polaris.utils.errors import InvalidDatasetError, PolarisHubError, PolarisUnauthorizedError +from polaris.utils.errors import ( + InvalidDatasetError, + PolarisHubError, + PolarisUnauthorizedError, +) +from polaris.utils.misc import should_verify_checksum from polaris.utils.types import ( AccessType, + ChecksumStrategy, HubOwner, IOMode, SupportedLicenseType, @@ -174,6 +180,11 @@ def _base_request_to_hub(self, url: str, method: str, **kwargs): return response + def get_metadata_from_response(self, response: Response, key: str) -> str | None: + """Get custom metadata saved to the R2 object from the headers.""" + key = f"{self.settings.custom_metadata_prefix}{key}" + return response.headers.get(key) + def request(self, method, url, withhold_token=False, auth=httpx.USE_CLIENT_DEFAULT, **kwargs): """Wraps the base request method to handle errors""" try: @@ -233,13 +244,19 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: dataset_list = [bm["artifactId"] for bm in response["data"]] return dataset_list - def get_dataset(self, owner: str | HubOwner, name: str, verify_checksum: bool = True) -> Dataset: + def get_dataset( + self, + owner: str | HubOwner, + name: str, + verify_checksum: ChecksumStrategy = "verify_unless_zarr", + ) -> Dataset: """Load a dataset from the Polaris Hub. Args: owner: The owner of the dataset. Can be either a user or organization from the Polaris Hub. name: The name of the dataset. - verify_checksum: Whether to use the checksum to verify the integrity of the dataset. + verify_checksum: Whether to use the checksum to verify the integrity of the dataset. If None, + will infer a practical default based on the dataset's storage location. Returns: A `Dataset` instance, if it exists. @@ -261,10 +278,13 @@ def get_dataset(self, owner: str | HubOwner, name: str, verify_checksum: bool = response["table"] = self._load_from_signed_url(url=url, headers=headers, load_fn=pd.read_parquet) - if not verify_checksum: - response.pop("md5Sum", None) + dataset = Dataset(**response) - return Dataset(**response) + if should_verify_checksum(verify_checksum, dataset): + dataset.verify_checksum() + else: + dataset.md5sum = response["md5Sum"] + return dataset def open_zarr_file( self, owner: str | HubOwner, name: str, path: str, mode: IOMode, as_consolidated: bool = True @@ -276,7 +296,8 @@ def open_zarr_file( name: Name of the dataset. path: Path to the Zarr file within the dataset. mode: The mode in which the file is opened. - as_consolidated: Whether to open the store with consolidated metadata for optimized reading. This is only applicable in 'r' and 'r+' modes. + as_consolidated: Whether to open the store with consolidated metadata for optimized reading. + This is only applicable in 'r' and 'r+' modes. Returns: The Zarr object representing the dataset. @@ -318,14 +339,17 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: return benchmarks_list def get_benchmark( - self, owner: str | HubOwner, name: str, verify_checksum: bool = True + self, + owner: str | HubOwner, + name: str, + verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> BenchmarkSpecification: """Load a benchmark from the Polaris Hub. Args: owner: The owner of the benchmark. Can be either a user or organization from the Polaris Hub. name: The name of the benchmark. - verify_checksum: Whether to use the checksum to verify the integrity of the dataset. + verify_checksum: Whether to use the checksum to verify the integrity of the benchmark. Returns: A `BenchmarkSpecification` instance, if it exists. @@ -347,10 +371,14 @@ def get_benchmark( else MultiTaskBenchmarkSpecification ) - if not verify_checksum: - response.pop("md5Sum", None) + benchmark = benchmark_cls(**response) + + if should_verify_checksum(verify_checksum, benchmark.dataset): + benchmark.verify_checksum() + else: + benchmark.md5sum = response["md5Sum"] - return benchmark_cls(**response) + return benchmark def upload_results( self, @@ -451,19 +479,20 @@ def upload_dataset( dataset_json["zarrRootPath"] = f"{PolarisFileSystem.protocol}://data.zarr" # Uploading a dataset is a three-step process. - # 1. Upload the dataset meta data to the hub and prepare the hub to receive the parquet file + # 1. Upload the dataset meta data to the hub and prepare the hub to receive the data # 2. Upload the parquet file to the hub # 3. Upload the associated Zarr archive # TODO: Revert step 1 in case step 2 fails - Is this needed? Or should this be taken care of by the hub? - # Write the parquet file directly to a buffer + # Prepare the parquet file buffer = BytesIO() dataset.table.to_parquet(buffer, engine="auto") parquet_size = len(buffer.getbuffer()) parquet_md5 = md5(buffer.getbuffer()).hexdigest() # Step 1: Upload meta-data - # Instead of directly uploading the table, we announce to the hub that we intend to upload one. + # Instead of directly uploading the data, we announce to the hub that we intend to upload it. + # We do so separately for the Zarr archive and Parquet file. url = f"/dataset/{dataset.artifact_id}" response = self._base_request_to_hub( url=url, @@ -472,8 +501,9 @@ def upload_dataset( "tableContent": { "size": parquet_size, "fileType": "parquet", - "md5sum": parquet_md5, + "md5Sum": parquet_md5, }, + "zarrContent": [md5sum.model_dump() for md5sum in dataset._zarr_md5sum_manifest], "access": access, **dataset_json, }, @@ -494,6 +524,7 @@ def upload_dataset( if hub_response.status_code == 307: # If the hub returns a 307 redirect, we need to follow it to get the signed URL hub_response_body = hub_response.json() + # Upload the data to the cloudflare url bucket_response = self.request( url=hub_response_body["url"], @@ -509,7 +540,7 @@ def upload_dataset( hub_response.raise_for_status() # Step 3: Upload any associated Zarr archive - if dataset.zarr_root is not None: + if dataset.uses_zarr: with tmp_attribute_change(self.settings, "default_timeout", timeout): # Copy the Zarr archive to the hub dest = self.open_zarr_file( @@ -531,7 +562,7 @@ def upload_dataset( zarr.copy_store( source=dataset.zarr_root.store.store, dest=dest.store, - log=logger.info, + log=logger.debug, if_exists=if_exists, ) diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py index ea201385..9bc8ab85 100644 --- a/polaris/hub/polarisfs.py +++ b/polaris/hub/polarisfs.py @@ -1,8 +1,8 @@ -import hashlib -from datetime import datetime, timezone +from hashlib import md5 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import fsspec +from loguru import logger from polaris.utils.errors import PolarisHubError from polaris.utils.types import TimeoutTypes @@ -87,10 +87,6 @@ def ls( if timeout is None: timeout = self.default_timeout - cached_listings = self._ls_from_cache(path) - if cached_listings is not None: - return cached_listings if detail else [d["name"] for d in cached_listings] - ls_path = self.sep.join([self.base_path, "ls", path]) # GET request to Polaris Hub to list objects in path @@ -145,11 +141,36 @@ def cat_file( if response.status_code != 307: raise PolarisHubError("Could not get signed URL from Polaris Hub.") - signed_url = response.json()["url"] + hub_response_body = response.json() + signed_url = hub_response_body["url"] + + headers = {"Content-Type": "application/octet-stream", **hub_response_body["headers"]} + + response = self.polaris_client.request( + url=signed_url, + method="GET", + auth=None, + headers=headers, + timeout=timeout, + ) + response.raise_for_status() + response_content = response.content + + # Verify the checksum on download + expected_md5sum = self.polaris_client.get_metadata_from_response(response, "md5sum") + if expected_md5sum is None: + raise PolarisHubError("MD5 checksum not found in response headers.") + logger.debug(f"MD5 checksum found in response headers: {expected_md5sum}.") + + md5sum = md5(response_content).hexdigest() + logger.debug(f"MD5 checksum computed for response content: {md5sum}.") - with fsspec.open(signed_url, "rb", **kwargs) as f: - data = f.read() - return data[start:end] + if md5sum != expected_md5sum: + raise PolarisHubError( + f"MD5 checksum verification failed. Expected {expected_md5sum}, got {md5sum}." + ) + + return response_content[start:end] def rm(self, path: str, recursive: bool = False, maxdepth: Optional[int] = None) -> None: """Remove a file or directory from the Polaris dataset. @@ -202,14 +223,7 @@ def pipe_file( hub_response_body = response.json() signed_url = hub_response_body["url"] - sha256_hash = hashlib.sha256(content).hexdigest() - - headers = { - "Content-Type": "application/octet-stream", - "x-amz-content-sha256": sha256_hash, - "x-amz-date": datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ"), - **hub_response_body["headers"], - } + headers = {"Content-Type": "application/octet-stream", **hub_response_body["headers"]} response = self.polaris_client.request( url=signed_url, diff --git a/polaris/hub/settings.py b/polaris/hub/settings.py index e1ba7752..1780a4d3 100644 --- a/polaris/hub/settings.py +++ b/polaris/hub/settings.py @@ -34,6 +34,7 @@ class PolarisHubSettings(BaseSettings): # Hub settings hub_url: HttpUrlString = "https://polarishub.io/" api_url: HttpUrlString | None = None + custom_metadata_prefix: str = "X-Amz-Meta-" # Hub authentication settings hub_token_url: HttpUrlString | None = None diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 5bebf291..49fea5bb 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -9,9 +9,11 @@ ) from polaris.dataset import Dataset, create_dataset_from_file from polaris.hub.client import PolarisHubClient +from polaris.utils.misc import should_verify_checksum +from polaris.utils.types import ChecksumStrategy -def load_dataset(path: str, verify_checksum: bool = True) -> Dataset: +def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr") -> Dataset: """ Loads a Polaris dataset. @@ -37,12 +39,20 @@ def load_dataset(path: str, verify_checksum: bool = True) -> Dataset: client = PolarisHubClient() return client.get_dataset(*path.split("/"), verify_checksum=verify_checksum) + # Load from local file if extension == "json": - return Dataset.from_json(path) - return create_dataset_from_file(path) + dataset = Dataset.from_json(path) + else: + dataset = create_dataset_from_file(path) + # Verify checksum if requested + if should_verify_checksum(verify_checksum, dataset): + dataset.verify_checksum() -def load_benchmark(path: str, verify_checksum: bool = True): + return dataset + + +def load_benchmark(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr"): """ Loads a Polaris benchmark. @@ -75,4 +85,11 @@ def load_benchmark(path: str, verify_checksum: bool = True): # e.g. we might end up with a single class per benchmark. is_single_task = isinstance(data["target_cols"], str) or len(data["target_cols"]) == 1 cls = SingleTaskBenchmarkSpecification if is_single_task else MultiTaskBenchmarkSpecification - return cls.from_json(path) + + benchmark = cls.from_json(path) + + # Verify checksum if requested + if should_verify_checksum(verify_checksum, benchmark.dataset): + benchmark.verify_checksum() + + return benchmark diff --git a/polaris/utils/errors.py b/polaris/utils/errors.py index 1cb46581..5d4494c1 100644 --- a/polaris/utils/errors.py +++ b/polaris/utils/errors.py @@ -33,3 +33,7 @@ class TestAccessError(Exception): __test__ = False pass + + +class InvalidZarrChecksum(Exception): + pass diff --git a/polaris/utils/httpx.py b/polaris/utils/httpx.py deleted file mode 100644 index 8c2a3709..00000000 --- a/polaris/utils/httpx.py +++ /dev/null @@ -1,35 +0,0 @@ -from httpx import Response - - -def _log_response(response: Response) -> str: - """ - Fully logs a request/response pair for HTTPX. - Used for debugging purposes. - """ - req_prefix = "< " - res_prefix = "> " - request = response.request - output = [f"{req_prefix}{request.method} {request.url}"] - - for name, value in request.headers.items(): - output.append(f"{req_prefix}{name}: {value}") - - output.append(req_prefix) - - if isinstance(request.content, (str, bytes)): - output.append(f"{req_prefix}{request.content}") - else: - output.append("<< Request body is not a string-like type >>") - - output.append("") - - output.append(f"{res_prefix} {response.status_code} {response.reason_phrase}") - - for name, value in response.headers.items(): - output.append(f"{res_prefix}{name}: {value}") - - output.append(res_prefix) - - output.append(f"{res_prefix}{response.text}") - - return "\n".join(output) diff --git a/polaris/utils/misc.py b/polaris/utils/misc.py index 98eb0be6..9a8199eb 100644 --- a/polaris/utils/misc.py +++ b/polaris/utils/misc.py @@ -1,6 +1,9 @@ -from typing import Any +from typing import TYPE_CHECKING, Any -from polaris.utils.types import SlugCompatibleStringType +from polaris.utils.types import ChecksumStrategy, SlugCompatibleStringType + +if TYPE_CHECKING: + from polaris.dataset import Dataset def listit(t: Any): @@ -16,3 +19,15 @@ def sluggify(sluggable: SlugCompatibleStringType): Converts a string to a slug-compatible string. """ return sluggable.lower().replace("_", "-") + + +def should_verify_checksum(strategy: ChecksumStrategy, dataset: "Dataset") -> bool: + """ + Determines whether a checksum should be verified. + """ + if strategy == "ignore": + return False + elif strategy == "verify": + return True + else: + return not dataset.uses_zarr diff --git a/polaris/utils/types.py b/polaris/utils/types.py index 45dc0332..fbee9d0e 100644 --- a/polaris/utils/types.py +++ b/polaris/utils/types.py @@ -107,6 +107,11 @@ Type to specify which action to take when encountering existing files within a Zarr archive. """ +ChecksumStrategy: TypeAlias = Literal["verify", "verify_unless_zarr", "ignore"] +""" +Type to specify which action to take to verify the data integrity of an artifact through a checksum. +""" + class HubOwner(BaseModel): """An owner of an artifact on the Polaris Hub diff --git a/pyproject.toml b/pyproject.toml index e640063c..6f42c52a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,8 +113,14 @@ data_file = ".coverage/coverage" omit = [ "polaris/__init__.py", "polaris/_version.py", + # We cannot yet test the interaction with the Hub. + # See e.g. https://github.com/polaris-hub/polaris/issues/30 "polaris/hub/client.py", + "polaris/hub/external_auth_client.py", + "polaris/hub/oauth2.py", "polaris/hub/settings.py", + "polaris/hub/polarisfs.py", + "polaris/hub/__init__.py", "polaris/hub/__init__.py", ] diff --git a/tests/conftest.py b/tests/conftest.py index aa02c3f7..2170a62a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,7 +60,7 @@ def test_user_owner(): return HubOwner(userId="test-user", slug="test-user") -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def test_dataset(test_data, test_org_owner): dataset = Dataset( table=test_data, @@ -81,8 +81,8 @@ def test_dataset(test_data, test_org_owner): def zarr_archive(tmp_path): tmp_path = fs.join(tmp_path, "data.zarr") root = zarr.open(tmp_path, mode="w") - root.array("A", data=np.random.random((100, 2048))) - root.array("B", data=np.random.random((100, 2048))) + root.array("A", data=np.random.random((100, 2048)), chunks=(1, None)) + root.array("B", data=np.random.random((100, 2048)), chunks=(1, None)) zarr.consolidate_metadata(root.store) return tmp_path diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index d3ccca2e..b959739c 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -6,7 +6,6 @@ MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification, ) -from polaris.utils.errors import PolarisChecksumError @pytest.mark.parametrize("is_single_task", [True, False]) @@ -133,6 +132,10 @@ def test_benchmark_checksum(is_single_task, test_single_task_benchmark, test_mul obj = test_single_task_benchmark if is_single_task else test_multi_task_benchmark cls = SingleTaskBenchmarkSpecification if is_single_task else MultiTaskBenchmarkSpecification + # Make sure the `md5sum` is part of the model dump even if not initiated yet. + # This is important for uploads to the Hub. + assert obj._md5sum is None and "md5sum" in obj.model_dump() + original = obj.md5sum assert original is not None @@ -140,34 +143,31 @@ def test_benchmark_checksum(is_single_task, test_single_task_benchmark, test_mul # Without any changes, same hash kwargs = obj.model_dump() - cls(**kwargs) + assert cls(**kwargs).md5sum == original # With a different ordering of the target columns kwargs["target_cols"] = kwargs["target_cols"][::-1] - cls(**kwargs) + assert cls(**kwargs).md5sum == original # With a different ordering of the metrics kwargs["metrics"] = kwargs["metrics"][::-1] - cls(**kwargs) + assert cls(**kwargs).md5sum == original # With a different ordering of the split kwargs["split"] = kwargs["split"][0][::-1], kwargs["split"][1] - cls(**kwargs) + assert cls(**kwargs).md5sum == original # --- Test that the checksum is NOT the same --- def _check_for_failure(_kwargs): - with pytest.raises((ValidationError, TypeError)) as error: - cls(**_kwargs) - assert error.error_count() == 1 # noqa - assert isinstance(error.errors()[0], PolarisChecksumError) # noqa + assert cls(**_kwargs).md5sum != _kwargs["md5sum"] # Split kwargs = obj.model_dump() - kwargs["split"] = kwargs["split"][0][1:] + [-1], kwargs["split"][1] + kwargs["split"] = kwargs["split"][0][1:], kwargs["split"][1] _check_for_failure(kwargs) kwargs = obj.model_dump() - kwargs["split"] = kwargs["split"][0], kwargs["split"][1][1:] + [-1] + kwargs["split"] = kwargs["split"][0], kwargs["split"][1][1:] _check_for_failure(kwargs) # Metrics @@ -188,3 +188,17 @@ def _check_for_failure(_kwargs): kwargs["md5sum"] = None dataset = cls(**kwargs) assert dataset.md5sum is not None + + +def test_setting_an_invalid_checksum(test_single_task_benchmark): + """Test whether setting an invalid checksum raises an error.""" + with pytest.raises(ValueError): + test_single_task_benchmark.md5sum = "invalid" + + +def test_checksum_verification(test_single_task_benchmark): + """Test whether setting an invalid checksum raises an error.""" + test_single_task_benchmark.verify_checksum() + test_single_task_benchmark.md5sum = "0" * 32 + with pytest.raises(ValueError): + test_single_task_benchmark.verify_checksum() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a78919f2..db4336e1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -5,35 +5,9 @@ import pytest import zarr from datamol.utils import fs -from pydantic import ValidationError from polaris.dataset import Dataset, Subset, create_dataset_from_file from polaris.loader import load_dataset -from polaris.utils.errors import PolarisChecksumError - - -def _equality_test(dataset_1, dataset_2): - """ - Utility function. - - When saving a dataset to a different location, it should be considered the same dataset - but currently the dataset checksum is used for equality and with pointer columns, - the checksum uses the file path, not the file content (which thus changes when saving). - - See also: https://github.com/polaris-hub/polaris/issues/16 - """ - if dataset_1 == dataset_2: - return True - if len(dataset_1) != len(dataset_2): - return False - if (dataset_1.table.columns != dataset_2.table.columns).all(): - return False - - for i in range(len(dataset_1)): - for col in dataset_1.table.columns: - if (dataset_1.get_data(row=i, col=col) != dataset_2.get_data(row=i, col=col)).all(): - return False - return True @pytest.mark.parametrize("with_caching", [True, False]) @@ -75,43 +49,35 @@ def test_load_data(tmp_path, with_slice, with_caching): def test_dataset_checksum(test_dataset): """Test whether the checksum is a good indicator of whether the dataset has changed in a meaningful way.""" - original = test_dataset.md5sum - assert original is not None + # Make sure the `md5sum` is part of the model dump even if not initiated yet. + # This is important for uploads to the Hub. + assert test_dataset._md5sum is None + assert "md5sum" in test_dataset.model_dump() # Without any changes, same hash kwargs = test_dataset.model_dump() - Dataset(**kwargs) + assert Dataset(**kwargs) == test_dataset # With unimportant changes, same hash kwargs["name"] = "changed" kwargs["description"] = "changed" kwargs["source"] = "https://changed.com" - Dataset(**kwargs) + assert Dataset(**kwargs) == test_dataset # Check sensitivity to the row and column ordering kwargs["table"] = kwargs["table"].iloc[::-1] kwargs["table"] = kwargs["table"][kwargs["table"].columns[::-1]] - Dataset(**kwargs) - - def _check_for_failure(_kwargs): - with pytest.raises(ValidationError) as error: - Dataset(**_kwargs) - assert error.error_count() == 1 # noqa - assert isinstance(error.errors()[0], PolarisChecksumError) # noqa + assert Dataset(**kwargs) == test_dataset # Without any changes, but different hash - kwargs["md5sum"] = "invalid" - _check_for_failure(kwargs) + dataset = Dataset(**kwargs) + dataset._md5sum = "invalid" + assert dataset != test_dataset # With changes, but same hash - kwargs["md5sum"] = original + kwargs["md5sum"] = test_dataset.md5sum kwargs["table"] = kwargs["table"].iloc[:-1] - _check_for_failure(kwargs) - - # With changes, but no hash - kwargs["md5sum"] = None - dataset = Dataset(**kwargs) - assert dataset.md5sum is not None + assert Dataset(**kwargs) != test_dataset def test_dataset_from_zarr(zarr_archive, tmpdir): @@ -132,10 +98,10 @@ def test_dataset_from_json(test_dataset, tmpdir): path = fs.join(str(tmpdir), "dataset.json") new_dataset = Dataset.from_json(path) - assert _equality_test(test_dataset, new_dataset) + assert test_dataset == new_dataset new_dataset = load_dataset(path) - assert _equality_test(test_dataset, new_dataset) + assert test_dataset == new_dataset def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmpdir): @@ -152,24 +118,23 @@ def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmpdir): path = dataset.to_json(json_dir) new_dataset = Dataset.from_json(path) - assert _equality_test(dataset, new_dataset) + assert dataset == new_dataset new_dataset = load_dataset(path) - assert _equality_test(dataset, new_dataset) + assert dataset == new_dataset def test_dataset_caching(zarr_archive, tmpdir): """Test whether the dataset remains the same after caching.""" - archive = zarr_archive - original_dataset = create_dataset_from_file(archive, tmpdir.join("original1")) - cached_dataset = create_dataset_from_file(archive, tmpdir.join("original2")) + original_dataset = create_dataset_from_file(zarr_archive, tmpdir.join("original1")) + cached_dataset = create_dataset_from_file(zarr_archive, tmpdir.join("original2")) assert original_dataset == cached_dataset cache_dir = cached_dataset.cache(tmpdir.join("cached").strpath) assert cached_dataset.zarr_root_path.startswith(cache_dir) - assert _equality_test(cached_dataset, original_dataset) + assert cached_dataset == original_dataset def test_dataset_index(): @@ -198,3 +163,17 @@ def test_dataset_in_memory_optimization(zarr_archive, tmpdir): d2 = perf_counter() - t2 assert d2 < d1 + + +def test_setting_an_invalid_checksum(test_dataset): + """Test whether setting an invalid checksum raises an error.""" + with pytest.raises(ValueError): + test_dataset.md5sum = "invalid" + + +def test_checksum_verification(test_dataset): + """Test whether setting an invalid checksum raises an error.""" + test_dataset.verify_checksum() + test_dataset.md5sum = "0" * 32 + with pytest.raises(ValueError): + test_dataset.verify_checksum() diff --git a/tests/test_integration.py b/tests/test_integration.py index 5d9a983b..5a6c1f69 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,6 +1,6 @@ import datamol as dm import numpy as np -from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from polaris.evaluate import BenchmarkResults @@ -57,7 +57,6 @@ def test_single_task_benchmark_clf_loop_with_multiple_test_sets( y_prob = {} y_pred = {} for k, test_subset in test.items(): - print(k, test_subset) x_test = np.array([dm.to_fp(dm.to_mol(smi)) for smi in test_subset.inputs]) y_prob[k] = model.predict_proba(x_test)[:, :1] # for binary classification y_pred[k] = model.predict(x_test) diff --git a/tests/test_zarr_checksum.py b/tests/test_zarr_checksum.py new file mode 100644 index 00000000..9bcf1765 --- /dev/null +++ b/tests/test_zarr_checksum.py @@ -0,0 +1,163 @@ +""" +The code in this file is based on the zarr-checksum package + +Mainted by Jacob Nesbitt, released under the DANDI org on Github +and with Kitware, Inc. credited as the author. This code is released +with the Apache 2.0 license. + +See also: https://github.com/dandi/zarr_checksum + +Instead of adding the package as a dependency, we opted to copy over the code +because it is a small and self-contained module that we will want to alter to +support our Polaris code base. + +NOTE: We have made some modifications to the original code. + +---- + +Copyright 2023 Kitware, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import uuid +from pathlib import Path +from shutil import copytree, rmtree + +import pytest +import zarr + +from polaris.dataset.zarr._checksum import ( + EMPTY_CHECKSUM, + InvalidZarrChecksum, + _ZarrChecksum, + _ZarrChecksumManifest, + _ZarrChecksumTree, + _ZarrDirectoryDigest, + compute_zarr_checksum, +) + + +def test_generate_digest() -> None: + manifest = _ZarrChecksumManifest( + directories=[_ZarrChecksum(digest="a7e86136543b019d72468ceebf71fb8e-1-1", name="a/b", size=1)], + files=[_ZarrChecksum(digest="92eb5ffee6ae2fec3ad71c777531578f-0-1", name="b", size=1)], + ) + assert manifest.generate_digest().digest == "9c5294e46908cf397cb7ef53ffc12efc-1-2" + + +def test_zarr_checksum_sort_order() -> None: + # The a < b in the name should take precedence over z > y in the md5 + a = _ZarrChecksum(name="a", digest="z", size=3) + b = _ZarrChecksum(name="b", digest="y", size=4) + assert sorted([b, a]) == [a, b] + + +def test_parse_zarr_directory_digest() -> None: + # Parse valid + _ZarrDirectoryDigest.parse("c228464f432c4376f0de6ddaea32650c-37481-38757151179") + _ZarrDirectoryDigest.parse(None) + + # Ensure exception is raised + with pytest.raises(InvalidZarrChecksum): + _ZarrDirectoryDigest.parse("asd") + with pytest.raises(InvalidZarrChecksum): + _ZarrDirectoryDigest.parse("asd-0--0") + + +def test_pop_deepest() -> None: + tree = _ZarrChecksumTree() + tree.add_leaf(Path("a/b"), size=1, digest="asd") + tree.add_leaf(Path("a/b/c"), size=1, digest="asd") + node = tree.pop_deepest() + + # Assert popped node is a/b/c, not a/b + assert str(node.path) == "a/b" + assert len(node.checksums.files) == 1 + assert len(node.checksums.directories) == 0 + assert node.checksums.files[0].name == "c" + + +def test_process_empty_tree() -> None: + tree = _ZarrChecksumTree() + assert tree.process().digest == EMPTY_CHECKSUM + + +def test_process_tree() -> None: + tree = _ZarrChecksumTree() + tree.add_leaf(Path("a/b"), size=1, digest="9dd4e461268c8034f5c8564e155c67a6") + tree.add_leaf(Path("c"), size=1, digest="415290769594460e2e485922904f345d") + checksum = tree.process() + + # This zarr checksum was computed against the same file structure using the previous + # zarr checksum implementation + # Assert the current implementation produces a matching checksum + assert checksum.digest == "e53fcb7b5c36b2f4647fbf826a44bdc9-2-2" + + +def test_checksum_for_zarr_archive(zarr_archive, tmpdir): + # NOTE: This test was not in the original code base of the zarr-checksum package. + checksum, _ = compute_zarr_checksum(zarr_archive) + + path = tmpdir.join("copy") + copytree(zarr_archive, path) + assert checksum == compute_zarr_checksum(str(path))[0] + + root = zarr.open(path) + root["A"][0:10] = 0 + assert checksum != compute_zarr_checksum(str(path))[0] + + +def test_zarr_leaf_to_checksum(zarr_archive): + # NOTE: This test was not in the original code base of the zarr-checksum package. + _, leaf_to_checksum = compute_zarr_checksum(zarr_archive) + root = zarr.open(zarr_archive) + + # Check the basic structure - Each key corresponds to a file in the zarr archive + assert len(leaf_to_checksum) == len(root.store) + assert all(k.path in root.store for k in leaf_to_checksum) + + +def test_zarr_checksum_fails_for_remote_storage(zarr_archive): + # NOTE: This test was not in the original code base of the zarr-checksum package. + with pytest.raises(RuntimeError): + compute_zarr_checksum("s3://bucket/data.zarr") + with pytest.raises(RuntimeError): + compute_zarr_checksum("gs://bucket/data.zarr") + + +def test_zarr_checksum_with_path_normalization(zarr_archive): + # NOTE: This test was not in the original code base of the zarr-checksum package. + + baseline = compute_zarr_checksum(zarr_archive)[0] + rootdir = os.path.dirname(zarr_archive) + + # Test a relative path + copytree(zarr_archive, os.path.join(rootdir, "relative", "data.zarr")) + compute_zarr_checksum(f"{zarr_archive}/../relative/data.zarr")[0] == baseline + + # Test with variables + rng_id = str(uuid.uuid4()) + os.environ["TMP_TEST_DIR"] = rng_id + copytree(zarr_archive, os.path.join(rootdir, "vars", rng_id)) + compute_zarr_checksum(f"{rootdir}/vars/${{TMP_TEST_DIR}}")[0] == baseline # Format ${...} + compute_zarr_checksum(f"{rootdir}/vars/$TMP_TEST_DIR")[0] == baseline # Format $... + + # And with the user abbreviation + try: + path = os.path.expanduser("~/data.zarr") + copytree(zarr_archive, path) + compute_zarr_checksum("~/data.zarr")[0] == baseline + finally: + rmtree(path)