Skip to content

Commit 1ba7548

Browse files
fix: merge datascience v3 and v2 (#974)
* add coder version * merge cooder and feedback prompts * align v2 and v3 proposal prompts * fix a small bug * fix a bug * fix another bug * support both function calling and json mode in v2 proposal * fix minor bug * reformat * remove proposal v3 * fix a small bug in json mode * fix CI * remove tmp file * remove v3 check --------- Co-authored-by: Xu Yang <[email protected]>
1 parent 923a326 commit 1ba7548

File tree

11 files changed

+428
-589
lines changed

11 files changed

+428
-589
lines changed

rdagent/components/coder/data_science/pipeline/__init__.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,12 @@ def implement_one_task(
9898
spec=T("scenarios.data_science.share:component_spec.Pipeline").r(),
9999
enable_model_dump=DS_RD_SETTING.enable_model_dump,
100100
)
101-
if DS_RD_SETTING.proposal_version == "v3":
102-
# FIXME: A temporary patch for BUILD
103-
user_prompt = T(".prompts:pipeline_coder.user_v3").r(
104-
competition_info=competition_info,
105-
folder_spec=data_folder_info,
106-
latest_code=workspace.file_dict.get("main.py"),
107-
latest_code_feedback=prev_task_feedback,
108-
)
109-
else:
110-
user_prompt = T(".prompts:pipeline_coder.user").r(
111-
competition_info=competition_info,
112-
folder_spec=data_folder_info,
113-
latest_code=workspace.file_dict.get("main.py"),
114-
latest_code_feedback=prev_task_feedback,
115-
)
101+
user_prompt = T(".prompts:pipeline_coder.user").r(
102+
competition_info=competition_info,
103+
folder_spec=data_folder_info,
104+
latest_code=workspace.file_dict.get("main.py"),
105+
latest_code_feedback=prev_task_feedback,
106+
)
116107

117108
for _ in range(5):
118109
pipeline_code = PythonAgentOut.extract_output(

rdagent/components/coder/data_science/pipeline/prompts.yaml

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -83,27 +83,6 @@ pipeline_coder:
8383
--------- Data Folder Description (All path are relative to the data folder) ---------
8484
{{ folder_spec }}
8585
86-
{% if latest_code %}
87-
--------- Former code ---------
88-
{{ latest_code }}
89-
{% if latest_code_feedback is not none %}
90-
--------- Feedback to former code ---------
91-
{{ latest_code_feedback }}
92-
The former code contains errors. You should correct the code based on the provided information, ensuring you do not repeat the same mistakes.
93-
{% else %}
94-
The former code is correct. You should try to improve the code based on the provided task while not changing the irrelevant parts.
95-
{% endif %}
96-
{% endif %}
97-
98-
You should strictly follow the code specifications provided by the specification to implement the function.
99-
100-
user_v3: |-
101-
--------- Competition Information ---------
102-
{{ competition_info }}
103-
104-
--------- Data Folder Description (All path are relative to the data folder) ---------
105-
{{ folder_spec }}
106-
10786
{% if latest_code %}
10887
--------- Former code ---------
10988
{{ latest_code }}

rdagent/oai/backend/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,13 @@ def _create_embedding_with_cache(
499499
self.cache.embedding_set(content_to_embedding_dict)
500500
return [content_to_embedding_dict[content] for content in input_content_list] # type: ignore[misc]
501501

502+
@abstractmethod
503+
def support_function_calling(self) -> bool:
504+
"""
505+
Check if the backend supports function calling
506+
"""
507+
raise NotImplementedError("Subclasses must implement this method")
508+
502509
@abstractmethod
503510
def _calculate_token_from_messages(self, messages: list[dict[str, Any]]) -> int:
504511
"""

rdagent/oai/backend/deprec.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,13 @@ def _azure_patch(model: str) -> str:
261261
raise
262262
return encoding
263263

264+
def support_function_calling(self) -> bool:
265+
"""
266+
Check if the backend supports function calling.
267+
Currently, deprec backend does not support function calling so it returns False. #FIXME: maybe a mapping to the backend class is needed.
268+
"""
269+
return False
270+
264271
def _create_embedding_inner_function( # type: ignore[no-untyped-def]
265272
self, input_content_list: list[str], *args, **kwargs
266273
) -> list[list[float]]: # noqa: ARG002

rdagent/oai/backend/litellm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
completion,
88
completion_cost,
99
embedding,
10+
supports_function_calling,
1011
supports_response_schema,
1112
token_counter,
1213
)
@@ -93,6 +94,12 @@ def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # no
9394
"""
9495
if json_mode and supports_response_schema(model=LITELLM_SETTINGS.chat_model):
9596
kwargs["response_format"] = {"type": "json_object"}
97+
elif not supports_response_schema(model=LITELLM_SETTINGS.chat_model) and "response_format" in kwargs:
98+
logger.warning(
99+
f"{LogColors.RED}Model {LITELLM_SETTINGS.chat_model} does not support response schema, ignoring response_format argument.{LogColors.END}",
100+
tag="llm_messages",
101+
)
102+
kwargs.pop("response_format")
96103

97104
if LITELLM_SETTINGS.log_llm_chat_content:
98105
logger.info(self._build_log_messages(messages), tag="llm_messages")
@@ -183,3 +190,9 @@ def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # no
183190
tag="token_cost",
184191
)
185192
return content, finish_reason
193+
194+
def support_function_calling(self) -> bool:
195+
"""
196+
Check if the backend supports function calling
197+
"""
198+
return supports_function_calling(model=LITELLM_SETTINGS.chat_model) and LITELLM_SETTINGS.enable_function_call

rdagent/oai/llm_conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ class LLMSettings(ExtendedBaseSettings):
1616
embedding_model: str = "text-embedding-3-small"
1717

1818
reasoning_effort: Literal["low", "medium", "high"] | None = None
19+
enable_function_call: bool = (
20+
True # Whether to enable function calling in chat models. may not work for models that do not support it.
21+
)
1922

2023
# Handling format
2124
reasoning_think_rm: bool = False

rdagent/scenarios/data_science/dev/feedback.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -87,29 +87,16 @@ def generate_feedback(self, exp: DSExperiment, trace: DSTrace) -> ExperimentFeed
8787
)
8888

8989
eda_output = exp.experiment_workspace.file_dict.get("EDA.md", None)
90-
if DS_RD_SETTING.proposal_version == "v3":
91-
# FIXME: Some minor changes. Did not have time to test the full.
92-
system_prompt = T(".prompts:exp_feedback_v3.system").r(
93-
scenario=self.scen.get_scenario_all_desc(eda_output=eda_output)
94-
)
95-
user_prompt = T(".prompts:exp_feedback_v3.user").r(
96-
sota_desc=sota_desc,
97-
cur_exp=exp,
98-
diff_edition=diff_edition,
99-
feedback_desc=feedback_desc,
100-
cur_vs_sota_score=cur_vs_sota_score,
101-
)
102-
else:
103-
system_prompt = T(".prompts:exp_feedback.system").r(
104-
scenario=self.scen.get_scenario_all_desc(eda_output=eda_output)
105-
)
106-
user_prompt = T(".prompts:exp_feedback.user").r(
107-
sota_desc=sota_desc,
108-
cur_exp=exp,
109-
diff_edition=diff_edition,
110-
feedback_desc=feedback_desc,
111-
cur_vs_sota_score=cur_vs_sota_score,
112-
)
90+
system_prompt = T(".prompts:exp_feedback.system").r(
91+
scenario=self.scen.get_scenario_all_desc(eda_output=eda_output)
92+
)
93+
user_prompt = T(".prompts:exp_feedback.user").r(
94+
sota_desc=sota_desc,
95+
cur_exp=exp,
96+
diff_edition=diff_edition,
97+
feedback_desc=feedback_desc,
98+
cur_vs_sota_score=cur_vs_sota_score,
99+
)
113100

114101
resp_dict = json.loads(
115102
APIBackend().build_messages_and_create_chat_completion(

rdagent/scenarios/data_science/dev/prompts.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ exp_feedback:
1515
- Recommend corrective actions explicitly.
1616
- Set `"Replace Best Result": "no"`.
1717
- Begin your `reasoning` with `[Submission format error]`, clearly stating the issues causing experiment failure.
18-
- If submission passes, proceed to Step 2.
18+
- If submission passes the submission format check:
19+
- If this is the first valid submission ever, set `"Replace Best Result": "yes"`.
20+
- Otherwise, proceed to Step 2.
1921
2022
Step 2: Evaluate Alignment with Competition Requirements (if format correct)
2123
- GOAL: CAREFULLY ANALYZE WHETHER THE EXPERIMENTAL SETUP AND CODE MAY CAUSE MISALIGNMENT BETWEEN VALIDATION AND TEST PERFORMANCE.
@@ -59,6 +61,8 @@ exp_feedback:
5961
Provide detailed and constructive feedback structured as follows:
6062
Example JSON Structure for Result Analysis:
6163
{
64+
"Submission Format Check": "yes or no",
65+
"First Valid Submission": "yes or no",
6266
"Observations": "Clearly summarize current and SOTA ensemble results with exact scores and notable patterns. Limit to no more than three concise, data-focused sentences. Your observation must be grounded by explicit evidence from scenario description or code implementation, not just validation scores.",
6367
"Feedback for Hypothesis": Explicitly confirm or refute the hypothesis based on specific data points or performance trends. Limit to two sentences.",
6468
"Evaluation Aligned With Task": "yes or no",
@@ -110,11 +114,11 @@ exp_feedback:
110114
{{ cur_exp.experiment_workspace.all_codes }}
111115
112116
## Feedback of past experiments
113-
{{ feedback_desc }}
117+
{{ feedback_desc or "There has not been any experiments yet." }}
114118
Please refer to these hypotheses and feedback to help you recommend new experiment and hypothesis
115119
116120
Tips:
117-
- Step 1: If submission format has issues, prioritize fixing them before proceeding.
121+
- Step 1: If submission format has issues, prioritize fixing them before proceeding. If the format is correct and it's the first valid submission ever (there has never been valid submissions in the past), set `"Replace Best Result": "yes"`. If the format is correct and this is not the first valid submission, proceed to Step 2.
118122
- Step 2: If evaluation alignment issues are identified (validation approach does not follow competition requirements), address these methodological discrepancies immediately.
119123
- Step 3: If new results significantly worse than SOTA, or repeated hyperparameter adjustments yield no improvement, it might be time to rethink or shift focus.
120124

rdagent/scenarios/data_science/loop.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,11 @@ def _get_exp_gen(class_uri: str, scen: Scenario):
9090
from rdagent.scenarios.data_science.proposal.exp_gen.proposal import (
9191
DSProposalV1ExpGen,
9292
DSProposalV2ExpGen,
93-
DSProposalV3ExpGen,
9493
)
9594

9695
if class_uri == "rdagent.scenarios.data_science.proposal.exp_gen.DSExpGen":
97-
if DS_RD_SETTING.proposal_version not in ["v1", "v2", "v3"]:
96+
if DS_RD_SETTING.proposal_version not in ["v1", "v2"]:
9897
return import_class(DS_RD_SETTING.proposal_version)(scen=scen)
99-
if DS_RD_SETTING.proposal_version == "v3":
100-
return DSProposalV3ExpGen(scen=scen)
10198
if DS_RD_SETTING.proposal_version == "v1":
10299
return DSProposalV1ExpGen(scen=scen)
103100
if DS_RD_SETTING.proposal_version == "v2":

0 commit comments

Comments
 (0)