Skip to content

Commit d28c6e4

Browse files
cwognumAndrewq11
andauthored
XL Datasets: Minimal Zarr-only dataset implementation (#186)
* Extracted common interface between V1 and V2 * Skeleton structure for tests and Dataset V2. Small changes to shared API * Implemented the test cases Test-driven development! Yeah * Basic test cases passed Now the fun starts... * Added additional validation * Improved docs * Fixed some reference errors in the docs * Disable use of iloc to loc mapping for Dataset V2 * Updated import to prevent circular import * Ruff check and format * Adding new Zarr manifest generation to DatasetV2 class (#185) * updates for calculating zarr manifests & adding basic tests for it * moving cache_dir assignment to DatasetV1 and DatasetV2 model validators * Updating argument types for parquet utils * Updating argument types for md5 util * fixing DatasetV1 export & dataset model validators * PR feedback updates * Adding test that checks the length of the manifest after update * PR feedback * fixing code check test * Move code to dataset base class * Addressed most feedback on the PR, still need to revisit the __getitem__ method * Worked on the __getitem__ method * Address special case of pointer columns * Renamed md5sum to zarr_manifest_md5sum for clarity, remove equality test from the v2 dataset and moved the verify_checksum parameter to v1 * Fix missing import * Added PR feedback * Update decorators --------- Co-authored-by: Andrew Quirke <[email protected]> Co-authored-by: Andrew Quirke <[email protected]>
1 parent 03ad4b7 commit d28c6e4

26 files changed

+1224
-421
lines changed

docs/api/dataset.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
options:
33
filters: ["!^_"]
44

5+
---
6+
7+
::: polaris.dataset._base.BaseDataset
8+
options:
9+
filters: ["!^_"]
10+
511
---
612

713
::: polaris.dataset.ColumnAnnotation

docs/quickstart.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ dataset.get_data(
8282
# Or, similarly:
8383
dataset[dataset.rows[0], dataset.columns[0]]
8484

85-
# Get the first 10 rows in memory
86-
dataset[:10]
85+
# Get an entire row
86+
dataset[0]
8787
```
8888

8989
## Core concepts

polaris/benchmark/_base.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from itertools import chain
21
import json
32
from hashlib import md5
3+
from itertools import chain
44
from typing import Any, Callable, Optional, Union
55

66
import fsspec
@@ -18,11 +18,11 @@
1818
from sklearn.utils.multiclass import type_of_target
1919

2020
from polaris._artifact import BaseArtifactModel
21-
from polaris.mixins import ChecksumMixin
22-
from polaris.dataset import Dataset, Subset, CompetitionDataset
21+
from polaris.dataset import CompetitionDataset, DatasetV1, Subset
2322
from polaris.evaluate import BenchmarkResults, Metric
2423
from polaris.evaluate.utils import evaluate_benchmark
2524
from polaris.hub.settings import PolarisHubSettings
25+
from polaris.mixins import ChecksumMixin
2626
from polaris.utils.dict2html import dict2html
2727
from polaris.utils.errors import InvalidBenchmarkError
2828
from polaris.utils.misc import listit
@@ -96,7 +96,7 @@ class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin):
9696

9797
# Public attributes
9898
# Data
99-
dataset: Union[Dataset, CompetitionDataset, str, dict[str, Any]]
99+
dataset: Union[DatasetV1, CompetitionDataset, str, dict[str, Any]]
100100
target_cols: ColumnsType
101101
input_cols: ColumnsType
102102
split: SplitType
@@ -111,12 +111,11 @@ class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin):
111111
def _validate_dataset(cls, v):
112112
"""
113113
Allows either passing a Dataset object or the kwargs to create one
114-
TODO (cwognum): Allow multiple datasets to be used as part of a benchmark
115114
"""
116115
if isinstance(v, dict):
117-
v = Dataset(**v)
116+
v = DatasetV1(**v)
118117
elif isinstance(v, str):
119-
v = Dataset.from_json(v)
118+
v = DatasetV1.from_json(v)
120119
return v
121120

122121
@field_validator("target_cols", "input_cols")
@@ -162,7 +161,7 @@ def _validate_main_metric(cls, v):
162161
return v
163162

164163
@model_validator(mode="after")
165-
def _validate_split(cls, m: "BenchmarkSpecification"):
164+
def _validate_split(self):
166165
"""
167166
Verifies that:
168167
1) There are no empty test partitions
@@ -171,7 +170,7 @@ def _validate_split(cls, m: "BenchmarkSpecification"):
171170
4) There is no overlap between the train and test set
172171
5) No row exists in the test set where all labels are missing/empty
173172
"""
174-
split = m.split
173+
split = self.split
175174

176175
# Train partition can be empty (zero-shot)
177176
# Test partitions cannot be empty
@@ -214,13 +213,13 @@ def _validate_split(cls, m: "BenchmarkSpecification"):
214213
raise InvalidBenchmarkError("The test set contains duplicate indices")
215214

216215
# All indices are valid given the dataset
217-
dataset = m.dataset
216+
dataset = self.dataset
218217
if dataset is not None:
219218
max_i = len(dataset)
220219
if any(i < 0 or i >= max_i for i in chain(train_idx_list, full_test_idx_set)):
221220
raise InvalidBenchmarkError("The predefined split contains invalid indices")
222221

223-
return m
222+
return self
224223

225224
@field_validator("target_types")
226225
def _validate_target_types(cls, v, info: ValidationInfo):
@@ -234,11 +233,20 @@ def _validate_target_types(cls, v, info: ValidationInfo):
234233

235234
for target in target_cols:
236235
if target not in v:
237-
val = dataset[:, target]
236+
# Skip inferring the target type for pointer columns.
237+
# This would be complex to implement properly.
238+
# For these columns, dataset creators can still manually specify the target type.
239+
anno = dataset.annotations.get(target)
240+
if anno is not None and anno.is_pointer:
241+
v[target] = None
242+
continue
243+
244+
val = dataset.table.loc[:, target]
238245

239246
# Non numeric columns can be targets (e.g. prediction molecular reactions),
240247
# but in that case we currently don't infer the target type.
241248
if not np.issubdtype(val.dtype, np.number):
249+
v[target] = None
242250
continue
243251

244252
# remove the nans for mutiple task dataset when the table is sparse
@@ -254,15 +262,14 @@ def _validate_target_types(cls, v, info: ValidationInfo):
254262
return v
255263

256264
@model_validator(mode="after")
257-
@classmethod
258-
def _validate_model(cls, m: "BenchmarkSpecification"):
265+
def _validate_model(self):
259266
"""
260267
Sets a default metric if missing.
261268
"""
262269
# Set a default main metric if not set yet
263-
if m.main_metric is None:
264-
m.main_metric = m.metrics[0]
265-
return m
270+
if self.main_metric is None:
271+
self.main_metric = self.metrics[0]
272+
return self
266273

267274
@field_serializer("metrics", "main_metric")
268275
def _serialize_metrics(self, v):
@@ -342,9 +349,10 @@ def n_classes(self) -> dict[str, int]:
342349
"""The number of classes for each of the target columns."""
343350
n_classes = {}
344351
for target in self.target_cols:
345-
target_type = self.target_types[target]
352+
target_type = self.target_types.get(target)
346353
if target_type is None or target_type == TargetType.REGRESSION:
347354
continue
355+
# TODO: Don't use table attribute
348356
n_classes[target] = self.dataset.table.loc[:, target].nunique()
349357
return n_classes
350358

polaris/dataset/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from polaris.dataset._column import ColumnAnnotation, Modality, KnownContentType
2-
from polaris.dataset._dataset import Dataset
1+
from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality
2+
from polaris.dataset._competition_dataset import CompetitionDataset
3+
from polaris.dataset._dataset import DatasetV1
4+
from polaris.dataset._dataset import DatasetV1 as Dataset
35
from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files
46
from polaris.dataset._subset import Subset
5-
from polaris.dataset._competition_dataset import CompetitionDataset
67

78
__all__ = [
89
"ColumnAnnotation",
@@ -14,4 +15,5 @@
1415
"DatasetFactory",
1516
"create_dataset_from_file",
1617
"create_dataset_from_files",
18+
"DatasetV1",
1719
]

0 commit comments

Comments
 (0)