Skip to content

Commit 1c48be1

Browse files
More unit tests
Signed-off-by: Anuradha Karuppiah <[email protected]>
1 parent 39e8c21 commit 1c48be1

File tree

3 files changed

+186
-3
lines changed

3 files changed

+186
-3
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,19 @@ async def test_function(mcp_client: MCPBaseClient):
151151
fn_obj = await builder.add_function(name="test_function",
152152
config=MCPToolConfig(url=HttpUrl(mcp_client.url),
153153
mcp_tool_name="return_42",
154-
transport=mcp_client.transport))
154+
transport="sse"))
155155
elif isinstance(mcp_client, MCPStdioClient):
156156
fn_obj = await builder.add_function(name="test_function",
157157
config=MCPToolConfig(mcp_tool_name="return_42",
158-
transport=mcp_client.transport,
158+
transport="stdio",
159159
command=mcp_client.command,
160160
args=mcp_client.args,
161161
env=mcp_client.env))
162162
elif isinstance(mcp_client, MCPStreamableHTTPClient):
163163
fn_obj = await builder.add_function(name="test_function",
164164
config=MCPToolConfig(url=HttpUrl(mcp_client.url),
165165
mcp_tool_name="return_42",
166-
transport=mcp_client.transport))
166+
transport="streamable-http"))
167167
else:
168168
raise ValueError(f"Invalid client type: {type(mcp_client)}")
169169

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""
2+
Functional tests for src/nat/tool/mcp/mcp_client_impl.py
3+
Focus on behavior (tool wrapping, filtering, handler integration), not class/type basics.
4+
"""
5+
6+
import asyncio
7+
from contextlib import asynccontextmanager
8+
from typing import Any, cast
9+
10+
import pytest
11+
from pydantic import BaseModel
12+
13+
from mcp.types import TextContent
14+
from nat.builder.workflow_builder import WorkflowBuilder
15+
from nat.tool.mcp.mcp_client_impl import MCPClientConfig
16+
from nat.tool.mcp.mcp_client_impl import MCPServerConfig
17+
from nat.tool.mcp.mcp_client_impl import MCPSingleToolConfig
18+
from nat.tool.mcp.mcp_client_impl import ToolOverrideConfig
19+
from nat.tool.mcp.mcp_client_impl import _filter_and_configure_tools
20+
from nat.tool.mcp.mcp_client_base import MCPBaseClient
21+
from pydantic.networks import HttpUrl
22+
23+
24+
class _InputSchema(BaseModel):
25+
param: str
26+
27+
28+
class _FakeTool:
29+
def __init__(self, name: str, description: str = "desc") -> None:
30+
self.name = name
31+
self.description = description
32+
self.input_schema = _InputSchema
33+
34+
async def acall(self, args: dict[str, Any]) -> str:
35+
return f"ok {args['param']}"
36+
37+
def set_description(self, description: str | None) -> None:
38+
if description is not None:
39+
self.description = description
40+
41+
42+
class _ErrorTool(_FakeTool):
43+
async def acall(self, args: dict[str, Any]) -> str: # type: ignore[override]
44+
raise RuntimeError("boom")
45+
46+
47+
class _FakeSession:
48+
def __init__(self, tools: dict[str, _FakeTool]) -> None:
49+
self._tools = tools
50+
51+
class _ToolInfo:
52+
def __init__(self, name: str, description: str) -> None:
53+
self.name = name
54+
self.description = description
55+
# Provide a trivial input schema compatible with MCPToolClient expectations
56+
self.inputSchema = {"type": "object", "properties": {"param": {"type": "string"}}, "required": ["param"]}
57+
58+
class _ListToolsResponse:
59+
def __init__(self, tools: list["_FakeSession._ToolInfo"]) -> None:
60+
self.tools = tools
61+
62+
async def list_tools(self) -> "_FakeSession._ListToolsResponse":
63+
infos = [self._ToolInfo(name=t.name, description=t.description) for t in self._tools.values()]
64+
return self._ListToolsResponse(tools=infos)
65+
66+
async def call_tool(self, tool_name: str, tool_args: dict[str, Any]):
67+
tool = self._tools[tool_name]
68+
class _Result:
69+
def __init__(self, text: str):
70+
self.content = [TextContent(type="text", text=text)]
71+
self.isError = False
72+
return _Result(await tool.acall(tool_args))
73+
74+
75+
class _FakeMCPClient(MCPBaseClient):
76+
def __init__(self, *, tools: dict[str, _FakeTool], transport: str = "sse", url: str | None = None,
77+
command: str | None = None) -> None:
78+
super().__init__(transport)
79+
self._tools_map = tools
80+
self._url = url
81+
self._command = command
82+
83+
@property
84+
def url(self) -> str | None:
85+
return self._url
86+
87+
@property
88+
def command(self) -> str | None:
89+
return self._command
90+
91+
@asynccontextmanager
92+
async def connect_to_server(self):
93+
yield _FakeSession(self._tools_map)
94+
95+
async def get_tools(self) -> dict[str, _FakeTool]: # type: ignore[override]
96+
return self._tools_map
97+
98+
async def get_tool(self, tool_name: str) -> _FakeTool: # type: ignore[override]
99+
return self._tools_map[tool_name]
100+
101+
102+
@pytest.mark.anyio
103+
async def test_mcp_single_tool_happy_path_kwargs():
104+
client = _FakeMCPClient(tools={"echo": _FakeTool("echo", "Echo tool")})
105+
106+
async with WorkflowBuilder() as builder:
107+
fn = await builder.add_function(
108+
name="echo_fn",
109+
config=MCPSingleToolConfig(client=client, tool_name="echo", tool_description="Overridden desc"),
110+
)
111+
112+
# Validate invocation path using kwargs
113+
result = await fn.acall_invoke(param="value")
114+
assert result == "ok value"
115+
116+
117+
@pytest.mark.anyio
118+
async def test_mcp_single_tool_returns_error_string_on_exception():
119+
client = _FakeMCPClient(tools={"err": _ErrorTool("err", "Err tool")})
120+
121+
async with WorkflowBuilder() as builder:
122+
fn = await builder.add_function(
123+
name="err_fn",
124+
config=MCPSingleToolConfig(client=client, tool_name="err"),
125+
)
126+
127+
result = await fn.acall_invoke(param="value")
128+
assert "boom" in result
129+
130+
131+
def test_filter_and_configure_tools_none_filter_returns_all():
132+
tools = {"a": _FakeTool("a", "da"), "b": _FakeTool("b", "db")}
133+
out = _filter_and_configure_tools(tools, tool_filter=None)
134+
assert out == {
135+
"a": {"function_name": "a", "description": "da"},
136+
"b": {"function_name": "b", "description": "db"},
137+
}
138+
139+
140+
def test_filter_and_configure_tools_list_filter_subsets():
141+
tools = {"a": _FakeTool("a", "da"), "b": _FakeTool("b", "db"), "c": _FakeTool("c", "dc")}
142+
out = _filter_and_configure_tools(tools, tool_filter=["b", "c"]) # type: ignore[arg-type]
143+
assert out == {
144+
"b": {"function_name": "b", "description": "db"},
145+
"c": {"function_name": "c", "description": "dc"},
146+
}
147+
148+
149+
def test_filter_and_configure_tools_dict_overrides_alias_and_description(caplog):
150+
tools = {"raw": _FakeTool("raw", "original")}
151+
overrides = {"raw": ToolOverrideConfig(alias="alias", description="new desc")}
152+
out = _filter_and_configure_tools(tools, tool_filter=overrides) # type: ignore[arg-type]
153+
assert out == {"raw": {"function_name": "alias", "description": "new desc"}}
154+
155+
156+
@pytest.mark.anyio
157+
async def test_mcp_client_function_handler_registers_tools(monkeypatch):
158+
# Prepare fake client classes to be used by the handler
159+
fake_tools = {"t1": _FakeTool("t1", "d1"), "t2": _FakeTool("t2", "d2")}
160+
161+
def _mk_client(_: str):
162+
return _FakeMCPClient(tools=fake_tools, transport="sse", url="http://x")
163+
164+
# Monkeypatch the symbols the handler resolves inside the function
165+
# Patch the source module where the handler imports client classes
166+
monkeypatch.setattr("nat.tool.mcp.mcp_client_base.MCPSSEClient", lambda url: _mk_client(url), raising=True)
167+
monkeypatch.setattr("nat.tool.mcp.mcp_client_base.MCPStdioClient",
168+
lambda command, args, env: _FakeMCPClient(tools=fake_tools, transport="stdio",
169+
command=command), raising=True)
170+
monkeypatch.setattr("nat.tool.mcp.mcp_client_base.MCPStreamableHTTPClient", lambda url: _mk_client(url),
171+
raising=True)
172+
173+
server_cfg = MCPServerConfig(transport="sse", url=cast(HttpUrl, "http://fake"))
174+
client_cfg = MCPClientConfig(server=server_cfg, tool_filter=["t1"]) # only expose t1
175+
176+
async with WorkflowBuilder() as builder:
177+
# Adding the handler function triggers discovery and registering of tools
178+
await builder.add_function(name="mcp_client", config=client_cfg)
179+
180+
# Confirm that the filtered tool has been registered as a function and is invokable
181+
fn = builder.get_function("t1")
182+
out = await fn.acall_invoke(param="v")
183+
assert out == "ok v"

0 commit comments

Comments
 (0)