77import plotly .express as px
88import plotly .graph_objects as go
99import streamlit as st
10- from plotly .subplots import make_subplots
1110from streamlit import session_state as state
1211
13- from rdagent .log .mle_summary import extract_mle_json
12+ from rdagent .log .mle_summary import extract_mle_json , is_valid_session
1413from rdagent .log .storage import FileStorage
1514
1615st .set_page_config (layout = "wide" , page_title = "RD-Agent" , page_icon = "🎓" , initial_sidebar_state = "expanded" )
1716
1817# 设置主日志路径
1918if "log_folder" not in state :
2019 state .log_folder = Path ("./log" )
20+ if "log_folders" not in state :
21+ state .log_folders = ["./log" ]
2122if "log_path" not in state :
2223 state .log_path = None
2324if "show_all_summary" not in state :
@@ -39,7 +40,7 @@ def extract_evoid(tag):
3940# @st.cache_data
4041def load_data (log_path ):
4142 state .data = defaultdict (lambda : defaultdict (dict ))
42- for msg in FileStorage (state . log_folder / log_path ).iter_msg ():
43+ for msg in FileStorage (log_path ).iter_msg ():
4344 if msg .tag and "llm" not in msg .tag and "session" not in msg .tag :
4445 if msg .tag == "competition" :
4546 state .data ["competition" ] = msg .content
@@ -67,7 +68,7 @@ def get_folders_sorted(log_path):
6768 """缓存并返回排序后的文件夹列表,并加入进度打印"""
6869 with st .spinner ("正在加载文件夹列表..." ):
6970 folders = sorted (
70- (folder for folder in log_path .iterdir () if folder . is_dir () and list ( folder . iterdir () )),
71+ (folder for folder in log_path .iterdir () if is_valid_session ( folder )),
7172 key = lambda folder : folder .stat ().st_mtime ,
7273 reverse = True ,
7374 )
@@ -77,9 +78,15 @@ def get_folders_sorted(log_path):
7778
7879# UI - Sidebar
7980with st .sidebar :
80- state .log_folder = Path (st .text_input ("**Log Folder**" , placeholder = state .log_folder , value = state .log_folder ))
81+ log_folder_str = st .text_area (
82+ "**Log Folders**(split by ';')" , placeholder = state .log_folder , value = ";" .join (state .log_folders )
83+ )
84+ state .log_folders = [folder .strip () for folder in log_folder_str .split (";" ) if folder .strip ()]
85+
86+ state .log_folder = Path (st .radio (f"Select :blue[**one log folder**]" , state .log_folders ))
8187 if not state .log_folder .exists ():
8288 st .warning (f"Path { state .log_folder } does not exist!" )
89+
8390 folders = get_folders_sorted (state .log_folder )
8491 st .selectbox (f"Select from :blue[**{ state .log_folder .absolute ()} **]" , folders , key = "log_path" )
8592
@@ -88,7 +95,7 @@ def get_folders_sorted(log_path):
8895 st .toast ("Please select a log path first!" , type = "error" )
8996 st .stop ()
9097
91- load_data (state .log_path )
98+ load_data (state .log_folder / state . log_path )
9299
93100 st .toggle ("One Trace / Log Folder Summary" , key = "show_all_summary" )
94101
@@ -111,7 +118,7 @@ def task_win(data):
111118def workspace_win (data ):
112119 show_files = {k : v for k , v in data .file_dict .items () if not "test" in k }
113120 if len (show_files ) > 0 :
114- with st .expander (f"Files in :blue[{ data .workspace_path } ]" ):
121+ with st .expander (f"Files in :blue[{ replace_ep_path ( data .workspace_path ) } ]" ):
115122 code_tabs = st .tabs (show_files .keys ())
116123 for ct , codename in zip (code_tabs , show_files .keys ()):
117124 with ct :
@@ -127,7 +134,7 @@ def workspace_win(data):
127134def exp_gen_win (data ):
128135 st .header ("Exp Gen" , divider = "blue" )
129136 st .subheader ("Hypothesis" )
130- st .markdown ( data .hypothesis )
137+ st .code ( str ( data .hypothesis ). replace ( " \n " , " \n \n " ), wrap_lines = True )
131138
132139 st .subheader ("pending_tasks" )
133140 for tasks in data .pending_tasks_list :
@@ -225,6 +232,18 @@ def main_win(data):
225232 )
226233
227234
235+ def replace_ep_path (p : Path ):
236+ # 替换workspace path为对应ep机器mount在ep03的path
237+ # TODO: FIXME: 使用配置项来处理
238+ match = re .search (r"ep\d+" , str (state .log_folder ))
239+ if match :
240+ ep = match .group (0 )
241+ return Path (
242+ str (p ).replace ("repos/RD-Agent-Exp" , f"repos/batch_ctrl/all_projects/{ ep } " ).replace ("/Data" , "/data" )
243+ )
244+ return p
245+
246+
228247def summarize_data ():
229248 st .header ("Summary" , divider = "rainbow" )
230249 df = pd .DataFrame (columns = ["Component" , "Running Score" , "Feedback" ], index = range (len (state .data ) - 1 ))
@@ -235,7 +254,9 @@ def summarize_data():
235254
236255 if "running" in loop_data :
237256 if "mle_score" not in state .data [loop ]:
238- mle_score_path = loop_data ["running" ].experiment_workspace .workspace_path / "mle_score.txt"
257+ mle_score_path = (
258+ replace_ep_path (loop_data ["running" ].experiment_workspace .workspace_path ) / "mle_score.txt"
259+ )
239260 try :
240261 mle_score_txt = mle_score_path .read_text ()
241262 state .data [loop ]["mle_score" ] = extract_mle_json (mle_score_txt )
@@ -264,12 +285,23 @@ def summarize_data():
264285
265286
266287def all_summarize_win ():
267- if not (state .log_folder / "summary.pkl" ).exists ():
268- st .warning (
269- f"No summary file found in { state .log_folder } \n Run:`dotenv run -- python rdagent/log/mle_summary.py grade_summary --log_folder=<your trace folder>`"
270- )
288+ summarys = {}
289+ for lf in state .log_folders :
290+ if not (Path (lf ) / "summary.pkl" ).exists ():
291+ st .warning (
292+ f"No summary file found in { lf } \n Run:`dotenv run -- python rdagent/log/mle_summary.py grade_summary --log_folder=<your trace folder>`"
293+ )
294+ else :
295+ summarys [lf ] = pd .read_pickle (Path (lf ) / "summary.pkl" )
296+
297+ if len (summarys ) == 0 :
271298 return
272- summary = pd .read_pickle (state .log_folder / "summary.pkl" )
299+
300+ summary = {}
301+ for lf , s in summarys .items ():
302+ for k , v in s .items ():
303+ summary [f"{ lf [lf .rfind ('ep' ):]} { k } " ] = v
304+
273305 summary = {k : v for k , v in summary .items () if "competition" in v }
274306 base_df = pd .DataFrame (
275307 columns = [
@@ -311,6 +343,27 @@ def all_summarize_win():
311343 base_df .loc [k , "Any Medal" ] = f"{ v ['get_medal_num' ]} ({ round (v ['get_medal_num' ] / loop_num * 100 , 2 )} %)"
312344
313345 st .dataframe (base_df )
346+ total_stat = (
347+ (
348+ base_df [
349+ [
350+ "Made Submission" ,
351+ "Valid Submission" ,
352+ "Above Median" ,
353+ "Bronze" ,
354+ "Silver" ,
355+ "Gold" ,
356+ "Any Medal" ,
357+ ]
358+ ]
359+ != "0 (0.0%)"
360+ ).sum ()
361+ / base_df .shape [0 ]
362+ * 100
363+ )
364+ total_stat .name = "总体统计(%)"
365+ st .dataframe (total_stat .round (2 ))
366+
314367 # write curve
315368 for k , v in summary .items ():
316369 with st .container (border = True ):
0 commit comments