|
| 1 | +""" |
| 2 | +Functional tests for src/nat/tool/mcp/mcp_client_impl.py |
| 3 | +Focus on behavior (tool wrapping, filtering, handler integration), not class/type basics. |
| 4 | +""" |
| 5 | + |
| 6 | +import asyncio |
| 7 | +from contextlib import asynccontextmanager |
| 8 | +from typing import Any, cast |
| 9 | + |
| 10 | +import pytest |
| 11 | +from pydantic import BaseModel |
| 12 | + |
| 13 | +from mcp.types import TextContent |
| 14 | +from nat.builder.workflow_builder import WorkflowBuilder |
| 15 | +from nat.tool.mcp.mcp_client_impl import MCPClientConfig |
| 16 | +from nat.tool.mcp.mcp_client_impl import MCPServerConfig |
| 17 | +from nat.tool.mcp.mcp_client_impl import MCPSingleToolConfig |
| 18 | +from nat.tool.mcp.mcp_client_impl import ToolOverrideConfig |
| 19 | +from nat.tool.mcp.mcp_client_impl import _filter_and_configure_tools |
| 20 | +from nat.tool.mcp.mcp_client_base import MCPBaseClient |
| 21 | +from pydantic.networks import HttpUrl |
| 22 | + |
| 23 | + |
| 24 | +class _InputSchema(BaseModel): |
| 25 | + param: str |
| 26 | + |
| 27 | + |
| 28 | +class _FakeTool: |
| 29 | + def __init__(self, name: str, description: str = "desc") -> None: |
| 30 | + self.name = name |
| 31 | + self.description = description |
| 32 | + self.input_schema = _InputSchema |
| 33 | + |
| 34 | + async def acall(self, args: dict[str, Any]) -> str: |
| 35 | + return f"ok {args['param']}" |
| 36 | + |
| 37 | + def set_description(self, description: str | None) -> None: |
| 38 | + if description is not None: |
| 39 | + self.description = description |
| 40 | + |
| 41 | + |
| 42 | +class _ErrorTool(_FakeTool): |
| 43 | + async def acall(self, args: dict[str, Any]) -> str: # type: ignore[override] |
| 44 | + raise RuntimeError("boom") |
| 45 | + |
| 46 | + |
| 47 | +class _FakeSession: |
| 48 | + def __init__(self, tools: dict[str, _FakeTool]) -> None: |
| 49 | + self._tools = tools |
| 50 | + |
| 51 | + class _ToolInfo: |
| 52 | + def __init__(self, name: str, description: str) -> None: |
| 53 | + self.name = name |
| 54 | + self.description = description |
| 55 | + # Provide a trivial input schema compatible with MCPToolClient expectations |
| 56 | + self.inputSchema = {"type": "object", "properties": {"param": {"type": "string"}}, "required": ["param"]} |
| 57 | + |
| 58 | + class _ListToolsResponse: |
| 59 | + def __init__(self, tools: list["_FakeSession._ToolInfo"]) -> None: |
| 60 | + self.tools = tools |
| 61 | + |
| 62 | + async def list_tools(self) -> "_FakeSession._ListToolsResponse": |
| 63 | + infos = [self._ToolInfo(name=t.name, description=t.description) for t in self._tools.values()] |
| 64 | + return self._ListToolsResponse(tools=infos) |
| 65 | + |
| 66 | + async def call_tool(self, tool_name: str, tool_args: dict[str, Any]): |
| 67 | + tool = self._tools[tool_name] |
| 68 | + class _Result: |
| 69 | + def __init__(self, text: str): |
| 70 | + self.content = [TextContent(type="text", text=text)] |
| 71 | + self.isError = False |
| 72 | + return _Result(await tool.acall(tool_args)) |
| 73 | + |
| 74 | + |
| 75 | +class _FakeMCPClient(MCPBaseClient): |
| 76 | + def __init__(self, *, tools: dict[str, _FakeTool], transport: str = "sse", url: str | None = None, |
| 77 | + command: str | None = None) -> None: |
| 78 | + super().__init__(transport) |
| 79 | + self._tools_map = tools |
| 80 | + self._url = url |
| 81 | + self._command = command |
| 82 | + |
| 83 | + @property |
| 84 | + def url(self) -> str | None: |
| 85 | + return self._url |
| 86 | + |
| 87 | + @property |
| 88 | + def command(self) -> str | None: |
| 89 | + return self._command |
| 90 | + |
| 91 | + @asynccontextmanager |
| 92 | + async def connect_to_server(self): |
| 93 | + yield _FakeSession(self._tools_map) |
| 94 | + |
| 95 | + async def get_tools(self) -> dict[str, _FakeTool]: # type: ignore[override] |
| 96 | + return self._tools_map |
| 97 | + |
| 98 | + async def get_tool(self, tool_name: str) -> _FakeTool: # type: ignore[override] |
| 99 | + return self._tools_map[tool_name] |
| 100 | + |
| 101 | + |
| 102 | +@pytest.mark.anyio |
| 103 | +async def test_mcp_single_tool_happy_path_kwargs(): |
| 104 | + client = _FakeMCPClient(tools={"echo": _FakeTool("echo", "Echo tool")}) |
| 105 | + |
| 106 | + async with WorkflowBuilder() as builder: |
| 107 | + fn = await builder.add_function( |
| 108 | + name="echo_fn", |
| 109 | + config=MCPSingleToolConfig(client=client, tool_name="echo", tool_description="Overridden desc"), |
| 110 | + ) |
| 111 | + |
| 112 | + # Validate invocation path using kwargs |
| 113 | + result = await fn.acall_invoke(param="value") |
| 114 | + assert result == "ok value" |
| 115 | + |
| 116 | + |
| 117 | +@pytest.mark.anyio |
| 118 | +async def test_mcp_single_tool_returns_error_string_on_exception(): |
| 119 | + client = _FakeMCPClient(tools={"err": _ErrorTool("err", "Err tool")}) |
| 120 | + |
| 121 | + async with WorkflowBuilder() as builder: |
| 122 | + fn = await builder.add_function( |
| 123 | + name="err_fn", |
| 124 | + config=MCPSingleToolConfig(client=client, tool_name="err"), |
| 125 | + ) |
| 126 | + |
| 127 | + result = await fn.acall_invoke(param="value") |
| 128 | + assert "boom" in result |
| 129 | + |
| 130 | + |
| 131 | +def test_filter_and_configure_tools_none_filter_returns_all(): |
| 132 | + tools = {"a": _FakeTool("a", "da"), "b": _FakeTool("b", "db")} |
| 133 | + out = _filter_and_configure_tools(tools, tool_filter=None) |
| 134 | + assert out == { |
| 135 | + "a": {"function_name": "a", "description": "da"}, |
| 136 | + "b": {"function_name": "b", "description": "db"}, |
| 137 | + } |
| 138 | + |
| 139 | + |
| 140 | +def test_filter_and_configure_tools_list_filter_subsets(): |
| 141 | + tools = {"a": _FakeTool("a", "da"), "b": _FakeTool("b", "db"), "c": _FakeTool("c", "dc")} |
| 142 | + out = _filter_and_configure_tools(tools, tool_filter=["b", "c"]) # type: ignore[arg-type] |
| 143 | + assert out == { |
| 144 | + "b": {"function_name": "b", "description": "db"}, |
| 145 | + "c": {"function_name": "c", "description": "dc"}, |
| 146 | + } |
| 147 | + |
| 148 | + |
| 149 | +def test_filter_and_configure_tools_dict_overrides_alias_and_description(caplog): |
| 150 | + tools = {"raw": _FakeTool("raw", "original")} |
| 151 | + overrides = {"raw": ToolOverrideConfig(alias="alias", description="new desc")} |
| 152 | + out = _filter_and_configure_tools(tools, tool_filter=overrides) # type: ignore[arg-type] |
| 153 | + assert out == {"raw": {"function_name": "alias", "description": "new desc"}} |
| 154 | + |
| 155 | + |
| 156 | +@pytest.mark.anyio |
| 157 | +async def test_mcp_client_function_handler_registers_tools(monkeypatch): |
| 158 | + # Prepare fake client classes to be used by the handler |
| 159 | + fake_tools = {"t1": _FakeTool("t1", "d1"), "t2": _FakeTool("t2", "d2")} |
| 160 | + |
| 161 | + def _mk_client(_: str): |
| 162 | + return _FakeMCPClient(tools=fake_tools, transport="sse", url="http://x") |
| 163 | + |
| 164 | + # Monkeypatch the symbols the handler resolves inside the function |
| 165 | + # Patch the source module where the handler imports client classes |
| 166 | + monkeypatch.setattr("nat.tool.mcp.mcp_client_base.MCPSSEClient", lambda url: _mk_client(url), raising=True) |
| 167 | + monkeypatch.setattr("nat.tool.mcp.mcp_client_base.MCPStdioClient", |
| 168 | + lambda command, args, env: _FakeMCPClient(tools=fake_tools, transport="stdio", |
| 169 | + command=command), raising=True) |
| 170 | + monkeypatch.setattr("nat.tool.mcp.mcp_client_base.MCPStreamableHTTPClient", lambda url: _mk_client(url), |
| 171 | + raising=True) |
| 172 | + |
| 173 | + server_cfg = MCPServerConfig(transport="sse", url=cast(HttpUrl, "http://fake")) |
| 174 | + client_cfg = MCPClientConfig(server=server_cfg, tool_filter=["t1"]) # only expose t1 |
| 175 | + |
| 176 | + async with WorkflowBuilder() as builder: |
| 177 | + # Adding the handler function triggers discovery and registering of tools |
| 178 | + await builder.add_function(name="mcp_client", config=client_cfg) |
| 179 | + |
| 180 | + # Confirm that the filtered tool has been registered as a function and is invokable |
| 181 | + fn = builder.get_function("t1") |
| 182 | + out = await fn.acall_invoke(param="v") |
| 183 | + assert out == "ok v" |
0 commit comments