diff --git a/libs/labelbox/src/labelbox/data/annotation_types/relationship.py b/libs/labelbox/src/labelbox/data/annotation_types/relationship.py index b65f21d16..0e9c4e934 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/relationship.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/relationship.py @@ -1,8 +1,10 @@ +from typing import Union from pydantic import BaseModel from enum import Enum from labelbox.data.annotation_types.annotation import ( BaseAnnotation, ObjectAnnotation, + ClassificationAnnotation, ) @@ -11,7 +13,7 @@ class Type(Enum): UNIDIRECTIONAL = "unidirectional" BIDIRECTIONAL = "bidirectional" - source: ObjectAnnotation + source: Union[ObjectAnnotation, ClassificationAnnotation] target: ObjectAnnotation type: Type = Type.UNIDIRECTIONAL diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index e822f3c42..5b146b660 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -24,6 +24,7 @@ VideoMaskAnnotation, VideoObjectAnnotation, ) +from labelbox.types import DocumentRectangle, DocumentEntity from .classification import ( NDChecklistSubclass, NDClassification, @@ -169,6 +170,7 @@ def _create_non_video_annotations(cls, label: Label): VideoClassificationAnnotation, VideoObjectAnnotation, VideoMaskAnnotation, + RelationshipAnnotation, ), ) ] @@ -179,8 +181,6 @@ def _create_non_video_annotations(cls, label: Label): yield NDObject.from_common(annotation, label.data) elif isinstance(annotation, (ScalarMetric, ConfusionMatrixMetric)): yield NDMetricAnnotation.from_common(annotation, label.data) - elif isinstance(annotation, RelationshipAnnotation): - yield NDRelationship.from_common(annotation, label.data) elif isinstance(annotation, PromptClassificationAnnotation): yield NDPromptClassification.from_common(annotation, label.data) elif isinstance(annotation, MessageEvaluationTaskAnnotation): @@ -191,19 +191,54 @@ def _create_non_video_annotations(cls, label: Label): ) @classmethod - def _create_relationship_annotations(cls, label: Label): + def _create_relationship_annotations( + cls, label: Label + ) -> Generator[NDRelationship, None, None]: + """Processes relationship annotations from a label, converting them to NDJSON format. + + Args: + label: Label containing relationship annotations to be processed + + Yields: + NDRelationship: Validated relationship annotations in NDJSON format + + Raises: + TypeError: If source/target types are invalid: + - Source: + - For PDF target annotations (DocumentRectangle, DocumentEntity): source must be ObjectAnnotation or ClassificationAnnotation + - For other target annotations: source must be ObjectAnnotation + - Target: + - Target must always be ObjectAnnotation + """ for annotation in label.annotations: if isinstance(annotation, RelationshipAnnotation): uuid1 = uuid4() uuid2 = uuid4() source = copy.copy(annotation.value.source) target = copy.copy(annotation.value.target) - if not isinstance(source, ObjectAnnotation) or not isinstance( - target, ObjectAnnotation + + # Check if source type is valid based on target type + if isinstance( + target.value, (DocumentRectangle, DocumentEntity) ): + if not isinstance( + source, (ObjectAnnotation, ClassificationAnnotation) + ): + raise TypeError( + f"Unable to create relationship with invalid source. For PDF targets, " + f"source must be ObjectAnnotation or ClassificationAnnotation. Got: {type(source)}" + ) + elif not isinstance(source, ObjectAnnotation): raise TypeError( - f"Unable to create relationship with non ObjectAnnotations. `Source: {type(source)} Target: {type(target)}`" + f"Unable to create relationship with non ObjectAnnotation source: {type(source)}" ) + + # Check if target type is valid + if not isinstance(target, ObjectAnnotation): + raise TypeError( + f"Unable to create relationship with non ObjectAnnotation target: {type(target)}" + ) + if not source._uuid: source._uuid = uuid1 if not target._uuid: diff --git a/libs/labelbox/tests/data/annotation_import/test_relationships.py b/libs/labelbox/tests/data/annotation_import/test_relationships.py index 1335261e5..f4a80dab9 100644 --- a/libs/labelbox/tests/data/annotation_import/test_relationships.py +++ b/libs/labelbox/tests/data/annotation_import/test_relationships.py @@ -10,7 +10,17 @@ RelationshipAnnotation, Relationship, TextEntity, + DocumentRectangle, + DocumentEntity, + Point, + Text, + ClassificationAnnotation, + DocumentTextSelection, + Radio, + ClassificationAnswer, + Checklist, ) +from labelbox.data.serialization.ndjson import NDJsonConverter import pytest @@ -220,3 +230,110 @@ def test_import_media_types( assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 + + +def test_valid_classification_relationships(): + def create_pdf_annotation(target_type: str) -> ObjectAnnotation: + if target_type == "bbox": + return ObjectAnnotation( + name="bbox", + value=DocumentRectangle( + start=Point(x=0, y=0), + end=Point(x=0.5, y=0.5), + page=1, + unit="PERCENT", + ), + ) + elif target_type == "entity": + return ObjectAnnotation( + name="entity", + value=DocumentEntity( + page=1, + textSelections=[ + DocumentTextSelection(token_ids=[], group_id="", page=1) + ], + ), + ) + raise ValueError(f"Unknown target type: {target_type}") + + def verify_relationship( + source: ClassificationAnnotation, target: ObjectAnnotation + ): + relationship = RelationshipAnnotation( + name="relationship", + value=Relationship( + source=source, + target=target, + type=Relationship.Type.UNIDIRECTIONAL, + ), + ) + label = Label( + data={"global_key": "global_key"}, annotations=[relationship] + ) + result = list(NDJsonConverter.serialize([label])) + assert len(result) == 1 + + # Test case 1: Text Classification -> DocumentRectangle + text_source = ClassificationAnnotation( + name="text", value=Text(answer="test") + ) + verify_relationship(text_source, create_pdf_annotation("bbox")) + + # Test case 2: Text Classification -> DocumentEntity + verify_relationship(text_source, create_pdf_annotation("entity")) + + # Test case 3: Radio Classification -> DocumentRectangle + radio_source = ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + classifications=[ + ClassificationAnnotation( + name="second_sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="second_sub_radio_answer" + ) + ), + ) + ], + ) + ), + ) + verify_relationship(radio_source, create_pdf_annotation("bbox")) + + # Test case 4: Checklist Classification -> DocumentEntity + checklist_source = ClassificationAnnotation( + name="sub_checklist_question", + value=Checklist( + answer=[ClassificationAnswer(name="first_sub_checklist_answer")] + ), + ) + verify_relationship(checklist_source, create_pdf_annotation("entity")) + + +def test_classification_relationship_restrictions(): + """Test all relationship validation error messages.""" + text = ClassificationAnnotation(name="text", value=Text(answer="test")) + point = ObjectAnnotation(name="point", value=Point(x=1, y=1)) + + # Test case: Classification -> Point (invalid) + # Should fail because classifications can only connect to PDF targets + relationship = RelationshipAnnotation( + name="relationship", + value=Relationship( + source=text, + target=point, + type=Relationship.Type.UNIDIRECTIONAL, + ), + ) + + with pytest.raises( + TypeError, + match="Unable to create relationship with non ObjectAnnotation source: .*", + ): + label = Label( + data={"global_key": "test_key"}, annotations=[relationship] + ) + list(NDJsonConverter.serialize([label]))