diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 0032aaad1..3238dc4a5 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -25,6 +25,11 @@ from labelbox.schema.tool_building.tool_type_mapping import ( map_tool_type_to_tool_cls, ) +from labelbox.schema.tool_building.types import ( + FeatureSchemaAttribute, + FeatureSchemaAttributes, +) +import warnings class DeleteFeatureFromOntologyResult: @@ -73,6 +78,7 @@ class Tool: classifications: (list) schema_id: (str) feature_schema_id: (str) + attributes: (list) """ class Type(Enum): @@ -95,6 +101,13 @@ class Type(Enum): classifications: List[Classification] = field(default_factory=list) schema_id: Optional[str] = None feature_schema_id: Optional[str] = None + attributes: Optional[FeatureSchemaAttributes] = None + + def __post_init__(self): + if self.attributes is not None: + warnings.warn( + "The attributes for Tools are in beta. The attribute name and signature may change in the future." + ) @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: @@ -109,6 +122,12 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: for c in dictionary["classifications"] ], color=dictionary["color"], + attributes=[ + FeatureSchemaAttribute.from_dict(attr) + for attr in dictionary.get("attributes", []) or [] + ] + if dictionary.get("attributes") + else None, ) def asdict(self) -> Dict[str, Any]: @@ -122,6 +141,9 @@ def asdict(self) -> Dict[str, Any]: ], "schemaNodeId": self.schema_id, "featureSchemaId": self.feature_schema_id, + "attributes": [a.asdict() for a in self.attributes] + if self.attributes is not None + else None, } def add_classification(self, classification: Classification) -> None: diff --git a/libs/labelbox/src/labelbox/schema/tool_building/classification.py b/libs/labelbox/src/labelbox/schema/tool_building/classification.py index 9c0c69bea..62fee2dda 100644 --- a/libs/labelbox/src/labelbox/schema/tool_building/classification.py +++ b/libs/labelbox/src/labelbox/schema/tool_building/classification.py @@ -5,7 +5,11 @@ from lbox.exceptions import InconsistentOntologyException -from labelbox.schema.tool_building.types import FeatureSchemaId +from labelbox.schema.tool_building.types import ( + FeatureSchemaId, + FeatureSchemaAttributes, + FeatureSchemaAttribute, +) @dataclass @@ -42,6 +46,7 @@ class Classification: schema_id: (str) feature_schema_id: (str) scope: (str) + attributes: (list) """ class Type(Enum): @@ -70,6 +75,7 @@ class UIMode(Enum): ui_mode: Optional[UIMode] = ( None # How this classification should be answered (e.g. hotkeys / autocomplete, etc) ) + attributes: Optional[FeatureSchemaAttributes] = None def __post_init__(self): if self.name is None: @@ -88,6 +94,10 @@ def __post_init__(self): else: if self.instructions is None: self.instructions = self.name + if self.attributes is not None: + warnings.warn( + "The attributes for Classifications are in beta. The attribute name and signature may change in the future." + ) @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> "Classification": @@ -103,6 +113,12 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> "Classification": schema_id=dictionary.get("schemaNodeId", None), feature_schema_id=dictionary.get("featureSchemaId", None), scope=cls.Scope(dictionary.get("scope", cls.Scope.GLOBAL)), + attributes=[ + FeatureSchemaAttribute.from_dict(attr) + for attr in dictionary.get("attributes", []) or [] + ] + if dictionary.get("attributes") + else None, ) def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: @@ -118,6 +134,9 @@ def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: "options": [o.asdict() for o in self.options], "schemaNodeId": self.schema_id, "featureSchemaId": self.feature_schema_id, + "attributes": [a.asdict() for a in self.attributes] + if self.attributes is not None + else None, } if ( self.class_type == self.Type.RADIO diff --git a/libs/labelbox/src/labelbox/schema/tool_building/types.py b/libs/labelbox/src/labelbox/schema/tool_building/types.py index 0d6e34717..38c789837 100644 --- a/libs/labelbox/src/labelbox/schema/tool_building/types.py +++ b/libs/labelbox/src/labelbox/schema/tool_building/types.py @@ -1,6 +1,33 @@ -from typing import Annotated - +from typing import Annotated, List from pydantic import Field + +from dataclasses import dataclass + +from typing import Any, Dict, List + + +@dataclass +class FeatureSchemaAttribute: + attributeName: str + attributeValue: str + + def asdict(self): + return { + "attributeName": self.attributeName, + "attributeValue": self.attributeValue, + } + + @classmethod + def from_dict(cls, dictionary: Dict[str, Any]) -> "FeatureSchemaAttribute": + return cls( + attributeName=dictionary["attributeName"], + attributeValue=dictionary["attributeValue"], + ) + + FeatureSchemaId = Annotated[str, Field(min_length=25, max_length=25)] SchemaId = Annotated[str, Field(min_length=25, max_length=25)] +FeatureSchemaAttributes = Annotated[ + List[FeatureSchemaAttribute], Field(default_factory=list) +] diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index 11984cbd7..87aea0468 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -25,6 +25,7 @@ from labelbox.schema.data_row import DataRowMetadataField from labelbox.schema.ontology_kind import OntologyKind from labelbox.schema.user import User +from labelbox.schema.tool_building.types import FeatureSchemaAttribute @pytest.fixture @@ -552,6 +553,76 @@ def point(): ) +@pytest.fixture +def auto_ocr_text_value_class(): + return Classification( + class_type=Classification.Type.TEXT, + name="Auto OCR Text Value", + instructions="Text value for ocr bboxes", + scope=Classification.Scope.GLOBAL, + required=False, + attributes=[ + FeatureSchemaAttribute( + attributeName="auto-ocr-text-value", attributeValue="true" + ) + ], + ) + + +@pytest.fixture +def auto_ocr_bbox(auto_ocr_text_value_class): + return Tool( + tool=Tool.Type.BBOX, + name="Auto ocr bbox", + color="ff0000", + attributes=[ + FeatureSchemaAttribute( + attributeName="auto-ocr", attributeValue="true" + ) + ], + classifications=[auto_ocr_text_value_class], + ) + + +@pytest.fixture +def requires_connection_classification(): + return Classification( + name="Requires connection radio", + instructions="Classification that requires a connection", + class_type=Classification.Type.RADIO, + attributes=[ + FeatureSchemaAttribute( + attributeName="requires-connection", attributeValue="true" + ) + ], + options=[Option(value="A"), Option(value="B")], + ) + + +@pytest.fixture +def requires_connection_classification_feature_schema( + client, requires_connection_classification +): + created_feature_schema = client.upsert_feature_schema( + requires_connection_classification.asdict() + ) + yield created_feature_schema + client.delete_unused_feature_schema( + created_feature_schema.normalized["featureSchemaId"] + ) + + +@pytest.fixture +def auto_ocr_bbox_feature_schema(client, auto_ocr_bbox): + created_feature_schema = client.upsert_feature_schema( + auto_ocr_bbox.asdict() + ) + yield created_feature_schema + client.delete_unused_feature_schema( + created_feature_schema.normalized["featureSchemaId"] + ) + + @pytest.fixture def feature_schema(client, point): created_feature_schema = client.upsert_feature_schema(point.asdict()) diff --git a/libs/labelbox/tests/integration/test_feature_schema.py b/libs/labelbox/tests/integration/test_feature_schema.py index 46ec8c067..5713a067b 100644 --- a/libs/labelbox/tests/integration/test_feature_schema.py +++ b/libs/labelbox/tests/integration/test_feature_schema.py @@ -115,3 +115,29 @@ def test_does_not_include_used_feature_schema(client, feature_schema): assert feature_schema_id not in unused_feature_schemas client.delete_unused_ontology(ontology.uid) + + +def test_upsert_tool_with_attributes(auto_ocr_bbox_feature_schema): + auto_ocr_attributes = auto_ocr_bbox_feature_schema.normalized["attributes"] + auto_ocr_text_value_attributes = auto_ocr_bbox_feature_schema.normalized[ + "classifications" + ][0]["attributes"] + assert auto_ocr_attributes == [ + {"attributeName": "auto-ocr", "attributeValue": "true"} + ] + assert auto_ocr_text_value_attributes == [ + {"attributeName": "auto-ocr-text-value", "attributeValue": "true"} + ] + + +def test_upsert_classification_with_attributes( + requires_connection_classification_feature_schema, +): + requires_connection_attributes = ( + requires_connection_classification_feature_schema.normalized[ + "attributes" + ] + ) + assert requires_connection_attributes == [ + {"attributeName": "requires-connection", "attributeValue": "true"} + ] diff --git a/libs/labelbox/tests/unit/test_unit_ontology.py b/libs/labelbox/tests/unit/test_unit_ontology.py index 61c9f523a..137e4fd1f 100644 --- a/libs/labelbox/tests/unit/test_unit_ontology.py +++ b/libs/labelbox/tests/unit/test_unit_ontology.py @@ -15,6 +15,7 @@ "color": "#FF0000", "tool": "polygon", "classifications": [], + "attributes": None, }, { "schemaNodeId": None, @@ -24,6 +25,7 @@ "color": "#FF0000", "tool": "superpixel", "classifications": [], + "attributes": None, }, { "schemaNodeId": None, @@ -32,6 +34,12 @@ "name": "bbox", "color": "#FF0000", "tool": "rectangle", + "attributes": [ + { + "attributeName": "auto-ocr", + "attributeValue": "true", + } + ], "classifications": [ { "schemaNodeId": None, @@ -56,6 +64,7 @@ "name": "nested nested text", "type": "text", "options": [], + "attributes": None, } ], }, @@ -67,6 +76,12 @@ "options": [], }, ], + "attributes": [ + { + "attributeName": "requires-connection", + "attributeValue": "true", + } + ], }, { "schemaNodeId": None, @@ -76,6 +91,7 @@ "name": "nested text", "type": "text", "options": [], + "attributes": None, }, ], }, @@ -87,6 +103,7 @@ "color": "#FF0000", "tool": "point", "classifications": [], + "attributes": None, }, { "schemaNodeId": None, @@ -96,6 +113,7 @@ "color": "#FF0000", "tool": "line", "classifications": [], + "attributes": None, }, { "schemaNodeId": None, @@ -105,6 +123,7 @@ "color": "#FF0000", "tool": "named-entity", "classifications": [], + "attributes": None, }, ], "classifications": [ @@ -117,6 +136,7 @@ "type": "radio", "scope": "global", "uiMode": "searchable", + "attributes": None, "options": [ { "schemaNodeId": None, diff --git a/libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py b/libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py index 5a18d5248..bfabbd632 100644 --- a/libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py +++ b/libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py @@ -47,6 +47,7 @@ def test_as_dict(): "schemaNodeId": None, "featureSchemaId": None, "scope": "global", + "attributes": None, } ], "color": None,