Skip to content

Commit d38eae9

Browse files
peteryang1peteryangmsyou-n-g
authored
feat: add type checker to api backend & align litellm and old backend (#647)
* move cache auto continue and retry to all api backend * add type checker to json mode output * fix CI * feat: Add json_mode handling and streaming support in chat completion function * lint * fix a bug when returning a dict which value could contain int or bool * remove litellm --------- Co-authored-by: Xu Yang <[email protected]> Co-authored-by: Young <[email protected]>
1 parent 14e664b commit d38eae9

File tree

25 files changed

+635
-607
lines changed

25 files changed

+635
-607
lines changed

rdagent/app/qlib_rd_loop/factor_from_report.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from pathlib import Path
3-
from typing import Any, Tuple
3+
from typing import Any, Dict, Tuple
44

55
import fire
66
from jinja2 import Environment, StrictUndefined
@@ -49,6 +49,7 @@ def generate_hypothesis(factor_result: dict, report_content: str) -> str:
4949
user_prompt=user_prompt,
5050
system_prompt=system_prompt,
5151
json_mode=True,
52+
json_target_type=Dict[str, str],
5253
)
5354

5455
response_json = json.loads(response)

rdagent/components/coder/CoSTEER/knowledge_management.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
from itertools import combinations
88
from pathlib import Path
9-
from typing import Union
9+
from typing import List, Union
1010

1111
from jinja2 import Environment, StrictUndefined
1212

@@ -339,6 +339,7 @@ def analyze_component(
339339
system_prompt=analyze_component_system_prompt,
340340
user_prompt=analyze_component_user_prompt,
341341
json_mode=True,
342+
json_target_type=List[int],
342343
),
343344
)["component_no_list"]
344345
return [all_component_nodes[index - 1] for index in sorted(list(set(component_no_list)))]

