Skip to content

Commit 50ea033

Browse files
xuangu-fangyou-n-g
andauthored
feat: advanced checkpoint selectors (#790)
* rebase selection code * bug-free run: checkpoint selection and dynamic EDA loading * add prototypes of various selectors, to imp. and test later * fix EDA write bug * imp SOTA-Jump policy * fix small bug * allow to set different selector by .env * add always-win selector * add init length for AlwaysWinCKPSelector * add back_jump selector * auto lint * add sota_exp_to_submit attribute; change the name of ckp_selector and sota-selector * fix bug * auto lint * working on auto sota selector * add subtrace counter * fix bug, remove unuse selector * add auto sota selector * auto lint * fix bug * fix small logic bug * add logging * add inject_diverse feat * auto lint * capable to None-select * feat: add hypothesis_gen config and ExpGen2TraceAndMerge functionality * refactor: use dynamic import for experiment generator instantiation * feat: add BestValidSelector for improved SOTA experiment selection * runnable twin-trace version * fix logic error of trace-merge * auto lint * use import_class to set selector, * auto-lint --------- Co-authored-by: Young <afe.young@gmail.com>
1 parent 0e54c9f commit 50ea033

13 files changed

Lines changed: 486 additions & 124 deletions

File tree

rdagent/app/data_science/conf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
1414
scen: str = "rdagent.scenarios.data_science.scen.KaggleScen"
1515
"""Scenario class for data mining model"""
1616

17+
hypothesis_gen: str = "rdagent.scenarios.data_science.proposal.exp_gen.DSExpGen"
18+
"""Hypothesis generation class"""
19+
1720
## Workflow Related
1821
consecutive_errors: int = 5
1922

@@ -47,6 +50,20 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
4750
enable_doc_dev: bool = False
4851
model_dump_check_level: Literal["medium", "high"] = "medium"
4952

53+
### selector related
54+
55+
#### checkpoint selector related
56+
# selector_name: str = "latest"
57+
selector_name: str = "rdagent.scenarios.data_science.proposal.exp_gen.ckp_select.LatestCKPSelector"
58+
"""The name of the selector to use"""
59+
sota_count_window: int = 5
60+
"""The number of trials to consider for SOTA count"""
61+
sota_count_threshold: int = 1
62+
"""The threshold for SOTA count"""
63+
64+
#### SOTA experiment selector related
65+
sota_exp_selector_name: str = "rdagent.scenarios.data_science.proposal.exp_gen.sota_exp_select.GlobalSOTASelector"
66+
"""The name of the SOTA experiment selector to use"""
5067
### knowledge base
5168
enable_knowledge_base: bool = False
5269
knowledge_base_version: str = "v1"
@@ -65,5 +82,8 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
6582
"""We'll use f"{DS_RD_SETTING.local_data_path}/{DS_RD_SETTING.eval_sub_dir}/{competition}"
6683
to find the scriipt to evaluate the submission on test"""
6784

85+
### inject diverse
86+
enable_inject_diverse: bool = False
87+
6888

6989
DS_RD_SETTING = DataScienceBasePropSetting()

rdagent/app/data_science/loop.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,31 @@
3232
from rdagent.scenarios.data_science.dev.runner import DSCoSTEERRunner
3333
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
3434
from rdagent.scenarios.data_science.proposal.exp_gen import DSExpGen, DSTrace
35+
from rdagent.scenarios.data_science.proposal.exp_gen.ckp_select import (
36+
BackJumpCKPSelector,
37+
LatestCKPSelector,
38+
SOTAJumpCKPSelector,
39+
)
3540
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import DSKnowledgeBase
36-
from rdagent.scenarios.data_science.proposal.exp_gen.select import LatestCKPSelector
41+
from rdagent.scenarios.data_science.proposal.exp_gen.sota_exp_select import (
42+
AutoSOTAexpSelector,
43+
BestValidSelector,
44+
GlobalSOTASelector,
45+
)
3746
from rdagent.scenarios.kaggle.kaggle_crawler import download_data
3847

48+
CKP_SELECTOR_NAME_MAP = {
49+
"latest": LatestCKPSelector,
50+
"sota_jump": SOTAJumpCKPSelector,
51+
"back_jump": BackJumpCKPSelector,
52+
}
53+
54+
SOTA_EXP_SELECTOR_NAME_MAP = {
55+
"global_sota": GlobalSOTASelector,
56+
"auto_sota": AutoSOTAexpSelector,
57+
"best_valid_sota": BestValidSelector,
58+
}
59+
3960

4061
class DataScienceRDLoop(RDLoop):
4162
skip_loop_error = (CoderError, RunnerError)
@@ -49,8 +70,15 @@ def __init__(self, PROP_SETTING: BasePropSetting):
4970

5071
# 2) task generation from a complete solution
5172
# self.exp_gen: ExpGen = import_class(PROP_SETTING.exp_gen)(scen)
52-
self.ckp_selector = LatestCKPSelector()
53-
self.exp_gen = DSExpGen(scen)
73+
74+
# self.ckp_selector = CKP_SELECTOR_NAME_MAP[DS_RD_SETTING.selector_name]()
75+
# self.sota_exp_selector = SOTA_EXP_SELECTOR_NAME_MAP[DS_RD_SETTING.sota_exp_selector_name]()
76+
self.ckp_selector = import_class(PROP_SETTING.selector_name)()
77+
self.sota_exp_selector = import_class(PROP_SETTING.sota_exp_selector_name)()
78+
79+
self.exp_gen = import_class(PROP_SETTING.hypothesis_gen)(scen)
80+
81+
# coders
5482
self.data_loader_coder = DataLoaderCoSTEER(scen)
5583
self.feature_coder = FeatureCoSTEER(scen)
5684
self.model_coder = ModelCoSTEER(scen)
@@ -76,6 +104,12 @@ def __init__(self, PROP_SETTING: BasePropSetting):
76104
super(RDLoop, self).__init__()
77105

78106
def direct_exp_gen(self, prev_out: dict[str, Any]):
107+
108+
# set the SOTA experiment to submit
109+
sota_exp_to_submit = self.sota_exp_selector.get_sota_exp_to_submit(self.trace)
110+
self.trace.set_sota_exp_to_submit(sota_exp_to_submit)
111+
112+
# set the checkpoint to start from
79113
selection = self.ckp_selector.get_selection(self.trace)
80114
exp = self.exp_gen.gen(self.trace, selection)
81115
logger.log_object(exp)

rdagent/components/coder/data_science/pipeline/eval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ def evaluate(
127127

128128
eda_output = implementation.file_dict.get("EDA.md", None)
129129

130+
eda_output = implementation.file_dict.get("EDA.md", None)
131+
132+
if not isinstance(implementation, FBWorkspace):
133+
eda_output = None
134+
else:
135+
eda_output = implementation.file_dict.get("EDA.md", None)
136+
130137
system_prompt = T(".prompts:pipeline_eval.system").r(
131138
scenario=self.scen.get_scenario_all_desc(eda_output=eda_output),
132139
task_desc=target_task.get_task_information(),

rdagent/core/proposal.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,22 @@ def get_selection(self, trace: Trace) -> tuple[int, ...] | None:
149149
- `(idx, )` represents starting from the `idx`-th trial in the trace.
150150
- `None` represents starting from scratch (start a new trace)
151151
152-
153152
- More advanced selection strategies in `select.py`
154153
"""
155154

156155

156+
class SOTAexpSelector:
157+
"""
158+
Select the SOTA experiment from the trace to submit
159+
"""
160+
161+
@abstractmethod
162+
def get_sota_exp_to_submit(self, trace: Trace) -> Experiment | None:
163+
"""
164+
Select the SOTA experiment from the trace to submit
165+
"""
166+
167+
157168
class ExpGen(ABC):
158169

159170
def __init__(self, scen: Scenario) -> None:

rdagent/scenarios/data_science/dev/feedback.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ def generate_feedback(self, exp: DSExperiment, trace: DSTrace) -> ExperimentFeed
3131
exp=sota_exp, heading="SOTA of previous exploration of the scenario"
3232
)
3333

34+
last_exp = trace.last_exp()
35+
3436
# Get feedback description using shared template
3537
feedback_desc = T("scenarios.data_science.share:describe.feedback").r(
36-
exp_and_feedback=(trace.hist[-1] if trace.hist else None), heading="Previous Trial Feedback"
38+
exp_and_feedback=trace.hist[-1] if trace.hist else None, heading="Previous Trial Feedback"
3739
)
3840

3941
# TODO:

rdagent/scenarios/data_science/proposal/exp_gen/base.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,17 @@ def __init__(self, scen: DataScienceScen, knowledge_base: KnowledgeBase | None =
6161

6262
self.knowledge_base = knowledge_base
6363

64+
self.sub_trace_count: int = 0
65+
6466
self.current_selection: tuple[int, ...] = (-1,)
6567

68+
self.sota_exp_to_submit: DSExperiment | None = None # grab the global best exp to submit
69+
6670
COMPLETE_ORDER = ("DataLoadSpec", "FeatureEng", "Model", "Ensemble", "Workflow")
6771

72+
def set_sota_exp_to_submit(self, exp: DSExperiment) -> None:
73+
self.sota_exp_to_submit = exp
74+
6875
def get_current_selection(self) -> tuple[int, ...]:
6976
return self.current_selection
7077

@@ -127,15 +134,22 @@ def retrieve_search_list(
127134
list[tuple[DSExperiment, ExperimentFeedback]]
128135
The search list.
129136
"""
137+
if search_type == "all":
138+
return self.hist
130139

131-
if selection is None:
132-
selection = self.get_current_selection()
140+
elif search_type == "ancestors":
133141

134-
if selection is None:
135-
# selection is None, which means we switch to a new trace, which is not implemented yet
136-
return []
142+
if selection is None:
143+
selection = self.get_current_selection()
137144

138-
return self.collect_all_ancestors(selection) if search_type == "ancestors" else self.hist
145+
if len(selection) == 0:
146+
# selection is (), which means we switch to a new trace
147+
return []
148+
149+
return self.collect_all_ancestors(selection)
150+
151+
else:
152+
raise ValueError(f"Invalid search type: {search_type}")
139153

140154
def collect_all_ancestors(
141155
self,
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import random
2+
3+
from rdagent.app.data_science.conf import DS_RD_SETTING
4+
from rdagent.core.proposal import CheckpointSelector, Trace
5+
from rdagent.log import rdagent_logger as logger
6+
7+
# # TODO: more advanced selector
8+
# # TODO/Discussion: load selector function here or define selector class in `proposal.py`?
9+
10+
11+
class LatestCKPSelector(CheckpointSelector):
12+
"""
13+
-`(-1, )` represents starting from the latest trial in the trace
14+
"""
15+
16+
def __init__(
17+
self,
18+
):
19+
logger.info(f"Using latest selector by default")
20+
21+
def get_selection(self, trace: Trace) -> tuple[int, ...]:
22+
23+
return (-1,)
24+
25+
26+
class SOTAJumpCKPSelector(CheckpointSelector):
27+
"""
28+
SOTA jump policy:
29+
if the cumulative SOTA in a window is below a threshold, jump to a new trial
30+
otherwise, continue the current latest trial
31+
"""
32+
33+
def __init__(
34+
self,
35+
) -> None:
36+
self.SOTA_COUNT_WINDOW = DS_RD_SETTING.sota_count_window
37+
self.SOTA_COUNT_THRESHOLD = DS_RD_SETTING.sota_count_threshold
38+
39+
logger.info(
40+
f"Using SOTA-jump selector with window {self.SOTA_COUNT_WINDOW} and threshold {self.SOTA_COUNT_THRESHOLD}"
41+
)
42+
43+
def get_selection(self, trace: Trace) -> tuple[int, ...]:
44+
45+
current_trace = trace.retrieve_search_list(search_type="ancestors")
46+
if len(trace.hist) > 0 and len(current_trace) > self.SOTA_COUNT_WINDOW:
47+
all_exp_list = trace.experiment_and_feedback_list_after_init(return_type="all", search_type="ancestors")
48+
# sota_exp_list = trace.experiment_and_feedback_list_after_init(return_type="sota", search_type="ancestors")
49+
exp_list_in_window = all_exp_list[-self.SOTA_COUNT_WINDOW :]
50+
51+
# compute the cumulative SOTA ratio in the window
52+
sota_count = 0
53+
for exp, fb in exp_list_in_window:
54+
if fb.decision:
55+
sota_count += 1
56+
if sota_count < self.SOTA_COUNT_THRESHOLD:
57+
trace.sub_trace_count += 1
58+
logger.info(
59+
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump to a new sub-trace"
60+
)
61+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
62+
return ()
63+
else:
64+
logger.info(
65+
f"SOTA count {sota_count} is above threshold {self.SOTA_COUNT_THRESHOLD}, continue the current latest trial"
66+
)
67+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
68+
return (-1,)
69+
70+
else:
71+
logger.info(f"Not enough history to make a decision, continue the current latest trial")
72+
return (-1,)
73+
74+
75+
class BackJumpCKPSelector(CheckpointSelector):
76+
"""
77+
back-jump policy:
78+
if the cumulative SOTA in a window is below a threshold,
79+
with 50% probability, reboot a new sub-trace
80+
with 50% probability, jump back to the "last second" SOTA trial (we assume the lastest SOTA trial is not good enough selection)
81+
"""
82+
83+
def __init__(
84+
self,
85+
) -> None:
86+
self.SOTA_COUNT_WINDOW = DS_RD_SETTING.sota_count_window
87+
self.SOTA_COUNT_THRESHOLD = DS_RD_SETTING.sota_count_threshold
88+
89+
logger.info(
90+
f"Using back-jump selector with window {self.SOTA_COUNT_WINDOW} and threshold {self.SOTA_COUNT_THRESHOLD}"
91+
)
92+
93+
def get_selection(self, trace: Trace) -> tuple[int, ...]:
94+
current_trace = trace.retrieve_search_list(search_type="ancestors")
95+
96+
if len(trace.hist) > 0 and len(current_trace) > self.SOTA_COUNT_WINDOW:
97+
98+
all_exp_list = trace.experiment_and_feedback_list_after_init(return_type="all", search_type="ancestors")
99+
# sota_exp_list = trace.experiment_and_feedback_list_after_init(return_type="sota", search_type="ancestors")
100+
exp_list_in_window = all_exp_list[-self.SOTA_COUNT_WINDOW :]
101+
102+
# compute the cumulative SOTA ratio in the window
103+
sota_count = 0
104+
for exp, fb in exp_list_in_window:
105+
if fb.decision:
106+
sota_count += 1
107+
108+
if sota_count < self.SOTA_COUNT_THRESHOLD:
109+
110+
random_choice = random.random()
111+
if random_choice < 0.5:
112+
trace.sub_trace_count += 1
113+
logger.info(
114+
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump a new sub-trace"
115+
)
116+
return () # reboot a new sub-trace
117+
else:
118+
logger.info(
119+
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump back to the last second SOTA in hist (may not in current sub-trace)"
120+
)
121+
sota_exp_list = trace.experiment_and_feedback_list_after_init(return_type="sota", search_type="all")
122+
if len(sota_exp_list) > 1:
123+
last_second_sota_idx = trace.hist.index(sota_exp_list[-2])
124+
logger.info(
125+
f"jump back to the last second SOTA in hist (may not in current sub-trace), index: {last_second_sota_idx}"
126+
)
127+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
128+
return (last_second_sota_idx,)
129+
else:
130+
trace.sub_trace_count += 1
131+
logger.info(
132+
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump a new sub-trace"
133+
)
134+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
135+
return () # reboot a new sub-trace
136+
137+
else:
138+
logger.info(
139+
f"SOTA count {sota_count} is above threshold {self.SOTA_COUNT_THRESHOLD}, continue the current latest trial"
140+
)
141+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
142+
return (-1,)
143+
else:
144+
logger.info(f"Not enough history to make a decision, continue the current latest trial")
145+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
146+
return (-1,)
147+
148+
149+
# TODO: implement these selectors and more

0 commit comments

Comments
 (0)