1
1
import warnings
2
2
from dataclasses import dataclass , field
3
+ from enum import Enum
3
4
from typing import Any , Dict , List , Optional
4
5
5
6
from labelbox .schema .tool_building .tool_type import ToolType
7
+ from labelbox .schema .tool_building .variant import Variant , VariantWithActions
6
8
7
9
8
- @dataclass
9
- class StepReasoningVariant :
10
- id : int
11
- name : str
12
- actions : List [str ] = field (default_factory = list )
13
-
14
- def asdict (self ) -> Dict [str , Any ]:
15
- return {"id" : self .id , "name" : self .name , "actions" : self .actions }
16
-
17
-
18
- @dataclass
19
- class IncorrectStepReasoningVariant :
20
- id : int
21
- name : str
22
- regenerate_steps : Optional [bool ] = True
23
- generate_and_rate_alternative_steps : Optional [bool ] = True
24
- rewrite_step : Optional [bool ] = True
25
- justification : Optional [bool ] = True
26
-
27
- def asdict (self ) -> Dict [str , Any ]:
28
- actions = []
29
- if self .regenerate_steps :
30
- actions .append ("regenerateSteps" )
31
- if self .generate_and_rate_alternative_steps :
32
- actions .append ("generateAndRateAlternativeSteps" )
33
- if self .rewrite_step :
34
- actions .append ("rewriteStep" )
35
- if self .justification :
36
- actions .append ("justification" )
37
- return {"id" : self .id , "name" : self .name , "actions" : actions }
38
-
39
- @classmethod
40
- def from_dict (
41
- cls , dictionary : Dict [str , Any ]
42
- ) -> "IncorrectStepReasoningVariant" :
43
- return cls (
44
- id = dictionary ["id" ],
45
- name = dictionary ["name" ],
46
- regenerate_steps = "regenerateSteps" in dictionary .get ("actions" , []),
47
- generate_and_rate_alternative_steps = "generateAndRateAlternativeSteps"
48
- in dictionary .get ("actions" , []),
49
- rewrite_step = "rewriteStep" in dictionary .get ("actions" , []),
50
- justification = "justification" in dictionary .get ("actions" , []),
51
- )
52
-
53
-
54
- def _create_correct_step () -> StepReasoningVariant :
55
- return StepReasoningVariant (
56
- id = StepReasoningVariants .CORRECT_STEP_ID , name = "Correct"
57
- )
58
-
59
-
60
- def _create_neutral_step () -> StepReasoningVariant :
61
- return StepReasoningVariant (
62
- id = StepReasoningVariants .NEUTRAL_STEP_ID , name = "Neutral"
63
- )
64
-
65
-
66
- def _create_incorrect_step () -> IncorrectStepReasoningVariant :
67
- return IncorrectStepReasoningVariant (
68
- id = StepReasoningVariants .INCORRECT_STEP_ID , name = "Incorrect"
69
- )
10
+ class IncorrectStepActions (Enum ):
11
+ REGENERATE_STEPS = "regenerateSteps"
12
+ GENERATE_AND_RATE_ALTERNATIVE_STEPS = "generateAndRateAlternativeSteps"
70
13
71
14
72
15
@dataclass
@@ -76,18 +19,22 @@ class StepReasoningVariants:
76
19
Currently the options are correct, neutral, and incorrect
77
20
"""
78
21
79
- CORRECT_STEP_ID = 0
80
- NEUTRAL_STEP_ID = 1
81
- INCORRECT_STEP_ID = 2
82
-
83
- correct_step : StepReasoningVariant = field (
84
- default_factory = _create_correct_step
22
+ correct_step : Variant = field (
23
+ default_factory = lambda : Variant (id = 0 , name = "Correct" )
85
24
)
86
- neutral_step : StepReasoningVariant = field (
87
- default_factory = _create_neutral_step
25
+ neutral_step : Variant = field (
26
+ default_factory = lambda : Variant ( id = 1 , name = "Neutral" )
88
27
)
89
- incorrect_step : IncorrectStepReasoningVariant = field (
90
- default_factory = _create_incorrect_step
28
+
29
+ incorrect_step : VariantWithActions = field (
30
+ default_factory = lambda : VariantWithActions (
31
+ id = 2 ,
32
+ name = "Incorrect" ,
33
+ _available_actions = {
34
+ action .value for action in IncorrectStepActions
35
+ },
36
+ actions = ["regenerateSteps" ], # regenerateSteps is on by default
37
+ )
91
38
)
92
39
93
40
def asdict (self ):
@@ -104,14 +51,12 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
104
51
incorrect_step = None
105
52
106
53
for variant in dictionary :
107
- if variant ["id" ] == cls .CORRECT_STEP_ID :
108
- correct_step = StepReasoningVariant (** variant )
109
- elif variant ["id" ] == cls .NEUTRAL_STEP_ID :
110
- neutral_step = StepReasoningVariant (** variant )
111
- elif variant ["id" ] == cls .INCORRECT_STEP_ID :
112
- incorrect_step = IncorrectStepReasoningVariant .from_dict (
113
- variant
114
- )
54
+ if variant ["id" ] == 0 :
55
+ correct_step = Variant (** variant )
56
+ elif variant ["id" ] == 1 :
57
+ neutral_step = Variant (** variant )
58
+ elif variant ["id" ] == 2 :
59
+ incorrect_step = VariantWithActions (** variant )
115
60
116
61
if not all ([correct_step , neutral_step , incorrect_step ]):
117
62
raise ValueError ("Invalid step reasoning variants" )
@@ -170,30 +115,12 @@ def __post_init__(self):
170
115
"This feature is experimental and subject to change." ,
171
116
)
172
117
173
- def reset_regenerate_steps (self ):
174
- """
175
- For live models, the default acation will invoke the model to generate alternatives if a step is marked as incorrect
176
- This method will reset the action to not regenerate the conversation
177
- """
178
- self .definition .variants .incorrect_step .regenerate_steps = False
179
-
180
- def reset_generate_and_rate_alternative_steps (self ):
181
- """
182
- For live models, will require labelers to rate the alternatives generated by the model
183
- """
184
- self .definition .variants .incorrect_step .generate_and_rate_alternative_steps = False
185
-
186
- def reset_rewrite_step (self ):
187
- """
188
- For live models, will require labelers to rewrite the conversation
189
- """
190
- self .definition .variants .incorrect_step .rewrite_step = False
191
-
192
- def reset_justification (self ):
118
+ def set_incorrect_step_actions (self , actions : List [IncorrectStepActions ]):
193
119
"""
194
- For live models, will require labelers to provide a justification for their evaluation
120
+ For live models, will invoke the model to generate alternatives if a step is marked as incorrect
195
121
"""
196
- self .definition .variants .incorrect_step .justification = False
122
+ actions_values = [action .value for action in actions ]
123
+ self .definition .variants .incorrect_step .set_actions (actions_values )
197
124
198
125
def asdict (self ) -> Dict [str , Any ]:
199
126
return {
0 commit comments