Skip to content

Commit 2112d67

Browse files
xuangu-fangyou-n-g
andauthored
feat: multi-trace online merge (#886)
* prompt: highlight overfitting rist in AutoSOTAexpSelector * set online merge time in conf * online multi-trace merge with time-limit policy * fix typo * feat: allow soft-knowledge-base + multi_trace * fix: improve file tree and _walk symlink handling (#877) * refactor: improve file tree and _walk symlink handling * remove unused code * lint * prompt: highlight overfitting rist in AutoSOTAexpSelector * set online merge time in conf * online multi-trace merge with time-limit policy * fix typo * feat: allow soft-knowledge-base + multi_trace * auto-lint * put the multi-trace related config together --------- Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
1 parent ee23e4e commit 2112d67

5 files changed

Lines changed: 282 additions & 18 deletions

File tree

rdagent/app/data_science/conf.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,6 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
5151
enable_doc_dev: bool = False
5252
model_dump_check_level: Literal["medium", "high"] = "medium"
5353

54-
### selector related
55-
56-
#### checkpoint selector related
57-
# selector_name: str = "latest"
58-
selector_name: str = "rdagent.scenarios.data_science.proposal.exp_gen.ckp_select.LatestCKPSelector"
59-
"""The name of the selector to use"""
60-
sota_count_window: int = 5
61-
"""The number of trials to consider for SOTA count"""
62-
sota_count_threshold: int = 1
63-
"""The threshold for SOTA count"""
64-
65-
#### SOTA experiment selector related
66-
sota_exp_selector_name: str = "rdagent.scenarios.data_science.proposal.exp_gen.sota_exp_select.GlobalSOTASelector"
67-
"""The name of the SOTA experiment selector to use"""
6854
### knowledge base
6955
enable_knowledge_base: bool = False
7056
knowledge_base_version: str = "v1"
@@ -83,8 +69,34 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
8369
"""We'll use f"{DS_RD_SETTING.local_data_path}/{DS_RD_SETTING.eval_sub_dir}/{competition}"
8470
to find the scriipt to evaluate the submission on test"""
8571

86-
### inject diverse
72+
"""---below are the settings for multi-trace---"""
73+
74+
### multi-trace related
75+
max_trace_num: int = 3
76+
"""The maximum number of traces to grow before merging"""
77+
78+
#### multi-trace:checkpoint selector
79+
selector_name: str = "rdagent.scenarios.data_science.proposal.exp_gen.ckp_select.LatestCKPSelector"
80+
"""The name of the selector to use"""
81+
sota_count_window: int = 5
82+
"""The number of trials to consider for SOTA count"""
83+
sota_count_threshold: int = 1
84+
"""The threshold for SOTA count"""
85+
86+
#### multi-trace: SOTA experiment selector
87+
sota_exp_selector_name: str = "rdagent.scenarios.data_science.proposal.exp_gen.sota_exp_select.GlobalSOTASelector"
88+
"""The name of the SOTA experiment selector to use"""
89+
90+
### multi-trace:inject optimals for multi-trace
91+
# inject diverse when start a new sub-trace
8792
enable_inject_diverse: bool = False
8893

94+
# inject diverse at the root of the trace
95+
enable_inject_knowledge_at_root: bool = False
96+
97+
#### multi-trace: time for final multi-trace merge
98+
merge_hours: int = 2
99+
"""The time for merge"""
100+
89101

90102
DS_RD_SETTING = DataScienceBasePropSetting()

rdagent/scenarios/data_science/proposal/exp_gen/ckp_select.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import random
2+
from datetime import datetime, timedelta
23

34
from rdagent.app.data_science.conf import DS_RD_SETTING
45
from rdagent.core.proposal import CheckpointSelector, Trace
56
from rdagent.log import rdagent_logger as logger
7+
from rdagent.log.timer import RD_Agent_TIMER_wrapper, RDAgentTimer
68

79
# # TODO: more advanced selector
810
# # TODO/Discussion: load selector function here or define selector class in `proposal.py`?
@@ -23,6 +25,80 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
2325
return (-1,)
2426

2527

28+
class LimitTimeCKPSelector(CheckpointSelector):
29+
"""
30+
recore the time of current sub-trace, and jump to a new sub-trace if the time is up
31+
"""
32+
33+
def __init__(
34+
self,
35+
):
36+
self.global_timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
37+
self.sub_trace_start_times = {}
38+
self.MAX_TRACE_NUM = DS_RD_SETTING.max_trace_num
39+
self.time_limit_pre_trace = None
40+
41+
def set_time_limit(self):
42+
43+
# Calculate total time excluding merge hours
44+
remaining_time = (
45+
self.global_timer.all_duration.total_seconds() - timedelta(hours=DS_RD_SETTING.merge_hours).total_seconds()
46+
)
47+
# Convert to timedelta after division
48+
self.time_limit_pre_trace = timedelta(seconds=remaining_time / DS_RD_SETTING.max_trace_num)
49+
# Track when each sub-trace starts
50+
logger.info(f"Using limit time selector with time limit {self.time_limit_pre_trace} per trace")
51+
52+
def get_selection(self, trace: Trace) -> tuple[int, ...]:
53+
"""
54+
Determine whether to continue with the current sub-trace or start a new one
55+
based on the time spent in the current sub-trace.
56+
57+
Returns:
58+
(-1,): Continue with the current latest trial
59+
(): Start a new sub-trace if max trace limit not reached
60+
"""
61+
62+
if self.time_limit_pre_trace is None:
63+
self.set_time_limit()
64+
65+
current_time = datetime.now()
66+
67+
if len(trace.hist) == 0:
68+
trace.sub_trace_count = 0
69+
self.sub_trace_start_times[trace.sub_trace_count] = current_time
70+
logger.info(f"Starting initial sub-trace {trace.sub_trace_count} at {current_time}")
71+
return (-1,) # Continue with latest trial for new sub-trace
72+
73+
# Calculate elapsed time for current sub-trace
74+
elapsed_time = current_time - self.sub_trace_start_times[trace.sub_trace_count]
75+
76+
if elapsed_time < self.time_limit_pre_trace:
77+
# Continue with current sub-trace
78+
logger.info(
79+
f"Elapsed time {elapsed_time} is below time limit {self.time_limit_pre_trace}, continue the current sub-trace"
80+
)
81+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
82+
return (-1,)
83+
else:
84+
# Check if we've reached the maximum number of traces
85+
if trace.sub_trace_count + 1 >= self.MAX_TRACE_NUM:
86+
logger.info(
87+
f"Reached maximum trace count ({self.MAX_TRACE_NUM}), continuing with the current sub-trace"
88+
)
89+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
90+
return (-1,)
91+
92+
# Time limit exceeded, start a new sub-trace
93+
trace.sub_trace_count += 1
94+
self.sub_trace_start_times[trace.sub_trace_count] = current_time
95+
logger.info(
96+
f"Elapsed time {elapsed_time} exceeds time limit {self.time_limit_pre_trace}, jump to a new sub-trace"
97+
)
98+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
99+
return tuple() # Empty tuple signals starting a new sub-trace
100+
101+
26102
class SOTAJumpCKPSelector(CheckpointSelector):
27103
"""
28104
SOTA jump policy:
@@ -35,13 +111,13 @@ def __init__(
35111
) -> None:
36112
self.SOTA_COUNT_WINDOW = DS_RD_SETTING.sota_count_window
37113
self.SOTA_COUNT_THRESHOLD = DS_RD_SETTING.sota_count_threshold
114+
self.MAX_TRACE_NUM = DS_RD_SETTING.max_trace_num
38115

39116
logger.info(
40117
f"Using SOTA-jump selector with window {self.SOTA_COUNT_WINDOW} and threshold {self.SOTA_COUNT_THRESHOLD}"
41118
)
42119

43120
def get_selection(self, trace: Trace) -> tuple[int, ...]:
44-
45121
current_trace = trace.retrieve_search_list(search_type="ancestors")
46122
if len(trace.hist) > 0 and len(current_trace) > self.SOTA_COUNT_WINDOW:
47123
all_exp_list = trace.experiment_and_feedback_list_after_init(return_type="all", search_type="ancestors")
@@ -54,6 +130,14 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
54130
if fb.decision:
55131
sota_count += 1
56132
if sota_count < self.SOTA_COUNT_THRESHOLD:
133+
# Check if we've reached the maximum number of traces
134+
if trace.sub_trace_count + 1 >= self.MAX_TRACE_NUM:
135+
logger.info(
136+
f"Reached maximum trace count ({self.MAX_TRACE_NUM}), continuing with the current sub-trace"
137+
)
138+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
139+
return (-1,)
140+
57141
trace.sub_trace_count += 1
58142
logger.info(
59143
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump to a new sub-trace"
@@ -85,6 +169,7 @@ def __init__(
85169
) -> None:
86170
self.SOTA_COUNT_WINDOW = DS_RD_SETTING.sota_count_window
87171
self.SOTA_COUNT_THRESHOLD = DS_RD_SETTING.sota_count_threshold
172+
self.MAX_TRACE_NUM = DS_RD_SETTING.max_trace_num
88173

89174
logger.info(
90175
f"Using back-jump selector with window {self.SOTA_COUNT_WINDOW} and threshold {self.SOTA_COUNT_THRESHOLD}"
@@ -106,6 +191,13 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
106191
sota_count += 1
107192

108193
if sota_count < self.SOTA_COUNT_THRESHOLD:
194+
# Check if we've reached the maximum number of traces before creating a new one
195+
if trace.sub_trace_count + 1 >= self.MAX_TRACE_NUM:
196+
logger.info(
197+
f"Reached maximum trace count ({self.MAX_TRACE_NUM}), continuing with the current sub-trace"
198+
)
199+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
200+
return (-1,)
109201

110202
random_choice = random.random()
111203
if random_choice < 0.5:
@@ -127,6 +219,14 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
127219
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
128220
return (last_second_sota_idx,)
129221
else:
222+
# Check max trace limit again before creating a new trace
223+
if trace.sub_trace_count + 1 >= self.MAX_TRACE_NUM:
224+
logger.info(
225+
f"Reached maximum trace count ({self.MAX_TRACE_NUM}), continuing with the current sub-trace"
226+
)
227+
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
228+
return (-1,)
229+
130230
trace.sub_trace_count += 1
131231
logger.info(
132232
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump a new sub-trace"

rdagent/scenarios/data_science/proposal/exp_gen/merge.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperimen
8282
return exp
8383

8484

85+
# dual-target version
8586
class ExpGen2TraceAndMerge(ExpGen):
8687
def __init__(self, *args, **kwargs):
8788
super().__init__(*args, **kwargs)
@@ -92,7 +93,7 @@ def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperimen
9293
timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
9394
logger.info(f"Remain time: {timer.remain_time_duration}")
9495

95-
if timer.remain_time_duration >= timedelta(hours=2):
96+
if timer.remain_time_duration >= timedelta(hours=DS_RD_SETTING.merge_hours):
9697
leaves: list[int] = trace.get_leaves()
9798
if len(leaves) < 2:
9899
selection = tuple() # create new trace
@@ -111,3 +112,124 @@ def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperimen
111112
return self.exp_gen.gen(trace, selection)
112113
else:
113114
return self.merge_exp_gen.gen(trace, selection)
115+
116+
117+
class MergeExpGen_MultiTrace(ExpGen):
118+
def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperiment:
119+
# Ignore the selection argument and use all leaves instead.
120+
leaves: list[int] = trace.get_leaves()
121+
trace.set_current_selection(selection) #
122+
123+
# assuming merging the first and sencond trace.
124+
sota_exp_fb = trace.sota_experiment_fb(selection=(leaves[0],))
125+
if sota_exp_fb is None:
126+
sota_exp_fb = trace.hist[leaves[0]]
127+
128+
sota_exp_desc = T("scenarios.data_science.share:describe.exp").r(
129+
exp=sota_exp_fb[0],
130+
heading="Best previous exploration of the scenario",
131+
)
132+
sota_exp_fb_desc = T("scenarios.data_science.share:describe.feedback").r(
133+
exp_and_feedback=sota_exp_fb,
134+
heading="The feedback for best previous exploration",
135+
)
136+
137+
exp_fb_desc_to_merge_list = []
138+
# find the best exp to merge
139+
for i in range(1, len(leaves)):
140+
exp_to_merge_fb = trace.sota_experiment_fb(selection=(leaves[i],))
141+
if exp_to_merge_fb is None:
142+
exp_to_merge_fb = trace.hist[leaves[i]]
143+
144+
exp_to_merge_desc = T("scenarios.data_science.share:describe.exp").r(
145+
exp=exp_to_merge_fb[0],
146+
heading="A solution that to be merged into previous best solution",
147+
)
148+
149+
success_fb_list = trace.experiment_and_feedback_list_after_init(
150+
return_type="sota", search_type="ancestors", selection=(leaves[i],)
151+
)
152+
if len(success_fb_list) > 0:
153+
exp_to_merge_fb_desc = T("scenarios.data_science.share:describe.trace").r(
154+
exp_and_feedback_list=success_fb_list,
155+
type="success",
156+
heading="Successful iterations:",
157+
success_trial_desc="These trials are the steps or changes that led to the success of the solution to be merged",
158+
pipeline=DS_RD_SETTING.coder_on_whole_pipeline,
159+
)
160+
else:
161+
exp_to_merge_fb_desc = T("scenarios.data_science.share:describe.feedback").r(
162+
exp_and_feedback=exp_to_merge_fb,
163+
heading="The feedback for the solution to be merged",
164+
)
165+
166+
exp_fb_desc_to_merge_list.append((exp_to_merge_desc, exp_to_merge_fb_desc))
167+
168+
task = PipelineTask(
169+
description=T("scenarios.data_science.proposal.exp_gen.merge:multi_trace").r(
170+
sota_exp_desc=sota_exp_desc,
171+
sota_exp_fb_desc=sota_exp_fb_desc,
172+
exp_fb_desc_to_merge_list=exp_fb_desc_to_merge_list,
173+
)
174+
)
175+
176+
exp = DSExperiment(
177+
pending_tasks_list=[[task]],
178+
hypothesis=DSHypothesis(
179+
component="Pipeline",
180+
hypothesis="Merging two different versions of solutions would get the best of both sides and result in a better solution",
181+
),
182+
)
183+
184+
if sota_exp_fb is not None:
185+
exp.experiment_workspace.inject_code_from_file_dict(sota_exp_fb[0].experiment_workspace)
186+
return exp
187+
188+
189+
# multi-target version
190+
# allow multiple traces to grow and then merge
191+
class ExpGen2TraceAndMergeV2(ExpGen):
192+
def __init__(self, *args, **kwargs):
193+
super().__init__(*args, **kwargs)
194+
self.merge_exp_gen = MergeExpGen_MultiTrace(self.scen)
195+
self.exp_gen = DSExpGen(self.scen)
196+
self.MAX_TRACE_NUM = DS_RD_SETTING.max_trace_num # maximum number of traces to grow before merging
197+
self.flag_start_merge = False
198+
199+
def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperiment:
200+
timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
201+
logger.info(f"Remain time: {timer.remain_time_duration}")
202+
203+
if timer.remain_time_duration >= timedelta(hours=DS_RD_SETTING.merge_hours):
204+
205+
if DS_RD_SETTING.enable_inject_knowledge_at_root:
206+
207+
if len(trace.hist) == 0:
208+
# set the knowledge base option to True for the first trace
209+
DS_RD_SETTING.enable_knowledge_base = True
210+
211+
else:
212+
# set the knowledge base option back to False for the other traces
213+
DS_RD_SETTING.enable_knowledge_base = False
214+
215+
return self.exp_gen.gen(trace, selection)
216+
217+
else:
218+
# disable reset in merging stage
219+
DS_RD_SETTING.coding_fail_reanalyze_threshold = 100000
220+
DS_RD_SETTING.consecutive_errors = 100000
221+
222+
leaves: list[int] = trace.get_leaves()
223+
if len(leaves) < 2:
224+
return self.exp_gen.gen(trace, selection=(-1,))
225+
else:
226+
227+
if not self.flag_start_merge: # root node of the merge trace
228+
self.flag_start_merge = True
229+
selection = tuple()
230+
return self.merge_exp_gen.gen(trace, selection)
231+
else:
232+
# return self.merge_exp_gen.gen(trace, selection)
233+
return self.exp_gen.gen(
234+
trace, selection=(-1,)
235+
) # continue the last trace, to polish the merged solution

rdagent/scenarios/data_science/proposal/exp_gen/merge.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,33 @@ task: |-
2222
{% if exp_to_merge_fb_desc %}
2323
{{ exp_to_merge_fb_desc }}
2424
{% endif %}
25+
26+
multi_trace: |-
27+
{% include "scenarios.data_science.share:scen.role" %}
28+
The user is improving a Kaggle competition implementation iteratively.
29+
Your task is to merge multiple solutions to create a better version (Combine the strengths of multiple solutions while discarding their weaknesses, to create a new version that is better than any of the given solutions alone). We expect the merged version to perform better than all given solutions.
30+
31+
You will be given:
32+
1) Previous Main Solution: this is the main solution you will build on to create an improved version;
33+
- Feedback to the main solutions
34+
2) Solution to be merged: multiple trials of solutions that you will combine with the previous main solution. For each solution, you will be given:
35+
- Solution: the approach or method used in this solution.
36+
- Successful iterations (the steps or changes that led to the success of the Solution to be merged) or feedback to the Solution to be merged.
37+
38+
# Previous Main Solution
39+
{{ sota_exp_desc }}
40+
{{ sota_exp_fb_desc }}
41+
42+
# Multiple Trials of Solutions to be merged
43+
{% for exp_to_merge_desc, exp_to_merge_fb_desc in exp_fb_desc_to_merge_list %}
44+
## Trial Index: {{ loop.index }}
45+
46+
### Solution Description:
47+
{{ exp_to_merge_desc }}
48+
49+
### Feedback to the Solution:
50+
{% if exp_to_merge_fb_desc %}
51+
{{ exp_to_merge_fb_desc }}
52+
{% endif %}
53+
54+
{% endfor %}

0 commit comments

Comments
 (0)