77from collections import defaultdict
88from typing import TYPE_CHECKING
99
10+ from rdagent .app .data_science .conf import DS_RD_SETTING
1011from rdagent .log import rdagent_logger as logger
12+ from rdagent .scenarios .kaggle .kaggle_crawler import get_metric_direction
1113
1214if TYPE_CHECKING :
1315 from rdagent .scenarios .data_science .proposal .exp_gen .base import DSTrace
@@ -38,6 +40,12 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
3840 """
3941 raise NotImplementedError
4042
43+ def reset (self ) -> None :
44+ """
45+ Reset the scheduler to the initial state.
46+ """
47+ pass
48+
4149
4250class BaseScheduler (TraceScheduler ):
4351 def __init__ (self ):
@@ -49,7 +57,10 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
4957 Atomically selects the next leaf node from the trace in order.
5058 """
5159 while True :
52- # step 0: Commit the pending selections
60+ # step 1: Commit the pending selections
61+ self .process_uncommitted_nodes (trace )
62+
63+ # step 2: update uncommited_rec_status & rec_commit_idx
5364 for i in range (self .rec_commit_idx , len (trace .dag_parent )):
5465 parent_of_i = trace .dag_parent [i ]
5566 if parent_of_i == trace .NEW_ROOT :
@@ -71,11 +82,22 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
7182
7283 await asyncio .sleep (1 )
7384
85+ def process_uncommitted_nodes (self , trace : DSTrace ) -> None :
86+ """
87+ A slot for implementing custom logic to process uncommitted nodes.
88+
89+ `uncommited_rec_status` & `rec_commit_idx` will be updated automatically.
90+ """
91+
7492 @abstractmethod
7593 def select (self , trace : DSTrace ) -> tuple [int , ...] | None :
7694 """Selects the parent nodes for the new experiment, or None if no selection can be made."""
7795 raise NotImplementedError
7896
97+ def reset (self ) -> None :
98+ self .uncommited_rec_status = defaultdict (int )
99+ self .rec_commit_idx = 0
100+
79101
80102class RoundRobinScheduler (BaseScheduler ):
81103 """
@@ -289,3 +311,138 @@ def calculate_potential(self, trace: DSTrace, leaf_id: int) -> float:
289311 Return random potential for uniform random selection.
290312 """
291313 return random .random ()
314+
315+
316+ class MCTSScheduler (ProbabilisticScheduler ):
317+ """
318+ A simplified MCTS-based scheduler using a PUCT-like scoring rule.
319+
320+ Formula:
321+ U(s, a) = Q(s, a) + c_puct * P(s, a) * sqrt(N(s)) / (1 + N(s, a))
322+ where Q is the average reward, N is the visit count, P is the prior probability, c_puct is the given weight to balance exploration and exploitation.
323+
324+ Design goals for the initial version:
325+ - Reuse ProbabilisticScheduler's potential calculation as prior P (via softmax).
326+ - Maintain visit/value statistics per leaf to compute Q and U.
327+ - Update visits on selection; update values after feedback via observe_feedback.
328+ - Keep NEW_ROOT policy and uncommitted status handling identical to base classes.
329+ """
330+
331+ def __init__ (self , max_trace_num : int , temperature : float = 1.0 , * args , ** kwargs ):
332+ super ().__init__ (max_trace_num , temperature )
333+ # Read c_puct from settings if available, otherwise fall back to default 1.0
334+ self .c_puct = getattr (DS_RD_SETTING , "scheduler_c_puct" , 1.0 ) or 1.0
335+ # Statistics keyed by leaf node index
336+ self .node_visit_count : dict [int , int ] = {}
337+ self .node_value_sum : dict [int , float ] = {}
338+ self .node_prior : dict [int , float ] = {}
339+ # Global counter to stabilize U term
340+ self .global_visit_count : int = 0
341+ # Last observed commit index for batch feedback observation
342+ self .last_observed_commit_idx : int = 0
343+
344+ def _get_q (self , node_id : int ) -> float :
345+ visits = self .node_visit_count .get (node_id , 0 )
346+ value_sum = self .node_value_sum .get (node_id , 0.0 )
347+ if visits <= 0 :
348+ # Unseen nodes default to neutral Q
349+ return 0.0
350+ return value_sum / visits
351+
352+ def _get_u (self , node_id : int ) -> float :
353+ prior = self .node_prior .get (node_id , 0.0 )
354+ visits = self .node_visit_count .get (node_id , 0 )
355+ # Avoid div-by-zero; encourage exploration when visits are small
356+ return self .c_puct * prior * math .sqrt (max (1 , self .global_visit_count )) / (1 + visits )
357+
358+ def select (self , trace : DSTrace ) -> tuple [int , ...] | None :
359+ # Step 1: keep same policy to reach target number of parallel traces
360+ # TODO: expanding from the virtual root node is implemented in a rule-based way.
361+ if trace .sub_trace_count + self .uncommited_rec_status [trace .NEW_ROOT ] < self .max_trace_num :
362+ return trace .NEW_ROOT
363+
364+ # Step 2: consider only available leaves (not being expanded)
365+ available_leaves = list (set (range (len (trace .hist ))))
366+ if not available_leaves :
367+ return None
368+
369+ # Step 3: compute priors (P) from potentials via softmax
370+ potentials = [self .calculate_potential (trace , leaf ) for leaf in available_leaves ]
371+ if any (p < 0 for p in potentials ):
372+ raise ValueError ("Potential function returned a negative value." )
373+ priors = self ._softmax_probabilities (potentials )
374+ for leaf , p in zip (available_leaves , priors ):
375+ self .node_prior [leaf ] = p
376+
377+ # Step 4: score each leaf using PUCT-like rule: Q + U
378+ best_leaf = None
379+ best_score = - float ("inf" )
380+ for leaf in available_leaves :
381+ q = self ._get_q (leaf )
382+ u = self ._get_u (leaf )
383+ score = q + u
384+ if score > best_score :
385+ best_score = score
386+ best_leaf = leaf
387+
388+ if best_leaf is None :
389+ return None
390+
391+ # # Step 5: optimistic visit update on selection; value update deferred to observe_feedback
392+ self .global_visit_count += 1
393+
394+ return (best_leaf ,)
395+
396+ def observe_feedback (self , trace : DSTrace , new_idx : int , reward : float | None = None ) -> None :
397+ """
398+ Update statistics after an experiment is committed to the trace.
399+
400+ Args:
401+ trace: The DSTrace object.
402+ new_idx: Index of the newly appended experiment in trace.hist.
403+ reward: Optional explicit reward. If None, derive from feedback.decision (1.0/0.0).
404+ """
405+ if reward is None :
406+ if 0 <= new_idx < len (trace .hist ):
407+ re , fb = trace .hist [new_idx ]
408+ if DS_RD_SETTING .enable_score_reward :
409+ bigger_is_better = get_metric_direction (trace .scen .competition )
410+ if getattr (fb , "decision" , False ):
411+ reward = math .tanh (re .result .loc ["ensemble" ].iloc [0 ].round (3 )) * (1 if bigger_is_better else - 1 )
412+ else :
413+ reward = - 1 if bigger_is_better else 1
414+ else :
415+ reward = 1.0 if getattr (fb , "decision" , False ) else 0.0
416+ else :
417+ # Out-of-range safety
418+ reward = 0.0
419+
420+ id_list = trace .get_parents (new_idx )
421+ for id in id_list :
422+ self .node_value_sum [id ] = self .node_value_sum .get (id , 0.0 ) + float (reward )
423+ self .node_visit_count [id ] = self .node_visit_count .get (id , 0 ) + 1
424+
425+ def reset (self ) -> None :
426+ """
427+ Clear all maintained statistics. Should be called when the underlying trace is reset.
428+ """
429+ super ().reset ()
430+ self .node_visit_count .clear ()
431+ self .node_value_sum .clear ()
432+ self .node_prior .clear ()
433+ self .global_visit_count = 0
434+ self .last_observed_commit_idx = 0
435+
436+ def process_uncommitted_nodes (self , trace : DSTrace ) -> None :
437+ """
438+ Batch observe all newly committed experiments since last observation.
439+ Should be called before making a new selection to ensure statistics are up-to-date.
440+ """
441+ start_idx = max (0 , self .last_observed_commit_idx )
442+ # Only observe fully committed items (both dag_parent and hist appended)
443+ end_idx = min (len (trace .dag_parent ), len (trace .hist ))
444+ if start_idx >= end_idx :
445+ return
446+ for idx in range (start_idx , end_idx ):
447+ self .observe_feedback (trace , idx )
448+ self .last_observed_commit_idx = end_idx
0 commit comments