Skip to content

Commit 9ca7187

Browse files
authored
Merge pull request #272 from Labelbox/ms/fix-classification-bug
fix classification bug
2 parents 8e1ebf3 + 3a085ec commit 9ca7187

File tree

8 files changed

+91
-21
lines changed

8 files changed

+91
-21
lines changed

labelbox/data/annotation_types/annotation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Any, Dict, List, Union
22

3-
from pydantic.main import BaseModel
4-
53
from .classification import Checklist, Dropdown, Radio, Text
64
from .feature import FeatureSchema
75
from .geometry import Geometry

labelbox/data/annotation_types/classification/classification.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
from typing import Any, Dict, List
22

3-
from pydantic.main import BaseModel
3+
try:
4+
from typing import Literal
5+
except:
6+
from typing_extensions import Literal
47

8+
from pydantic import BaseModel, validator
59
from ..feature import FeatureSchema
610

711

12+
# TODO: Replace when pydantic adds support for unions that don't coerce types
13+
class _TempName(BaseModel):
14+
name: str
15+
16+
def dict(self, *args, **kwargs):
17+
res = super().dict(*args, **kwargs)
18+
res.pop('name')
19+
return res
20+
21+
822
class ClassificationAnswer(FeatureSchema):
923
"""
1024
- Represents a classification option.
@@ -19,8 +33,9 @@ class Radio(BaseModel):
1933
answer: ClassificationAnswer
2034

2135

22-
class Checklist(BaseModel):
36+
class Checklist(_TempName):
2337
""" A classification with many selected options allowed """
38+
name: Literal["checklist"] = "checklist"
2439
answer: List[ClassificationAnswer]
2540

2641

@@ -29,9 +44,10 @@ class Text(BaseModel):
2944
answer: str
3045

3146

32-
class Dropdown(BaseModel):
47+
class Dropdown(_TempName):
3348
"""
3449
- A classification with many selected options allowed .
3550
- This is not currently compatible with MAL.
3651
"""
52+
name: Literal["dropdown"] = "dropdown"
3753
answer: List[ClassificationAnswer]

labelbox/data/serialization/labelbox_v1/classification.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Union
22

33
from pydantic.main import BaseModel
4+
from pydantic.schema import schema
45

56
from ...annotation_types.annotation import ClassificationAnnotation
67
from ...annotation_types.classification import Checklist, ClassificationAnswer, Radio, Text, Dropdown
@@ -15,7 +16,7 @@ class LBV1ClassificationAnswer(LBV1Feature):
1516
class LBV1Radio(LBV1Feature):
1617
answer: LBV1ClassificationAnswer
1718

18-
def to_common(self):
19+
def to_common(self) -> Radio:
1920
return Radio(answer=ClassificationAnswer(
2021
feature_schema_id=self.answer.schema_id,
2122
name=self.answer.title,
@@ -39,7 +40,7 @@ def from_common(cls, radio: Radio, feature_schema_id: Cuid,
3940
class LBV1Checklist(LBV1Feature):
4041
answers: List[LBV1ClassificationAnswer]
4142

42-
def to_common(self):
43+
def to_common(self) -> Checklist:
4344
return Checklist(answer=[
4445
ClassificationAnswer(feature_schema_id=answer.schema_id,
4546
name=answer.title,
@@ -64,6 +65,34 @@ def from_common(cls, checklist: Checklist, feature_schema_id: Cuid,
6465
**extra)
6566

6667

68+
class LBV1Dropdown(LBV1Feature):
69+
answer: List[LBV1ClassificationAnswer]
70+
71+
def to_common(self) -> Dropdown:
72+
return Dropdown(answer=[
73+
ClassificationAnswer(feature_schema_id=answer.schema_id,
74+
name=answer.title,
75+
extra={
76+
'feature_id': answer.feature_id,
77+
'value': answer.value
78+
}) for answer in self.answer
79+
])
80+
81+
@classmethod
82+
def from_common(cls, dropdown: Dropdown, feature_schema_id: Cuid,
83+
**extra) -> "LBV1Dropdown":
84+
return cls(schema_id=feature_schema_id,
85+
answer=[
86+
LBV1ClassificationAnswer(
87+
schema_id=answer.feature_schema_id,
88+
title=answer.name,
89+
value=answer.extra.get('value'),
90+
feature_id=answer.extra.get('feature_id'))
91+
for answer in dropdown.answer
92+
],
93+
**extra)
94+
95+
6796
class LBV1Text(LBV1Feature):
6897
answer: str
6998

@@ -77,7 +106,8 @@ def from_common(cls, text: Text, feature_schema_id: Cuid,
77106

78107

79108
class LBV1Classifications(BaseModel):
80-
classifications: List[Union[LBV1Radio, LBV1Checklist, LBV1Text]] = []
109+
classifications: List[Union[LBV1Text, LBV1Radio, LBV1Dropdown,
110+
LBV1Checklist]] = []
81111

82112
def to_common(self) -> List[ClassificationAnnotation]:
83113
classifications = [
@@ -112,10 +142,10 @@ def from_common(
112142
@staticmethod
113143
def lookup_classification(
114144
annotation: ClassificationAnnotation
115-
) -> Union[LBV1Text, LBV1Checklist, LBV1Radio]:
145+
) -> Union[LBV1Text, LBV1Checklist, LBV1Radio, LBV1Checklist]:
116146
return {
117147
Text: LBV1Text,
118-
Dropdown: LBV1Checklist,
148+
Dropdown: LBV1Dropdown,
119149
Checklist: LBV1Checklist,
120150
Radio: LBV1Radio
121151
}.get(type(annotation.value))

labelbox/data/serialization/labelbox_v1/converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from labelbox.data.serialization.labelbox_v1.objects import LBV1Mask
12
from typing import Any, Dict, Generator, Iterable
23
import logging
34

labelbox/data/serialization/labelbox_v1/objects.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
from typing import Any, Dict, List, Optional, Union
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, validator
44

55
from ...annotation_types.annotation import (ClassificationAnnotation,
66
ObjectAnnotation)
77
from ...annotation_types.data import MaskData
88
from ...annotation_types.geometry import Line, Mask, Point, Polygon, Rectangle
99
from ...annotation_types.ner import TextEntity
1010
from ...annotation_types.types import Cuid
11-
from .classification import LBV1Checklist, LBV1Classifications, LBV1Radio, LBV1Text
11+
from .classification import LBV1Checklist, LBV1Classifications, LBV1Radio, LBV1Text, LBV1Dropdown
1212
from .feature import LBV1Feature
1313

1414

1515
class LBV1ObjectBase(LBV1Feature):
1616
color: Optional[str] = None
1717
instanceURI: Optional[str] = None
18-
classifications: List[Union[LBV1Radio, LBV1Checklist, LBV1Text]] = []
18+
classifications: List[Union[LBV1Text, LBV1Radio, LBV1Dropdown,
19+
LBV1Checklist]] = []
1920

2021
def dict(self, *args, **kwargs):
2122
res = super().dict(*args, **kwargs)
@@ -24,6 +25,14 @@ def dict(self, *args, **kwargs):
2425
res.pop('instanceURI')
2526
return res
2627

28+
@validator('classifications', pre=True)
29+
def validate_subclasses(cls, value, field):
30+
# Dropdown subclasses create extra unessesary nesting. So we just remove it.
31+
if isinstance(value, list) and len(value):
32+
if isinstance(value[0], list):
33+
return value[0]
34+
return value
35+
2736

2837
class _Point(BaseModel):
2938
x: float

labelbox/data/serialization/ndjson/classification.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def to_common(
165165
def lookup_subclassification(
166166
annotation: ClassificationAnnotation
167167
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]:
168-
if isinstance(annotation, Dropdown):
169-
raise TypeError("Dropdowns are not supported for MAL")
168+
if isinstance(annotation.value, Dropdown):
169+
raise TypeError("Dropdowns are not supported for MAL.")
170170
return {
171171
Text: NDTextSubclass,
172172
Checklist: NDChecklistSubclass,
@@ -213,13 +213,12 @@ def lookup_classification(
213213
annotation: Union[ClassificationAnnotation,
214214
VideoClassificationAnnotation]
215215
) -> Union[NDText, NDChecklist, NDRadio]:
216-
if isinstance(annotation, Dropdown):
217-
raise TypeError("Dropdowns are not supported for MAL")
216+
if isinstance(annotation.value, Dropdown):
217+
raise TypeError("Dropdowns are not supported for MAL.")
218218
return {
219219
Text: NDText,
220220
Checklist: NDChecklist,
221-
Radio: NDRadio,
222-
Dropdown: NDChecklist,
221+
Radio: NDRadio
223222
}.get(type(annotation.value))
224223

225224

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"ID": "ckt3hu4aq524s0y9e7i0dbbtg", "DataRow ID": "ckt3h8c7e200j0y6ob2g3629b", "Labeled Data": "https://storage.labelbox.com/ckk4q1vgapsau07324awnsjq2%2Fa8c0d364-b10f-5b49-09a3-151264999cfb-1?Expires=1630706709794&KeyName=labelbox-assets-key-3&Signature=W8RPYzKo1Qs6Qxclapnx4_chhJ8", "Label": {"objects": [{"featureId": "ckt3hu9ux00003h69bzk55eaj", "schemaId": "ckt3h8e8s51li0y7ucy0bgrlz", "color": "#ff0000", "title": "deer", "value": "deer", "polygon": [{"x": 71.832, "y": 62.37}, {"x": 50.636, "y": 90.087}, {"x": 103.625, "y": 94.571}, {"x": 104.033, "y": 59.517}], "instanceURI": "https://api.labelbox.com/masks/feature/ckt3hu9ux00003h69bzk55eaj?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiJja2s0cTF2Z3djMHZwMDcwNHhoeDdtNHZrIiwib3JnYW5pemF0aW9uSWQiOiJja2s0cTF2Z2Fwc2F1MDczMjRhd25zanEyIiwiaWF0IjoxNjMwNjI1NjgxLCJleHAiOjE2MzMyMTc2ODF9.hUJ46wdigdN9RdICF8FYzaA7MLbscfjvBeP5O-QhGa8", "classifications": [{"featureId": "ckt3hud1m00023h696pofvwkq", "schemaId": "ckt3hu2yn523g0y9ehzhg5opb", "title": "rrrr", "value": "rrrr", "answer": {"featureId": "ckt3hud1l00013h69xob7gwoh", "schemaId": "ckt3hu302523w0y9e39pub4zp", "title": "rrrrrrrr", "value": "rrrrrrrr"}}]}, {"featureId": "ckt3hullq000c3h69o8iqty8k", "schemaId": "ckt3h8e8s51lm0y7ucl6bgixn", "color": "#0000ff", "title": "deer_eyes", "value": "deer_eyes", "instanceURI": "https://api.labelbox.com/masks/feature/ckt3hullq000c3h69o8iqty8k?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiJja2s0cTF2Z3djMHZwMDcwNHhoeDdtNHZrIiwib3JnYW5pemF0aW9uSWQiOiJja2s0cTF2Z2Fwc2F1MDczMjRhd25zanEyIiwiaWF0IjoxNjMwNjI1NjgxLCJleHAiOjE2MzMyMTc2ODF9.hUJ46wdigdN9RdICF8FYzaA7MLbscfjvBeP5O-QhGa8"}, {"featureId": "ckt3hug4600073h692cpxwc2v", "schemaId": "ckt3hu2ya523e0y9eaika5flx", "color": "#00ff00", "title": "deer_nose", "value": "deer_nose", "point": {"x": 73.054, "y": 165.495}, "instanceURI": "https://api.labelbox.com/masks/feature/ckt3hug4600073h692cpxwc2v?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiJja2s0cTF2Z3djMHZwMDcwNHhoeDdtNHZrIiwib3JnYW5pemF0aW9uSWQiOiJja2s0cTF2Z2Fwc2F1MDczMjRhd25zanEyIiwiaWF0IjoxNjMwNjI1NjgxLCJleHAiOjE2MzMyMTc2ODF9.hUJ46wdigdN9RdICF8FYzaA7MLbscfjvBeP5O-QhGa8", "classifications": [[{"featureId": "ckt3hui8800093h69swml8uzf", "schemaId": "ckt3hu308523x0y9eclq1g2k8", "title": "description", "value": "description", "answers": [{"featureId": "ckt3hui8800083h69696qdgdf", "schemaId": "ckt3hu31m524b0y9eeiuobmlo", "title": "wet", "value": "wet"}, {"featureId": "ckt3kb8ym00013h69lydpjjkp", "schemaId": "ckt3k5asi26wt0y6o805pf9ve", "title": "dry", "value": "dry"}]}, {"featureId": "ckt3hujza000b3h69jfu850bx", "schemaId": "ckt3hu32x524p0y9ehuzf6wtw", "title": "asdasdsadsa", "value": "asdasdsadsa", "answer": {"featureId": "ckt3hujza000a3h691fbs2iik", "schemaId": "ckt3hu33h524r0y9e5q6j91wt", "title": "aasdsadsada", "value": "aasdsadsada"}}, {"featureId": "ckt3kbapf00033h69aup591oj", "schemaId": "ckt3k5au226x50y6o5wlv9j3g", "title": "anotherone...", "value": "anotherone...", "answer": {"featureId": "ckt3kbape00023h69ohwyr625", "schemaId": "ckt3k5aum26xb0y6o2gxtctze", "title": "weeewrew", "value": "weeewrew"}}]]}], "classifications": [{"featureId": "ckt3huphm000f3h69esv5u6ch", "schemaId": "ckt3h8e8r51lg0y7u2mtz9791", "title": "image_description", "value": "image_description", "answers": [{"featureId": "ckt3huphm000e3h69a18xeme7", "schemaId": "ckt3h8e9m51lo0y7u5v81b3cz", "title": "bright", "value": "bright"}, {"featureId": "ckt3hus0i000h3h69kwde3ilk", "schemaId": "ckt3h8e9m51lq0y7u15ygdlyw", "title": "not_blurry", "value": "not_blurry"}, {"featureId": "ckt3huuhi000k3h69io60j4xq", "schemaId": "ckt3h8e9m51ls0y7u82aj5m3f", "title": "dark", "value": "dark"}]}, {"featureId": "ckt3hvcm4000p3h69vdbc0c23", "schemaId": "ckt3hu2y652300y9e6fkefo42", "title": "a", "value": "a", "answer": {"featureId": "ckt3hvcm4000o3h693oy8exqm", "schemaId": "ckt3hu2z6523l0y9e83yg91zc", "title": "ass", "value": "ass"}}, {"featureId": "ckt3hvdl4000r3h69jqvrrb67", "schemaId": "ckt3hu30w52410y9e485y1vwi", "title": "asdsadsad", "value": "asdsadsad", "answer": [{"featureId": "ckt3hvdl4000q3h69i8afxwv5", "schemaId": "ckt3hu31o524d0y9ehwk4ds2a", "title": "asdasdsa", "value": "asdasdsa"}]}, {"featureId": "ckt3hvfz4000s3h69heqauwmu", "schemaId": "ckt3hu30x52430y9ebixq8f3t", "title": "wee223", "value": "wee223", "answer": "fretre"}, {"featureId": "ckt3hvgkq000u3h69tjsg5rl1", "schemaId": "ckt3hu30x52450y9eaw056vgj", "title": "3223432", "value": "3223432", "answer": {"featureId": "ckt3hvgkq000t3h6928o3rgzf", "schemaId": "ckt3hu31v524j0y9e7x1g2shq", "title": "dddd", "value": "dddd"}}, {"featureId": "ckt3hvife000v3h6980b3ccxa", "schemaId": "ckt3hu2y652320y9e4n1lb9dx", "title": "weeeee", "value": "weeeee", "answer": "5354"}, {"featureId": "ckt3hviho000x3h69xt3npgtb", "schemaId": "ckt3hu2y652340y9e1brp3l1i", "title": "asdsadsa", "value": "asdsadsa", "answer": [{"featureId": "ckt3hviho000w3h690kvv8isp", "schemaId": "ckt3hu2zn523p0y9eel4l4ts9", "title": "asxdsds", "value": "asxdsds"}]}], "relationships": []}, "Created By": "[email protected]", "Project Name": "test_annotation_types", "Created At": "2021-09-02T22:23:24.000Z", "Updated At": "2021-09-02T23:31:27.707Z", "Seconds to Label": 220.91400000000002, "External ID": null, "Agreement": -1, "Benchmark Agreement": -1, "Benchmark ID": null, "Dataset Name": "label_dataset", "Reviews": [], "View Label": "https://editor.labelbox.com?project=ckt3h8d9z51l10y7uavvf5zcn&label=ckt3hu4aq524s0y9e7i0dbbtg", "Has Open Issues": 0, "Skipped": false}

tests/data/serialization/labelbox_v1/test_image.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import json
22

3+
import pytest
4+
35
from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter
46

57

6-
def test_image():
7-
with open('tests/data/assets/labelbox_v1/image_export.json', 'r') as file:
8+
@pytest.mark.parametrize("file_path", [
9+
'tests/data/assets/labelbox_v1/highly_nested_image.json',
10+
'tests/data/assets/labelbox_v1/image_export.json'
11+
])
12+
def test_image(file_path):
13+
with open(file_path, 'r') as file:
814
payload = json.load(file)
915

1016
collection = LBV1Converter.deserialize([payload])
@@ -20,4 +26,14 @@ def test_image():
2026
if not len(annotation_a['classifications']):
2127
# We don't add a classification key to the payload if there is no classifications.
2228
annotation_a.pop('classifications')
29+
30+
if isinstance(annotation_b.get('classifications'),
31+
list) and len(annotation_b['classifications']):
32+
if isinstance(annotation_b['classifications'][0], list):
33+
annotation_b['classifications'] = annotation_b[
34+
'classifications'][0]
35+
2336
assert annotation_a == annotation_b
37+
38+
39+
# After check the nd serializer on this shit.. It should work for almost everything (except the other horse shit..)

0 commit comments

Comments
 (0)