Skip to content

Commit e4a9628

Browse files
committed
MAL and GT support for pdf relationships
1 parent bee0195 commit e4a9628

File tree

3 files changed

+138
-21
lines changed

3 files changed

+138
-21
lines changed

libs/labelbox/src/labelbox/data/annotation_types/relationship.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from typing import Union
12
from pydantic import BaseModel
23
from enum import Enum
34
from labelbox.data.annotation_types.annotation import (
45
BaseAnnotation,
56
ObjectAnnotation,
7+
ClassificationAnnotation,
68
)
79

810

@@ -11,7 +13,7 @@ class Type(Enum):
1113
UNIDIRECTIONAL = "unidirectional"
1214
BIDIRECTIONAL = "bidirectional"
1315

14-
source: ObjectAnnotation
16+
source: Union[ObjectAnnotation, ClassificationAnnotation]
1517
target: ObjectAnnotation
1618
type: Type = Type.UNIDIRECTIONAL
1719

libs/labelbox/src/labelbox/data/serialization/ndjson/label.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
from itertools import groupby
44
from operator import itemgetter
5-
from typing import Generator, List, Tuple, Union
5+
from typing import Generator, List, Tuple, Union, Iterator, Dict
66
from uuid import uuid4
77

88
from pydantic import BaseModel
@@ -24,6 +24,7 @@
2424
VideoMaskAnnotation,
2525
VideoObjectAnnotation,
2626
)
27+
from labelbox.types import DocumentRectangle, DocumentEntity
2728
from .classification import (
2829
NDChecklistSubclass,
2930
NDClassification,
@@ -61,9 +62,7 @@ class NDLabel(BaseModel):
6162
annotations: AnnotationType
6263

6364
@classmethod
64-
def from_common(
65-
cls, data: LabelCollection
66-
) -> Generator["NDLabel", None, None]:
65+
def from_common(cls, data: LabelCollection) -> Generator["NDLabel", None, None]:
6766
for label in data:
6867
yield from cls._create_relationship_annotations(label)
6968
yield from cls._create_non_video_annotations(label)
@@ -127,16 +126,12 @@ def _create_video_annotations(
127126
if isinstance(
128127
annot, (VideoClassificationAnnotation, VideoObjectAnnotation)
129128
):
130-
video_annotations[annot.feature_schema_id or annot.name].append(
131-
annot
132-
)
129+
video_annotations[annot.feature_schema_id or annot.name].append(annot)
133130
elif isinstance(annot, VideoMaskAnnotation):
134131
yield NDObject.from_common(annotation=annot, data=label.data)
135132

136133
for annotation_group in video_annotations.values():
137-
segment_frame_ranges = cls._get_segment_frame_ranges(
138-
annotation_group
139-
)
134+
segment_frame_ranges = cls._get_segment_frame_ranges(annotation_group)
140135
if isinstance(annotation_group[0], VideoClassificationAnnotation):
141136
annotation = annotation_group[0]
142137
frames_data = []
@@ -169,6 +164,7 @@ def _create_non_video_annotations(cls, label: Label):
169164
VideoClassificationAnnotation,
170165
VideoObjectAnnotation,
171166
VideoMaskAnnotation,
167+
RelationshipAnnotation,
172168
),
173169
)
174170
]
@@ -179,8 +175,6 @@ def _create_non_video_annotations(cls, label: Label):
179175
yield NDObject.from_common(annotation, label.data)
180176
elif isinstance(annotation, (ScalarMetric, ConfusionMatrixMetric)):
181177
yield NDMetricAnnotation.from_common(annotation, label.data)
182-
elif isinstance(annotation, RelationshipAnnotation):
183-
yield NDRelationship.from_common(annotation, label.data)
184178
elif isinstance(annotation, PromptClassificationAnnotation):
185179
yield NDPromptClassification.from_common(annotation, label.data)
186180
elif isinstance(annotation, MessageEvaluationTaskAnnotation):
@@ -191,19 +185,35 @@ def _create_non_video_annotations(cls, label: Label):
191185
)
192186

193187
@classmethod
194-
def _create_relationship_annotations(cls, label: Label):
188+
def _create_relationship_annotations(
189+
cls, label: Label
190+
) -> Generator[NDRelationship, None, None]:
195191
for annotation in label.annotations:
196192
if isinstance(annotation, RelationshipAnnotation):
197193
uuid1 = uuid4()
198194
uuid2 = uuid4()
199195
source = copy.copy(annotation.value.source)
200196
target = copy.copy(annotation.value.target)
201-
if not isinstance(source, ObjectAnnotation) or not isinstance(
202-
target, ObjectAnnotation
203-
):
197+
198+
# Check if source type is valid based on target type
199+
if isinstance(target.value, (DocumentRectangle, DocumentEntity)):
200+
if not isinstance(
201+
source, (ObjectAnnotation, ClassificationAnnotation)
202+
):
203+
raise TypeError(
204+
f"Unable to create relationship with invalid source. For PDF targets, "
205+
f"source must be ObjectAnnotation or ClassificationAnnotation. Got: {type(source)}"
206+
)
207+
elif not isinstance(source, ObjectAnnotation):
204208
raise TypeError(
205-
f"Unable to create relationship with non ObjectAnnotations. `Source: {type(source)} Target: {type(target)}`"
209+
f"Unable to create relationship with non ObjectAnnotation source: {type(source)}"
206210
)
211+
212+
if not isinstance(target, ObjectAnnotation):
213+
raise TypeError(
214+
f"Unable to create relationship with non ObjectAnnotation target: {type(target)}"
215+
)
216+
207217
if not source._uuid:
208218
source._uuid = uuid1
209219
if not target._uuid:

