Skip to content

Commit 2fce6e3

Browse files
author
Val Brodsky
committed
Refactor StepReasoning to also reuse Variants
1 parent a728135 commit 2fce6e3

File tree

5 files changed

+90
-123
lines changed

5 files changed

+90
-123
lines changed

libs/labelbox/src/labelbox/schema/tool_building/fact_checking_tool.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import warnings
12
from dataclasses import dataclass, field
2-
from typing import Any, Dict, List, Optional, Set
3+
from enum import Enum
4+
from typing import Any, Dict, List, Optional
35

46
from labelbox.schema.tool_building.tool_type import ToolType
57
from labelbox.schema.tool_building.variant import (
@@ -8,6 +10,18 @@
810
)
911

1012

13+
class UnsupportedStepActions(Enum):
14+
WRITE_JUSTIFICATION = "writeJustification"
15+
16+
17+
class CanConfidentlyAssessStepActions(Enum):
18+
WRITE_JUSTIFICATION = "writeJustification"
19+
20+
21+
class NoFactualInformationStepActions(Enum):
22+
WRITE_JUSTIFICATION = "writeJustification"
23+
24+
1125
@dataclass
1226
class FactCheckingVariants:
1327
"""
@@ -26,21 +40,32 @@ class FactCheckingVariants:
2640
)
2741
unsupported_step: VariantWithActions = field(
2842
default_factory=lambda: VariantWithActions(
29-
id=3, name="Unsupported", _available_actions={"writeJustification"}
43+
id=3,
44+
name="Unsupported",
45+
_available_actions={
46+
action.value for action in UnsupportedStepActions
47+
},
48+
actions=[UnsupportedStepActions.WRITE_JUSTIFICATION.value],
3049
)
3150
)
3251
cant_confidently_assess_step: VariantWithActions = field(
3352
default_factory=lambda: VariantWithActions(
3453
id=4,
3554
name="Can't confidently assess",
36-
_available_actions={"writeJustification"},
55+
_available_actions={
56+
action.value for action in CanConfidentlyAssessStepActions
57+
},
58+
actions=[CanConfidentlyAssessStepActions.WRITE_JUSTIFICATION.value],
3759
)
3860
)
3961
no_factual_information_step: VariantWithActions = field(
4062
default_factory=lambda: VariantWithActions(
4163
id=5,
4264
name="No factual information",
43-
_available_actions={"writeJustification"},
65+
_available_actions={
66+
action.value for action in NoFactualInformationStepActions
67+
},
68+
actions=[NoFactualInformationStepActions.WRITE_JUSTIFICATION.value],
4469
)
4570
)
4671

@@ -138,23 +163,31 @@ class FactCheckingTool:
138163
default_factory=FactCheckingDefinition
139164
)
140165

166+
def __post_init__(self):
167+
warnings.warn(
168+
"This feature is experimental and subject to change.",
169+
)
170+
141171
def set_unsupported_step_actions(
142-
self, actions: Set[str] = {"writeJustification"}
172+
self, actions: List[UnsupportedStepActions]
143173
) -> None:
144-
self.definition.variants.unsupported_step.set_actions(actions)
174+
actions_values = [action.value for action in actions]
175+
self.definition.variants.unsupported_step.set_actions(actions_values)
145176

146177
def set_cant_confidently_assess_step_actions(
147-
self, actions: Set[str] = {"writeJustification"}
178+
self, actions: List[CanConfidentlyAssessStepActions]
148179
) -> None:
180+
actions_values = [action.value for action in actions]
149181
self.definition.variants.cant_confidently_assess_step.set_actions(
150-
actions
182+
actions_values
151183
)
152184

153185
def set_no_factual_information_step_actions(
154-
self, actions: Set[str] = {"writeJustification"}
186+
self, actions: List[NoFactualInformationStepActions]
155187
) -> None:
188+
actions_values = [action.value for action in actions]
156189
self.definition.variants.no_factual_information_step.set_actions(
157-
actions
190+
actions_values
158191
)
159192

160193
def asdict(self) -> Dict[str, Any]:

libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py

Lines changed: 29 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,15 @@
11
import warnings
22
from dataclasses import dataclass, field
3+
from enum import Enum
34
from typing import Any, Dict, List, Optional
45

56
from labelbox.schema.tool_building.tool_type import ToolType
7+
from labelbox.schema.tool_building.variant import Variant, VariantWithActions
68

79

8-
@dataclass
9-
class StepReasoningVariant:
10-
id: int
11-
name: str
12-
actions: List[str] = field(default_factory=list)
13-
14-
def asdict(self) -> Dict[str, Any]:
15-
return {"id": self.id, "name": self.name, "actions": self.actions}
16-
17-
18-
@dataclass
19-
class IncorrectStepReasoningVariant:
20-
id: int
21-
name: str
22-
regenerate_steps: Optional[bool] = True
23-
generate_and_rate_alternative_steps: Optional[bool] = True
24-
rewrite_step: Optional[bool] = True
25-
justification: Optional[bool] = True
26-
27-
def asdict(self) -> Dict[str, Any]:
28-
actions = []
29-
if self.regenerate_steps:
30-
actions.append("regenerateSteps")
31-
if self.generate_and_rate_alternative_steps:
32-
actions.append("generateAndRateAlternativeSteps")
33-
if self.rewrite_step:
34-
actions.append("rewriteStep")
35-
if self.justification:
36-
actions.append("justification")
37-
return {"id": self.id, "name": self.name, "actions": actions}
38-
39-
@classmethod
40-
def from_dict(
41-
cls, dictionary: Dict[str, Any]
42-
) -> "IncorrectStepReasoningVariant":
43-
return cls(
44-
id=dictionary["id"],
45-
name=dictionary["name"],
46-
regenerate_steps="regenerateSteps" in dictionary.get("actions", []),
47-
generate_and_rate_alternative_steps="generateAndRateAlternativeSteps"
48-
in dictionary.get("actions", []),
49-
rewrite_step="rewriteStep" in dictionary.get("actions", []),
50-
justification="justification" in dictionary.get("actions", []),
51-
)
52-
53-
54-
def _create_correct_step() -> StepReasoningVariant:
55-
return StepReasoningVariant(
56-
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
57-
)
58-
59-
60-
def _create_neutral_step() -> StepReasoningVariant:
61-
return StepReasoningVariant(
62-
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
63-
)
64-
65-
66-
def _create_incorrect_step() -> IncorrectStepReasoningVariant:
67-
return IncorrectStepReasoningVariant(
68-
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
69-
)
10+
class IncorrectStepActions(Enum):
11+
REGENERATE_STEPS = "regenerateSteps"
12+
GENERATE_AND_RATE_ALTERNATIVE_STEPS = "generateAndRateAlternativeSteps"
7013

7114

7215
@dataclass
@@ -76,18 +19,22 @@ class StepReasoningVariants:
7619
Currently the options are correct, neutral, and incorrect
7720
"""
7821

