Skip to content

Commit ca9ed12

Browse files
committed
Update StepReasoningTool to use _BaseStepReasoningTool
1 parent 647220f commit ca9ed12

File tree

5 files changed

+149
-168
lines changed

5 files changed

+149
-168
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import warnings
2+
from dataclasses import dataclass, field
3+
from typing import Any, Dict, List, Optional, Set
4+
5+
from labelbox.schema.tool_building.tool_type import ToolType
6+
7+
8+
@dataclass
9+
class _Variant:
10+
id: int
11+
name: str
12+
actions: List[str] = field(default_factory=list)
13+
_available_actions: Set[str] = field(default_factory=set)
14+
15+
def set_actions(self, actions: List[str]) -> None:
16+
self.actions = []
17+
for action in actions:
18+
if action in self._available_actions:
19+
self.actions.append(action)
20+
21+
def reset_actions(self) -> None:
22+
self.actions = []
23+
24+
def asdict(self) -> Dict[str, Any]:
25+
return {
26+
"id": self.id,
27+
"name": self.name,
28+
"actions": self.actions,
29+
}
30+
31+
32+
@dataclass
33+
class _Definition:
34+
variants: List[_Variant]
35+
version: int = field(default=1)
36+
title: Optional[str] = None
37+
value: Optional[str] = None
38+
39+
def __post_init__(self):
40+
if self.version != 1:
41+
raise ValueError("Invalid version")
42+
43+
def asdict(self) -> Dict[str, Any]:
44+
result = {
45+
"variants": [variant.asdict() for variant in self.variants],
46+
"version": self.version,
47+
}
48+
if self.title is not None:
49+
result["title"] = self.title
50+
if self.value is not None:
51+
result["value"] = self.value
52+
return result
53+
54+
@classmethod
55+
def from_dict(cls, dictionary: Dict[str, Any]) -> "_Definition":
56+
variants = [_Variant(**variant) for variant in dictionary["variants"]]
57+
title = dictionary.get("title", None)
58+
value = dictionary.get("value", None)
59+
return cls(variants=variants, title=title, value=value)
60+
61+
62+
@dataclass
63+
class _BaseStepReasoningTool:
64+
name: str
65+
type: ToolType
66+
definition: _Definition
67+
schema_id: Optional[str] = None
68+
feature_schema_id: Optional[str] = None
69+
color: Optional[str] = None
70+
required: bool = False
71+
72+
def __post_init__(self):
73+
warnings.warn(
74+
"This feature is experimental and subject to change.",
75+
)
76+
77+
if not self.name.strip():
78+
raise ValueError("Name is required")
79+
80+
def asdict(self) -> Dict[str, Any]:
81+
return {
82+
"tool": self.type.value,
83+
"name": self.name,
84+
"required": self.required,
85+
"schemaNodeId": self.schema_id,
86+
"featureSchemaId": self.feature_schema_id,
87+
"definition": self.definition.asdict(),
88+
"color": self.color,
89+
}
90+
91+
@classmethod
92+
def from_dict(cls, dictionary: Dict[str, Any]) -> "_BaseStepReasoningTool":
93+
return cls(
94+
name=dictionary["name"],
95+
schema_id=dictionary.get("schemaNodeId", None),
96+
feature_schema_id=dictionary.get("featureSchemaId", None),
97+
required=dictionary.get("required", False),
98+
definition=_Definition.from_dict(dictionary["definition"]),
99+
color=dictionary.get("color", None),
100+
)
Lines changed: 18 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
import warnings
21
from dataclasses import dataclass, field
32
from enum import Enum
4-
from typing import Any, Dict, List, Optional
53

4+
from labelbox.schema.tool_building.base_step_reasoning_tool import (
5+
_BaseStepReasoningTool,
6+
_Definition,
7+
_Variant,
8+
)
69
from labelbox.schema.tool_building.tool_type import ToolType
7-
from labelbox.schema.tool_building.variant import VariantWithActions
810

911

1012
class IncorrectStepActions(Enum):
@@ -14,156 +16,29 @@ class IncorrectStepActions(Enum):
1416
JUSTIFICATION = "justification"
1517

