Skip to content

Commit 13890e0

Browse files
xuangu-fangyou-n-gjingyuanlmHoder-zyf
authored
feat: mcts policy based on trace scheduler (#1203)
* init mcts class * full ver of MCTS * auto-lint * make MCTS feedback in exp-gen() * refactor: move reset logic from Trace to ExpGen and update usage accordingly * fix: reinitialize trace on consecutive errors in DataScienceRDLoop * feat: add reset method to BaseScheduler and call in MCTSScheduler reset * style: reorder imports for consistency and PEP8 compliance * lint * fix observe_feedback * fix bug * remove uncommited_rec_status * more simple * refactor: move commit observation logic to process_uncommitted_nodes method * docs: add TODO comment about rule-based virtual root node expansion * add score reward * fix bug * fix small bug * lint * change reward * lint --------- Co-authored-by: Young <[email protected]> Co-authored-by: jingyuanlm <[email protected]> Co-authored-by: amstrongzyf <[email protected]>
1 parent 0f722e1 commit 13890e0

File tree

5 files changed

+183
-1
lines changed

5 files changed

+183
-1
lines changed

rdagent/app/data_science/conf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
101101
scheduler_temperature: float = 1.0
102102
"""The temperature for the trace scheduler for softmax calculation, used in ProbabilisticScheduler"""
103103

104+
# PUCT exploration constant for MCTSScheduler (ignored by other schedulers)
105+
scheduler_c_puct: float = 1.0
106+
"""Exploration constant used by MCTSScheduler (PUCT)."""
107+
108+
enable_score_reward: bool = False
109+
"""Enable using score-based reward for trace selection in multi-trace scheduling."""
110+
104111
#### multi-trace:checkpoint selector
105112
selector_name: str = "rdagent.scenarios.data_science.proposal.exp_gen.select.expand.LatestCKPSelector"
106113
"""The name of the selector to use"""

rdagent/core/proposal.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,14 @@ async def async_gen(self, trace: Trace, loop: LoopBase) -> Experiment:
325325
return self.gen(trace)
326326
await asyncio.sleep(1)
327327

328+
def reset(self) -> None:
329+
"""
330+
Reset the proposal to the initial state.
331+
Sometimes the main loop may want to reset the whole process to the initial state.
332+
Default implementation does nothing; override in subclasses if needed.
333+
"""
334+
return
335+
328336

329337
class HypothesisGen(ABC):
330338

rdagent/scenarios/data_science/loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from rdagent.scenarios.data_science.proposal.exp_gen.base import DataScienceScen
3535
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import DSKnowledgeBase
3636
from rdagent.scenarios.data_science.proposal.exp_gen.proposal import DSProposalV2ExpGen
37+
from rdagent.scenarios.data_science.proposal.exp_gen.trace_scheduler import (
38+
MCTSScheduler,
39+
)
3740
from rdagent.utils.workflow.misc import wait_retry
3841

3942

@@ -246,6 +249,7 @@ def record(self, prev_out: dict[str, Any]):
246249
),
247250
cur_loop_id,
248251
)
252+
# Value backpropagation is handled in async_gen before next() via observe_commits
249253

250254
if self.trace.sota_experiment() is None:
251255
if DS_RD_SETTING.coder_on_whole_pipeline:
@@ -271,6 +275,8 @@ def record(self, prev_out: dict[str, Any]):
271275
logger.error("Consecutive errors reached the limit. Dumping trace.")
272276
logger.log_object(self.trace, tag="trace before restart")
273277
self.trace = DSTrace(scen=self.trace.scen, knowledge_base=self.trace.knowledge_base)
278+
# Reset the trace; MCTS stats will be cleared via registered callback
279+
self.exp_gen.reset()
274280

275281
# set the SOTA experiment to submit
276282
sota_exp_to_submit = self.sota_exp_selector.get_sota_exp_to_submit(self.trace)

