Skip to content

Commit dc34cf5

Browse files
committed
Update FactCheckingTool to use _BaseStepReasoningTool
1 parent ca9ed12 commit dc34cf5

File tree

6 files changed

+88
-292
lines changed

6 files changed

+88
-292
lines changed

libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ def set_actions(self, actions: List[str]) -> None:
1717
for action in actions:
1818
if action in self._available_actions:
1919
self.actions.append(action)
20+
else:
21+
warnings.warn(
22+
f"Variant ID {self.id} {action} is an invalid action, skipping"
23+
)
2024

2125
def reset_actions(self) -> None:
2226
self.actions = []
@@ -28,6 +32,10 @@ def asdict(self) -> Dict[str, Any]:
2832
"actions": self.actions,
2933
}
3034

35+
def _post_init(self):
36+
# Call set_actions to remove any invalid actions
37+
self.set_actions(self.actions)
38+
3139

3240
@dataclass
3341
class _Definition:
Lines changed: 48 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -1,233 +1,73 @@
1-
import warnings
21
from dataclasses import dataclass, field
32
from enum import Enum
4-
from typing import Any, Dict, List, Optional
53

6-
from labelbox.schema.tool_building.tool_type import ToolType
7-
from labelbox.schema.tool_building.variant import (
8-
VariantWithActions,
4+
from labelbox.schema.tool_building.base_step_reasoning_tool import (
5+
_BaseStepReasoningTool,
6+
_Definition,
7+
_Variant,
98
)
9+
from labelbox.schema.tool_building.tool_type import ToolType
1010

1111

1212
class FactCheckingActions(Enum):
1313
WRITE_JUSTIFICATION = "justification"
1414

1515

16-
class UnsupportedStepActions(Enum):
17-
WRITE_JUSTIFICATION = "justification"
18-
19-
20-
class CanConfidentlyAssessStepActions(Enum):
21-
WRITE_JUSTIFICATION = "justification"
22-
23-
24-
class NoFactualInformationStepActions(Enum):
25-
WRITE_JUSTIFICATION = "justification"
26-
27-
28-
@dataclass
29-
class FactCheckingVariants:
30-
"""
31-
This class is used to define the possible options for fact-checking a step
32-
NOTE do not change the variants directly
33-
"""
34-
35-
accurate_step: VariantWithActions = field(
36-
default_factory=lambda: VariantWithActions(
37-
id=0,
38-
name="Accurate",
39-
_available_actions={action.value for action in FactCheckingActions},
40-
actions=[action.value for action in FactCheckingActions],
41-
)
16+
def build_fact_checking_definition():
17+
accurate_step = _Variant(
18+
id=0,
19+
name="Accurate",
20+
_available_actions={action.value for action in FactCheckingActions},
21+
actions=[action.value for action in FactCheckingActions],
4222
)
43-
44-
inaccurate_step: VariantWithActions = field(
45-
default_factory=lambda: VariantWithActions(
46-
id=1,
47-
name="Inaccurate",
48-
_available_actions={action.value for action in FactCheckingActions},
49-
actions=[action.value for action in FactCheckingActions],
50-
)
23+
inaccurate_step = _Variant(
24+
id=1,
25+
name="Inaccurate",
26+
_available_actions={action.value for action in FactCheckingActions},
27+
actions=[action.value for action in FactCheckingActions],
5128
)
52-
disputed_step: VariantWithActions = field(
53-
default_factory=lambda: VariantWithActions(
54-
id=2,
55-
name="Disputed",
56-
_available_actions={action.value for action in FactCheckingActions},
57-
actions=[action.value for action in FactCheckingActions],
58-
)
29+
disputed_step = _Variant(
30+
id=2,
31+
name="Disputed",
32+
_available_actions={action.value for action in FactCheckingActions},
33+
actions=[action.value for action in FactCheckingActions],
5934
)
60-
unsupported_step: VariantWithActions = field(
61-
default_factory=lambda: VariantWithActions(
62-
id=3,
63-
name="Unsupported",
64-
_available_actions=set(),
65-
actions=[],
66-
)
35+
unsupported_step = _Variant(
36+
id=3,
37+
name="Unsupported",
38+
_available_actions=set(),
39+
actions=[],
6740
)
68-
cant_confidently_assess_step: VariantWithActions = field(
69-
default_factory=lambda: VariantWithActions(
70-
id=4,
71-
name="Can't confidently assess",
72-
_available_actions=set(),
73-
actions=[],
74-
)
41+
cant_confidently_assess_step = _Variant(
42+
id=4,
43+
name="Can't confidently assess",
44+
_available_actions=set(),
45+
actions=[],
7546
)
76-
no_factual_information_step: VariantWithActions = field(
77-
default_factory=lambda: VariantWithActions(
78-
id=5,
79-
name="No factual information",
80-
_available_actions=set(),
81-
actions=[],
82-
)
47+
no_factual_information_step = _Variant(
48+
id=5,
49+
name="No factual information",
50+
_available_actions=set(),
51+
actions=[],
8352
)
84-
85-
def asdict(self):
86-
return [
87-
self.accurate_step.asdict(),
88-
self.inaccurate_step.asdict(),
89-
self.disputed_step.asdict(),
90-
self.unsupported_step.asdict(),
91-
self.cant_confidently_assess_step.asdict(),
92-
self.no_factual_information_step.asdict(),
93-
]
94-
95-
@classmethod
96-
def from_dict(cls, dictionary: List[Dict[str, Any]]):
97-
accurate_step = None
98-
inaccurate_step = None
99-
disputed_step = None
100-
unsupported_step = None
101-
cant_confidently_assess_step = None
102-
no_factual_information_step = None
103-
104-
for variant in dictionary:
105-
if variant["id"] == 0:
106-
accurate_step = VariantWithActions(**variant)
107-
elif variant["id"] == 1:
108-
inaccurate_step = VariantWithActions(**variant)
109-
elif variant["id"] == 2:
110-
disputed_step = VariantWithActions(**variant)
111-
elif variant["id"] == 3:
112-
unsupported_step = VariantWithActions(**variant)
113-
elif variant["id"] == 4:
114-
cant_confidently_assess_step = VariantWithActions(**variant)
115-
elif variant["id"] == 5:
116-
no_factual_information_step = VariantWithActions(**variant)
117-
118-
if not all(
119-
[
120-
accurate_step,
121-
inaccurate_step,
122-
disputed_step,
123-
unsupported_step,
124-
cant_confidently_assess_step,
125-
no_factual_information_step,
126-
]
127-
):
128-
raise ValueError("Missing variant")
129-
130-
return cls(
131-
accurate_step=accurate_step, # type: ignore
132-
inaccurate_step=inaccurate_step, # type: ignore
133-
disputed_step=disputed_step, # type: ignore
134-
unsupported_step=unsupported_step, # type: ignore
135-
cant_confidently_assess_step=cant_confidently_assess_step, # type: ignore
136-
no_factual_information_step=no_factual_information_step, # type: ignore
137-
)
53+
variants = [
54+
accurate_step,
55+
inaccurate_step,
56+
disputed_step,
57+
unsupported_step,
58+
cant_confidently_assess_step,
59+
no_factual_information_step,
60+
]
61+
return _Definition(variants=variants)
13862

13963

14064
@dataclass
141-
class FactCheckingDefinition:
142-
variants: FactCheckingVariants = field(default_factory=FactCheckingVariants)
143-
version: int = field(default=1)
144-
title: Optional[str] = None
145-
value: Optional[str] = None
146-
147-
def __post_init__(self):
148-
if self.version != 1:
149-
raise ValueError("Invalid version")
150-
151-
def asdict(self) -> Dict[str, Any]:
152-
result = {"variants": self.variants.asdict(), "version": self.version}
153-
if self.title is not None:
154-
result["title"] = self.title
155-
if self.value is not None:
156-
result["value"] = self.value
157-
return result
158-
159-
@classmethod
160-
def from_dict(cls, dictionary: Dict[str, Any]) -> "FactCheckingDefinition":
161-
variants = FactCheckingVariants.from_dict(dictionary["variants"])
162-
title = dictionary.get("title", None)
163-
value = dictionary.get("value", None)
164-
return cls(variants=variants, title=title, value=value)
165-
166-
167-
@dataclass
168-
class FactCheckingTool:
65+
class FactCheckingTool(_BaseStepReasoningTool):
16966
"""
17067
Use this class in OntologyBuilder to create a tool for fact checking
17168
"""
17269

173-
name: str
17470
type: ToolType = field(default=ToolType.FACT_CHECKING, init=False)
175-
required: bool = False
176-
schema_id: Optional[str] = None
177-
feature_schema_id: Optional[str] = None
178-
color: Optional[str] = None
179-
definition: FactCheckingDefinition = field(
180-
default_factory=FactCheckingDefinition
71+
definition: _Definition = field(
72+
default_factory=build_fact_checking_definition
18173
)
182-
183-
def __post_init__(self):
184-
warnings.warn(
185-
"This feature is experimental and subject to change.",
186-
)
187-
188-
if self.name.strip() == "":
189-
raise ValueError("Name cannot be empty")
190-
191-
def set_unsupported_step_actions(
192-
self, actions: List[UnsupportedStepActions]
193-
) -> None:
194-
actions_values = [action.value for action in actions]
195-
self.definition.variants.unsupported_step.set_actions(actions_values)
196-
197-
def set_cant_confidently_assess_step_actions(
198-
self, actions: List[CanConfidentlyAssessStepActions]
199-
) -> None:
200-
actions_values = [action.value for action in actions]
201-
self.definition.variants.cant_confidently_assess_step.set_actions(
202-
actions_values
203-
)
204-
205-
def set_no_factual_information_step_actions(
206-
self, actions: List[NoFactualInformationStepActions]
207-
) -> None:
208-
actions_values = [action.value for action in actions]
209-
self.definition.variants.no_factual_information_step.set_actions(
210-
actions_values
211-
)
212-
213-
def asdict(self) -> Dict[str, Any]:
214-
return {
215-
"tool": self.type.value,
216-
"name": self.name,
217-
"required": self.required,
218-
"schemaNodeId": self.schema_id,
219-
"featureSchemaId": self.feature_schema_id,
220-
"definition": self.definition.asdict(),
221-
}
222-
223-
@classmethod
224-
def from_dict(cls, dictionary: Dict[str, Any]) -> "FactCheckingTool":
225-
return cls(
226-
name=dictionary["name"],
227-
schema_id=dictionary.get("schemaNodeId", None),
228-
feature_schema_id=dictionary.get("featureSchemaId", None),
229-
required=dictionary.get("required", False),
230-
definition=FactCheckingDefinition.from_dict(
231-
dictionary["definition"]
232-
),
233-
)

libs/labelbox/tests/integration/test_ontology.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -422,19 +422,24 @@ def test_fact_checking_ontology(chat_evaluation_ontology):
422422
break
423423
assert fact_checking is not None
424424

425-
assert fact_checking.definition.variants.asdict() == [
426-
{"id": 0, "name": "Accurate", "actions": ["justification"]},
427-
{"id": 1, "name": "Inaccurate", "actions": ["justification"]},
428-
{"id": 2, "name": "Disputed", "actions": ["justification"]},
429-
{"id": 3, "name": "Unsupported", "actions": []},
430-
{
431-
"id": 4,
432-
"name": "Can't confidently assess",
433-
"actions": [],
434-
},
435-
{
436-
"id": 5,
437-
"name": "No factual information",
438-
"actions": [],
439-
},
440-
]
425+
assert fact_checking.definition.asdict() == {
426+
"title": "fact checking",
427+
"value": "fact_checking",
428+
"variants": [
429+
{"id": 0, "name": "Accurate", "actions": ["justification"]},
430+
{"id": 1, "name": "Inaccurate", "actions": ["justification"]},
431+
{"id": 2, "name": "Disputed", "actions": ["justification"]},
432+
{"id": 3, "name": "Unsupported", "actions": []},
433+
{
434+
"id": 4,
435+
"name": "Can't confidently assess",
436+
"actions": [],
437+
},
438+
{
439+
"id": 5,
440+
"name": "No factual information",
441+
"actions": [],
442+
},
443+
],
444+
"version": 1,
445+
}

libs/labelbox/tests/unit/test_unit_fact_checking_tool.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,50 +14,7 @@ def test_fact_checking_as_dict_default():
1414
"required": False,
1515
"schemaNodeId": None,
1616
"featureSchemaId": None,
17-
"definition": {
18-
"variants": [
19-
{"id": 0, "name": "Accurate", "actions": ["justification"]},
20-
{"id": 1, "name": "Inaccurate", "actions": ["justification"]},
21-
{"id": 2, "name": "Disputed", "actions": ["justification"]},
22-
{
23-
"id": 3,
24-
"name": "Unsupported",
25-
"actions": [],
26-
},
27-
{
28-
"id": 4,
29-
"name": "Can't confidently assess",
30-
"actions": [],
31-
},
32-
{
33-
"id": 5,
34-
"name": "No factual information",
35-
"actions": [],
36-
},
37-
],
38-
"version": 1,
39-
},
40-
}
41-
42-
assert tool_dict == expected_dict
43-
44-
45-
def test_step_reasoning_as_dict_with_actions():
46-
tool = FactCheckingTool(name="Fact Checking Tool")
47-
tool.set_unsupported_step_actions([])
48-
tool.set_cant_confidently_assess_step_actions([])
49-
tool.set_no_factual_information_step_actions([])
50-
51-
# Get the dictionary representation
52-
tool_dict = tool.asdict()
53-
54-
# Expected dictionary structure
55-
expected_dict = {
56-
"tool": "fact-checking",
57-
"name": "Fact Checking Tool",
58-
"required": False,
59-
"schemaNodeId": None,
60-
"featureSchemaId": None,
17+
"color": None,
6118
"definition": {
6219
"variants": [
6320
{"id": 0, "name": "Accurate", "actions": ["justification"]},

0 commit comments

Comments
 (0)