diff --git a/libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py b/libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py index 624e951a7..4d7ce54c5 100644 --- a/libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py +++ b/libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py @@ -76,7 +76,7 @@ class _BaseStepReasoningTool(ABC): schema_id: Optional[str] = None feature_schema_id: Optional[str] = None color: Optional[str] = None - required: bool = False + required: bool = False # This attribute is for consistency with other tools and backend, default is False def __post_init__(self): if not self.name.strip(): diff --git a/libs/labelbox/src/labelbox/schema/tool_building/prompt_issue_tool.py b/libs/labelbox/src/labelbox/schema/tool_building/prompt_issue_tool.py index 5b9b55d9c..1a965ff95 100644 --- a/libs/labelbox/src/labelbox/schema/tool_building/prompt_issue_tool.py +++ b/libs/labelbox/src/labelbox/schema/tool_building/prompt_issue_tool.py @@ -36,7 +36,7 @@ class PromptIssueTool: name: str type: ToolType = field(default=ToolType.PROMPT_ISSUE, init=False) - required: bool = False + required: bool = False # This attribute is for consistency with other tools and backend, default is False schema_id: Optional[str] = None feature_schema_id: Optional[str] = None color: Optional[str] = None @@ -64,11 +64,20 @@ def _validate_classifications( if ( len(classifications) != 1 or classifications[0].class_type != Classification.Type.CHECKLIST + or len(classifications[0].options) < 1 ): return False return True def asdict(self) -> Dict[str, Any]: + classifications_valid = self._validate_classifications( + self.classifications + ) + if not classifications_valid: + raise ValueError( + "Classifications for Prompt Issue Tool are invalid" + ) + return { "tool": self.type.value, "name": self.name, diff --git a/libs/labelbox/tests/integration/test_ontology.py b/libs/labelbox/tests/integration/test_ontology.py index febfdacfd..d56f54958 100644 --- a/libs/labelbox/tests/integration/test_ontology.py +++ b/libs/labelbox/tests/integration/test_ontology.py @@ -3,9 +3,9 @@ import pytest -from labelbox import MediaType, OntologyBuilder, Tool +from labelbox import MediaType, OntologyBuilder, OntologyKind, Tool from labelbox.orm.model import Entity -from labelbox.schema.tool_building.classification import Classification +from labelbox.schema.tool_building.classification import Classification, Option from labelbox.schema.tool_building.fact_checking_tool import ( FactCheckingTool, ) @@ -488,3 +488,33 @@ def test_prompt_issue_ontology(chat_evaluation_ontology): classification = prompt_issue_tool.classifications[0] assert classification.class_type == Classification.Type.CHECKLIST assert len(classification.options) == 3 # Check number of options + + +def test_invalid_prompt_issue_ontology(client): + tool = PromptIssueTool(name="Prompt Issue Tool") + + option1 = Option(value="value") + radio_class = Classification( + class_type=Classification.Type.RADIO, + name="radio-class", + options=[option1], + ) + text_class = Classification( + class_type=Classification.Type.TEXT, name="text-class" + ) + + tool.classifications.append(radio_class) + tool.classifications.append(text_class) + + builder = OntologyBuilder( + tools=[tool], + ) + with pytest.raises( + ValueError, match="Classifications for Prompt Issue Tool are invalid" + ): + client.create_ontology( + name="plt-1710", + media_type=MediaType.Conversational, + ontology_kind=OntologyKind.ModelEvaluation, + normalized=builder.asdict(), + ) diff --git a/libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py b/libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py index 04dc89668..5a18d5248 100644 --- a/libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py +++ b/libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py @@ -53,6 +53,9 @@ def test_as_dict(): } assert tool_dict == expected_dict + +def test_classification_validation(): + tool = PromptIssueTool(name="Prompt Issue Tool") with pytest.raises(ValueError): tool.classifications = [ Classification(Classification.Type.TEXT, "prompt_issue")