Skip to content

Commit b9fb7cf

Browse files
RolandMinruiXuWinstonLiyt
authored
fix: add wait_retry to exp_gen v2 (#783)
* add wait retry to v2 * format * fix a bug --------- Co-authored-by: Xu <[email protected]> Co-authored-by: yuanteli <[email protected]>
1 parent ac008a6 commit b9fb7cf

File tree

1 file changed

+15
-11
lines changed
  • rdagent/scenarios/data_science/proposal/exp_gen

1 file changed

+15
-11
lines changed

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,7 @@ def identify_scenario_problem(self, scenario_desc: str, competition_desc: str, s
242242
)
243243
return json.loads(response)
244244

245-
def identify_feedback_problem(
246-
self,
247-
scenario_desc: str,
248-
exp_feedback_list_desc: str,
249-
sota_exp_desc: str,
250-
pipeline: bool,
251-
) -> Dict:
245+
def identify_feedback_problem(self, scenario_desc: str, exp_feedback_list_desc: str, sota_exp_desc: str) -> Dict:
252246
sys_prompt = T(".prompts_v2:scenario_problem.system").r(
253247
problem_spec=T(".prompts_v2:specification.problem").r(),
254248
problem_output_format=T(".prompts_v2:output_format.problem").r(),
@@ -266,6 +260,13 @@ def identify_feedback_problem(
266260
)
267261
return json.loads(response)
268262

263+
def _append_retry(args: tuple, kwargs: dict) -> tuple[tuple, dict]:
264+
# Only modify the user_prompt on retries (i > 0)
265+
user_prompt = args[0]
266+
user_prompt += "\n\nretrying..."
267+
return (user_prompt,), kwargs
268+
269+
@wait_retry(retry_n=5, transform_args_fn=_append_retry)
269270
def hypothesis_gen(
270271
self,
271272
component_desc: str,
@@ -293,7 +294,13 @@ def hypothesis_gen(
293294
json_mode=True,
294295
json_target_type=Dict[str, Dict[str, str | Dict[str, str | int]]],
295296
)
296-
return json.loads(response)
297+
resp_dict = json.loads(response)
298+
for key, value in resp_dict.items():
299+
assert "reason" in value, "Reason not provided."
300+
assert "component" in value, "Component not provided."
301+
assert "hypothesis" in value, "Hypothesis not provided."
302+
assert "evaluation" in value, "Evaluation not provided."
303+
return resp_dict
297304

298305
def hypothesis_rank(self, hypothesis_dict: dict, problem_dict: dict, pipeline: bool) -> DSHypothesis:
299306
weights = {
@@ -361,7 +368,6 @@ def task_gen(
361368
component_desc=component_desc,
362369
workflow_check=not pipeline and hypothesis.component != "Workflow",
363370
)
364-
365371
user_prompt = T(".prompts_v2:task_gen.user").r(
366372
scenario_desc=scenario_desc,
367373
sota_exp_desc=sota_exp_desc,
@@ -394,7 +400,6 @@ def task_gen(
394400
# exp.experiment_workspace.inject_code_from_folder(sota_exp.experiment_workspace.workspace_path)
395401
if sota_exp is not None:
396402
exp.experiment_workspace.inject_code_from_file_dict(sota_exp.experiment_workspace)
397-
398403
if not pipeline and new_workflow_desc != "No update needed":
399404
workflow_task = WorkflowTask(
400405
name="Workflow",
@@ -442,7 +447,6 @@ def gen(self, trace: DSTrace, pipeline: bool = False) -> DSExperiment:
442447
scenario_desc=scenario_desc,
443448
exp_feedback_list_desc=exp_feedback_list_desc,
444449
sota_exp_desc=sota_exp_desc,
445-
pipeline=pipeline,
446450
)
447451
all_problems = {**scen_problems, **fb_problems}
448452

0 commit comments

Comments
 (0)