Skip to content

Commit 9e60c32

Browse files
authored
feat: async mechanism for multi-trace (#981)
* start to work on multi-trace + async * init ver of async-multi-tarce, to test * add eng-ver log * complete version of async+ mul-trace * debug * fix bug on DS_RD_SETTING.get() * update * fix bug + simplif the usage of async in multi-trace * fix mini bug of arg_name * Move local_selection into class Experiment & clean the code
1 parent 76df96e commit 9e60c32

File tree

10 files changed

+232
-8
lines changed

10 files changed

+232
-8
lines changed

rdagent/core/experiment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,9 @@ def __init__(
324324
{}
325325
) # TODO: in Kaggle, now sub results are all saved in self.result, remove this in the future.
326326

327+
# For parallel multi-trace support
328+
self.local_selection: tuple[int, ...] | None = None
329+
327330

328331
ASpecificExp = TypeVar("ASpecificExp", bound=Experiment)
329332

rdagent/scenarios/data_science/dev/feedback.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def generate_feedback(self, exp: DSExperiment, trace: DSTrace) -> ExperimentFeed
2626
# 3. 相对sota_exp的改动
2727
# 4. result 任务的结果
2828
# 5. sota_exp.result 之前最好的结果
29+
2930
sota_exp = trace.sota_experiment()
3031
sota_desc = T("scenarios.data_science.share:describe.exp").r(
3132
exp=sota_exp, heading="SOTA of previous exploration of the scenario"

rdagent/scenarios/data_science/experiment/experiment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,6 @@ def is_ready_to_run(self) -> bool:
2929
(so it is different from `trace.next_incomplete_component`.)
3030
"""
3131
return self.experiment_workspace is not None and "main.py" in self.experiment_workspace.file_dict
32+
33+
def set_local_selection(self, local_selection: tuple[int, ...]) -> None:
34+
self.local_selection = local_selection

rdagent/scenarios/data_science/loop.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def __init__(self, PROP_SETTING: BasePropSetting):
142142
super(RDLoop, self).__init__()
143143

144144
async def direct_exp_gen(self, prev_out: dict[str, Any]):
145-
146145
# set the SOTA experiment to submit
147146
sota_exp_to_submit = self.sota_exp_selector.get_sota_exp_to_submit(self.trace)
148147
self.trace.set_sota_exp_to_submit(sota_exp_to_submit)
@@ -151,7 +150,11 @@ async def direct_exp_gen(self, prev_out: dict[str, Any]):
151150
selection = self.ckp_selector.get_selection(self.trace)
152151
# set the current selection for the trace
153152
self.trace.set_current_selection(selection)
153+
154+
# in parallel + multi-trace mode, the above global "trace.current_selection" will not be used
155+
# instead, we will use the "local_selection" attached to each exp to in async_gen().
154156
exp = await self.exp_gen.async_gen(self.trace, self)
157+
155158
logger.log_object(exp)
156159

157160
# FIXME: this is for LLM debug webapp, remove this when the debugging is done.
@@ -197,6 +200,11 @@ def feedback(self, prev_out: dict[str, Any]) -> ExperimentFeedback:
197200
- If we come to feedback phase, the previous development steps are successful.
198201
"""
199202
exp: DSExperiment = prev_out["running"]
203+
204+
# set the local selection to the trace after feedback
205+
if exp.local_selection is not None:
206+
self.trace.set_current_selection(exp.local_selection)
207+
200208
if self.trace.next_incomplete_component() is None or DS_RD_SETTING.coder_on_whole_pipeline:
201209
# we have alreadly completed components in previous trace. So current loop is focusing on a new proposed idea.
202210
# So we need feedback for the proposal.
@@ -211,19 +219,36 @@ def feedback(self, prev_out: dict[str, Any]) -> ExperimentFeedback:
211219
return feedback
212220

213221
def record(self, prev_out: dict[str, Any]):
214-
# set the DAG parent for the trace
215-
self.trace.sync_dag_parent_and_hist()
222+
223+
exp: DSExperiment = None
216224

217225
e = prev_out.get(self.EXCEPTION_KEY, None)
218226
if e is None:
219-
self.trace.hist.append((prev_out["running"], prev_out["feedback"]))
227+
exp = prev_out["running"]
228+
self.trace.hist.append((exp, prev_out["feedback"]))
229+
230+
# NOTE: we put below operations on selections here, instead of out of the if-else block,
231+
# to fit the corner case that the trace will be reset
232+
233+
# set the local selection to the trace as global selection, then set the DAG parent for the trace
234+
if exp.local_selection is not None:
235+
self.trace.set_current_selection(exp.local_selection)
236+
self.trace.sync_dag_parent_and_hist()
237+
220238
else:
239+
exp: DSExperiment = prev_out["direct_exp_gen"] if isinstance(e, CoderError) else prev_out["coding"]
221240
self.trace.hist.append(
222241
(
223-
prev_out["direct_exp_gen"] if isinstance(e, CoderError) else prev_out["coding"],
242+
exp,
224243
ExperimentFeedback.from_exception(e),
225244
)
226245
)
246+
247+
# set the local selection to the trace as global selection, then set the DAG parent for the trace
248+
if exp.local_selection is not None:
249+
self.trace.set_current_selection(exp.local_selection)
250+
self.trace.sync_dag_parent_and_hist()
251+
227252
if self.trace.sota_experiment() is None:
228253
if DS_RD_SETTING.coder_on_whole_pipeline:
229254
# check if feedback is not generated

rdagent/scenarios/data_science/proposal/exp_gen/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,12 @@ def last_exp(
266266
def last_exp_fb(
267267
self,
268268
search_type: Literal["all", "ancestors"] = "ancestors",
269+
selection: tuple[int, ...] | None = None,
269270
) -> tuple[DSExperiment, ExperimentFeedback] | None:
270271
"""
271272
Access the last experiment and feedback
272273
"""
273-
search_list = self.retrieve_search_list(search_type)
274+
search_list = self.retrieve_search_list(search_type, selection=selection)
274275
for exp, ef in search_list[::-1]:
275276
return exp, ef
276277
return None

rdagent/scenarios/data_science/proposal/exp_gen/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
218218
scenario_desc=scenario_desc,
219219
sota_exp_desc=sota_exp_desc,
220220
sota_exp=sota_exp_fb[0] if sota_exp_fb else None,
221-
hypothesis=new_hypothesis,
221+
hypotheses=[new_hypothesis],
222222
pipeline=DS_RD_SETTING.coder_on_whole_pipeline,
223223
failed_exp_feedback_list_desc="",
224224
)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from datetime import timedelta
5+
from typing import TYPE_CHECKING
6+
7+
from rdagent.app.data_science.conf import DS_RD_SETTING
8+
from rdagent.core.conf import RD_AGENT_SETTINGS
9+
from rdagent.core.proposal import ExpGen
10+
from rdagent.log import rdagent_logger as logger
11+
from rdagent.log.timer import RD_Agent_TIMER_wrapper, RDAgentTimer
12+
from rdagent.scenarios.data_science.loop import DataScienceRDLoop
13+
from rdagent.scenarios.data_science.proposal.exp_gen.merge import ExpGen2Hypothesis
14+
from rdagent.scenarios.data_science.proposal.exp_gen.trace_scheduler import (
15+
RoundRobinScheduler,
16+
TraceScheduler,
17+
)
18+
19+
if TYPE_CHECKING:
20+
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
21+
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSTrace, Experiment
22+
from rdagent.utils.workflow.loop import LoopBase
23+
24+
25+
class ParallelMultiTraceExpGen(ExpGen):
26+
"""
27+
An experiment generation strategy that enables parallel multi-trace exploration.
28+
29+
This generator is designed to work with the "Attribute Injection" model.
30+
It uses a TraceScheduler to determine which parent node to expand, and
31+
injects this parent context into the experiment object itself.
32+
"""
33+
34+
def __init__(self, *args, **kwargs):
35+
super().__init__(*args, **kwargs)
36+
# The underlying generator for creating a single experiment
37+
self.exp_gen = DataScienceRDLoop._get_exp_gen(
38+
"rdagent.scenarios.data_science.proposal.exp_gen.DSExpGen", self.scen
39+
)
40+
self.merge_exp_gen = ExpGen2Hypothesis(self.scen)
41+
self.trace_scheduler: TraceScheduler = RoundRobinScheduler()
42+
self.max_trace_num = DS_RD_SETTING.max_trace_num
43+
44+
def gen(self, trace: "DSTrace") -> "Experiment":
45+
raise NotImplementedError(
46+
"ParallelMultiTraceExpGen is designed for async usage, please call async_gen instead."
47+
)
48+
49+
async def async_gen(self, trace: DSTrace, loop: LoopBase) -> DSExperiment:
50+
"""
51+
Waits for a free execution slot, selects a parent trace using the
52+
scheduler, generates a new experiment, and injects the parent context
53+
into it before returning.
54+
"""
55+
timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
56+
logger.info(f"Remain time: {timer.remain_time_duration}")
57+
local_selection: tuple[int, ...] = None
58+
59+
while True:
60+
61+
if timer.remain_time_duration >= timedelta(hours=DS_RD_SETTING.merge_hours):
62+
63+
if DS_RD_SETTING.enable_inject_knowledge_at_root:
64+
65+
if len(trace.hist) == 0:
66+
# set the knowledge base option to True for the first trace
67+
DS_RD_SETTING.enable_knowledge_base = True
68+
69+
else:
70+
# set the knowledge base option back to False for the other traces
71+
DS_RD_SETTING.enable_knowledge_base = False
72+
# step 1: select the parant trace to expand
73+
# Policy: if we have fewer traces than our target, start a new one.
74+
if trace.sub_trace_count < self.max_trace_num:
75+
local_selection = trace.NEW_ROOT
76+
else:
77+
# Otherwise, use the scheduler to pick an existing trace to expand.
78+
local_selection = await self.trace_scheduler.select_trace(trace)
79+
80+
if loop.get_unfinished_loop_cnt(loop.loop_idx) < RD_AGENT_SETTINGS.get_max_parallel():
81+
82+
# set the local selection as the global current selection for the trace
83+
trace.set_current_selection(local_selection)
84+
# step 2: generate the experiment with the local selection
85+
exp = self.exp_gen.gen(trace)
86+
87+
# Inject the local selection to the experiment object
88+
exp.set_local_selection(local_selection)
89+
90+
return exp
91+
92+
else:
93+
# enter the merging stage
94+
# make sure the all loops are finished
95+
if loop.get_unfinished_loop_cnt(loop.loop_idx) < 1:
96+
# disable reset in merging stage
97+
DS_RD_SETTING.coding_fail_reanalyze_threshold = 100000
98+
DS_RD_SETTING.consecutive_errors = 100000
99+
100+
leaves: list[int] = trace.get_leaves()
101+
if len(leaves) < 2:
102+
trace.set_current_selection(selection=(-1,))
103+
return self.exp_gen.gen(trace)
104+
else:
105+
selection = (leaves[0],)
106+
if trace.sota_exp_to_submit is not None:
107+
if trace.is_parent(trace.exp2idx(trace.sota_exp_to_submit), leaves[1]):
108+
selection = (leaves[1],)
109+
trace.set_current_selection(selection)
110+
return self.merge_exp_gen.gen(trace)
111+
112+
await asyncio.sleep(1)

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,11 @@ def get_all_hypotheses(self, problem_dict: dict, hypothesis_dict: dict) -> list[
802802
)
803803
return result
804804

805-
def gen(self, trace: DSTrace) -> DSExperiment:
805+
def gen(
806+
self,
807+
trace: DSTrace,
808+
) -> DSExperiment:
809+
806810
pipeline = DS_RD_SETTING.coder_on_whole_pipeline
807811
if not pipeline and (draft_exp := draft_exp_in_decomposition(self.scen, trace)):
808812
return draft_exp
@@ -839,6 +843,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
839843
pipeline=pipeline,
840844
)
841845

846+
# NOTE: we currently don't support inject diverse problems for the parallel + multi-trace mode,
842847
if DS_RD_SETTING.enable_inject_diverse and len(trace.hist) > 0:
843848
if len(trace.current_selection) == 0:
844849
# start a new sub-trace, and inject diverse problems.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from abc import ABC, abstractmethod
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSTrace
9+
10+
11+
class TraceScheduler(ABC):
12+
"""
13+
An abstract base class for trace scheduling strategies.
14+
Determines which active trace to expand next during parallel exploration.
15+
"""
16+
17+
@abstractmethod
18+
async def select_trace(self, trace: DSTrace) -> tuple[int, ...]:
19+
"""
20+
Selects the next trace to expand.
21+
22+
This method must be async to allow for safe concurrent access.
23+
24+
Args:
25+
trace: The DSTrace object containing the full experiment history.
26+
27+
Returns:
28+
A tuple representing the selection of the parent node for the new experiment.
29+
e.g., (leaf_idx,) for an existing trace, or trace.NEW_ROOT for a new one.
30+
"""
31+
raise NotImplementedError
32+
33+
34+
class RoundRobinScheduler(TraceScheduler):
35+
"""
36+
A concurrency-safe scheduling strategy that cycles through active traces
37+
in a round-robin fashion.
38+
39+
NOTE: we don't need to use asyncio.Lock here as the kickoff_loop ensures the ExpGen is always sequential, instead of parallel.
40+
"""
41+
42+
def __init__(self):
43+
self._last_selected_leaf_id = -1
44+
45+
async def select_trace(self, trace: DSTrace) -> tuple[int, ...]:
46+
"""
47+
Atomically selects the next leaf node from the trace in order.
48+
"""
49+
50+
leaves = trace.get_leaves()
51+
if not leaves:
52+
# This is the very first experiment in a new tree.
53+
return trace.NEW_ROOT
54+
55+
# Find the index of the last selected leaf in the current list of leaves
56+
try:
57+
current_position = leaves.index(self._last_selected_leaf_id)
58+
# Move to the next position, wrapping around if necessary
59+
next_position = (current_position + 1) % len(leaves)
60+
except ValueError:
61+
# This can happen if the last selected leaf is no longer a leaf
62+
# (it has been expanded) or if this is the first selection.
63+
# In either case, start from the beginning.
64+
next_position = 0
65+
66+
selected_leaf = leaves[next_position]
67+
self._last_selected_leaf_id = selected_leaf
68+
69+
return (selected_leaf,)

rdagent/utils/workflow/loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def get_semaphore(self, step_name: str) -> asyncio.Semaphore:
136136
if isinstance(limit := RD_AGENT_SETTINGS.step_semaphore, dict):
137137
limit = limit.get(step_name, 1) # default to 1 if not specified
138138

139+
# NOTE: we assume the record step is always the last step to modify the global environment,
140+
# so we set the limit to 1 to avoid race condition
141+
if step_name == "record":
142+
limit = 1
143+
139144
if step_name not in self.semaphores:
140145
self.semaphores[step_name] = asyncio.Semaphore(limit)
141146
return self.semaphores[step_name]

0 commit comments

Comments
 (0)