Skip to content

Commit 8a69c9c

Browse files
authored
feat: multi log folder, replace "epxx" in workspace path (#555)
* multi log folder, replace "epxx" in workspace path * valid trace path * hypothesis show change * fix CI * add total stat in all summary page
1 parent fa86e4d commit 8a69c9c

File tree

1 file changed

+67
-14
lines changed

1 file changed

+67
-14
lines changed

rdagent/log/ui/dsapp.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
import plotly.express as px
88
import plotly.graph_objects as go
99
import streamlit as st
10-
from plotly.subplots import make_subplots
1110
from streamlit import session_state as state
1211

13-
from rdagent.log.mle_summary import extract_mle_json
12+
from rdagent.log.mle_summary import extract_mle_json, is_valid_session
1413
from rdagent.log.storage import FileStorage
1514

1615
st.set_page_config(layout="wide", page_title="RD-Agent", page_icon="🎓", initial_sidebar_state="expanded")
1716

1817
# 设置主日志路径
1918
if "log_folder" not in state:
2019
state.log_folder = Path("./log")
20+
if "log_folders" not in state:
21+
state.log_folders = ["./log"]
2122
if "log_path" not in state:
2223
state.log_path = None
2324
if "show_all_summary" not in state:
@@ -39,7 +40,7 @@ def extract_evoid(tag):
3940
# @st.cache_data
4041
def load_data(log_path):
4142
state.data = defaultdict(lambda: defaultdict(dict))
42-
for msg in FileStorage(state.log_folder / log_path).iter_msg():
43+
for msg in FileStorage(log_path).iter_msg():
4344
if msg.tag and "llm" not in msg.tag and "session" not in msg.tag:
4445
if msg.tag == "competition":
4546
state.data["competition"] = msg.content
@@ -67,7 +68,7 @@ def get_folders_sorted(log_path):
6768
"""缓存并返回排序后的文件夹列表,并加入进度打印"""
6869
with st.spinner("正在加载文件夹列表..."):
6970
folders = sorted(
70-
(folder for folder in log_path.iterdir() if folder.is_dir() and list(folder.iterdir())),
71+
(folder for folder in log_path.iterdir() if is_valid_session(folder)),
7172
key=lambda folder: folder.stat().st_mtime,
7273
reverse=True,
7374
)
@@ -77,9 +78,15 @@ def get_folders_sorted(log_path):
7778

7879
# UI - Sidebar
7980
with st.sidebar:
80-
state.log_folder = Path(st.text_input("**Log Folder**", placeholder=state.log_folder, value=state.log_folder))
81+
log_folder_str = st.text_area(
82+
"**Log Folders**(split by ';')", placeholder=state.log_folder, value=";".join(state.log_folders)
83+
)
84+
state.log_folders = [folder.strip() for folder in log_folder_str.split(";") if folder.strip()]
85+
86+
state.log_folder = Path(st.radio(f"Select :blue[**one log folder**]", state.log_folders))
8187
if not state.log_folder.exists():
8288
st.warning(f"Path {state.log_folder} does not exist!")
89+
8390
folders = get_folders_sorted(state.log_folder)
8491
st.selectbox(f"Select from :blue[**{state.log_folder.absolute()}**]", folders, key="log_path")
8592

@@ -88,7 +95,7 @@ def get_folders_sorted(log_path):
8895
st.toast("Please select a log path first!", type="error")
8996
st.stop()
9097

91-
load_data(state.log_path)
98+
load_data(state.log_folder / state.log_path)
9299

93100
st.toggle("One Trace / Log Folder Summary", key="show_all_summary")
94101

@@ -111,7 +118,7 @@ def task_win(data):
111118
def workspace_win(data):
112119
show_files = {k: v for k, v in data.file_dict.items() if not "test" in k}
113120
if len(show_files) > 0:
114-
with st.expander(f"Files in :blue[{data.workspace_path}]"):
121+
with st.expander(f"Files in :blue[{replace_ep_path(data.workspace_path)}]"):
115122
code_tabs = st.tabs(show_files.keys())
116123
for ct, codename in zip(code_tabs, show_files.keys()):
117124
with ct:
@@ -127,7 +134,7 @@ def workspace_win(data):
127134
def exp_gen_win(data):
128135
st.header("Exp Gen", divider="blue")
129136
st.subheader("Hypothesis")
130-
st.markdown(data.hypothesis)
137+
st.code(str(data.hypothesis).replace("\n", "\n\n"), wrap_lines=True)
131138

132139
st.subheader("pending_tasks")
133140
for tasks in data.pending_tasks_list:
@@ -225,6 +232,18 @@ def main_win(data):
225232
)
226233

