diff --git a/pyproject.toml b/pyproject.toml index e2c7b58d65..5a882b3a00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "neo>=0.14.0", "probeinterface>=0.2.23", "packaging", + "pydantic", ] [build-system] diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 47ce8cf848..ca8c731040 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -232,8 +232,13 @@ def random_spikes_selection( def apply_merges_to_sorting( - sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append" -): + sorting: BaseSorting, + merge_unit_groups: list[list[int | str]] | list[tuple[int | str]], + new_unit_ids: list[int | str] | None = None, + censor_ms: float | None = None, + return_extra: bool = False, + new_id_strategy: str = "append", +) -> NumpySorting | tuple[NumpySorting, np.ndarray, list[int | str]]: """ Apply a resolved representation of the merges to a sorting object. @@ -245,9 +250,9 @@ def apply_merges_to_sorting( Parameters ---------- - sorting : Sorting + sorting : BaseSorting The Sorting object to apply merges. - merge_unit_groups : list/tuple of lists/tuples + merge_unit_groups : list of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids : list | None, default: None diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index bb4ee4db1c..a53b4c5cb9 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1140,18 +1140,18 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin def merge_units( self, - merge_unit_groups, - new_unit_ids=None, - censor_ms=None, - merging_mode="soft", - sparsity_overlap=0.75, - new_id_strategy="append", - return_new_unit_ids=False, - format="memory", - folder=None, - verbose=False, + merge_unit_groups: list[list[str | int]] | list[tuple[str | int]], + new_unit_ids: list[int | str] | None = None, + censor_ms: float | None = None, + merging_mode: str = "soft", + sparsity_overlap: float = 0.75, + new_id_strategy: str = "append", + return_new_unit_ids: bool = False, + format: str = "memory", + folder: Path | str | None = None, + verbose: bool = False, **job_kwargs, - ) -> "SortingAnalyzer": + ) -> "SortingAnalyzer | tuple[SortingAnalyzer, list[int | str]]": """ This method is equivalent to `save_as()` but with a list of merges that have to be achieved. Merges units by creating a new SortingAnalyzer object with the appropriate merges diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 80f251ca43..ea1c19e719 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,14 +1,12 @@ -from itertools import combinations +from __future__ import annotations import numpy as np from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting -import copy +from spikeinterface.curation.curation_model import CurationModel -supported_curation_format_versions = {"1"} - -def validate_curation_dict(curation_dict): +def validate_curation_dict(curation_dict: dict): """ Validate that the curation dictionary given as parameter complies with the format @@ -19,119 +17,11 @@ def validate_curation_dict(curation_dict): curation_dict : dict """ + # this will validate the format of the curation_dict + CurationModel(**curation_dict) - # format - if "format_version" not in curation_dict: - raise ValueError("No version_format") - - if curation_dict["format_version"] not in supported_curation_format_versions: - raise ValueError( - f"Format version ({curation_dict['format_version']}) not supported. " - f"Only {supported_curation_format_versions} are valid" - ) - - # unit_ids - labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) - merged_units_set = set(sum(curation_dict["merge_unit_groups"], [])) - removed_units_set = set(curation_dict["removed_units"]) - - if curation_dict["unit_ids"] is not None: - # old format v0 did not contain unit_ids so this can contains None - unit_set = set(curation_dict["unit_ids"]) - if not labeled_unit_set.issubset(unit_set): - raise ValueError("Curation format: some labeled units are not in the unit list") - if not merged_units_set.issubset(unit_set): - raise ValueError("Curation format: some merged units are not in the unit list") - if not removed_units_set.issubset(unit_set): - raise ValueError("Curation format: some removed units are not in the unit list") - - for group in curation_dict["merge_unit_groups"]: - if len(group) < 2: - raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") - - all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] - for gp_1, gp_2 in combinations(all_merging_groups, 2): - if len(gp_1.intersection(gp_2)) != 0: - raise ValueError("Curation format: some units belong to multiple merge groups") - if len(removed_units_set.intersection(merged_units_set)) != 0: - raise ValueError("Curation format: some units were merged and deleted") - - # Check the labels exclusivity - for lbl in curation_dict["manual_labels"]: - for label_key in curation_dict["label_definitions"].keys(): - if label_key in lbl: - unit_id = lbl["unit_id"] - label_value = lbl[label_key] - if not isinstance(label_value, list): - raise ValueError(f"Curation format: manual_labels {unit_id} is invalid shoudl be a list") - - is_exclusive = curation_dict["label_definitions"][label_key]["exclusive"] - - if is_exclusive and not len(label_value) <= 1: - raise ValueError( - f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" - ) - - -def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_format="1"): - """ - Converts the old sortingview curation format (v0) into a curation dictionary new format (v1) - Couple of caveats: - * The list of units is not available in the original sortingview dictionary. We set it to None - * Labels can not be mutually exclusive. - * Labels have no category, so we regroup them under the "all_labels" category - - Parameters - ---------- - sortingview_dict : dict - Dictionary containing the curation information from sortingview - destination_format : str - Version of the format to use. - Default to "1" - - Returns - ------- - curation_dict : dict - A curation dictionary - """ - - assert destination_format == "1" - if "mergeGroups" not in sortingview_dict.keys(): - sortingview_dict["mergeGroups"] = [] - merge_groups = sortingview_dict["mergeGroups"] - merged_units = sum(merge_groups, []) - first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) - if str.isdigit(first_unit_id): - unit_id_type = int - else: - unit_id_type = str - - all_units = [] - all_labels = [] - manual_labels = [] - general_cat = "all_labels" - for unit_id_, l_labels in sortingview_dict["labelsByUnit"].items(): - all_labels.extend(l_labels) - # recorver the correct type for unit_id - unit_id = unit_id_type(unit_id_) - all_units.append(unit_id) - manual_labels.append({"unit_id": unit_id, general_cat: l_labels}) - labels_def = {"all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False}} - - curation_dict = { - "format_version": destination_format, - "unit_ids": None, - "label_definitions": labels_def, - "manual_labels": manual_labels, - "merge_unit_groups": merge_groups, - "removed_units": [], - } - - return curation_dict - - -def curation_label_to_vectors(curation_dict): +def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into dict of vectors. For label category with exclusive=True : a column is created and values are the unique label. @@ -141,66 +31,46 @@ def curation_label_to_vectors(curation_dict): Parameters ---------- - curation_dict : dict - A curation dictionary + curation_dict : dict or CurationModel + A curation dictionary or model Returns ------- labels : dict of numpy vector """ - unit_ids = list(curation_dict["unit_ids"]) + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + unit_ids = list(curation_model.unit_ids) n = len(unit_ids) labels = {} - for label_key, label_def in curation_dict["label_definitions"].items(): - if label_def["exclusive"]: + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: assert label_key not in labels, f"{label_key} is already a key" labels[label_key] = [""] * n - for lbl in curation_dict["manual_labels"]: - value = lbl.get(label_key, []) - if len(value) == 1: - unit_index = unit_ids.index(lbl["unit_id"]) - labels[label_key][unit_index] = value[0] + for manual_label in curation_model.manual_labels: + values = manual_label.labels.get(label_key, []) + if len(values) == 1: + unit_index = unit_ids.index(manual_label.unit_id) + labels[label_key][unit_index] = values[0] labels[label_key] = np.array(labels[label_key]) else: - for label_opt in label_def["label_options"]: + for label_opt in label_def.label_options: assert label_opt not in labels, f"{label_opt} is already a key" labels[label_opt] = np.zeros(n, dtype=bool) - for lbl in curation_dict["manual_labels"]: - values = lbl.get(label_key, []) + for manual_label in curation_model.manual_labels: + values = manual_label.labels.get(label_key, []) for value in values: - unit_index = unit_ids.index(lbl["unit_id"]) + unit_index = unit_ids.index(manual_label.unit_id) labels[value][unit_index] = True - return labels -def clean_curation_dict(curation_dict): - """ - In some cases the curation_dict can have inconsistencies (like in the sorting view format). - For instance, some unit_ids are both in 'merge_unit_groups' and 'removed_units'. - This is ambiguous! - - This cleaner helper function ensures units tagged as `removed_units` are removed from the `merge_unit_groups` - """ - curation_dict = copy.deepcopy(curation_dict) - - clean_merge_unit_groups = [] - for group in curation_dict["merge_unit_groups"]: - clean_group = [] - for unit_id in group: - if unit_id not in curation_dict["removed_units"]: - clean_group.append(unit_id) - if len(clean_group) > 1: - clean_merge_unit_groups.append(clean_group) - - curation_dict["merge_unit_groups"] = clean_merge_unit_groups - return curation_dict - - -def curation_label_to_dataframe(curation_dict): +def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into a pandas dataframe. For label category with exclusive=True : a column is created and values are the unique label. @@ -220,11 +90,18 @@ def curation_label_to_dataframe(curation_dict): """ import pandas as pd - labels = pd.DataFrame(curation_label_to_vectors(curation_dict), index=curation_dict["unit_ids"]) + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + labels = pd.DataFrame(curation_label_to_vectors(curation_model), index=curation_model.unit_ids) return labels -def apply_curation_labels(sorting, new_unit_ids, curation_dict): +def apply_curation_labels( + sorting: BaseSorting, new_unit_ids: list[int, str], curation_dict_or_model: dict | CurationModel +): """ Apply manual labels after merges. @@ -233,25 +110,30 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): * for merged group, when exclusive=True, if all have the same label then this label is applied * for merged group, when exclusive=False, if one unit has the label then the new one have also it """ + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model # Please note that manual_labels is done on the unit_ids before the merge!!! - manual_labels = curation_label_to_vectors(curation_dict) + manual_labels = curation_label_to_vectors(curation_model) # apply on non merged for key, values in manual_labels.items(): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): if unit_id not in new_unit_ids: - ind = list(curation_dict["unit_ids"]).index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - for new_unit_id, old_group_ids in zip(new_unit_ids, curation_dict["merge_unit_groups"]): - for label_key, label_def in curation_dict["label_definitions"].items(): - if label_def["exclusive"]: + for new_unit_id, merge in zip(new_unit_ids, curation_model.merges): + old_group_ids = merge.merge_unit_group + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: group_values = [] for unit_id in old_group_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) value = manual_labels[label_key][ind] if value != "": group_values.append(value) @@ -260,10 +142,10 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: - for key in label_def["label_options"]: + for key in label_def.label_options: group_values = [] for unit_id in old_group_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) value = manual_labels[key][ind] group_values.append(value) new_value = np.any(group_values) @@ -271,21 +153,21 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): def apply_curation( - sorting_or_analyzer, - curation_dict, - censor_ms=None, - new_id_strategy="append", - merging_mode="soft", - sparsity_overlap=0.75, - verbose=False, + sorting_or_analyzer: BaseSorting | SortingAnalyzer, + curation_dict_or_model: dict | CurationModel, + censor_ms: float | None = None, + new_id_strategy: str = "append", + merging_mode: str = "soft", + sparsity_overlap: float = 0.75, + verbose: bool = False, **job_kwargs, ): """ Apply curation dict to a Sorting or a SortingAnalyzer. Steps are done in this order: - 1. Apply removal using curation_dict["removed_units"] - 2. Apply merges using curation_dict["merge_unit_groups"] + 1. Apply removal using curation_dict["removed"] + 2. Apply merges using curation_dict["merges"] 3. Set labels using curation_dict["manual_labels"] A new Sorting or SortingAnalyzer (in memory) is returned. @@ -294,17 +176,18 @@ def apply_curation( Parameters ---------- sorting_or_analyzer : Sorting | SortingAnalyzer - The Sorting object to apply merges. - curation_dict : dict - The curation dict. + The Sorting or SortingAnalyzer object to apply merges. + curation_dict : dict or CurationModel + The curation dict or model. censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of as the desired refractory period. If `censor_ms=None`, no spikes are discarded. - new_id_strategy : "append" | "take_first", default: "append" + new_id_strategy : "append" | "take_first" | "join", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges + * "join" : new_unit_ids will be the concatenation of all unit_ids of every list of merges merging_mode : "soft" | "hard", default: "soft" How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately @@ -324,30 +207,37 @@ def apply_curation( """ - validate_curation_dict(curation_dict) - if not np.array_equal(np.asarray(curation_dict["unit_ids"]), sorting_or_analyzer.unit_ids): + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids): raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer") if isinstance(sorting_or_analyzer, BaseSorting): sorting = sorting_or_analyzer - sorting = sorting.remove_units(curation_dict["removed_units"]) - sorting, _, new_unit_ids = apply_merges_to_sorting( - sorting, - curation_dict["merge_unit_groups"], - censor_ms=censor_ms, - return_extra=True, - new_id_strategy=new_id_strategy, - ) - apply_curation_labels(sorting, new_unit_ids, curation_dict) + sorting = sorting.remove_units(curation_model.removed) + if len(curation_model.merges) > 0: + sorting, _, new_unit_ids = apply_merges_to_sorting( + sorting, + merge_unit_groups=[m.merge_unit_group for m in curation_model.merges], + censor_ms=censor_ms, + return_extra=True, + new_id_strategy=new_id_strategy, + ) + else: + new_unit_ids = [] + apply_curation_labels(sorting, new_unit_ids, curation_model) return sorting elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - if len(curation_dict["removed_units"]) > 0: - analyzer = analyzer.remove_units(curation_dict["removed_units"]) - if len(curation_dict["merge_unit_groups"]) > 0: + if len(curation_model.removed) > 0: + analyzer = analyzer.remove_units(curation_model.removed) + if len(curation_model.removed) > 0: analyzer, new_unit_ids = analyzer.merge_units( - curation_dict["merge_unit_groups"], + merge_unit_groups=[m.merge_unit_group for m in curation_model.merges], censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, @@ -359,7 +249,7 @@ def apply_curation( ) else: new_unit_ids = [] - apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) + apply_curation_labels(analyzer.sorting, new_unit_ids, curation_model) return analyzer else: raise TypeError( diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py new file mode 100644 index 0000000000..afcbef17d5 --- /dev/null +++ b/src/spikeinterface/curation/curation_model.py @@ -0,0 +1,360 @@ +from pydantic import BaseModel, Field, model_validator, field_validator +from typing import List, Dict, Union, Optional, Literal, Tuple +from itertools import chain, combinations +import numpy as np + + +class LabelDefinition(BaseModel): + name: str = Field(..., description="Name of the label") + label_options: List[str] = Field(..., description="List of possible label options", min_length=2) + exclusive: bool = Field(..., description="Whether the label is exclusive") + + +class ManualLabel(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + labels: Dict[str, List[str]] = Field(..., description="Dictionary of labels for the unit") + + +class Merge(BaseModel): + merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged") + merge_new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") + + +class Split(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + split_mode: Literal["indices", "labels"] = Field( + default="indices", + description=( + "Mode of the split. The split can be defined by indices or labels. " + "If indices, the split is defined by the a list of lists of indices of spikes within spikes " + "belonging to the unit (`split_indices`). " + "If labels, the split is defined by a list of labels for each spike (`split_labels`). " + ), + ) + split_indices: Optional[Union[List[List[int]]]] = Field(default=None, description="List of indices for the split") + split_labels: Optional[List[int]] = Field(default=None, description="List of labels for the split") + split_new_unit_ids: Optional[List[Union[int, str]]] = Field( + default=None, description="List of new unit IDs for each split" + ) + + +class CurationModel(BaseModel): + supported_versions: Tuple[Literal["1"], Literal["2"]] = Field( + default=["1", "2"], description="Supported versions of the curation format" + ) + format_version: str = Field(..., description="Version of the curation format") + unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") + label_definitions: Optional[Dict[str, LabelDefinition]] = Field( + default=None, description="Dictionary of label definitions" + ) + manual_labels: Optional[List[ManualLabel]] = Field(default=None, description="List of manual labels") + removed: Optional[List[Union[int, str]]] = Field(default=None, description="List of removed unit IDs") + merges: Optional[List[Merge]] = Field(default=None, description="List of merges") + splits: Optional[List[Split]] = Field(default=None, description="List of splits") + + @field_validator("label_definitions", mode="before") + def add_label_definition_name(cls, label_definitions): + if label_definitions is None: + return {} + if isinstance(label_definitions, dict): + label_definitions = dict(label_definitions) + for key in list(label_definitions.keys()): + if isinstance(label_definitions[key], dict): + label_definitions[key] = dict(label_definitions[key]) + label_definitions[key]["name"] = key + return label_definitions + return label_definitions + + @classmethod + def check_manual_labels(cls, values): + + unit_ids = list(values["unit_ids"]) + manual_labels = values.get("manual_labels") + if manual_labels is None: + values["manual_labels"] = [] + else: + manual_labels = list(manual_labels) + for i, manual_label in enumerate(manual_labels): + manual_label = dict(manual_label) + unit_id = manual_label["unit_id"] + labels = manual_label.get("labels") + if labels is None: + labels = set(manual_label.keys()) - {"unit_id"} + manual_label["labels"] = {} + else: + manual_label["labels"] = {k: list(v) for k, v in labels.items()} + for label in labels: + if label not in values["label_definitions"]: + raise ValueError(f"Manual label {unit_id} has an unknown label {label}") + if label not in manual_label["labels"]: + if label in manual_label: + manual_label["labels"][label] = list(manual_label[label]) + else: + raise ValueError(f"Manual label {unit_id} has no value for label {label}") + if unit_id not in unit_ids: + raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") + manual_labels[i] = manual_label + values["manual_labels"] = manual_labels + return values + + @classmethod + def check_merges(cls, values): + + unit_ids = list(values["unit_ids"]) + merges = values.get("merges") + if merges is None: + values["merges"] = [] + return values + + if isinstance(merges, dict): + # Convert dict format to list of Merge objects + merge_list = [] + for merge_new_id, merge_group in merges.items(): + merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_id": merge_new_id}) + merges = merge_list + + # Make a copy of the list + merges = list(merges) + + # Convert items to Merge objects + for i, merge in enumerate(merges): + if isinstance(merge, list): + merge = {"merge_unit_group": list(merge)} + if isinstance(merge, dict): + merge = dict(merge) + if "merge_unit_group" in merge: + merge["merge_unit_group"] = list(merge["merge_unit_group"]) + merges[i] = Merge(**merge) + + # Validate merges + for merge in merges: + # Check unit ids exist + for unit_id in merge.merge_unit_group: + if unit_id not in unit_ids: + raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list") + + # Check minimum group size + if len(merge.merge_unit_group) < 2: + raise ValueError("Merge unit groups must have at least 2 elements") + + # Check new unit id not already used + if merge.merge_new_unit_id is not None: + if merge.merge_new_unit_id in unit_ids: + raise ValueError(f"New unit ID {merge.merge_new_unit_id} is already in the unit list") + + values["merges"] = merges + return values + + @classmethod + def check_splits(cls, values): + + unit_ids = list(values["unit_ids"]) + splits = values.get("splits") + if splits is None: + values["splits"] = [] + return values + + # Convert dict format to list format + if isinstance(splits, dict): + split_list = [] + for unit_id, split_data in splits.items(): + if isinstance(split_data[0], (list, np.ndarray)) if split_data else False: + split_list.append( + { + "unit_id": unit_id, + "split_mode": "indices", + "split_indices": [list(indices) for indices in split_data], + } + ) + else: + split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": list(split_data)}) + splits = split_list + + # Make a copy of the list + splits = list(splits) + + # Convert items to Split objects + for i, split in enumerate(splits): + if isinstance(split, dict): + split = dict(split) + if "split_indices" in split: + split["split_indices"] = [list(indices) for indices in split["split_indices"]] + if "split_labels" in split: + split["split_labels"] = list(split["split_labels"]) + if "split_new_unit_ids" in split: + split["split_new_unit_ids"] = list(split["split_new_unit_ids"]) + splits[i] = Split(**split) + + # Validate splits + for split in splits: + # Check unit exists + if split.unit_id not in unit_ids: + raise ValueError(f"Split unit_id {split.unit_id} is not in the unit list") + + # Validate based on mode + if split.split_mode == "indices": + if split.split_indices is None: + raise ValueError(f"Split unit {split.unit_id} has no split_indices defined") + if len(split.split_indices) < 1: + raise ValueError(f"Split unit {split.unit_id} has empty split_indices") + # Check no duplicate indices + all_indices = list(chain.from_iterable(split.split_indices)) + if len(all_indices) != len(set(all_indices)): + raise ValueError(f"Split unit {split.unit_id} has duplicate indices") + + elif split.split_mode == "labels": + if split.split_labels is None: + raise ValueError(f"Split unit {split.unit_id} has no split_labels defined") + if len(split.split_labels) == 0: + raise ValueError(f"Split unit {split.unit_id} has empty split_labels") + + # Validate new unit IDs + if split.split_new_unit_ids is not None: + if split.split_mode == "indices": + if len(split.split_new_unit_ids) != len(split.split_indices): + raise ValueError( + f"Number of new unit IDs does not match number of splits for unit {split.unit_id}" + ) + elif split.split_mode == "labels": + if len(split.split_new_unit_ids) != len(set(split.split_labels)): + raise ValueError( + f"Number of new unit IDs does not match number of unique labels for unit {split.unit_id}" + ) + + for new_id in split.split_new_unit_ids: + if new_id in unit_ids: + raise ValueError(f"New unit ID {new_id} is already in the unit list") + + values["splits"] = splits + return values + + @classmethod + def check_removed(cls, values): + unit_ids = list(values["unit_ids"]) + removed = values.get("removed") + if removed is None: + values["removed"] = [] + else: + removed = list(removed) + for unit_id in removed: + if unit_id not in unit_ids: + raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + values["removed"] = removed + return values + + @classmethod + def convert_old_format(cls, values): + format_version = values.get("format_version", "0") + if format_version == "0": + print("Conversion from format version v0 (sortingview) to v2") + if "mergeGroups" not in values.keys(): + values["mergeGroups"] = [] + merge_groups = values["mergeGroups"] + + first_unit_id = next(iter(values["labelsByUnit"].keys())) + if str.isdigit(first_unit_id): + unit_id_type = int + else: + unit_id_type = str + + all_units = [] + all_labels = [] + manual_labels = [] + general_cat = "all_labels" + for unit_id_, l_labels in values["labelsByUnit"].items(): + all_labels.extend(l_labels) + unit_id = unit_id_type(unit_id_) + if unit_id not in all_units: + all_units.append(unit_id) + manual_labels.append({"unit_id": unit_id, general_cat: list(l_labels)}) + labels_def = { + "all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False} + } + for merge_group in merge_groups: + all_units.extend(merge_group) + all_units = list(set(all_units)) + + values = { + "format_version": "2", + "unit_ids": values.get("unit_ids", all_units), + "label_definitions": labels_def, + "manual_labels": list(manual_labels), + "merges": [{"merge_unit_group": list(group)} for group in merge_groups], + "splits": [], + "removed": [], + } + elif values["format_version"] == "1": + merge_unit_groups = values.get("merge_unit_groups") + if merge_unit_groups is not None: + values["merges"] = [{"merge_unit_group": list(group)} for group in merge_unit_groups] + removed_units = values.get("removed_units") + if removed_units is not None: + values["removed"] = list(removed_units) + return values + + @model_validator(mode="before") + def validate_fields(cls, values): + values = dict(values) + values["label_definitions"] = values.get("label_definitions", {}) + values = cls.convert_old_format(values) + values = cls.check_manual_labels(values) + values = cls.check_merges(values) + values = cls.check_splits(values) + values = cls.check_removed(values) + return values + + @model_validator(mode="after") + def validate_curation_dict(cls, values): + if values.format_version not in values.supported_versions: + raise ValueError( + f"Format version {values.format_version} not supported. Only {values.supported_versions} are valid" + ) + + labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) if values.manual_labels else set() + merged_units_set = ( + set(chain.from_iterable(merge.merge_unit_group for merge in values.merges)) if values.merges else set() + ) + split_units_set = set(split.unit_id for split in values.splits) if values.splits else set() + removed_set = set(values.removed) if values.removed else set() + unit_ids = values.unit_ids + + unit_set = set(unit_ids) + if not labeled_unit_set.issubset(unit_set): + raise ValueError("Curation format: some labeled units are not in the unit list") + if not merged_units_set.issubset(unit_set): + raise ValueError("Curation format: some merged units are not in the unit list") + if not split_units_set.issubset(unit_set): + raise ValueError("Curation format: some split units are not in the unit list") + if not removed_set.issubset(unit_set): + raise ValueError("Curation format: some removed units are not in the unit list") + + # Check for units being merged multiple times + all_merging_groups = [set(merge.merge_unit_group) for merge in values.merges] if values.merges else [] + for gp_1, gp_2 in combinations(all_merging_groups, 2): + if len(gp_1.intersection(gp_2)) != 0: + raise ValueError("Curation format: some units belong to multiple merge groups") + + # Check no overlaps between operations + if len(removed_set.intersection(merged_units_set)) != 0: + raise ValueError("Curation format: some units were merged and deleted") + if len(removed_set.intersection(split_units_set)) != 0: + raise ValueError("Curation format: some units were split and deleted") + if len(merged_units_set.intersection(split_units_set)) != 0: + raise ValueError("Curation format: some units were both merged and split") + + for manual_label in values.manual_labels: + for label_key in values.label_definitions.keys(): + if label_key in manual_label.labels: + unit_id = manual_label.unit_id + label_value = manual_label.labels[label_key] + if not isinstance(label_value, list): + raise ValueError(f"Curation format: manual_labels {unit_id} is invalid should be a list") + + is_exclusive = values.label_definitions[label_key].exclusive + + if is_exclusive and not len(label_value) <= 1: + raise ValueError( + f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" + ) + + return values diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f33051309c..fe21b72263 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -7,13 +7,11 @@ import numpy as np from pathlib import Path -from .curationsorting import CurationSorting from .curation_format import ( - convert_from_sortingview_curation_format_v0, apply_curation, curation_label_to_vectors, - clean_curation_dict, ) +from .curation_model import CurationModel, Merge def get_kachery(): @@ -82,15 +80,14 @@ def apply_sortingview_curation( except: raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") - # convert to new format - if "format_version" not in curation_dict: - curation_dict = convert_from_sortingview_curation_format_v0(curation_dict) - unit_ids = sorting_or_analyzer.unit_ids + curation_dict["unit_ids"] = unit_ids + curation_model = CurationModel(**curation_dict) - # this is a hack because it was not in the old format - curation_dict["unit_ids"] = list(unit_ids) + if skip_merge: + curation_model.merges = [] + # this is a hack because it was not in the old format if exclude_labels is not None: assert include_labels is None, "Use either `include_labels` or `exclude_labels` to filter units." manual_labels = curation_label_to_vectors(curation_dict) @@ -99,7 +96,7 @@ def apply_sortingview_curation( remove_mask = manual_labels[k] removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) - curation_dict["removed_units"] = removed_units + curation_model.removed = removed_units if include_labels is not None: manual_labels = curation_label_to_vectors(curation_dict) @@ -108,122 +105,21 @@ def apply_sortingview_curation( remove_mask = ~manual_labels[k] removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) - curation_dict["removed_units"] = removed_units - - if skip_merge: - curation_dict["merge_unit_groups"] = [] - - # cleaner to ensure validity - curation_dict = clean_curation_dict(curation_dict) - - # apply - sorting_curated = apply_curation(sorting_or_analyzer, curation_dict, new_id_strategy="join") + curation_model.removed = removed_units + + # make merges and removed units + if len(curation_model.removed) > 0: + clean_merges = [] + for merge in curation_model.merges: + clean_merge = [] + for unit_id in merge.merge_unit_group: + if unit_id not in curation_model.removed: + clean_merge.append(unit_id) + if len(clean_merge) > 1: + clean_merges.append(Merge(merge_unit_group=clean_merge)) + curation_model.merges = clean_merges + + # apply curation + sorting_curated = apply_curation(sorting_or_analyzer, curation_model, new_id_strategy="join") return sorting_curated - - -# TODO @alessio you remove this after testing -def apply_sortingview_curation_legacy( - sorting, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=False -): - """ - Apply curation from SortingView manual curation. - First, merges (if present) are applied. Then labels are loaded and units - are optionally filtered based on exclude_labels and include_labels. - - Parameters - ---------- - sorting : BaseSorting - The sorting object to be curated - uri_or_json : str or Path - The URI curation link from SortingView or the path to the curation json file - exclude_labels : list, default: None - Optional list of labels to exclude (e.g. ["reject", "noise"]). - Mutually exclusive with include_labels - include_labels : list, default: None - Optional list of labels to include (e.g. ["accept"]). - Mutually exclusive with exclude_labels, by default None - skip_merge : bool, default: False - If True, merges are not applied (only labels) - verbose : bool, default: False - If True, output is verbose - - Returns - ------- - sorting_curated : BaseSorting - The curated sorting - """ - ka = get_kachery() - curation_sorting = CurationSorting(sorting, make_graph=False, properties_policy="keep") - - # get sorting view curation - if Path(uri_or_json).suffix == ".json" and not str(uri_or_json).startswith("gh://"): - with open(uri_or_json, "r") as f: - sortingview_curation_dict = json.load(f) - else: - try: - sortingview_curation_dict = ka.load_json(uri=uri_or_json) - except: - raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") - - unit_ids_dtype = sorting.unit_ids.dtype - - # STEP 1: merge groups - labels_dict = sortingview_curation_dict["labelsByUnit"] - if "mergeGroups" in sortingview_curation_dict and not skip_merge: - merge_groups = sortingview_curation_dict["mergeGroups"] - for merge_group in merge_groups: - # Store labels of units that are about to be merged - labels_to_inherit = [] - for unit in merge_group: - labels_to_inherit.extend(labels_dict.get(str(unit), [])) - labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates - - if verbose: - print(f"Merging {merge_group}") - if unit_ids_dtype.kind in ("U", "S"): - merge_group = [str(unit) for unit in merge_group] - # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(merge_group) - curation_sorting.merge(merge_group, new_unit_id=new_unit_id) - else: - # in this case, the CurationSorting takes care of finding a new unused int - curation_sorting.merge(merge_group, new_unit_id=None) - new_unit_id = curation_sorting.max_used_id # merged unit id - labels_dict[str(new_unit_id)] = labels_to_inherit - - # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. - # For example, the first 3 units could be labeled as "accept". - # In this case, the first 3 values of the property "accept" will be True, the rest False - - # Initialize the properties dictionary - properties = { - label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for labels in labels_dict.values() - for label in labels - } - - # Populate the properties dictionary - for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - unit_id_str = str(unit_id) - if unit_id_str in labels_dict: - for label in labels_dict[unit_id_str]: - properties[label][unit_index] = True - - for prop_name, prop_values in properties.items(): - curation_sorting.current_sorting.set_property(prop_name, prop_values) - - if include_labels is not None or exclude_labels is not None: - units_to_remove = [] - unit_ids = curation_sorting.current_sorting.unit_ids - assert include_labels or exclude_labels, "Use either `include_labels` or `exclude_labels` to filter units." - if include_labels: - for include_label in include_labels: - units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(include_label) == False]) - if exclude_labels: - for exclude_label in exclude_labels: - units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) - units_to_remove = np.unique(units_to_remove) - curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index af9d8e1eac..c3ed4a115f 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -8,15 +8,34 @@ from spikeinterface.curation.curation_format import ( validate_curation_dict, - convert_from_sortingview_curation_format_v0, curation_label_to_vectors, curation_label_to_dataframe, apply_curation, ) +""" +v1 = { + 'format_version': '1', + 'unit_ids': List[int | str], + 'label_definitions': { + 'category_key1': + { + 'label_options': List[str], + 'exclusive': bool} + }, + 'manual_labels': [ + { + 'unit_id': str or int, + 'category_key1': List[str], + } + ], + 'removed_units': List[int | str] # Can not be in the merged_units + 'merge_unit_groups': List[List[int | str]], # one cell goes into at most one list +} -"""example = { - 'unit_ids': List[str, int], +v2 = { + 'format_version': '2', + 'unit_ids': List[int | int], 'label_definitions': { 'category_key1': { @@ -24,18 +43,38 @@ 'exclusive': bool} }, 'manual_labels': [ - {'unit_id': str or int, - category_key1': List[str], + { + 'unit_id': str | int, + 'category_key1': List[str], } ], - 'merge_unit_groups': List[List[unit_ids]], # one cell goes into at most one list - 'removed_units': List[unit_ids] # Can not be in the merged_units -} -""" + 'removed': List[unit_ids], # Can not be in the merged_units + 'merges': [ + { + 'merge_unit_group': List[unit_ids], + 'merge_new_unit_id': int | str (optional) + } + ], + 'splits': [ + { + 'unit_id': int | str + 'split_mode': 'indices' or 'labels', + 'split_indices': List[List[int]], + 'split_labels': List[int]], + 'split_new_unit_ids': List[int | str] + } + ] + +sortingview_curation = { + 'mergeGroups': List[List[int | str]], + 'labelsByUnit': { + 'unit_id': List[str] + } +""" curation_ids_int = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], "label_definitions": { "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, @@ -48,19 +87,21 @@ {"unit_id": 1, "quality": ["good"]}, { "unit_id": 2, - "quality": [ - "noise", - ], + "quality": ["noise"], "putative_type": ["excitatory", "pyramidal"], }, {"unit_id": 3, "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list - "removed_units": [31, 42], # Can not be in the merged_units + "merges": [{"merge_unit_group": [3, 6]}, {"merge_unit_group": [10, 14, 20]}], + "splits": [], + "removed": [31, 42], } +# Test dictionary format for merges +curation_ids_int_dict = {**curation_ids_int, "merges": {50: [3, 6], 51: [10, 14, 20]}} + curation_ids_str = { - "format_version": "1", + "format_version": "2", "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], "label_definitions": { "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, @@ -73,40 +114,65 @@ {"unit_id": "u1", "quality": ["good"]}, { "unit_id": "u2", - "quality": [ - "noise", - ], + "quality": ["noise"], "putative_type": ["excitatory", "pyramidal"], }, {"unit_id": "u3", "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list - "removed_units": ["u31", "u42"], # Can not be in the merged_units + "merges": [{"merge_unit_group": ["u3", "u6"]}, {"merge_unit_group": ["u10", "u14", "u20"]}], + "splits": [], + "removed": ["u31", "u42"], } -# This is a failure example with duplicated merge -duplicate_merge = curation_ids_int.copy() -duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] +# Test dictionary format for merges with string IDs +curation_ids_str_dict = {**curation_ids_str, "merges": {"u50": ["u3", "u6"], "u51": ["u10", "u14", "u20"]}} + +# Test with splits +curation_with_splits = { + **curation_ids_int, + "splits": [ + {"unit_id": 2, "split_mode": "indices", "split_indices": [[0, 1, 2], [3, 4, 5]], "split_new_unit_ids": [50, 51]} + ], +} + +# Test dictionary format for splits +curation_with_splits_dict = {**curation_ids_int, "splits": {2: [[0, 1, 2], [3, 4, 5]]}} +# This is a failure example with duplicated merge +duplicate_merge = {**curation_ids_int, "merges": [{"merge_unit_group": [3, 6, 10]}, {"merge_unit_group": [10, 14, 20]}]} # This is a failure example with unit 3 both in removed and merged -merged_and_removed = curation_ids_int.copy() -merged_and_removed["merge_unit_groups"] = [[3, 6], [10, 14, 20]] -merged_and_removed["removed_units"] = [3, 31, 42] +merged_and_removed = { + **curation_ids_int, + "merges": [{"merge_unit_group": [3, 6]}, {"merge_unit_group": [10, 14, 20]}], + "removed": [3, 31, 42], +} -# this is a failure because unit 99 is not in the initial list -unknown_merged_unit = curation_ids_int.copy() -unknown_merged_unit["merge_unit_groups"] = [[3, 6, 99], [10, 14, 20]] +# This is a failure because unit 99 is not in the initial list +unknown_merged_unit = { + **curation_ids_int, + "merges": [{"merge_unit_group": [3, 6, 99]}, {"merge_unit_group": [10, 14, 20]}], +} -# this is a failure because unit 99 is not in the initial list -unknown_removed_unit = curation_ids_int.copy() -unknown_removed_unit["removed_units"] = [31, 42, 99] +# This is a failure because unit 99 is not in the initial list +unknown_removed_unit = {**curation_ids_int, "removed": [31, 42, 99]} def test_curation_format_validation(): + # Test basic formats + print(curation_ids_int) validate_curation_dict(curation_ids_int) + print(curation_ids_int) validate_curation_dict(curation_ids_str) + # Test dictionary formats + validate_curation_dict(curation_ids_int_dict) + validate_curation_dict(curation_ids_str_dict) + + # Test splits + validate_curation_dict(curation_with_splits) + validate_curation_dict(curation_with_splits_dict) + with pytest.raises(ValueError): # Raised because duplicated merged units validate_curation_dict(duplicate_merge) @@ -122,13 +188,13 @@ def test_curation_format_validation(): def test_to_from_json(): - json.loads(json.dumps(curation_ids_int, indent=4)) json.loads(json.dumps(curation_ids_str, indent=4)) + json.loads(json.dumps(curation_ids_int_dict, indent=4)) + json.loads(json.dumps(curation_with_splits, indent=4)) def test_convert_from_sortingview_curation_format_v0(): - parent_folder = Path(__file__).parent for filename in ( "sv-sorting-curation.json", @@ -136,18 +202,13 @@ def test_convert_from_sortingview_curation_format_v0(): "sv-sorting-curation-str.json", "sv-sorting-curation-false-positive.json", ): - json_file = parent_folder / filename with open(json_file, "r") as f: curation_v0 = json.load(f) - # print(curation_v0) - curation_v1 = convert_from_sortingview_curation_format_v0(curation_v0) - # print(curation_v1) - validate_curation_dict(curation_v1) + validate_curation_dict(curation_v0) def test_curation_label_to_vectors(): - labels = curation_label_to_vectors(curation_ids_int) assert "quality" in labels assert "excitatory" in labels @@ -158,35 +219,45 @@ def test_curation_label_to_vectors(): def test_curation_label_to_dataframe(): - df = curation_label_to_dataframe(curation_ids_int) assert "quality" in df.columns assert "excitatory" in df.columns print(df) df = curation_label_to_dataframe(curation_ids_str) - # print(df) + print(df) def test_apply_curation(): recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) - sorting._main_ids = np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]) + sorting = sorting.rename_units([1, 2, 3, 6, 10, 14, 20, 31, 42]) analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + # Test with list format sorting_curated = apply_curation(sorting, curation_ids_int) assert sorting_curated.get_property("quality", ids=[1])[0] == "good" assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" assert sorting_curated.get_property("excitatory", ids=[2])[0] + # Test with dictionary format + sorting_curated = apply_curation(sorting, curation_ids_int_dict) + assert sorting_curated.get_property("quality", ids=[1])[0] == "good" + assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" + assert sorting_curated.get_property("excitatory", ids=[2])[0] + + # Test with splits + sorting_curated = apply_curation(sorting, curation_with_splits) + assert sorting_curated.get_property("quality", ids=[1])[0] == "good" + + # Test analyzer analyzer_curated = apply_curation(analyzer, curation_ids_int) assert "quality" in analyzer_curated.sorting.get_property_keys() if __name__ == "__main__": - # test_curation_format_validation() - # test_to_from_json() - # test_convert_from_sortingview_curation_format_v0() - # test_curation_label_to_vectors() - # test_curation_label_to_dataframe() - + test_curation_format_validation() + test_to_from_json() + test_convert_from_sortingview_curation_format_v0() + test_curation_label_to_vectors() + test_curation_label_to_dataframe() test_apply_curation() diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py new file mode 100644 index 0000000000..7354ac1892 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -0,0 +1,288 @@ +import pytest + +from pydantic import ValidationError +import numpy as np + +from spikeinterface.curation.curation_model import CurationModel, LabelDefinition + + +# Test data for format version +def test_format_version(): + # Valid format version + CurationModel(format_version="1", unit_ids=[1, 2, 3]) + + # Invalid format version + with pytest.raises(ValidationError): + CurationModel(format_version="3", unit_ids=[1, 2, 3]) + with pytest.raises(ValidationError): + CurationModel(format_version="0.1", unit_ids=[1, 2, 3]) + + +# Test data for label definitions +def test_label_definitions(): + valid_label_def = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + } + + model = CurationModel(**valid_label_def) + assert "quality" in model.label_definitions + assert model.label_definitions["quality"].name == "quality" + assert model.label_definitions["quality"].exclusive is True + + # Test invalid label definition + with pytest.raises(ValidationError): + LabelDefinition(name="quality", label_options=[], exclusive=True) # Empty options should be invalid + + +# Test manual labels +def test_manual_labels(): + valid_labels = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst", "fast"]}}, + {"unit_id": 2, "labels": {"quality": ["noise"]}}, + ], + } + + model = CurationModel(**valid_labels) + assert len(model.manual_labels) == 2 + + # Test invalid unit ID + invalid_unit = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [{"unit_id": 4, "labels": {"quality": ["good"]}}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit) + + # Test violation of exclusive label + invalid_exclusive = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good", "noise"]}} # Multiple values for exclusive label + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_exclusive) + + +# Test merge functionality +def test_merge_units(): + # Test list format + valid_merge = { + "format_version": "2", + "unit_ids": [1, 2, 3, 4], + "merges": [ + {"merge_unit_group": [1, 2], "merge_new_unit_id": 5}, + {"merge_unit_group": [3, 4], "merge_new_unit_id": 6}, + ], + } + + model = CurationModel(**valid_merge) + assert len(model.merges) == 2 + assert model.merges[0].merge_new_unit_id == 5 + assert model.merges[1].merge_new_unit_id == 6 + + # Test dictionary format + valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} + + model = CurationModel(**valid_merge_dict) + assert len(model.merges) == 2 + merge_new_ids = {merge.merge_new_unit_id for merge in model.merges} + assert merge_new_ids == {5, 6} + + # Test list format + valid_merge_list = { + "format_version": "2", + "unit_ids": [1, 2, 3, 4], + "merges": [[1, 2], [3, 4]], # Merge each pair into a new unit + } + model = CurationModel(**valid_merge_list) + assert len(model.merges) == 2 + + # Test invalid merge group (single unit) + invalid_merge_group = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "merges": [{"merge_unit_group": [1], "merge_new_unit_id": 4}], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_group) + + # Test overlapping merge groups + invalid_overlap = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "merges": [ + {"merge_unit_group": [1, 2], "merge_new_unit_id": 4}, + {"merge_unit_group": [2, 3], "merge_new_unit_id": 5}, + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_overlap) + + +# Test split functionality +def test_split_units(): + # Test indices mode with list format + valid_split_indices = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": [ + { + "unit_id": 1, + "split_mode": "indices", + "split_indices": [[0, 1, 2], [3, 4, 5]], + "split_new_unit_ids": [4, 5], + } + ], + } + + model = CurationModel(**valid_split_indices) + assert len(model.splits) == 1 + assert model.splits[0].split_mode == "indices" + assert len(model.splits[0].split_indices) == 2 + + # Test labels mode with list format + valid_split_labels = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": [ + {"unit_id": 1, "split_mode": "labels", "split_labels": [0, 0, 1, 1, 0, 2], "split_new_unit_ids": [4, 5, 6]} + ], + } + + model = CurationModel(**valid_split_labels) + assert len(model.splits) == 1 + assert model.splits[0].split_mode == "labels" + assert len(set(model.splits[0].split_labels)) == 3 + + # Test dictionary format with indices + valid_split_dict = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": { + 1: [[0, 1, 2], [3, 4, 5]], # Split unit 1 into two parts + 2: [[0, 1], [2, 3], [4, 5]], # Split unit 2 into three parts + }, + } + + model = CurationModel(**valid_split_dict) + assert len(model.splits) == 2 + assert all(split.split_mode == "indices" for split in model.splits) + + # Test invalid unit ID + invalid_unit_id = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": [{"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]]}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit_id) + + # Test invalid new unit IDs count for indices mode + invalid_new_ids = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": [ + { + "unit_id": 1, + "split_mode": "indices", + "split_indices": [[0, 1], [2, 3]], + "split_new_unit_ids": [4], # Should have 2 new IDs for 2 splits + } + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_new_ids) + + +# Test removed units +def test_removed_units(): + valid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [2]} + + model = CurationModel(**valid_remove) + assert len(model.removed) == 1 + + # Test removing non-existent unit + invalid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit + with pytest.raises(ValidationError): + CurationModel(**invalid_remove) + + # Test conflict between merge and remove + invalid_merge_remove = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_id": 4}], + "removed": [1], # Unit is both merged and removed + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_remove) + + +# Test complete model with multiple operations +def test_complete_model(): + complete_model = { + "format_version": "2", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merges": [{"merge_unit_group": [2, 3], "merge_new_unit_id": 6}], + "splits": [ + {"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]], "split_new_unit_ids": [7, 8]} + ], + "removed": [5], + } + + model = CurationModel(**complete_model) + assert model.format_version == "2" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merges) == 1 + assert len(model.splits) == 1 + assert len(model.removed) == 1 + + # Test dictionary format for complete model + complete_model_dict = { + "format_version": "2", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merges": {6: [2, 3]}, + "splits": {4: [[0, 1], [2, 3]]}, + "removed": [5], + } + + model = CurationModel(**complete_model_dict) + assert model.format_version == "2" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merges) == 1 + assert len(model.splits) == 1 + assert len(model.removed) == 1