rdagent/components/coder/data_science/ensemble/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"""
1313

1414
import json
15+
from typing import Dict
1516

1617
from rdagent.components.coder.CoSTEER import CoSTEER
1718
from rdagent.components.coder.CoSTEER.evaluators import (
@@ -85,7 +86,10 @@ def implement_one_task(
8586
for _ in range(5):
8687
ensemble_code = json.loads(
8788
APIBackend().build_messages_and_create_chat_completion(
88-
user_prompt=user_prompt, system_prompt=system_prompt, json_mode=True
89+
user_prompt=user_prompt,
90+
system_prompt=system_prompt,
91+
json_mode=True,
92+
json_target_type=Dict[str, str],
8993
)
9094
)["code"]
9195
if ensemble_code != workspace.file_dict.get("ensemble.py"):

rdagent/components/coder/data_science/feature/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from typing import Dict
23

34
from rdagent.components.coder.CoSTEER import CoSTEER
45
from rdagent.components.coder.CoSTEER.evaluators import (
@@ -70,7 +71,10 @@ def implement_one_task(
7071
for _ in range(5):
7172
feature_code = json.loads(
7273
APIBackend().build_messages_and_create_chat_completion(
73-
user_prompt=user_prompt, system_prompt=system_prompt, json_mode=True
74+
user_prompt=user_prompt,
75+
system_prompt=system_prompt,
76+
json_mode=True,
77+
json_target_type=Dict[str, str],
7478
)
7579
)["code"]
7680
if feature_code != workspace.file_dict.get("feature.py"):

rdagent/components/coder/data_science/model/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Dict
2+
13
from rdagent.components.coder.CoSTEER import CoSTEER
24
from rdagent.components.coder.CoSTEER.evaluators import (
35
CoSTEERMultiEvaluator,
@@ -83,6 +85,7 @@ def implement_one_task(
8385
user_prompt=user_prompt,
8486
system_prompt=system_prompt,
8587
json_mode=BatchEditOut.json_mode,
88+
json_target_type=Dict[str, str],
8689
)
8790
)
8891

rdagent/components/coder/data_science/raw_data_loader/__init__.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import json
2626
import re
27+
from typing import Dict
2728

2829
from rdagent.app.data_science.conf import DS_RD_SETTING
2930
from rdagent.components.coder.CoSTEER import CoSTEER
@@ -108,20 +109,30 @@ def implement_one_task(
108109
spec_session = APIBackend().build_chat_session(session_system_prompt=system_prompt)
109110

110111
data_loader_spec = json.loads(
111-
spec_session.build_chat_completion(user_prompt=data_loader_prompt, json_mode=True)
112+
spec_session.build_chat_completion(
113+
user_prompt=data_loader_prompt, json_mode=True, json_target_type=Dict[str, str]
114+
)
115+
)["spec"]
116+
feature_spec = json.loads(
117+
spec_session.build_chat_completion(
118+
user_prompt=feature_prompt, json_mode=True, json_target_type=Dict[str, str]
119+
)
120+
)["spec"]
121+
model_spec = json.loads(
122+
spec_session.build_chat_completion(
123+
user_prompt=model_prompt, json_mode=True, json_target_type=Dict[str, str]
124+
)
125+
)["spec"]
126+
ensemble_spec = json.loads(
127+
spec_session.build_chat_completion(
128+
user_prompt=ensemble_prompt, json_mode=True, json_target_type=Dict[str, str]
129+
)
130+
)["spec"]
131+
workflow_spec = json.loads(
132+
spec_session.build_chat_completion(
133+
user_prompt=workflow_prompt, json_mode=True, json_target_type=Dict[str, str]
134+
)
112135
)["spec"]
113-
feature_spec = json.loads(spec_session.build_chat_completion(user_prompt=feature_prompt, json_mode=True))[
114-
"spec"
115-
]
116-
model_spec = json.loads(spec_session.build_chat_completion(user_prompt=model_prompt, json_mode=True))[
117-
"spec"
118-
]
119-
ensemble_spec = json.loads(spec_session.build_chat_completion(user_prompt=ensemble_prompt, json_mode=True))[
120-
"spec"
121-
]
122-
workflow_spec = json.loads(spec_session.build_chat_completion(user_prompt=workflow_prompt, json_mode=True))[
123-
"spec"
124-
]
125136
else:
126137
data_loader_spec = workspace.file_dict["spec/data_loader.md"]
127138
feature_spec = workspace.file_dict["spec/feature.md"]
@@ -146,7 +157,10 @@ def implement_one_task(
146157
for _ in range(5):
147158
data_loader_code = json.loads(
148159
APIBackend().build_messages_and_create_chat_completion(
149-
user_prompt=user_prompt, system_prompt=system_prompt, json_mode=True
160+
user_prompt=user_prompt,
161+
system_prompt=system_prompt,
162+
json_mode=True,
163+
json_target_type=Dict[str, str],
150164
)
151165
)["code"]
152166
if data_loader_code != workspace.file_dict.get("load_data.py"):

rdagent/components/coder/data_science/workflow/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from typing import Dict
23

34
from rdagent.components.coder.CoSTEER import CoSTEER
45
from rdagent.components.coder.CoSTEER.evaluators import (
@@ -73,7 +74,10 @@ def implement_one_task(
7374
for _ in range(5):
7475
workflow_code = json.loads(
7576
APIBackend().build_messages_and_create_chat_completion(
76-
user_prompt=user_prompt, system_prompt=system_prompt, json_mode=True
77+
user_prompt=user_prompt,
78+
system_prompt=system_prompt,
79+
json_mode=True,
80+
json_target_type=Dict[str, str],
7781
)
7882
)["code"]
7983
if workflow_code != workspace.file_dict.get("main.py"):

rdagent/components/coder/factor_coder/eva_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
from abc import abstractmethod
44
from pathlib import Path
5-
from typing import Tuple
5+
from typing import Dict, Tuple
66

77
import pandas as pd
88
from jinja2 import Environment, StrictUndefined
@@ -212,7 +212,10 @@ def evaluate(
212212
try:
213213
api = APIBackend() if attempts == 0 else APIBackend(use_chat_cache=False)
214214
resp = api.build_messages_and_create_chat_completion(
215-
user_prompt=gen_df_info_str, system_prompt=system_prompt, json_mode=True
215+
user_prompt=gen_df_info_str,
216+
system_prompt=system_prompt,
217+
json_mode=True,
218+
json_target_type=Dict[str, str | bool | int],
216219
)
217220
resp_dict = json.loads(resp)
218221
resp_dict["output_format_decision"] = str(resp_dict["output_format_decision"]).lower() in ["true", "1"]
@@ -556,6 +559,7 @@ def evaluate(
556559
system_prompt=system_prompt,
557560
json_mode=True,
558561
seed=attempts, # in case of useless retrying when cache enabled.
562+
json_target_type=Dict[str, str | bool | int],
559563
),
560564
)
561565
final_decision = final_evaluation_dict["final_decision"]

rdagent/components/coder/factor_coder/evolving_strategy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
from pathlib import Path
5+
from typing import Dict
56

67
from jinja2 import Environment, StrictUndefined
78

@@ -168,7 +169,10 @@ def implement_one_task(
168169
APIBackend(
169170
use_chat_cache=FACTOR_COSTEER_SETTINGS.coder_use_cache
170171
).build_messages_and_create_chat_completion(
171-
user_prompt=user_prompt, system_prompt=system_prompt, json_mode=True
172+
user_prompt=user_prompt,
173+
system_prompt=system_prompt,
174+
json_mode=True,
175+
json_target_type=Dict[str, str],
172176
)
173177
)["code"]
174178
return code

rdagent/components/coder/model_coder/eva_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from pathlib import Path
3-
from typing import Tuple
3+
from typing import Dict, Tuple
44

55
import numpy as np
66
from jinja2 import Environment, StrictUndefined
@@ -177,6 +177,7 @@ def evaluate(
177177
user_prompt=user_prompt,
178178
system_prompt=system_prompt,
179179
json_mode=True,
180+
json_target_type=Dict[str, str | bool | int],
180181
),
181182
)
182183
if isinstance(final_evaluation_dict["final_decision"], str) and final_evaluation_dict[

0 commit comments

Comments
 (0)