Skip to content

Commit c3af277

Browse files
committed
Add FactCheckingTool
Refactor StepReasoning to also reuse Variants
1 parent 41104a1 commit c3af277

16 files changed

+480
-234
lines changed

docs/labelbox/fact-checking-tool.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Fact Checking Tool
2+
===============================================================================================
3+
4+
.. automodule:: labelbox.schema.tool_building.fact_checking_tool
5+
:members:
6+
:show-inheritance:

docs/labelbox/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Labelbox Python SDK Documentation
1919
enums
2020
exceptions
2121
export-task
22+
fact-checking-tool
2223
foundry-client
2324
foundry-model
2425
identifiable

libs/labelbox/src/labelbox/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@
5555
ResponseOption,
5656
Tool,
5757
)
58-
from labelbox.schema.ontology import PromptResponseClassification
59-
from labelbox.schema.ontology import ResponseOption
58+
from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool
59+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
6060
from labelbox.schema.role import Role, ProjectRole
6161
from labelbox.schema.invite import Invite, InviteLimit
6262
from labelbox.schema.data_row_metadata import (

libs/labelbox/src/labelbox/schema/ontology.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212

1313
from labelbox.orm.db_object import DbObject
1414
from labelbox.orm.model import Field, Relationship
15+
from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool
1516
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
1617
from labelbox.schema.tool_building.tool_type import ToolType
18+
from labelbox.schema.tool_building.tool_type_mapping import (
19+
map_tool_type_to_tool_cls,
20+
)
1721

1822
FeatureSchemaId: Type[str] = Annotated[
1923
str, StringConstraints(min_length=25, max_length=25)
@@ -490,14 +494,20 @@ def add_classification(self, classification: Classification) -> None:
490494
self.classifications.append(classification)
491495

492496

497+
"""
498+
The following 2 functions help to bridge the gap between the step reasoning all other tool ontologies.
499+
"""
500+
501+
493502
def tool_cls_from_type(tool_type: str):
494-
if tool_type.lower() == ToolType.STEP_REASONING.value:
495-
return StepReasoningTool
503+
tool_cls = map_tool_type_to_tool_cls(tool_type)
504+
if tool_cls is not None:
505+
return tool_cls
496506
return Tool
497507

498508

499509
def tool_type_cls_from_type(tool_type: str):
500-
if tool_type.lower() == ToolType.STEP_REASONING.value:
510+
if ToolType.valid(tool_type):
501511
return ToolType
502512
return Tool.Type
503513

@@ -596,7 +606,9 @@ class OntologyBuilder:
596606
597607
"""
598608

599-
tools: List[Union[Tool, StepReasoningTool]] = field(default_factory=list)
609+
tools: List[Union[Tool, StepReasoningTool, FactCheckingTool]] = field(
610+
default_factory=list
611+
)
600612
classifications: List[
601613
Union[Classification, PromptResponseClassification]
602614
] = field(default_factory=list)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
import labelbox.schema.tool_building.tool_type
22
import labelbox.schema.tool_building.step_reasoning_tool
3+
import labelbox.schema.tool_building.fact_checking_tool
4+
import labelbox.schema.tool_building.tool_type_mapping
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import warnings
2+
from abc import ABC
3+
from dataclasses import dataclass, field
4+
from typing import Any, Dict, List, Optional, Set
5+
6+
from labelbox.schema.tool_building.tool_type import ToolType
7+
8+
9+
@dataclass
10+
class _Variant:
11+
id: int
12+
name: str
13+
actions: List[str] = field(default_factory=list)
14+
_available_actions: Set[str] = field(default_factory=set)
15+
16+
def set_actions(self, actions: List[str]) -> None:
17+
self.actions = []
18+
for action in actions:
19+
if action in self._available_actions:
20+
self.actions.append(action)
21+
else:
22+
warnings.warn(
23+
f"Variant ID {self.id} {action} is an invalid action, skipping"
24+
)
25+
26+
def reset_actions(self) -> None:
27+
self.actions = []
28+
29+
def asdict(self) -> Dict[str, Any]:
30+
return {
31+
"id": self.id,
32+
"name": self.name,
33+
"actions": self.actions,
34+
}
35+
36+
def _post_init(self):
37+
# Call set_actions to remove any invalid actions
38+
self.set_actions(self.actions)
39+
40+
41+
@dataclass
42+
class _Definition:
43+
variants: List[_Variant]
44+
version: int = field(default=1)
45+
title: Optional[str] = None
46+
value: Optional[str] = None
47+
48+
def __post_init__(self):
49+
if self.version != 1:
50+
raise ValueError("Invalid version")
51+
52+
def asdict(self) -> Dict[str, Any]:
53+
result = {
54+
"variants": [variant.asdict() for variant in self.variants],
55+
"version": self.version,
56+
}
57+
if self.title is not None:
58+
result["title"] = self.title
59+
if self.value is not None:
60+
result["value"] = self.value
61+
return result
62+
63+
@classmethod
64+
def from_dict(cls, dictionary: Dict[str, Any]) -> "_Definition":
65+
variants = [_Variant(**variant) for variant in dictionary["variants"]]
66+
title = dictionary.get("title", None)
67+
value = dictionary.get("value", None)
68+
return cls(variants=variants, title=title, value=value)
69+
70+
71+
@dataclass
72+
class _BaseStepReasoningTool(ABC):
73+
name: str
74+
definition: _Definition
75+
type: Optional[ToolType] = None
76+
schema_id: Optional[str] = None
77+
feature_schema_id: Optional[str] = None
78+
color: Optional[str] = None
79+
required: bool = False
80+
81+
def __post_init__(self):
82+
warnings.warn(
83+
"This feature is experimental and subject to change.",
84+
)
85+
86+
if not self.name.strip():
87+
raise ValueError("Name is required")
88+
89+
def asdict(self) -> Dict[str, Any]:
90+
return {
91+
"tool": self.type.value if self.type else None,
92+
"name": self.name,
93+
"required": self.required,
94+
"schemaNodeId": self.schema_id,
95+
"featureSchemaId": self.feature_schema_id,
96+
"definition": self.definition.asdict(),
97+
"color": self.color,
98+
}
99+
100+
@classmethod
101+
def from_dict(cls, dictionary: Dict[str, Any]) -> "_BaseStepReasoningTool":
102+
return cls(
103+
name=dictionary["name"],
104+
schema_id=dictionary.get("schemaNodeId", None),
105+
feature_schema_id=dictionary.get("featureSchemaId", None),
106+
required=dictionary.get("required", False),
107+
definition=_Definition.from_dict(dictionary["definition"]),
108+
color=dictionary.get("color", None),
109+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from dataclasses import dataclass, field
2+
from enum import Enum
3+
4+
from labelbox.schema.tool_building.base_step_reasoning_tool import (
5+
_BaseStepReasoningTool,
6+
_Definition,
7+
_Variant,
8+
)
9+
from labelbox.schema.tool_building.tool_type import ToolType
10+
11+
12+
class FactCheckingActions(Enum):
13+
WRITE_JUSTIFICATION = "justification"
14+
15+
16+
def build_fact_checking_definition():
17+
accurate_step = _Variant(
18+
id=0,
19+
name="Accurate",
20+
_available_actions={action.value for action in FactCheckingActions},
21+
actions=[action.value for action in FactCheckingActions],
22+
)
23+
inaccurate_step = _Variant(
24+
id=1,
25+
name="Inaccurate",
26+
_available_actions={action.value for action in FactCheckingActions},
27+
actions=[action.value for action in FactCheckingActions],
28+
)
29+
disputed_step = _Variant(
30+
id=2,
31+
name="Disputed",
32+
_available_actions={action.value for action in FactCheckingActions},
33+
actions=[action.value for action in FactCheckingActions],
34+
)
35+
unsupported_step = _Variant(
36+
id=3,
37+
name="Unsupported",
38+
_available_actions=set(),
39+
actions=[],
40+
)
41+
cant_confidently_assess_step = _Variant(
42+
id=4,
43+
name="Can't confidently assess",
44+
_available_actions=set(),
45+
actions=[],
46+
)
47+
no_factual_information_step = _Variant(
48+
id=5,
49+
name="No factual information",
50+
_available_actions=set(),
51+
actions=[],
52+
)
53+
variants = [
54+
accurate_step,
55+
inaccurate_step,
56+
disputed_step,
57+
unsupported_step,
58+
cant_confidently_assess_step,
59+
no_factual_information_step,
60+
]
61+
return _Definition(variants=variants)
62+
63+
64+
@dataclass
65+
class FactCheckingTool(_BaseStepReasoningTool):
66+
"""
67+
Use this class in OntologyBuilder to create a tool for fact checking
68+
"""
69+
70+
type: ToolType = field(default=ToolType.FACT_CHECKING, init=False)
71+
definition: _Definition = field(
72+
default_factory=build_fact_checking_definition
73+
)
74+
75+
def __post_init__(self):
76+
super().__post_init__()
77+
# Set available actions for variants 0, 1, 2 'out of band' since they are not passed in the definition
78+
self._set_variant_available_actions()
79+
80+
def _set_variant_available_actions(self):
81+
for variant in self.definition.variants:
82+
if variant.id in [0, 1, 2]:
83+
for action in FactCheckingActions:
84+
variant._available_actions.add(action.value)

0 commit comments

Comments
 (0)