Skip to content

Commit f75a009

Browse files
committed
fix comments
Signed-off-by: Yuchen Zhang <[email protected]>
1 parent e79a724 commit f75a009

File tree

4 files changed

+140
-92
lines changed

4 files changed

+140
-92
lines changed

packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ async def _with_reconnect(self, coro):
264264
raise
265265

266266
@mcp_exception_handler
267-
async def get_tools(self):
267+
async def get_tools(self) -> dict[str, "MCPToolClient"]:
268268
"""
269269
Retrieve a dictionary of all tools served by the MCP server.
270270
Uses unauthenticated session for discovery.

packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_impl.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,50 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35+
class MCPFunctionGroup(FunctionGroup):
36+
"""
37+
A specialized FunctionGroup for MCP clients that includes MCP-specific attributes
38+
with proper type safety.
39+
"""
40+
41+
def __init__(self, *args, **kwargs):
42+
super().__init__(*args, **kwargs)
43+
# MCP client attributes with proper typing
44+
self._mcp_client = None # Will be set to the actual MCP client instance
45+
self._mcp_client_server_name: str | None = None
46+
self._mcp_client_transport: str | None = None
47+
48+
@property
49+
def mcp_client(self):
50+
"""Get the MCP client instance."""
51+
return self._mcp_client
52+
53+
@mcp_client.setter
54+
def mcp_client(self, client):
55+
"""Set the MCP client instance."""
56+
self._mcp_client = client
57+
58+
@property
59+
def mcp_client_server_name(self) -> str | None:
60+
"""Get the MCP client server name."""
61+
return self._mcp_client_server_name
62+
63+
@mcp_client_server_name.setter
64+
def mcp_client_server_name(self, server_name: str | None):
65+
"""Set the MCP client server name."""
66+
self._mcp_client_server_name = server_name
67+
68+
@property
69+
def mcp_client_transport(self) -> str | None:
70+
"""Get the MCP client transport type."""
71+
return self._mcp_client_transport
72+
73+
@mcp_client_transport.setter
74+
def mcp_client_transport(self, transport: str | None):
75+
"""Set the MCP client transport type."""
76+
self._mcp_client_transport = transport
77+
78+
3579
class MCPToolOverrideConfig(BaseModel):
3680
"""
3781
Configuration for overriding tool properties when exposing from MCP server.
@@ -175,16 +219,15 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
175219

176220
logger.info("Configured to use MCP server at %s", client.server_name)
177221

178-
# Create the function group
179-
group = FunctionGroup(config=config)
222+
# Create the MCP function group
223+
group = MCPFunctionGroup(config=config)
180224

181225
async with client:
182226
# Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
183227
# can reuse the already-established session instead of creating a new client per request.
184-
# These attributes are intentionally internal/private to avoid API surface commitments.
185-
setattr(group, "_mcp_client", client)
186-
setattr(group, "_mcp_client_server_name", client.server_name)
187-
setattr(group, "_mcp_client_transport", client.transport)
228+
group.mcp_client = client
229+
group.mcp_client_server_name = client.server_name
230+
group.mcp_client_transport = client.transport
188231

189232
all_tools = await client.get_tools()
190233
tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)

src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py

Lines changed: 89 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pydantic import Field
3838
from starlette.websockets import WebSocket
3939

40+
from nat.builder.function import Function
4041
from nat.builder.workflow_builder import WorkflowBuilder
4142
from nat.data_models.api_server import ChatRequest
4243
from nat.data_models.api_server import ChatResponse
@@ -1117,97 +1118,101 @@ async def get_mcp_client_tool_list() -> MCPClientToolListResponse:
11171118

11181119
# Find MCP client function groups
11191120
for group_name, configured_group in function_groups.items():
1120-
if configured_group.config.type == "mcp_client":
1121-
from nat.plugins.mcp.client_impl import MCPClientConfig
1121+
if configured_group.config.type != "mcp_client":
1122+
continue
11221123

1123-
config = configured_group.config
1124-
assert isinstance(config, MCPClientConfig)
1124+
from nat.plugins.mcp.client_impl import MCPClientConfig
11251125

1126-
# Reuse the existing MCP client session stored on the function group instance
1127-
group_instance = configured_group.instance
1128-
client = getattr(group_instance, "_mcp_client", None)
1129-
if client is None:
1130-
raise RuntimeError(f"MCP client not found for group {group_name}")
1126+
config = configured_group.config
1127+
assert isinstance(config, MCPClientConfig)
1128+
1129+
# Reuse the existing MCP client session stored on the function group instance
1130+
group_instance = configured_group.instance
1131+
1132+
client = group_instance.mcp_client
1133+
if client is None:
1134+
raise RuntimeError(f"MCP client not found for group {group_name}")
1135+
1136+
try:
1137+
session_healthy = False
1138+
server_tools: dict[str, Any] = {}
11311139

