@@ -363,29 +363,49 @@ def check_factor_viability(
363363def __check_factor_duplication_simulate_json_mode (
364364 factor_df : pd .DataFrame ,
365365) -> list [list [str ]]:
366- session = APIBackend ().build_chat_session (
367- session_system_prompt = document_process_prompts ["factor_duplicate_system" ],
368- )
369366 current_user_prompt = factor_df .to_string ()
370367
368+ working_list = [factor_df ]
369+ final_list = []
370+
371+ while len (working_list ) > 0 :
372+ current_df = working_list .pop (0 )
373+ if (
374+ APIBackend ().build_messages_and_calculate_token (
375+ user_prompt = current_df .to_string (), system_prompt = document_process_prompts ["factor_duplicate_system" ]
376+ )
377+ > RD_AGENT_SETTINGS .chat_token_limit
378+ ):
379+ working_list .append (current_df .iloc [: current_df .shape [0 ] // 2 , :])
380+ working_list .append (current_df .iloc [current_df .shape [0 ] // 2 :, :])
381+ else :
382+ final_list .append (current_df )
383+
371384 generated_duplicated_groups = []
372- for _ in range ( 10 ) :
373- extract_result_resp = session . build_chat_completion (
374- user_prompt = current_user_prompt ,
375- json_mode = True ,
385+ for current_df in final_list :
386+ current_factor_to_string = current_df . to_string ()
387+ session = APIBackend (). build_chat_session (
388+ session_system_prompt = document_process_prompts [ "factor_duplicate_system" ] ,
376389 )
377- ret_dict = json .loads (extract_result_resp )
378- if len (ret_dict ) == 0 :
379- return generated_duplicated_groups
380- else :
381- generated_duplicated_groups .extend (ret_dict )
382- current_user_prompt = """Continue to extract duplicated groups. If no more duplicated group found please respond empty dict."""
390+ for _ in range (10 ):
391+ extract_result_resp = session .build_chat_completion (
392+ user_prompt = current_factor_to_string ,
393+ json_mode = True ,
394+ )
395+ ret_dict = json .loads (extract_result_resp )
396+ if len (ret_dict ) == 0 :
397+ return generated_duplicated_groups
398+ else :
399+ generated_duplicated_groups .extend (ret_dict )
400+ current_factor_to_string = """Continue to extract duplicated groups. If no more duplicated group found please respond empty dict."""
383401 return generated_duplicated_groups
384402
385403
386404def __kmeans_embeddings (embeddings : np .ndarray , k : int = 20 ) -> list [list [str ]]:
387405 x_normalized = normalize (embeddings )
388406
407+ np .random .seed (42 )
408+
389409 kmeans = KMeans (
390410 n_clusters = k ,
391411 init = "random" ,
@@ -468,7 +488,7 @@ def __deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[li
468488 else :
469489 for k in range (
470490 len (full_str_list ) // RD_AGENT_SETTINGS .max_input_duplicate_factor_group ,
471- 40 ,
491+ RD_AGENT_SETTINGS . max_kmeans_group_number ,
472492 ):
473493 kmeans_index_group = __kmeans_embeddings (embeddings = embeddings , k = k )
474494 if len (kmeans_index_group [0 ]) < RD_AGENT_SETTINGS .max_input_duplicate_factor_group :
@@ -479,26 +499,22 @@ def __deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[li
479499
480500 duplication_names_list = []
481501
482- pool = mp .Pool (target_k )
483- result_list = [
484- pool .apply_async (
485- __check_factor_duplication_simulate_json_mode ,
486- (factor_df .loc [factor_name_group , :],),
487- )
488- for factor_name_group in factor_name_groups
489- ]
502+ result_list = multiprocessing_wrapper (
503+ [
504+ (__check_factor_duplication_simulate_json_mode , (factor_df .loc [factor_name_group , :],))
505+ for factor_name_group in factor_name_groups
506+ ],
507+ n = RD_AGENT_SETTINGS .multi_proc_n ,
508+ )
490509
491- pool .close ()
492- pool .join ()
510+ duplication_names_list = []
493511
494- for result in result_list :
495- deduplication_factor_names_list = result .get ()
496- for deduplication_factor_names in deduplication_factor_names_list :
497- filter_factor_names = [
498- factor_name for factor_name in set (deduplication_factor_names ) if factor_name in factor_dict
499- ]
500- if len (filter_factor_names ) > 1 :
501- duplication_names_list .append (filter_factor_names )
512+ for deduplication_factor_names_list in result_list :
513+ filter_factor_names = [
514+ factor_name for factor_name in set (deduplication_factor_names_list ) if factor_name in factor_dict
515+ ]
516+ if len (filter_factor_names ) > 1 :
517+ duplication_names_list .append (filter_factor_names )
502518
503519 return duplication_names_list
504520
@@ -509,6 +525,8 @@ def deduplicate_factors_by_llm( # noqa: C901, PLR0912
509525) -> list [list [str ]]:
510526 final_duplication_names_list = []
511527 current_round_factor_dict = factor_dict
528+
529+ # handle multi-round deduplication
512530 for _ in range (10 ):
513531 duplication_names_list = __deduplicate_factor_dict (current_round_factor_dict )
514532
@@ -524,11 +542,13 @@ def deduplicate_factors_by_llm( # noqa: C901, PLR0912
524542 else :
525543 break
526544
545+ # sort the final list of duplicates by their length, largest first
527546 final_duplication_names_list = sorted (final_duplication_names_list , key = lambda x : len (x ), reverse = True )
528547
529- to_replace_dict = {}
548+ to_replace_dict = {} # to map duplicates to the target factor names
530549 for duplication_names in duplication_names_list :
531550 if factor_viability_dict is not None :
551+ # check viability of each factor in the duplicates group
532552 viability_list = [factor_viability_dict [name ]["viability" ] for name in duplication_names ]
533553 if True not in viability_list :
534554 continue
@@ -543,6 +563,7 @@ def deduplicate_factors_by_llm( # noqa: C901, PLR0912
543563 llm_deduplicated_factor_dict = {}
544564 added_lower_name_set = set ()
545565 for factor_name in factor_dict :
566+ # only add factors that haven't been replaced and are not duplicates
546567 if factor_name not in to_replace_dict and factor_name .lower () not in added_lower_name_set :
547568 if factor_viability_dict is not None and not factor_viability_dict [factor_name ]["viability" ]:
548569 continue
0 commit comments