|
| 1 | +""" |
| 2 | +Tools for grouping features and labels so that we can compute metrics on the individual groups |
| 3 | +""" |
| 4 | +from collections import defaultdict |
| 5 | +from typing import Dict, List, Tuple, Union |
| 6 | +try: |
| 7 | + from typing import Literal |
| 8 | +except ImportError: |
| 9 | + from typing_extensions import Literal |
| 10 | + |
| 11 | +from labelbox.data.annotation_types import Label |
| 12 | +from labelbox.data.annotation_types.collection import LabelList |
| 13 | +from labelbox.data.annotation_types.feature import FeatureSchema |
| 14 | + |
| 15 | + |
| 16 | +def get_identifying_key( |
| 17 | + features_a: List[FeatureSchema], features_b: List[FeatureSchema] |
| 18 | +) -> Union[Literal['name'], Literal['feature_schema_id']]: |
| 19 | + """ |
| 20 | + Checks to make sure that features in both sets contain the same type of identifying keys. |
| 21 | + This can either be the feature name or feature schema id. |
| 22 | +
|
| 23 | + Args: |
| 24 | + features_a : List of FeatureSchemas (usually ObjectAnnotations or ClassificationAnnotations) |
| 25 | + features_b : List of FeatureSchemas (usually ObjectAnnotations or ClassificationAnnotations) |
| 26 | + Returns: |
| 27 | + The field name that is present in both feature lists. |
| 28 | + """ |
| 29 | + |
| 30 | + all_schema_ids_defined_pred, all_names_defined_pred = all_have_key( |
| 31 | + features_a) |
| 32 | + if (not all_schema_ids_defined_pred and not all_names_defined_pred): |
| 33 | + raise ValueError("All data must have feature_schema_ids or names set") |
| 34 | + |
| 35 | + all_schema_ids_defined_gt, all_names_defined_gt = all_have_key(features_b) |
| 36 | + |
| 37 | + # Prefer name becuse the user will be able to know what it means |
| 38 | + # Schema id incase that doesn't exist. |
| 39 | + if (all_names_defined_pred and all_names_defined_gt): |
| 40 | + return 'name' |
| 41 | + elif all_schema_ids_defined_pred and all_schema_ids_defined_gt: |
| 42 | + return 'feature_schema_id' |
| 43 | + else: |
| 44 | + raise ValueError( |
| 45 | + "Ground truth and prediction annotations must have set all name or feature ids. " |
| 46 | + "Otherwise there is no key to match on. Please update.") |
| 47 | + |
| 48 | + |
| 49 | +def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: |
| 50 | + """ |
| 51 | + Checks to make sure that all FeatureSchemas have names set or feature_schema_ids set. |
| 52 | +
|
| 53 | + Args: |
| 54 | + features (List[FeatureSchema]) : |
| 55 | +
|
| 56 | + """ |
| 57 | + all_names = True |
| 58 | + all_schemas = True |
| 59 | + for feature in features: |
| 60 | + if feature.name is None: |
| 61 | + all_names = False |
| 62 | + if feature.feature_schema_id is None: |
| 63 | + all_schemas = False |
| 64 | + return all_schemas, all_names |
| 65 | + |
| 66 | + |
| 67 | +def get_label_pairs(labels_a: LabelList, |
| 68 | + labels_b: LabelList, |
| 69 | + match_on="uid", |
| 70 | + filter=False) -> Dict[str, Tuple[Label, Label]]: |
| 71 | + """ |
| 72 | + This is a function to pairing a list of prediction labels and a list of ground truth labels easier. |
| 73 | + There are a few potentiall problems with this function. |
| 74 | + We are assuming that the data row `uid` or `external id` have been provided by the user. |
| 75 | + However, these particular fields are not required and can be empty. |
| 76 | + If this assumption fails, then the user has to determine their own matching strategy. |
| 77 | +
|
| 78 | + Args: |
| 79 | + labels_a (LabelList): A collection of labels to match with labels_b |
| 80 | + labels_b (LabelList): A collection of labels to match with labels_a |
| 81 | + match_on ('uid' or 'external_id'): The data row key to match labels by. Can either be uid or external id. |
| 82 | + filter (bool): Whether or not to ignore mismatches |
| 83 | +
|
| 84 | + Returns: |
| 85 | + A dict containing the union of all either uids or external ids and values as a tuple of the matched labels |
| 86 | +
|
| 87 | + """ |
| 88 | + |
| 89 | + if match_on not in ['uid', 'external_id']: |
| 90 | + raise ValueError("Can only match on `uid` or `exteranl_id`.") |
| 91 | + |
| 92 | + label_lookup_a = { |
| 93 | + getattr(label.data, match_on, None): label for label in labels_a |
| 94 | + } |
| 95 | + label_lookup_b = { |
| 96 | + getattr(label.data, match_on, None): label for label in labels_b |
| 97 | + } |
| 98 | + all_keys = set(label_lookup_a.keys()).union(label_lookup_b.keys()) |
| 99 | + if None in label_lookup_a or None in label_lookup_b: |
| 100 | + raise ValueError( |
| 101 | + f"One or more of the labels has a data row without the required key {match_on}." |
| 102 | + " It cannot be determined which labels match without this information." |
| 103 | + f" Either assign {match_on} to each Label or create your own pairing function." |
| 104 | + ) |
| 105 | + pairs = defaultdict(list) |
| 106 | + for key in all_keys: |
| 107 | + a, b = label_lookup_a.pop(key, None), label_lookup_b.pop(key, None) |
| 108 | + if a is None or b is None: |
| 109 | + if not filter: |
| 110 | + raise ValueError( |
| 111 | + f"{match_on} {key} is not available in both LabelLists. " |
| 112 | + "Set `filter = True` to filter out these examples, assign the ids manually, or create your own matching function." |
| 113 | + ) |
| 114 | + else: |
| 115 | + continue |
| 116 | + pairs[key].append([a, b]) |
| 117 | + return pairs |
| 118 | + |
| 119 | + |
| 120 | +def get_feature_pairs( |
| 121 | + features_a: List[FeatureSchema], features_b: List[FeatureSchema] |
| 122 | +) -> Dict[str, Tuple[List[FeatureSchema], List[FeatureSchema]]]: |
| 123 | + """ |
| 124 | + Matches features by schema_ids |
| 125 | +
|
| 126 | + Args: |
| 127 | + labels_a (List[FeatureSchema]): A list of features to match with features_b |
| 128 | + labels_b (List[FeatureSchema]): A list of features to match with features_a |
| 129 | + Returns: |
| 130 | + The matched features as dict. The key will be the feature name and the value will be |
| 131 | + two lists each containing the matched features from each set. |
| 132 | +
|
| 133 | + """ |
| 134 | + identifying_key = get_identifying_key(features_a, features_b) |
| 135 | + lookup_a, lookup_b = _create_feature_lookup( |
| 136 | + features_a, |
| 137 | + identifying_key), _create_feature_lookup(features_b, identifying_key) |
| 138 | + |
| 139 | + keys = set(lookup_a.keys()).union(set(lookup_b.keys())) |
| 140 | + result = defaultdict(list) |
| 141 | + for key in keys: |
| 142 | + result[key].extend([lookup_a[key], lookup_b[key]]) |
| 143 | + return result |
| 144 | + |
| 145 | + |
| 146 | +def _create_feature_lookup(features: List[FeatureSchema], |
| 147 | + key: str) -> Dict[str, List[FeatureSchema]]: |
| 148 | + """ |
| 149 | + Groups annotation by name (if available otherwise feature schema id). |
| 150 | +
|
| 151 | + Args: |
| 152 | + annotations: List of annotations to group |
| 153 | + Returns: |
| 154 | + a dict where each key is the feature_schema_id (or name) |
| 155 | + and the value is a list of annotations that have that feature_schema_id (or name) |
| 156 | + """ |
| 157 | + grouped_features = defaultdict(list) |
| 158 | + for feature in features: |
| 159 | + grouped_features[getattr(feature, key)].append(feature) |
| 160 | + return grouped_features |
0 commit comments