Skip to content

Commit 7b010f8

Browse files
authored
fix: fix some bugs in llm calling (#217)
* fix some bugs in llm calling * fix a CI error * fix a small bug in kmeans group number choosing * fix a ci bug
1 parent 8256067 commit 7b010f8

File tree

2 files changed

+55
-33
lines changed

2 files changed

+55
-33
lines changed

rdagent/core/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class RDAgentSettings(BaseSettings):
9898
# factor extraction conf
9999
max_input_duplicate_factor_group: int = 300
100100
max_output_duplicate_factor_group: int = 20
101+
max_kmeans_group_number: int = 40
101102

102103
# workspace conf
103104
workspace_path: Path = Path.cwd() / "git_ignore_folder" / "RD-Agent_workspace"

rdagent/scenarios/qlib/factor_experiment_loader/pdf_loader.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -363,29 +363,49 @@ def check_factor_viability(
363363
def __check_factor_duplication_simulate_json_mode(
364364
factor_df: pd.DataFrame,
365365
) -> list[list[str]]:
366-
session = APIBackend().build_chat_session(
367-
session_system_prompt=document_process_prompts["factor_duplicate_system"],
368-
)
369366
current_user_prompt = factor_df.to_string()
370367

368+
working_list = [factor_df]
369+
final_list = []
370+
371+
while len(working_list) > 0:
372+
current_df = working_list.pop(0)
373+
if (
374+
APIBackend().build_messages_and_calculate_token(
375+
user_prompt=current_df.to_string(), system_prompt=document_process_prompts["factor_duplicate_system"]
376+
)
377+
> RD_AGENT_SETTINGS.chat_token_limit
378+
):
379+
working_list.append(current_df.iloc[: current_df.shape[0] // 2, :])
380+
working_list.append(current_df.iloc[current_df.shape[0] // 2 :, :])
381+
else:
382+
final_list.append(current_df)
383+
371384
generated_duplicated_groups = []
372-
for _ in range(10):
373-
extract_result_resp = session.build_chat_completion(
374-
user_prompt=current_user_prompt,
375-
json_mode=True,
385+
for current_df in final_list:
386+
current_factor_to_string = current_df.to_string()
387+
session = APIBackend().build_chat_session(
388+
session_system_prompt=document_process_prompts["factor_duplicate_system"],
376389
)
377-
ret_dict = json.loads(extract_result_resp)
378-
if len(ret_dict) == 0:
379-
return generated_duplicated_groups
380-
else:
381-
generated_duplicated_groups.extend(ret_dict)
382-
current_user_prompt = """Continue to extract duplicated groups. If no more duplicated group found please respond empty dict."""
390+
for _ in range(10):
391+
extract_result_resp = session.build_chat_completion(
392+
user_prompt=current_factor_to_string,
393+
json_mode=True,
394+
)
395+
ret_dict = json.loads(extract_result_resp)
396+
if len(ret_dict) == 0:
397+
return generated_duplicated_groups
398+
else:
399+
generated_duplicated_groups.extend(ret_dict)
400+
current_factor_to_string = """Continue to extract duplicated groups. If no more duplicated group found please respond empty dict."""
383401
return generated_duplicated_groups
384402

385403

386404
def __kmeans_embeddings(embeddings: np.ndarray, k: int = 20) -> list[list[str]]:
387405
x_normalized = normalize(embeddings)
388406

407+
np.random.seed(42)
408+
389409
kmeans = KMeans(
390410
n_clusters=k,
391411
init="random",
@@ -468,7 +488,7 @@ def __deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[li
468488
else:
469489
for k in range(
470490
len(full_str_list) // RD_AGENT_SETTINGS.max_input_duplicate_factor_group,
471-
40,
491+
RD_AGENT_SETTINGS.max_kmeans_group_number,
472492
):
473493
kmeans_index_group = __kmeans_embeddings(embeddings=embeddings, k=k)
474494
if len(kmeans_index_group[0]) < RD_AGENT_SETTINGS.max_input_duplicate_factor_group:
@@ -479,26 +499,22 @@ def __deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[li
479499

480500
duplication_names_list = []
481501

482-
pool = mp.Pool(target_k)
483-
result_list = [
484-
pool.apply_async(
485-
__check_factor_duplication_simulate_json_mode,
486-
(factor_df.loc[factor_name_group, :],),
487-
)
488-
for factor_name_group in factor_name_groups
489-
]
502+
result_list = multiprocessing_wrapper(
503+
[
504+
(__check_factor_duplication_simulate_json_mode, (factor_df.loc[factor_name_group, :],))
505+
for factor_name_group in factor_name_groups
506+
],
507+
n=RD_AGENT_SETTINGS.multi_proc_n,
508+
)
490509

491-
pool.close()
492-
pool.join()
510+
duplication_names_list = []
493511

494-
for result in result_list:
495-
deduplication_factor_names_list = result.get()
496-
for deduplication_factor_names in deduplication_factor_names_list:
497-
filter_factor_names = [
498-
factor_name for factor_name in set(deduplication_factor_names) if factor_name in factor_dict
499-
]
500-
if len(filter_factor_names) > 1:
501-
duplication_names_list.append(filter_factor_names)
512+
for deduplication_factor_names_list in result_list:
513+
filter_factor_names = [
514+
factor_name for factor_name in set(deduplication_factor_names_list) if factor_name in factor_dict
515+
]
516+
if len(filter_factor_names) > 1:
517+
duplication_names_list.append(filter_factor_names)
502518

503519
return duplication_names_list
504520

@@ -509,6 +525,8 @@ def deduplicate_factors_by_llm( # noqa: C901, PLR0912
509525
) -> list[list[str]]:
510526
final_duplication_names_list = []
511527
current_round_factor_dict = factor_dict
528+
529+
# handle multi-round deduplication
512530
for _ in range(10):
513531
duplication_names_list = __deduplicate_factor_dict(current_round_factor_dict)
514532

@@ -524,11 +542,13 @@ def deduplicate_factors_by_llm( # noqa: C901, PLR0912
524542
else:
525543
break
526544

545+
# sort the final list of duplicates by their length, largest first
527546
final_duplication_names_list = sorted(final_duplication_names_list, key=lambda x: len(x), reverse=True)
528547

529-
to_replace_dict = {}
548+
to_replace_dict = {} # to map duplicates to the target factor names
530549
for duplication_names in duplication_names_list:
531550
if factor_viability_dict is not None:
551+
# check viability of each factor in the duplicates group
532552
viability_list = [factor_viability_dict[name]["viability"] for name in duplication_names]
533553
if True not in viability_list:
534554
continue
@@ -543,6 +563,7 @@ def deduplicate_factors_by_llm( # noqa: C901, PLR0912
543563
llm_deduplicated_factor_dict = {}
544564
added_lower_name_set = set()
545565
for factor_name in factor_dict:
566+
# only add factors that haven't been replaced and are not duplicates
546567
if factor_name not in to_replace_dict and factor_name.lower() not in added_lower_name_set:
547568
if factor_viability_dict is not None and not factor_viability_dict[factor_name]["viability"]:
548569
continue

0 commit comments

Comments
 (0)