227234

235+
def replace_ep_path(p: Path):
236+
# 替换workspace path为对应ep机器mount在ep03的path
237+
# TODO: FIXME: 使用配置项来处理
238+
match = re.search(r"ep\d+", str(state.log_folder))
239+
if match:
240+
ep = match.group(0)
241+
return Path(
242+
str(p).replace("repos/RD-Agent-Exp", f"repos/batch_ctrl/all_projects/{ep}").replace("/Data", "/data")
243+
)
244+
return p
245+
246+
228247
def summarize_data():
229248
st.header("Summary", divider="rainbow")
230249
df = pd.DataFrame(columns=["Component", "Running Score", "Feedback"], index=range(len(state.data) - 1))
@@ -235,7 +254,9 @@ def summarize_data():
235254

236255
if "running" in loop_data:
237256
if "mle_score" not in state.data[loop]:
238-
mle_score_path = loop_data["running"].experiment_workspace.workspace_path / "mle_score.txt"
257+
mle_score_path = (
258+
replace_ep_path(loop_data["running"].experiment_workspace.workspace_path) / "mle_score.txt"
259+
)
239260
try:
240261
mle_score_txt = mle_score_path.read_text()
241262
state.data[loop]["mle_score"] = extract_mle_json(mle_score_txt)
@@ -264,12 +285,23 @@ def summarize_data():
264285

265286

266287
def all_summarize_win():
267-
if not (state.log_folder / "summary.pkl").exists():
268-
st.warning(
269-
f"No summary file found in {state.log_folder}\nRun:`dotenv run -- python rdagent/log/mle_summary.py grade_summary --log_folder=<your trace folder>`"
270-
)
288+
summarys = {}
289+
for lf in state.log_folders:
290+
if not (Path(lf) / "summary.pkl").exists():
291+
st.warning(
292+
f"No summary file found in {lf}\nRun:`dotenv run -- python rdagent/log/mle_summary.py grade_summary --log_folder=<your trace folder>`"
293+
)
294+
else:
295+
summarys[lf] = pd.read_pickle(Path(lf) / "summary.pkl")
296+
297+
if len(summarys) == 0:
271298
return
272-
summary = pd.read_pickle(state.log_folder / "summary.pkl")
299+
300+
summary = {}
301+
for lf, s in summarys.items():
302+
for k, v in s.items():
303+
summary[f"{lf[lf.rfind('ep'):]}{k}"] = v
304+
273305
summary = {k: v for k, v in summary.items() if "competition" in v}
274306
base_df = pd.DataFrame(
275307
columns=[
@@ -311,6 +343,27 @@ def all_summarize_win():
311343
base_df.loc[k, "Any Medal"] = f"{v['get_medal_num']} ({round(v['get_medal_num'] / loop_num * 100, 2)}%)"
312344

313345
st.dataframe(base_df)
346+
total_stat = (
347+
(
348+
base_df[
349+
[
350+
"Made Submission",
351+
"Valid Submission",
352+
"Above Median",
353+
"Bronze",
354+
"Silver",
355+
"Gold",
356+
"Any Medal",
357+
]
358+
]
359+
!= "0 (0.0%)"
360+
).sum()
361+
/ base_df.shape[0]
362+
* 100
363+
)
364+
total_stat.name = "总体统计(%)"
365+
st.dataframe(total_stat.round(2))
366+
314367
# write curve
315368
for k, v in summary.items():
316369
with st.container(border=True):

0 commit comments

Comments
 (0)