1618

17-
@dataclass
18-
class StepReasoningVariants:
19-
"""
20-
This class is used to define the possible options for evaluating a step
21-
Currently the options are correct, neutral, and incorrect
22-
NOTE: do not change the variant values
23-
"""
24-
25-
correct_step: VariantWithActions = field(
26-
default_factory=lambda: VariantWithActions(
27-
id=0,
28-
name="Correct",
29-
actions=[],
30-
)
31-
)
32-
neutral_step: VariantWithActions = field(
33-
default_factory=lambda: VariantWithActions(
34-
id=1,
35-
name="Neutral",
36-
actions=[],
37-
)
38-
)
39-
40-
incorrect_step: VariantWithActions = field(
41-
default_factory=lambda: VariantWithActions(
42-
id=2,
43-
name="Incorrect",
44-
_available_actions={
45-
action.value for action in IncorrectStepActions
46-
},
47-
actions=[
48-
action.value for action in IncorrectStepActions
49-
], # initialize to all IncorrectStepActions by default
50-
)
51-
)
52-
53-
def asdict(self):
54-
return [
55-
self.correct_step.asdict(),
56-
self.neutral_step.asdict(),
57-
self.incorrect_step.asdict(),
58-
]
59-
60-
@classmethod
61-
def from_dict(cls, dictionary: List[Dict[str, Any]]):
62-
correct_step = None
63-
neutral_step = None
64-
incorrect_step = None
65-
66-
for variant in dictionary:
67-
if variant["id"] == 0:
68-
correct_step = VariantWithActions(**variant)
69-
elif variant["id"] == 1:
70-
neutral_step = VariantWithActions(**variant)
71-
elif variant["id"] == 2:
72-
incorrect_step = VariantWithActions(**variant)
73-
74-
if not all([correct_step, neutral_step, incorrect_step]):
75-
raise ValueError("Invalid step reasoning variants")
76-
77-
return cls(
78-
correct_step=correct_step, # type: ignore
79-
neutral_step=neutral_step, # type: ignore
80-
incorrect_step=incorrect_step, # type: ignore
81-
)
82-
83-
84-
@dataclass
85-
class StepReasoningDefinition:
86-
variants: StepReasoningVariants = field(
87-
default_factory=StepReasoningVariants
19+
def build_step_reasoning_definition():
20+
correct_step = _Variant(id=0, name="Correct", actions=[])
21+
neutral_step = _Variant(id=1, name="Neutral", actions=[])
22+
incorrect_step = _Variant(
23+
id=2,
24+
name="Incorrect",
25+
_available_actions={action.value for action in IncorrectStepActions},
26+
actions=[action.value for action in IncorrectStepActions],
8827
)
89-
version: int = field(default=1)
90-
title: Optional[str] = None
91-
value: Optional[str] = None
92-
93-
def __post_init__(self):
94-
if self.version != 1:
95-
raise ValueError("Invalid version")
96-
97-
def asdict(self) -> Dict[str, Any]:
98-
result = {"variants": self.variants.asdict(), "version": self.version}
99-
if self.title is not None:
100-
result["title"] = self.title
101-
if self.value is not None:
102-
result["value"] = self.value
103-
return result
104-
105-
@classmethod
106-
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
107-
variants = StepReasoningVariants.from_dict(dictionary["variants"])
108-
title = dictionary.get("title", None)
109-
value = dictionary.get("value", None)
110-
return cls(variants=variants, title=title, value=value)
28+
variants = [correct_step, neutral_step, incorrect_step]
29+
return _Definition(variants=variants)
11130

11231

11332
@dataclass
114-
class StepReasoningTool:
33+
class StepReasoningTool(_BaseStepReasoningTool):
11534
"""
11635
Use this class in OntologyBuilder to create a tool for step reasoning
11736
The definition field lists the possible options to evaulate a step
11837
11938
NOTE: color attribute is for backward compatibility only and should not be set directly
12039
"""
12140

122-
name: str
12341
type: ToolType = field(default=ToolType.STEP_REASONING, init=False)
124-
required: bool = False
125-
schema_id: Optional[str] = None
126-
feature_schema_id: Optional[str] = None
127-
color: Optional[str] = None
128-
definition: StepReasoningDefinition = field(
129-
default_factory=StepReasoningDefinition
42+
definition: _Definition = field(
43+
default_factory=build_step_reasoning_definition
13044
)
131-
132-
def __post_init__(self):
133-
warnings.warn(
134-
"This feature is experimental and subject to change.",
135-
)
136-
137-
if not self.name.strip():
138-
raise ValueError("Name is required")
139-
140-
def set_incorrect_step_actions(self, actions: List[IncorrectStepActions]):
141-
"""
142-
For live models, will invoke the model to generate alternatives if a step is marked as incorrect
143-
NOTE by default all actions are set to True
144-
Pass empty list to reset to false
145-
"""
146-
actions_values = [action.value for action in actions]
147-
self.definition.variants.incorrect_step.set_actions(actions_values)
148-
149-
def asdict(self) -> Dict[str, Any]:
150-
return {
151-
"tool": self.type.value,
152-
"name": self.name,
153-
"required": self.required,
154-
"schemaNodeId": self.schema_id,
155-
"featureSchemaId": self.feature_schema_id,
156-
"definition": self.definition.asdict(),
157-
}
158-
159-
@classmethod
160-
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
161-
return cls(
162-
name=dictionary["name"],
163-
schema_id=dictionary.get("schemaNodeId", None),
164-
feature_schema_id=dictionary.get("featureSchemaId", None),
165-
required=dictionary.get("required", False),
166-
definition=StepReasoningDefinition.from_dict(
167-
dictionary["definition"]
168-
),
169-
)

libs/labelbox/tests/integration/test_ontology.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -358,28 +358,33 @@ def test_step_reasoning_ontology(chat_evaluation_ontology):
358358
step_reasoning_tool = tool
359359
break
360360
assert step_reasoning_tool is not None
361-
assert step_reasoning_tool.definition.variants.asdict() == [
362-
{
363-
"id": 0,
364-
"name": "Correct",
365-
"actions": [],
366-
},
367-
{
368-
"id": 1,
369-
"name": "Neutral",
370-
"actions": [],
371-
},
372-
{
373-
"id": 2,
374-
"name": "Incorrect",
375-
"actions": [
376-
"regenerateSteps",
377-
"generateAndRateAlternativeSteps",
378-
"rewriteStep",
379-
"justification",
380-
],
381-
},
382-
]
361+
assert step_reasoning_tool.definition.asdict() == {
362+
"title": "step reasoning",
363+
"value": "step_reasoning",
364+
"variants": [
365+
{
366+
"id": 0,
367+
"name": "Correct",
368+
"actions": [],
369+
},
370+
{
371+
"id": 1,
372+
"name": "Neutral",
373+
"actions": [],
374+
},
375+
{
376+
"id": 2,
377+
"name": "Incorrect",
378+
"actions": [
379+
"regenerateSteps",
380+
"generateAndRateAlternativeSteps",
381+
"rewriteStep",
382+
"justification",
383+
],
384+
},
385+
],
386+
"version": 1,
387+
}
383388

384389

385390
def test_fact_checking_ontology(chat_evaluation_ontology):

libs/labelbox/tests/unit/test_unit_step_ontology_variants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from labelbox.schema.tool_building.variant import VariantWithActions
1+
from labelbox.schema.tool_building.base_step_reasoning_tool import _Variant
22

33

4-
def test_variant_with_actions_as_dict():
5-
variant = VariantWithActions(
4+
def test_variant():
5+
variant = _Variant(
66
id=0, name="Correct", _available_actions={"regenerateSteps"}
77
)
88
variant.set_actions(["regenerateSteps"])

libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_step_reasoning_as_dict_default():
1818
"required": False,
1919
"schemaNodeId": None,
2020
"featureSchemaId": None,
21+
"color": None,
2122
"definition": {
2223
"variants": [
2324
{"id": 0, "name": "Correct", "actions": []},

0 commit comments

Comments
 (0)