Skip to content

[PLT-2011] Vb/fix prompt issue classification update plt 2011 #1917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class _BaseStepReasoningTool(ABC):
schema_id: Optional[str] = None
feature_schema_id: Optional[str] = None
color: Optional[str] = None
required: bool = False
required: bool = False # This attribute is for consistency with other tools and backend, default is False

def __post_init__(self):
if not self.name.strip():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class PromptIssueTool:

name: str
type: ToolType = field(default=ToolType.PROMPT_ISSUE, init=False)
required: bool = False
required: bool = False # This attribute is for consistency with other tools and backend, default is False
schema_id: Optional[str] = None
feature_schema_id: Optional[str] = None
color: Optional[str] = None
Expand Down Expand Up @@ -64,11 +64,20 @@ def _validate_classifications(
if (
len(classifications) != 1
or classifications[0].class_type != Classification.Type.CHECKLIST
or len(classifications[0].options) < 1
):
return False
return True

def asdict(self) -> Dict[str, Any]:
classifications_valid = self._validate_classifications(
self.classifications
)
if not classifications_valid:
raise ValueError(
"Classifications for Prompt Issue Tool are invalid"
)

return {
"tool": self.type.value,
"name": self.name,
Expand Down
34 changes: 32 additions & 2 deletions libs/labelbox/tests/integration/test_ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import pytest

from labelbox import MediaType, OntologyBuilder, Tool
from labelbox import MediaType, OntologyBuilder, OntologyKind, Tool
from labelbox.orm.model import Entity
from labelbox.schema.tool_building.classification import Classification
from labelbox.schema.tool_building.classification import Classification, Option
from labelbox.schema.tool_building.fact_checking_tool import (
FactCheckingTool,
)
Expand Down Expand Up @@ -488,3 +488,33 @@ def test_prompt_issue_ontology(chat_evaluation_ontology):
classification = prompt_issue_tool.classifications[0]
assert classification.class_type == Classification.Type.CHECKLIST
assert len(classification.options) == 3 # Check number of options


def test_invalid_prompt_issue_ontology(client):
tool = PromptIssueTool(name="Prompt Issue Tool")

option1 = Option(value="value")
radio_class = Classification(
class_type=Classification.Type.RADIO,
name="radio-class",
options=[option1],
)
text_class = Classification(
class_type=Classification.Type.TEXT, name="text-class"
)

tool.classifications.append(radio_class)
tool.classifications.append(text_class)

builder = OntologyBuilder(
tools=[tool],
)
with pytest.raises(
ValueError, match="Classifications for Prompt Issue Tool are invalid"
):
client.create_ontology(
name="plt-1710",
media_type=MediaType.Conversational,
ontology_kind=OntologyKind.ModelEvaluation,
normalized=builder.asdict(),
)
3 changes: 3 additions & 0 deletions libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def test_as_dict():
}
assert tool_dict == expected_dict


def test_classification_validation():
tool = PromptIssueTool(name="Prompt Issue Tool")
with pytest.raises(ValueError):
tool.classifications = [
Classification(Classification.Type.TEXT, "prompt_issue")
Expand Down
Loading