@@ -242,13 +242,7 @@ def identify_scenario_problem(self, scenario_desc: str, competition_desc: str, s
242242 )
243243 return json .loads (response )
244244
245- def identify_feedback_problem (
246- self ,
247- scenario_desc : str ,
248- exp_feedback_list_desc : str ,
249- sota_exp_desc : str ,
250- pipeline : bool ,
251- ) -> Dict :
245+ def identify_feedback_problem (self , scenario_desc : str , exp_feedback_list_desc : str , sota_exp_desc : str ) -> Dict :
252246 sys_prompt = T (".prompts_v2:scenario_problem.system" ).r (
253247 problem_spec = T (".prompts_v2:specification.problem" ).r (),
254248 problem_output_format = T (".prompts_v2:output_format.problem" ).r (),
@@ -266,6 +260,13 @@ def identify_feedback_problem(
266260 )
267261 return json .loads (response )
268262
263+ def _append_retry (args : tuple , kwargs : dict ) -> tuple [tuple , dict ]:
264+ # Only modify the user_prompt on retries (i > 0)
265+ user_prompt = args [0 ]
266+ user_prompt += "\n \n retrying..."
267+ return (user_prompt ,), kwargs
268+
269+ @wait_retry (retry_n = 5 , transform_args_fn = _append_retry )
269270 def hypothesis_gen (
270271 self ,
271272 component_desc : str ,
@@ -293,7 +294,13 @@ def hypothesis_gen(
293294 json_mode = True ,
294295 json_target_type = Dict [str , Dict [str , str | Dict [str , str | int ]]],
295296 )
296- return json .loads (response )
297+ resp_dict = json .loads (response )
298+ for key , value in resp_dict .items ():
299+ assert "reason" in value , "Reason not provided."
300+ assert "component" in value , "Component not provided."
301+ assert "hypothesis" in value , "Hypothesis not provided."
302+ assert "evaluation" in value , "Evaluation not provided."
303+ return resp_dict
297304
298305 def hypothesis_rank (self , hypothesis_dict : dict , problem_dict : dict , pipeline : bool ) -> DSHypothesis :
299306 weights = {
@@ -361,7 +368,6 @@ def task_gen(
361368 component_desc = component_desc ,
362369 workflow_check = not pipeline and hypothesis .component != "Workflow" ,
363370 )
364-
365371 user_prompt = T (".prompts_v2:task_gen.user" ).r (
366372 scenario_desc = scenario_desc ,
367373 sota_exp_desc = sota_exp_desc ,
@@ -394,7 +400,6 @@ def task_gen(
394400 # exp.experiment_workspace.inject_code_from_folder(sota_exp.experiment_workspace.workspace_path)
395401 if sota_exp is not None :
396402 exp .experiment_workspace .inject_code_from_file_dict (sota_exp .experiment_workspace )
397-
398403 if not pipeline and new_workflow_desc != "No update needed" :
399404 workflow_task = WorkflowTask (
400405 name = "Workflow" ,
@@ -442,7 +447,6 @@ def gen(self, trace: DSTrace, pipeline: bool = False) -> DSExperiment:
442447 scenario_desc = scenario_desc ,
443448 exp_feedback_list_desc = exp_feedback_list_desc ,
444449 sota_exp_desc = sota_exp_desc ,
445- pipeline = pipeline ,
446450 )
447451 all_problems = {** scen_problems , ** fb_problems }
448452
0 commit comments