Skip to content

Commit 982b43e

Browse files
RolandMinruiXu
andauthored
feat: add output_path to load function of LoopBase (microsoft#628)
* feat: add output_path to load from checkpoint function * add default value to output_path * sort import * sort imports --------- Co-authored-by: Xu <v-xuminrui@microsoft.com>
1 parent b64ab18 commit 982b43e

2 files changed

Lines changed: 14 additions & 4 deletions

File tree

rdagent/app/data_science/loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,15 @@ def record(self, prev_out: dict[str, Any]):
142142
logger.log_object(self.trace.sota_experiment(), tag="SOTA experiment")
143143

144144

145-
def main(path=None, step_n=None, loop_n=None, competition="bms-molecular-translation"):
145+
def main(path=None, output_path=None, step_n=None, loop_n=None, competition="bms-molecular-translation"):
146146
"""
147147
148148
Parameters
149149
----------
150150
path :
151151
path like `$LOG_PATH/__session__/1/0_propose`. It indicates that we restore the state that after finish the step 0 in loop1
152+
output_path :
153+
path like `$LOG_PATH`. It indicates that where we want to save our session and log information.
152154
step_n :
153155
How many steps to run; if None, it will run forever until error or KeyboardInterrupt
154156
loop_n :
@@ -179,7 +181,7 @@ def main(path=None, step_n=None, loop_n=None, competition="bms-molecular-transla
179181
if path is None:
180182
kaggle_loop = DataScienceRDLoop(DS_RD_SETTING)
181183
else:
182-
kaggle_loop = DataScienceRDLoop.load(path)
184+
kaggle_loop = DataScienceRDLoop.load(path, output_path)
183185
kaggle_loop.run(step_n=step_n, loop_n=loop_n)
184186

185187

rdagent/utils/workflow.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections import defaultdict
1515
from dataclasses import dataclass
1616
from pathlib import Path
17-
from typing import Any, Callable, TypeVar, cast
17+
from typing import Any, Callable, Optional, TypeVar, Union, cast
1818

1919
from tqdm.auto import tqdm
2020

@@ -161,10 +161,18 @@ def dump(self, path: str | Path) -> None:
161161
pickle.dump(self, f)
162162

163163
@classmethod
164-
def load(cls, path: str | Path) -> "LoopBase":
164+
def load(cls, path: Union[str, Path], output_path: Optional[Union[str, Path]] = None) -> "LoopBase":
165165
path = Path(path)
166166
with path.open("rb") as f:
167167
session = cast(LoopBase, pickle.load(f))
168+
169+
# set session folder
170+
if output_path:
171+
output_path = Path(output_path)
172+
output_path.mkdir(parents=True, exist_ok=True)
173+
session.session_folder = output_path / "__session__"
174+
175+
# set trace path
168176
logger.set_trace_path(session.session_folder.parent)
169177

170178
max_loop = max(session.loop_trace.keys())

0 commit comments

Comments
 (0)