libs/labelbox/tests/data/annotation_import/test_relationships.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,17 @@
1010
RelationshipAnnotation,
1111
Relationship,
1212
TextEntity,
13+
DocumentRectangle,
14+
DocumentEntity,
15+
Point,
16+
Text,
17+
ClassificationAnnotation,
18+
DocumentTextSelection,
19+
Radio,
20+
ClassificationAnswer,
21+
Checklist,
1322
)
23+
from labelbox.data.serialization.ndjson import NDJsonConverter
1424
import pytest
1525

1626

@@ -169,9 +179,7 @@ def configured_project(
169179
data_row_data = []
170180

171181
for _ in range(3):
172-
data_row_data.append(
173-
data_row_json_by_media_type[media_type](rand_gen(str))
174-
)
182+
data_row_data.append(data_row_json_by_media_type[media_type](rand_gen(str)))
175183

176184
task = dataset.create_data_rows(data_row_data)
177185
task.wait_till_done()
@@ -220,3 +228,100 @@ def test_import_media_types(
220228

221229
assert label_import.state == AnnotationImportState.FINISHED
222230
assert len(label_import.errors) == 0
231+
232+
233+
def test_valid_classification_relationships():
234+
def create_pdf_annotation(target_type: str) -> ObjectAnnotation:
235+
if target_type == "bbox":
236+
return ObjectAnnotation(
237+
name="bbox",
238+
value=DocumentRectangle(
239+
start=Point(x=0, y=0),
240+
end=Point(x=0.5, y=0.5),
241+
page=1,
242+
unit="PERCENT",
243+
),
244+
)
245+
elif target_type == "entity":
246+
return ObjectAnnotation(
247+
name="entity",
248+
value=DocumentEntity(
249+
page=1,
250+
textSelections=[
251+
DocumentTextSelection(token_ids=[], group_id="", page=1)
252+
],
253+
),
254+
)
255+
raise ValueError(f"Unknown target type: {target_type}")
256+
257+
def verify_relationship(source: ClassificationAnnotation, target: ObjectAnnotation):
258+
relationship = RelationshipAnnotation(
259+
name="relationship",
260+
value=Relationship(
261+
source=source,
262+
target=target,
263+
type=Relationship.Type.UNIDIRECTIONAL,
264+
),
265+
)
266+
label = Label(data={"global_key": "global_key"}, annotations=[relationship])
267+
result = list(NDJsonConverter.serialize([label]))
268+
assert len(result) == 1
269+
270+
# Test case 1: Text Classification -> DocumentRectangle
271+
text_source = ClassificationAnnotation(name="text", value=Text(answer="test"))
272+
verify_relationship(text_source, create_pdf_annotation("bbox"))
273+
274+
# Test case 2: Text Classification -> DocumentEntity
275+
verify_relationship(text_source, create_pdf_annotation("entity"))
276+
277+
# Test case 3: Radio Classification -> DocumentRectangle
278+
radio_source = ClassificationAnnotation(
279+
name="sub_radio_question",
280+
value=Radio(
281+
answer=ClassificationAnswer(
282+
name="first_sub_radio_answer",
283+
classifications=[
284+
ClassificationAnnotation(
285+
name="second_sub_radio_question",
286+
value=Radio(
287+
answer=ClassificationAnswer(name="second_sub_radio_answer")
288+
),
289+
)
290+
],
291+
)
292+
),
293+
)
294+
verify_relationship(radio_source, create_pdf_annotation("bbox"))
295+
296+
# Test case 4: Checklist Classification -> DocumentEntity
297+
checklist_source = ClassificationAnnotation(
298+
name="sub_checklist_question",
299+
value=Checklist(
300+
answer=[ClassificationAnswer(name="first_sub_checklist_answer")]
301+
),
302+
)
303+
verify_relationship(checklist_source, create_pdf_annotation("entity"))
304+
305+
306+
def test_classification_relationship_restrictions():
307+
"""Test all relationship validation error messages."""
308+
text = ClassificationAnnotation(name="text", value=Text(answer="test"))
309+
point = ObjectAnnotation(name="point", value=Point(x=1, y=1))
310+
311+
# Test case: Classification -> Point (invalid)
312+
# Should fail because classifications can only connect to PDF targets
313+
relationship = RelationshipAnnotation(
314+
name="relationship",
315+
value=Relationship(
316+
source=text,
317+
target=point,
318+
type=Relationship.Type.UNIDIRECTIONAL,
319+
),
320+
)
321+
322+
with pytest.raises(
323+
TypeError,
324+
match="Unable to create relationship with non ObjectAnnotation source: .*",
325+
):
326+
label = Label(data={"global_key": "test_key"}, annotations=[relationship])
327+
list(NDJsonConverter.serialize([label]))

0 commit comments

Comments
 (0)