2
2
import copy
3
3
from itertools import groupby
4
4
from operator import itemgetter
5
- from typing import Generator , List , Tuple , Union
5
+ from typing import Generator , List , Tuple , Union , Iterator , Dict
6
6
from uuid import uuid4
7
7
8
8
from pydantic import BaseModel
24
24
VideoMaskAnnotation ,
25
25
VideoObjectAnnotation ,
26
26
)
27
+ from labelbox .types import DocumentRectangle , DocumentEntity
27
28
from .classification import (
28
29
NDChecklistSubclass ,
29
30
NDClassification ,
@@ -61,9 +62,7 @@ class NDLabel(BaseModel):
61
62
annotations : AnnotationType
62
63
63
64
@classmethod
64
- def from_common (
65
- cls , data : LabelCollection
66
- ) -> Generator ["NDLabel" , None , None ]:
65
+ def from_common (cls , data : LabelCollection ) -> Generator ["NDLabel" , None , None ]:
67
66
for label in data :
68
67
yield from cls ._create_relationship_annotations (label )
69
68
yield from cls ._create_non_video_annotations (label )
@@ -127,16 +126,12 @@ def _create_video_annotations(
127
126
if isinstance (
128
127
annot , (VideoClassificationAnnotation , VideoObjectAnnotation )
129
128
):
130
- video_annotations [annot .feature_schema_id or annot .name ].append (
131
- annot
132
- )
129
+ video_annotations [annot .feature_schema_id or annot .name ].append (annot )
133
130
elif isinstance (annot , VideoMaskAnnotation ):
134
131
yield NDObject .from_common (annotation = annot , data = label .data )
135
132
136
133
for annotation_group in video_annotations .values ():
137
- segment_frame_ranges = cls ._get_segment_frame_ranges (
138
- annotation_group
139
- )
134
+ segment_frame_ranges = cls ._get_segment_frame_ranges (annotation_group )
140
135
if isinstance (annotation_group [0 ], VideoClassificationAnnotation ):
141
136
annotation = annotation_group [0 ]
142
137
frames_data = []
@@ -169,6 +164,7 @@ def _create_non_video_annotations(cls, label: Label):
169
164
VideoClassificationAnnotation ,
170
165
VideoObjectAnnotation ,
171
166
VideoMaskAnnotation ,
167
+ RelationshipAnnotation ,
172
168
),
173
169
)
174
170
]
@@ -179,8 +175,6 @@ def _create_non_video_annotations(cls, label: Label):
179
175
yield NDObject .from_common (annotation , label .data )
180
176
elif isinstance (annotation , (ScalarMetric , ConfusionMatrixMetric )):
181
177
yield NDMetricAnnotation .from_common (annotation , label .data )
182
- elif isinstance (annotation , RelationshipAnnotation ):
183
- yield NDRelationship .from_common (annotation , label .data )
184
178
elif isinstance (annotation , PromptClassificationAnnotation ):
185
179
yield NDPromptClassification .from_common (annotation , label .data )
186
180
elif isinstance (annotation , MessageEvaluationTaskAnnotation ):
@@ -191,19 +185,35 @@ def _create_non_video_annotations(cls, label: Label):
191
185
)
192
186
193
187
@classmethod
194
- def _create_relationship_annotations (cls , label : Label ):
188
+ def _create_relationship_annotations (
189
+ cls , label : Label
190
+ ) -> Generator [NDRelationship , None , None ]:
195
191
for annotation in label .annotations :
196
192
if isinstance (annotation , RelationshipAnnotation ):
197
193
uuid1 = uuid4 ()
198
194
uuid2 = uuid4 ()
199
195
source = copy .copy (annotation .value .source )
200
196
target = copy .copy (annotation .value .target )
201
- if not isinstance (source , ObjectAnnotation ) or not isinstance (
202
- target , ObjectAnnotation
203
- ):
197
+
198
+ # Check if source type is valid based on target type
199
+ if isinstance (target .value , (DocumentRectangle , DocumentEntity )):
200
+ if not isinstance (
201
+ source , (ObjectAnnotation , ClassificationAnnotation )
202
+ ):
203
+ raise TypeError (
204
+ f"Unable to create relationship with invalid source. For PDF targets, "
205
+ f"source must be ObjectAnnotation or ClassificationAnnotation. Got: { type (source )} "
206
+ )
207
+ elif not isinstance (source , ObjectAnnotation ):
204
208
raise TypeError (
205
- f"Unable to create relationship with non ObjectAnnotations. `Source : { type (source )} Target: { type ( target ) } ` "
209
+ f"Unable to create relationship with non ObjectAnnotation source : { type (source )} "
206
210
)
211
+
212
+ if not isinstance (target , ObjectAnnotation ):
213
+ raise TypeError (
214
+ f"Unable to create relationship with non ObjectAnnotation target: { type (target )} "
215
+ )
216
+
207
217
if not source ._uuid :
208
218
source ._uuid = uuid1
209
219
if not target ._uuid :
0 commit comments