11from __future__ import annotations
22
33import asyncio
4+ import math
5+ import random
46from abc import ABC , abstractmethod
57from collections import defaultdict
68from typing import TYPE_CHECKING
79
10+ from rdagent .log import rdagent_logger as logger
11+
812if TYPE_CHECKING :
913 from rdagent .scenarios .data_science .proposal .exp_gen .base import DSTrace
1014
@@ -22,7 +26,7 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
2226
2327 For proposing selections, we have to follow the rules
2428 - Suggest selection: suggest a selection that is suitable for the current trace.
25- - Suggested should be garenteed to be recorded at last!!!
29+ - Suggested should be garenteed to be recorded at last!!!!
2630 - If no suitable selection is found, the function should async wait!!!!
2731
2832 Args:
@@ -35,17 +39,8 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
3539 raise NotImplementedError
3640
3741
38- class RoundRobinScheduler (TraceScheduler ):
39- """
40- A concurrency-safe scheduling strategy that cycles through active traces
41- in a round-robin fashion.
42-
43- NOTE: we don't need to use asyncio.Lock here as the kickoff_loop ensures the ExpGen is always sequential, instead of parallel.
44- """
45-
46- def __init__ (self , max_trace_num : int ):
47- self .max_trace_num = max_trace_num
48- self ._last_selected_leaf_id = - 1
42+ class BaseScheduler (TraceScheduler ):
43+ def __init__ (self ):
4944 self .rec_commit_idx = 0 # the node before rec_idx is already committed.
5045 self .uncommited_rec_status = defaultdict (int ) # the uncommited record status
5146
@@ -56,25 +51,241 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
5651 while True :
5752 # step 0: Commit the pending selections
5853 for i in range (self .rec_commit_idx , len (trace .dag_parent )):
59-
60- if trace . dag_parent [ i ] == trace .NEW_ROOT :
54+ parent_of_i = trace . dag_parent [ i ]
55+ if parent_of_i == trace .NEW_ROOT :
6156 self .uncommited_rec_status [trace .NEW_ROOT ] -= 1
6257 else :
63- for p in trace . dag_parent [ i ] :
58+ for p in parent_of_i :
6459 self .uncommited_rec_status [p ] -= 1
65-
6660 self .rec_commit_idx = len (trace .hist )
6761
68- # step 1: select the parant trace to expand
69- # Policy: if we have fewer traces than our target, start a new one.
70- if trace .sub_trace_count + self .uncommited_rec_status [trace .NEW_ROOT ] < self .max_trace_num :
71- self .uncommited_rec_status [trace .NEW_ROOT ] += 1
72- return trace .NEW_ROOT
73-
74- # Step2: suggest a selection to a not expanding leave
75- leaves = trace .get_leaves ()
76- for leaf in leaves :
77- if self .uncommited_rec_status [leaf ] == 0 :
78- self .uncommited_rec_status [leaf ] += 1
79- return (leaf ,)
62+ parents = self .select (trace )
63+
64+ if parents is not None :
65+ if parents == trace .NEW_ROOT :
66+ self .uncommited_rec_status [trace .NEW_ROOT ] += 1
67+ else :
68+ for p in parents :
69+ self .uncommited_rec_status [p ] += 1
70+ return parents
71+
8072 await asyncio .sleep (1 )
73+
74+ @abstractmethod
75+ def select (self , trace : DSTrace ) -> tuple [int , ...] | None :
76+ """Selects the parent nodes for the new experiment, or None if no selection can be made."""
77+ raise NotImplementedError
78+
79+
80+ class RoundRobinScheduler (BaseScheduler ):
81+ """
82+ A concurrency-safe scheduling strategy that cycles through active traces
83+ in a round-robin fashion.
84+
85+ NOTE: we don't need to use asyncio.Lock here as the kickoff_loop ensures the ExpGen is always sequential, instead of parallel.
86+ """
87+
88+ def __init__ (self , max_trace_num : int , * args , ** kwargs ):
89+ logger .info (f"RoundRobinScheduler: max_trace_num={ max_trace_num } " )
90+ self .max_trace_num = max_trace_num
91+ self ._last_selected_leaf_id = - 1
92+ super ().__init__ ()
93+
94+ def select (self , trace : DSTrace ) -> tuple [int , ...] | None :
95+ """
96+ Atomically selects the next leaf node from the trace in order.
97+ If no suitable selection is found, return None.
98+ """
99+ # Policy: if we have fewer traces than our target, start a new one.
100+ if trace .sub_trace_count + self .uncommited_rec_status [trace .NEW_ROOT ] < self .max_trace_num :
101+ return trace .NEW_ROOT
102+
103+ # Step2: suggest a selection to a not expanding leave
104+ leaves = trace .get_leaves ()
105+ for leaf in leaves :
106+ if self .uncommited_rec_status [leaf ] == 0 :
107+ return (leaf ,)
108+
109+ return None
110+
111+
112+ # ======================================================================================
113+ # Probabilistic Scheduler and its potential functions
114+ # ======================================================================================
115+
116+
117+ class ProbabilisticScheduler (BaseScheduler ):
118+ """
119+ A concurrency-safe scheduling strategy that samples the next trace to expand
120+ based on a probability distribution derived from a potential function.
121+ """
122+
123+ def __init__ (self , max_trace_num : int , temperature : float = 1.0 , * args , ** kwargs ):
124+ """
125+ Args:
126+ max_trace_num: The target number of parallel traces.
127+ temperature: Temperature parameter for softmax calculation. Higher values make selection more uniform.
128+ """
129+ if max_trace_num <= 0 :
130+ raise ValueError ("max_trace_num must be positive." )
131+ if temperature <= 0 :
132+ raise ValueError ("temperature must be positive." )
133+
134+ self .max_trace_num = max_trace_num
135+ self .temperature = temperature
136+ super ().__init__ ()
137+
138+ def calculate_potential (self , trace : DSTrace , leaf_id : int ) -> float :
139+ """
140+ Calculate potential score for a given leaf node.
141+ This is the base implementation that provides uniform distribution.
142+
143+ Args:
144+ trace: The DSTrace object containing the full experiment history.
145+ leaf_id: The index of the leaf node to evaluate.
146+
147+ Returns:
148+ float: A potential score. Higher means more likely to be selected.
149+ """
150+ return 1.0 # Uniform distribution by default
151+
152+ def _softmax_probabilities (self , potentials : list [float ]) -> list [float ]:
153+ """
154+ Convert potential scores to probabilities using softmax.
155+
156+ Args:
157+ potentials: List of potential scores.
158+
159+ Returns:
160+ List of probabilities that sum to 1.
161+ """
162+ if not potentials :
163+ return []
164+
165+ # Apply temperature scaling
166+ scaled_potentials = [p / self .temperature for p in potentials ]
167+
168+ # Compute softmax
169+ max_potential = max (scaled_potentials )
170+ exp_potentials = [math .exp (p - max_potential ) for p in scaled_potentials ]
171+ sum_exp = sum (exp_potentials )
172+
173+ if sum_exp == 0 :
174+ # If all potentials are very small, return uniform distribution
175+ return [1.0 / len (potentials )] * len (potentials )
176+
177+ return [exp_p / sum_exp for exp_p in exp_potentials ]
178+
179+ def select (self , trace : DSTrace ) -> tuple [int , ...] | None :
180+ """
181+ Selects the next leaf node based on probabilistic sampling.
182+ """
183+ # Step 1: If we have fewer traces than our target, start a new one.
184+ # This policy prioritizes reaching the desired number of traces.
185+ if trace .sub_trace_count + self .uncommited_rec_status [trace .NEW_ROOT ] < self .max_trace_num :
186+ return trace .NEW_ROOT
187+
188+ # Step 2: Probabilistically select a leaf to expand.
189+ leaves = trace .get_leaves ()
190+ available_leaves = [leaf for leaf in leaves if self .uncommited_rec_status [leaf ] == 0 ]
191+
192+ if not available_leaves :
193+ return None
194+
195+ # Calculate potential for each available leaf
196+ potentials = [self .calculate_potential (trace , leaf ) for leaf in available_leaves ]
197+
198+ if any (p < 0 for p in potentials ):
199+ raise ValueError ("Potential function returned a negative value." )
200+
201+ # Convert potentials to probabilities using softmax
202+ probabilities = self ._softmax_probabilities (potentials )
203+
204+ # Select a leaf based on probabilities
205+ selected_leaf = random .choices (available_leaves , weights = probabilities , k = 1 )[0 ]
206+
207+ return (selected_leaf ,)
208+
209+
210+ class TraceLengthScheduler (ProbabilisticScheduler ):
211+ """
212+ A scheduler that prefers longer traces (more experiments)
213+ -- default: prefer to expand the trace that has more experiments (quicker to get the result).
214+ -- if inverse=True, prefer to expand the trace that has less experiments.
215+
216+ """
217+
218+ def __init__ (self , max_trace_num : int , temperature : float = 1.0 , inverse : bool = False , * args , ** kwargs ):
219+ """
220+ Args:
221+ max_trace_num: The target number of parallel traces.
222+ temperature: Temperature parameter for softmax calculation.
223+ inverse: If True, shorter traces get higher potential.
224+ """
225+ logger .info (
226+ f"TraceLengthScheduler: max_trace_num={ max_trace_num } , temperature={ temperature } , inverse={ inverse } "
227+ )
228+ super ().__init__ (max_trace_num , temperature )
229+ self .inverse = inverse
230+
231+ def calculate_potential (self , trace : DSTrace , leaf_id : int ) -> float :
232+ """
233+ Calculate potential based on the length of the trace leading to the leaf.
234+ """
235+ # Get the path from root to this leaf using existing method
236+ path = trace .get_parents (leaf_id )
237+ path_len = len (path )
238+
239+ if path_len == 0 :
240+ return 1.0
241+
242+ return 1.0 / path_len if self .inverse else float (path_len )
243+
244+
245+ class SOTABasedScheduler (ProbabilisticScheduler ):
246+ """
247+ A scheduler that prefers traces with more SOTA (State of the Art) results.
248+ """
249+
250+ def __init__ (self , max_trace_num : int , temperature : float = 1.0 , inverse : bool = False , * args , ** kwargs ):
251+ """
252+ Args:
253+ max_trace_num: The target number of parallel traces.
254+ temperature: Temperature parameter for softmax calculation.
255+ inverse: If True, fewer SOTA results get higher potential.
256+ """
257+ logger .info (f"SOTABasedScheduler: max_trace_num={ max_trace_num } , temperature={ temperature } , inverse={ inverse } " )
258+ super ().__init__ (max_trace_num , temperature )
259+ self .inverse = inverse
260+
261+ def calculate_potential (self , trace : DSTrace , leaf_id : int ) -> float :
262+ """
263+ Calculate potential based on the number of SOTA results in the trace.
264+ """
265+ # Get the path from root to this leaf
266+ path = trace .get_parents (leaf_id )
267+ sota_count = 0
268+
269+ for node_id in path :
270+ # Check if this experiment was successful (decision=True)
271+ if node_id < len (trace .hist ):
272+ exp , feedback = trace .hist [node_id ]
273+ if feedback .decision :
274+ sota_count += 1
275+
276+ if self .inverse :
277+ # Add 1 to avoid division by zero and give traces with 0 SOTAs the highest potential.
278+ return 1.0 / (sota_count + 1 )
279+ return float (sota_count )
280+
281+
282+ class RandomScheduler (ProbabilisticScheduler ):
283+ """
284+ A scheduler that selects traces randomly with uniform distribution.
285+ """
286+
287+ def calculate_potential (self , trace : DSTrace , leaf_id : int ) -> float :
288+ """
289+ Return random potential for uniform random selection.
290+ """
291+ return random .random ()
0 commit comments