rdagent/scenarios/data_science/proposal/exp_gen/router/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from rdagent.scenarios.data_science.proposal.exp_gen.proposal import DSProposalV2ExpGen
2323
from rdagent.scenarios.data_science.proposal.exp_gen.trace_scheduler import (
24+
MCTSScheduler,
2425
RoundRobinScheduler,
2526
SOTABasedScheduler,
2627
TraceScheduler,
@@ -63,6 +64,9 @@ def gen(
6364
"ParallelMultiTraceExpGen is designed for async usage, please call async_gen instead."
6465
)
6566

67+
def reset(self) -> None:
68+
self.trace_scheduler.reset()
69+
6670
async def async_gen(self, trace: DSTrace, loop: LoopBase) -> DSExperiment:
6771
"""
6872
Waits for a free execution slot, selects a parent trace using the

rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from collections import defaultdict
88
from typing import TYPE_CHECKING
99

10+
from rdagent.app.data_science.conf import DS_RD_SETTING
1011
from rdagent.log import rdagent_logger as logger
12+
from rdagent.scenarios.kaggle.kaggle_crawler import get_metric_direction
1113

1214
if TYPE_CHECKING:
1315
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSTrace
@@ -38,6 +40,12 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
3840
"""
3941
raise NotImplementedError
4042

43+
def reset(self) -> None:
44+
"""
45+
Reset the scheduler to the initial state.
46+
"""
47+
pass
48+
4149

4250
class BaseScheduler(TraceScheduler):
4351
def __init__(self):
@@ -49,7 +57,10 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
4957
Atomically selects the next leaf node from the trace in order.
5058
"""
5159
while True:
52-
# step 0: Commit the pending selections
60+
# step 1: Commit the pending selections
61+
self.process_uncommitted_nodes(trace)
62+
63+
# step 2: update uncommited_rec_status & rec_commit_idx
5364
for i in range(self.rec_commit_idx, len(trace.dag_parent)):
5465
parent_of_i = trace.dag_parent[i]
5566
if parent_of_i == trace.NEW_ROOT:
@@ -71,11 +82,22 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
7182

7283
await asyncio.sleep(1)
7384

85+
def process_uncommitted_nodes(self, trace: DSTrace) -> None:
86+
"""
87+
A slot for implementing custom logic to process uncommitted nodes.
88+
89+
`uncommited_rec_status` & `rec_commit_idx` will be updated automatically.
90+
"""
91+
7492
@abstractmethod
7593
def select(self, trace: DSTrace) -> tuple[int, ...] | None:
7694
"""Selects the parent nodes for the new experiment, or None if no selection can be made."""
7795
raise NotImplementedError
7896

97+
def reset(self) -> None:
98+
self.uncommited_rec_status = defaultdict(int)
99+
self.rec_commit_idx = 0
100+
79101

80102
class RoundRobinScheduler(BaseScheduler):
81103
"""
@@ -289,3 +311,138 @@ def calculate_potential(self, trace: DSTrace, leaf_id: int) -> float:
289311
Return random potential for uniform random selection.
290312
"""
291313
return random.random()
314+
315+
316+
class MCTSScheduler(ProbabilisticScheduler):
317+
"""
318+
A simplified MCTS-based scheduler using a PUCT-like scoring rule.
319+
320+
Formula:
321+
U(s, a) = Q(s, a) + c_puct * P(s, a) * sqrt(N(s)) / (1 + N(s, a))
322+
where Q is the average reward, N is the visit count, P is the prior probability, c_puct is the given weight to balance exploration and exploitation.
323+
324+
Design goals for the initial version:
325+
- Reuse ProbabilisticScheduler's potential calculation as prior P (via softmax).
326+
- Maintain visit/value statistics per leaf to compute Q and U.
327+
- Update visits on selection; update values after feedback via observe_feedback.
328+
- Keep NEW_ROOT policy and uncommitted status handling identical to base classes.
329+
"""
330+
331+
def __init__(self, max_trace_num: int, temperature: float = 1.0, *args, **kwargs):
332+
super().__init__(max_trace_num, temperature)
333+
# Read c_puct from settings if available, otherwise fall back to default 1.0
334+
self.c_puct = getattr(DS_RD_SETTING, "scheduler_c_puct", 1.0) or 1.0
335+
# Statistics keyed by leaf node index
336+
self.node_visit_count: dict[int, int] = {}
337+
self.node_value_sum: dict[int, float] = {}
338+
self.node_prior: dict[int, float] = {}
339+
# Global counter to stabilize U term
340+
self.global_visit_count: int = 0
341+
# Last observed commit index for batch feedback observation
342+
self.last_observed_commit_idx: int = 0
343+
344+
def _get_q(self, node_id: int) -> float:
345+
visits = self.node_visit_count.get(node_id, 0)
346+
value_sum = self.node_value_sum.get(node_id, 0.0)
347+
if visits <= 0:
348+
# Unseen nodes default to neutral Q
349+
return 0.0
350+
return value_sum / visits
351+
352+
def _get_u(self, node_id: int) -> float:
353+
prior = self.node_prior.get(node_id, 0.0)
354+
visits = self.node_visit_count.get(node_id, 0)
355+
# Avoid div-by-zero; encourage exploration when visits are small
356+
return self.c_puct * prior * math.sqrt(max(1, self.global_visit_count)) / (1 + visits)
357+
358+
def select(self, trace: DSTrace) -> tuple[int, ...] | None:
359+
# Step 1: keep same policy to reach target number of parallel traces
360+
# TODO: expanding from the virtual root node is implemented in a rule-based way.
361+
if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num:
362+
return trace.NEW_ROOT
363+
364+
# Step 2: consider only available leaves (not being expanded)
365+
available_leaves = list(set(range(len(trace.hist))))
366+
if not available_leaves:
367+
return None
368+
369+
# Step 3: compute priors (P) from potentials via softmax
370+
potentials = [self.calculate_potential(trace, leaf) for leaf in available_leaves]
371+
if any(p < 0 for p in potentials):
372+
raise ValueError("Potential function returned a negative value.")
373+
priors = self._softmax_probabilities(potentials)
374+
for leaf, p in zip(available_leaves, priors):
375+
self.node_prior[leaf] = p
376+
377+
# Step 4: score each leaf using PUCT-like rule: Q + U
378+
best_leaf = None
379+
best_score = -float("inf")
380+
for leaf in available_leaves:
381+
q = self._get_q(leaf)
382+
u = self._get_u(leaf)
383+
score = q + u
384+
if score > best_score:
385+
best_score = score
386+
best_leaf = leaf
387+
388+
if best_leaf is None:
389+
return None
390+
391+
# # Step 5: optimistic visit update on selection; value update deferred to observe_feedback
392+
self.global_visit_count += 1
393+
394+
return (best_leaf,)
395+
396+
def observe_feedback(self, trace: DSTrace, new_idx: int, reward: float | None = None) -> None:
397+
"""
398+
Update statistics after an experiment is committed to the trace.
399+
400+
Args:
401+
trace: The DSTrace object.
402+
new_idx: Index of the newly appended experiment in trace.hist.
403+
reward: Optional explicit reward. If None, derive from feedback.decision (1.0/0.0).
404+
"""
405+
if reward is None:
406+
if 0 <= new_idx < len(trace.hist):
407+
re, fb = trace.hist[new_idx]
408+
if DS_RD_SETTING.enable_score_reward:
409+
bigger_is_better = get_metric_direction(trace.scen.competition)
410+
if getattr(fb, "decision", False):
411+
reward = math.tanh(re.result.loc["ensemble"].iloc[0].round(3)) * (1 if bigger_is_better else -1)
412+
else:
413+
reward = -1 if bigger_is_better else 1
414+
else:
415+
reward = 1.0 if getattr(fb, "decision", False) else 0.0
416+
else:
417+
# Out-of-range safety
418+
reward = 0.0
419+
420+
id_list = trace.get_parents(new_idx)
421+
for id in id_list:
422+
self.node_value_sum[id] = self.node_value_sum.get(id, 0.0) + float(reward)
423+
self.node_visit_count[id] = self.node_visit_count.get(id, 0) + 1
424+
425+
def reset(self) -> None:
426+
"""
427+
Clear all maintained statistics. Should be called when the underlying trace is reset.
428+
"""
429+
super().reset()
430+
self.node_visit_count.clear()
431+
self.node_value_sum.clear()
432+
self.node_prior.clear()
433+
self.global_visit_count = 0
434+
self.last_observed_commit_idx = 0
435+
436+
def process_uncommitted_nodes(self, trace: DSTrace) -> None:
437+
"""
438+
Batch observe all newly committed experiments since last observation.
439+
Should be called before making a new selection to ensure statistics are up-to-date.
440+
"""
441+
start_idx = max(0, self.last_observed_commit_idx)
442+
# Only observe fully committed items (both dag_parent and hist appended)
443+
end_idx = min(len(trace.dag_parent), len(trace.hist))
444+
if start_idx >= end_idx:
445+
return
446+
for idx in range(start_idx, end_idx):
447+
self.observe_feedback(trace, idx)
448+
self.last_observed_commit_idx = end_idx

0 commit comments

Comments
 (0)