Skip to content

Commit a0cd102

Browse files
XianBWyou-n-g
andauthored
feat: add a rag mcp in proposal (#1267)
* add simple rag mcp * add rag_agent in expGen v2 * add conf config for research rag * fix CI * refactor: move context7 and rag config files to new conf modules * make rag agent general * fix CI --------- Co-authored-by: Young <[email protected]>
1 parent c73f67a commit a0cd102

File tree

7 files changed

+61
-1
lines changed

7 files changed

+61
-1
lines changed

rdagent/app/data_science/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
175175
enable_generate_unique_hypothesis: bool = False
176176
"""Enable generate unique hypothesis. If True, generate unique hypothesis for each component. If False, generate unique hypothesis for each component."""
177177

178+
enable_research_rag: bool = False
179+
"""Enable research RAG for hypothesis generation."""
180+
178181
#### hypothesis critique and rewrite
179182
enable_hypo_critique_rewrite: bool = False
180183
"""Enable hypothesis critique and rewrite stages for improving hypothesis quality"""

rdagent/components/agent/context7/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pydantic_ai.mcp import MCPServerStreamableHTTP
44

55
from rdagent.components.agent.base import PAIAgent
6-
from rdagent.components.agent.mcp.context7 import SETTINGS
6+
from rdagent.components.agent.context7.conf import SETTINGS
77
from rdagent.log import rdagent_logger as logger
88
from rdagent.utils.agent.tpl import T
99

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pydantic_ai.mcp import MCPServerStreamableHTTP
2+
3+
from rdagent.components.agent.base import PAIAgent
4+
from rdagent.components.agent.rag.conf import SETTINGS
5+
from rdagent.utils.agent.tpl import T
6+
7+
8+
class Agent(PAIAgent):
9+
"""
10+
A specific agent for RAG
11+
"""
12+
13+
def __init__(self, system_prompt: str | None = None):
14+
toolsets = [MCPServerStreamableHTTP(SETTINGS.url, timeout=SETTINGS.timeout)]
15+
if system_prompt is None:
16+
system_prompt = "You are a Retrieval-Augmented Generation (RAG) agent. Use the retrieved documents to answer the user's queries accurately and concisely."
17+
super().__init__(system_prompt=system_prompt, toolsets=toolsets)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
Settings for RAG agent.
3+
4+
TODO: how run the RAG mcp server
5+
"""
6+
7+
from pydantic_settings import BaseSettings, SettingsConfigDict
8+
9+
10+
class Settings(BaseSettings):
11+
"""Project specific settings."""
12+
13+
url: str = "http://localhost:8124/mcp"
14+
timeout: int = 120
15+
16+
model_config = SettingsConfigDict(
17+
env_prefix="RAG_",
18+
# extra="allow", # Does it allow extrasettings
19+
)
20+
21+
22+
SETTINGS = Settings()

rdagent/scenarios/data_science/proposal/exp_gen/prompts_v2.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ hypothesis_gen:
294294
# Identified Challenges{% if enable_idea_pool %} with Sampled Ideas{% endif %}
295295
{{ problems }}
296296
297+
{% if knowledge %}
298+
# Some reference knowledge from the community
299+
{{ knowledge }}
300+
{% endif %}
301+
297302
hypothesis_critique:
298303
system: |-
299304
{% include "scenarios.data_science.share:scen.role" %}

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import BaseModel, Field
1010

1111
from rdagent.app.data_science.conf import DS_RD_SETTING
12+
from rdagent.components.agent.rag import Agent as RAGAgent
1213
from rdagent.components.coder.data_science.ensemble.exp import EnsembleTask
1314
from rdagent.components.coder.data_science.feature.exp import FeatureTask
1415
from rdagent.components.coder.data_science.model.exp import ModelTask
@@ -645,12 +646,24 @@ def hypothesis_gen(
645646
sibling_hypotheses=sibling_hypotheses,
646647
former_user_instructions_str=str(former_user_instructions) if former_user_instructions else None,
647648
)
649+
650+
# knowledge retrieval
651+
if DS_RD_SETTING.enable_research_rag:
652+
rag_agent = RAGAgent(
653+
system_prompt="""You are a helpful assistant.
654+
You help users retrieve relevant knowledge from community discussions and public code."""
655+
)
656+
knowledge = rag_agent.query(problem_formatted_str)
657+
else:
658+
knowledge = None
659+
648660
user_prompt = T(".prompts_v2:hypothesis_gen.user").r(
649661
scenario_desc=scenario_desc,
650662
exp_and_feedback_list_desc=exp_feedback_list_desc,
651663
sota_exp_desc=sota_exp_desc,
652664
problems=problem_formatted_str,
653665
enable_idea_pool=enable_idea_pool,
666+
knowledge=knowledge,
654667
)
655668
response = APIBackend().build_messages_and_create_chat_completion(
656669
user_prompt=user_prompt,

0 commit comments

Comments
 (0)