79-
CORRECT_STEP_ID = 0
80-
NEUTRAL_STEP_ID = 1
81-
INCORRECT_STEP_ID = 2
82-
83-
correct_step: StepReasoningVariant = field(
84-
default_factory=_create_correct_step
22+
correct_step: Variant = field(
23+
default_factory=lambda: Variant(id=0, name="Correct")
8524
)
86-
neutral_step: StepReasoningVariant = field(
87-
default_factory=_create_neutral_step
25+
neutral_step: Variant = field(
26+
default_factory=lambda: Variant(id=1, name="Neutral")
8827
)
89-
incorrect_step: IncorrectStepReasoningVariant = field(
90-
default_factory=_create_incorrect_step
28+
29+
incorrect_step: VariantWithActions = field(
30+
default_factory=lambda: VariantWithActions(
31+
id=2,
32+
name="Incorrect",
33+
_available_actions={
34+
action.value for action in IncorrectStepActions
35+
},
36+
actions=["regenerateSteps"], # regenerateSteps is on by default
37+
)
9138
)
9239

9340
def asdict(self):
@@ -104,14 +51,12 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
10451
incorrect_step = None
10552

10653
for variant in dictionary:
107-
if variant["id"] == cls.CORRECT_STEP_ID:
108-
correct_step = StepReasoningVariant(**variant)
109-
elif variant["id"] == cls.NEUTRAL_STEP_ID:
110-
neutral_step = StepReasoningVariant(**variant)
111-
elif variant["id"] == cls.INCORRECT_STEP_ID:
112-
incorrect_step = IncorrectStepReasoningVariant.from_dict(
113-
variant
114-
)
54+
if variant["id"] == 0:
55+
correct_step = Variant(**variant)
56+
elif variant["id"] == 1:
57+
neutral_step = Variant(**variant)
58+
elif variant["id"] == 2:
59+
incorrect_step = VariantWithActions(**variant)
11560

11661
if not all([correct_step, neutral_step, incorrect_step]):
11762
raise ValueError("Invalid step reasoning variants")
@@ -170,30 +115,12 @@ def __post_init__(self):
170115
"This feature is experimental and subject to change.",
171116
)
172117

173-
def reset_regenerate_steps(self):
174-
"""
175-
For live models, the default acation will invoke the model to generate alternatives if a step is marked as incorrect
176-
This method will reset the action to not regenerate the conversation
177-
"""
178-
self.definition.variants.incorrect_step.regenerate_steps = False
179-
180-
def reset_generate_and_rate_alternative_steps(self):
181-
"""
182-
For live models, will require labelers to rate the alternatives generated by the model
183-
"""
184-
self.definition.variants.incorrect_step.generate_and_rate_alternative_steps = False
185-
186-
def reset_rewrite_step(self):
187-
"""
188-
For live models, will require labelers to rewrite the conversation
189-
"""
190-
self.definition.variants.incorrect_step.rewrite_step = False
191-
192-
def reset_justification(self):
118+
def set_incorrect_step_actions(self, actions: List[IncorrectStepActions]):
193119
"""
194-
For live models, will require labelers to provide a justification for their evaluation
120+
For live models, will invoke the model to generate alternatives if a step is marked as incorrect
195121
"""
196-
self.definition.variants.incorrect_step.justification = False
122+
actions_values = [action.value for action in actions]
123+
self.definition.variants.incorrect_step.set_actions(actions_values)
197124

198125
def asdict(self) -> Dict[str, Any]:
199126
return {

libs/labelbox/src/labelbox/schema/tool_building/tool_type_mapping.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55

66
def map_tool_type_to_tool_cls(tool_type_str: str):
77
if not ToolType.valid(tool_type_str):
8-
raise ValueError(f"Invalid tool type {tool_type_str}")
8+
return None
99

1010
tool_type = ToolType(tool_type_str.lower())
1111
if tool_type == ToolType.STEP_REASONING:
1212
return StepReasoningTool
1313
elif tool_type == ToolType.FACT_CHECKING:
1414
return FactCheckingTool
15-
16-
return None

libs/labelbox/src/labelbox/schema/tool_building/variant.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class VariantWithActions:
2222
actions: List[str] = field(default_factory=list)
2323
_available_actions: Set[str] = field(default_factory=set)
2424

25-
def set_actions(self, actions: Set[str]) -> None:
25+
def set_actions(self, actions: List[str]) -> None:
26+
self.actions = []
2627
for action in actions:
2728
if action in self._available_actions:
2829
self.actions.append(action)
@@ -31,8 +32,11 @@ def reset_actions(self) -> None:
3132
self.actions = []
3233

3334
def asdict(self) -> Dict[str, Any]:
34-
return {
35+
data = {
3536
"id": self.id,
3637
"name": self.name,
37-
"actions": list(set(self.actions)),
3838
}
39+
if len(self.actions) > 0:
40+
data["actions"] = self.actions
41+
42+
return data

libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
1+
from labelbox.schema.tool_building.step_reasoning_tool import (
2+
IncorrectStepActions,
3+
StepReasoningTool,
4+
)
25

36

47
def test_step_reasoning_as_dict_default():
@@ -31,10 +34,12 @@ def test_step_reasoning_as_dict_default():
3134

3235
def test_step_reasoning_as_dict_with_actions():
3336
tool = StepReasoningTool(name="step reasoning")
34-
tool.reset_generate_and_rate_alternative_steps()
35-
tool.reset_regenerate_steps()
36-
tool.reset_rewrite_step()
37-
tool.reset_justification()
37+
tool.set_incorrect_step_actions(
38+
[
39+
IncorrectStepActions.REGENERATE_STEPS,
40+
IncorrectStepActions.GENERATE_AND_RATE_ALTERNATIVE_STEPS,
41+
]
42+
)
3843
assert tool.asdict() == {
3944
"tool": "step-reasoning",
4045
"name": "step reasoning",

0 commit comments

Comments
 (0)