11321140
try:
1141+
server_tools = await client.get_tools()
1142+
session_healthy = True
1143+
except Exception as e:
1144+
logger.exception(f"Failed to connect to MCP server {client.server_name}: {e}")
11331145
session_healthy = False
1134-
server_tools: dict[str, Any] = {}
1135-
1136-
try:
1137-
server_tools = await client.get_tools()
1138-
session_healthy = True
1139-
except Exception as e:
1140-
logger.warning(f"Failed to connect to MCP server {client.server_name}: {e}")
1141-
session_healthy = False
1142-
1143-
# Get workflow function group configuration (configured client-side tools)
1144-
configured_short_names: set[str] = set()
1145-
configured_full_to_fn: dict[str, Any] = {}
1146-
try:
1147-
accessible_functions = await group_instance.get_accessible_functions()
1148-
configured_full_to_fn = accessible_functions
1149-
configured_short_names = {
1150-
name.split('.', 1)[1] if '.' in name else name
1151-
for name in accessible_functions.keys()
1152-
}
1153-
except Exception as e:
1154-
logger.warning(f"Failed to get accessible functions for group {group_name}: {e}")
1155-
1156-
# Build alias->original mapping from overrides
1157-
alias_to_original: dict[str, str] = {}
1158-
try:
1159-
overrides = getattr(config, "tool_overrides", None) or {}
1160-
for orig_name, override in overrides.items():
1161-
alias = getattr(override, "alias", None)
1162-
if alias:
1163-
alias_to_original[alias] = orig_name
1164-
except Exception:
1165-
pass
1166-
1167-
# Create tool info list (always return configured tools; mark availability)
1168-
tools_info: list[dict[str, Any]] = []
1169-
available_count = 0
1170-
for fn_short in sorted(configured_short_names):
1171-
orig_name = alias_to_original.get(fn_short, fn_short)
1172-
available = session_healthy and (orig_name in server_tools)
1173-
if available:
1174-
available_count += 1
1175-
1176-
# Prefer the workflow function description (includes overrides)
1177-
full_name = f"{group_name}.{fn_short}"
1178-
wf_fn = configured_full_to_fn.get(full_name)
1179-
description = getattr(
1180-
wf_fn, "description",
1181-
None) or (server_tools[orig_name].description if available else "")
1182-
1183-
tools_info.append(
1184-
MCPToolInfo(name=fn_short,
1185-
description=description or "",
1186-
server=client.server_name,
1187-
available=available).dict())
1188-
1189-
mcp_clients_info.append({
1190-
"function_group": group_name,
1191-
"server": client.server_name,
1192-
"transport": config.server.transport,
1193-
"session_healthy": session_healthy,
1194-
"tools": tools_info,
1195-
"total_tools": len(configured_short_names),
1196-
"available_tools": available_count
1197-
})
11981146

1147+
# Get workflow function group configuration (configured client-side tools)
1148+
configured_short_names: set[str] = set()
1149+
configured_full_to_fn: dict[str, Function] = {}
1150+
try:
1151+
# Pass a no-op filter function to bypass any default filtering that might check
1152+
# health status, preventing potential infinite recursion during health status checks.
1153+
async def pass_through_filter(fn):
1154+
return fn
1155+
1156+
accessible_functions = await group_instance.get_accessible_functions(
1157+
filter_fn=pass_through_filter)
1158+
configured_full_to_fn = accessible_functions
1159+
configured_short_names = {name.split('.', 1)[1] for name in accessible_functions.keys()}
11991160
except Exception as e:
1200-
logger.error(f"Error processing MCP client {group_name}: {e}")
1201-
mcp_clients_info.append({
1202-
"function_group": group_name,
1203-
"server": "unknown",
1204-
"transport": config.server.transport if config.server else "unknown",
1205-
"session_healthy": False,
1206-
"error": str(e),
1207-
"tools": [],
1208-
"total_tools": 0,
1209-
"workflow_tools": 0
1210-
})
1161+
logger.exception(f"Failed to get accessible functions for group {group_name}: {e}")
1162+
1163+
# Build alias->original mapping from overrides
1164+
alias_to_original: dict[str, str] = {}
1165+
try:
1166+
if config.tool_overrides is not None:
1167+
for orig_name, override in config.tool_overrides.items():
1168+
if override.alias is not None:
1169+
alias_to_original[override.alias] = orig_name
1170+
except Exception:
1171+
pass
1172+
1173+
# Create tool info list (always return configured tools; mark availability)
1174+
tools_info: list[dict[str, Any]] = []
1175+
available_count = 0
1176+
for wf_fn, fn_short in zip(configured_full_to_fn.values(), configured_short_names):
1177+
orig_name = alias_to_original.get(fn_short, fn_short)
1178+
available = session_healthy and (orig_name in server_tools)
1179+
if available:
1180+
available_count += 1
1181+
1182+
description = (server_tools[orig_name].description
1183+
if available else None) or wf_fn.description or ""
1184+
1185+
tools_info.append(
1186+
MCPToolInfo(name=fn_short,
1187+
description=description or "",
1188+
server=client.server_name,
1189+
available=available).model_dump())
1190+
1191+
# Sort tools_info by name to maintain consistent ordering
1192+
tools_info.sort(key=lambda x: x['name'])
1193+
1194+
mcp_clients_info.append({
1195+
"function_group": group_name,
1196+
"server": client.server_name,
1197+
"transport": config.server.transport,
1198+
"session_healthy": session_healthy,
1199+
"tools": tools_info,
1200+
"total_tools": len(configured_short_names),
1201+
"available_tools": available_count
1202+
})
1203+
1204+
except Exception as e:
1205+
logger.error(f"Error processing MCP client {group_name}: {e}")
1206+
mcp_clients_info.append({
1207+
"function_group": group_name,
1208+
"server": "unknown",
1209+
"transport": config.server.transport if config.server else "unknown",
1210+
"session_healthy": False,
1211+
"error": str(e),
1212+
"tools": [],
1213+
"total_tools": 0,
1214+
"workflow_tools": 0
1215+
})
12111216

12121217
return MCPClientToolListResponse(mcp_clients=mcp_clients_info)
12131218

tests/nat/front_ends/fastapi/test_mcp_client_endpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class _GroupInstanceStub:
5151

5252
def __init__(self, client: _ClientStub, functions_map: dict[str, _FnStub]):
5353
# Reuse the pre-established client session on the group, like runtime
54-
self._mcp_client = client
54+
self.mcp_client = client
5555
self._functions_map = functions_map
5656

5757
async def get_accessible_functions(self) -> dict[str, _FnStub]:

0 commit comments

Comments
 (0)