Skip to content

Commit 778c166

Browse files
RolandMinruiXu
andauthored
feat: add loop_n parameter to the main loop (#611)
* add loop_n parameter to main loop * complete the loop_n --------- Co-authored-by: Xu <[email protected]>
1 parent e8d7198 commit 778c166

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

rdagent/app/data_science/loop.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def record(self, prev_out: dict[str, Any]):
140140
logger.log_object(self.trace.sota_experiment(), tag="SOTA experiment")
141141

142142

143-
def main(path=None, step_n=None, competition="bms-molecular-translation"):
143+
def main(path=None, step_n=None, loop_n=None, competition="bms-molecular-translation"):
144144
"""
145145
146146
Parameters
@@ -149,6 +149,10 @@ def main(path=None, step_n=None, competition="bms-molecular-translation"):
149149
path like `$LOG_PATH/__session__/1/0_propose`. It indicates that we restore the state that after finish the step 0 in loop1
150150
step_n :
151151
How many steps to run; if None, it will run forever until error or KeyboardInterrupt
152+
loop_n :
153+
How many loops to run; if None, it will run forever until error or KeyboardInterrupt
154+
- if current loop is incomplete, it will be counted as the first loop for completion.
155+
- if both step_n and loop_n are provided, the process will stop as soon as either condition is met.
152156
competition :
153157
154158
@@ -174,7 +178,7 @@ def main(path=None, step_n=None, competition="bms-molecular-translation"):
174178
kaggle_loop = DataScienceRDLoop(DS_RD_SETTING)
175179
else:
176180
kaggle_loop = DataScienceRDLoop.load(path)
177-
kaggle_loop.run(step_n=step_n)
181+
kaggle_loop.run(step_n=step_n, loop_n=loop_n)
178182

179183

180184
if __name__ == "__main__":

rdagent/utils/workflow.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,27 @@ def __init__(self) -> None:
9191
self.loop_trace = defaultdict(list[LoopTrace]) # the key is the number of loop
9292
self.session_folder = logger.log_trace_path / "__session__"
9393

94-
def run(self, step_n: int | None = None) -> None:
94+
def run(self, step_n: int | None = None, loop_n: int | None = None) -> None:
9595
"""
9696
9797
Parameters
9898
----------
9999
step_n : int | None
100100
How many steps to run;
101101
`None` indicates to run forever until error or KeyboardInterrupt
102+
loop_n: int | None
103+
How many steps to run; if current loop is incomplete, it will be counted as the first loop for completion
104+
`None` indicates to run forever until error or KeyboardInterrupt
102105
"""
103106
with tqdm(total=len(self.steps), desc="Workflow Progress", unit="step") as pbar:
104107
while True:
105108
if step_n is not None:
106109
if step_n <= 0:
107110
break
108111
step_n -= 1
112+
if loop_n is not None:
113+
if loop_n <= 0:
114+
break
109115

110116
li, si = self.loop_idx, self.step_idx
111117
name = self.steps[si]
@@ -141,6 +147,8 @@ def run(self, step_n: int | None = None) -> None:
141147
self.step_idx = (self.step_idx + 1) % len(self.steps)
142148
if self.step_idx == 0: # reset to step 0 in next round
143149
self.loop_idx += 1
150+
if loop_n is not None:
151+
loop_n -= 1
144152
self.loop_prev_out = {}
145153
pbar.reset() # reset the progress bar for the next loop
146154

0 commit comments

Comments
 (0)