1
- import warnings
2
1
from dataclasses import dataclass , field
3
2
from enum import Enum
4
- from typing import Any , Dict , List , Optional
5
3
4
+ from labelbox .schema .tool_building .base_step_reasoning_tool import (
5
+ _BaseStepReasoningTool ,
6
+ _Definition ,
7
+ _Variant ,
8
+ )
6
9
from labelbox .schema .tool_building .tool_type import ToolType
7
- from labelbox .schema .tool_building .variant import VariantWithActions
8
10
9
11
10
12
class IncorrectStepActions (Enum ):
@@ -14,156 +16,29 @@ class IncorrectStepActions(Enum):
14
16
JUSTIFICATION = "justification"
15
17
16
18
17
- @dataclass
18
- class StepReasoningVariants :
19
- """
20
- This class is used to define the possible options for evaluating a step
21
- Currently the options are correct, neutral, and incorrect
22
- NOTE: do not change the variant values
23
- """
24
-
25
- correct_step : VariantWithActions = field (
26
- default_factory = lambda : VariantWithActions (
27
- id = 0 ,
28
- name = "Correct" ,
29
- actions = [],
30
- )
31
- )
32
- neutral_step : VariantWithActions = field (
33
- default_factory = lambda : VariantWithActions (
34
- id = 1 ,
35
- name = "Neutral" ,
36
- actions = [],
37
- )
38
- )
39
-
40
- incorrect_step : VariantWithActions = field (
41
- default_factory = lambda : VariantWithActions (
42
- id = 2 ,
43
- name = "Incorrect" ,
44
- _available_actions = {
45
- action .value for action in IncorrectStepActions
46
- },
47
- actions = [
48
- action .value for action in IncorrectStepActions
49
- ], # initialize to all IncorrectStepActions by default
50
- )
51
- )
52
-
53
- def asdict (self ):
54
- return [
55
- self .correct_step .asdict (),
56
- self .neutral_step .asdict (),
57
- self .incorrect_step .asdict (),
58
- ]
59
-
60
- @classmethod
61
- def from_dict (cls , dictionary : List [Dict [str , Any ]]):
62
- correct_step = None
63
- neutral_step = None
64
- incorrect_step = None
65
-
66
- for variant in dictionary :
67
- if variant ["id" ] == 0 :
68
- correct_step = VariantWithActions (** variant )
69
- elif variant ["id" ] == 1 :
70
- neutral_step = VariantWithActions (** variant )
71
- elif variant ["id" ] == 2 :
72
- incorrect_step = VariantWithActions (** variant )
73
-
74
- if not all ([correct_step , neutral_step , incorrect_step ]):
75
- raise ValueError ("Invalid step reasoning variants" )
76
-
77
- return cls (
78
- correct_step = correct_step , # type: ignore
79
- neutral_step = neutral_step , # type: ignore
80
- incorrect_step = incorrect_step , # type: ignore
81
- )
82
-
83
-
84
- @dataclass
85
- class StepReasoningDefinition :
86
- variants : StepReasoningVariants = field (
87
- default_factory = StepReasoningVariants
19
+ def build_step_reasoning_definition ():
20
+ correct_step = _Variant (id = 0 , name = "Correct" , actions = [])
21
+ neutral_step = _Variant (id = 1 , name = "Neutral" , actions = [])
22
+ incorrect_step = _Variant (
23
+ id = 2 ,
24
+ name = "Incorrect" ,
25
+ _available_actions = {action .value for action in IncorrectStepActions },
26
+ actions = [action .value for action in IncorrectStepActions ],
88
27
)
89
- version : int = field (default = 1 )
90
- title : Optional [str ] = None
91
- value : Optional [str ] = None
92
-
93
- def __post_init__ (self ):
94
- if self .version != 1 :
95
- raise ValueError ("Invalid version" )
96
-
97
- def asdict (self ) -> Dict [str , Any ]:
98
- result = {"variants" : self .variants .asdict (), "version" : self .version }
99
- if self .title is not None :
100
- result ["title" ] = self .title
101
- if self .value is not None :
102
- result ["value" ] = self .value
103
- return result
104
-
105
- @classmethod
106
- def from_dict (cls , dictionary : Dict [str , Any ]) -> "StepReasoningDefinition" :
107
- variants = StepReasoningVariants .from_dict (dictionary ["variants" ])
108
- title = dictionary .get ("title" , None )
109
- value = dictionary .get ("value" , None )
110
- return cls (variants = variants , title = title , value = value )
28
+ variants = [correct_step , neutral_step , incorrect_step ]
29
+ return _Definition (variants = variants )
111
30
112
31
113
32
@dataclass
114
- class StepReasoningTool :
33
+ class StepReasoningTool ( _BaseStepReasoningTool ) :
115
34
"""
116
35
Use this class in OntologyBuilder to create a tool for step reasoning
117
36
The definition field lists the possible options to evaulate a step
118
37
119
38
NOTE: color attribute is for backward compatibility only and should not be set directly
120
39
"""
121
40
122
- name : str
123
41
type : ToolType = field (default = ToolType .STEP_REASONING , init = False )
124
- required : bool = False
125
- schema_id : Optional [str ] = None
126
- feature_schema_id : Optional [str ] = None
127
- color : Optional [str ] = None
128
- definition : StepReasoningDefinition = field (
129
- default_factory = StepReasoningDefinition
42
+ definition : _Definition = field (
43
+ default_factory = build_step_reasoning_definition
130
44
)
131
-
132
- def __post_init__ (self ):
133
- warnings .warn (
134
- "This feature is experimental and subject to change." ,
135
- )
136
-
137
- if not self .name .strip ():
138
- raise ValueError ("Name is required" )
139
-
140
- def set_incorrect_step_actions (self , actions : List [IncorrectStepActions ]):
141
- """
142
- For live models, will invoke the model to generate alternatives if a step is marked as incorrect
143
- NOTE by default all actions are set to True
144
- Pass empty list to reset to false
145
- """
146
- actions_values = [action .value for action in actions ]
147
- self .definition .variants .incorrect_step .set_actions (actions_values )
148
-
149
- def asdict (self ) -> Dict [str , Any ]:
150
- return {
151
- "tool" : self .type .value ,
152
- "name" : self .name ,
153
- "required" : self .required ,
154
- "schemaNodeId" : self .schema_id ,
155
- "featureSchemaId" : self .feature_schema_id ,
156
- "definition" : self .definition .asdict (),
157
- }
158
-
159
- @classmethod
160
- def from_dict (cls , dictionary : Dict [str , Any ]) -> "StepReasoningTool" :
161
- return cls (
162
- name = dictionary ["name" ],
163
- schema_id = dictionary .get ("schemaNodeId" , None ),
164
- feature_schema_id = dictionary .get ("featureSchemaId" , None ),
165
- required = dictionary .get ("required" , False ),
166
- definition = StepReasoningDefinition .from_dict (
167
- dictionary ["definition" ]
168
- ),
169
- )
0 commit comments