Skip to content

Commit 184114b

Browse files
authored
Enable Chat History for WebSocket Messages (#999)
This PR updates the WebSocket message handler to pass the correct payload to the agent, enabling chat history for supported workflow types. ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/resources/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. ## Summary by CodeRabbit * **New Features** * Enhanced WebSocket handling for multi-message chats and unified Chat/Generate streams with consistent outputs. * Improved human-in-the-loop prompts and responses, returning clearer, consistent message content. * **Bug Fixes** * Prevents workflows from starting while another task runs, reducing race conditions. * More robust message creation and error handling for fewer failed interactions. Authors: - Eric Evans II (https://github.com/ericevans-nv) - Will Killian (https://github.com/willkill07) Approvers: - Will Killian (https://github.com/willkill07) URL: #999
1 parent 0e920aa commit 184114b

File tree

2 files changed

+66
-42
lines changed

2 files changed

+66
-42
lines changed

src/nat/front_ends/fastapi/message_handler.py

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from starlette.websockets import WebSocketDisconnect
2626

2727
from nat.authentication.interfaces import FlowHandlerBase
28+
from nat.data_models.api_server import ChatRequest
2829
from nat.data_models.api_server import ChatResponse
2930
from nat.data_models.api_server import ChatResponseChunk
3031
from nat.data_models.api_server import Error
@@ -33,6 +34,8 @@
3334
from nat.data_models.api_server import ResponseSerializable
3435
from nat.data_models.api_server import SystemResponseContent
3536
from nat.data_models.api_server import TextContent
37+
from nat.data_models.api_server import UserMessageContentRoleType
38+
from nat.data_models.api_server import UserMessages
3639
from nat.data_models.api_server import WebSocketMessageStatus
3740
from nat.data_models.api_server import WebSocketMessageType
3841
from nat.data_models.api_server import WebSocketSystemInteractionMessage
@@ -64,12 +67,12 @@ def __init__(self, socket: WebSocket, session_manager: SessionManager, step_adap
6467
self._running_workflow_task: asyncio.Task | None = None
6568
self._message_parent_id: str = "default_id"
6669
self._conversation_id: str | None = None
67-
self._workflow_schema_type: str = None
68-
self._user_interaction_response: asyncio.Future[HumanResponse] | None = None
70+
self._workflow_schema_type: str | None = None
71+
self._user_interaction_response: asyncio.Future[TextContent] | None = None
6972

7073
self._flow_handler: FlowHandlerBase | None = None
7174

72-
self._schema_output_mapping: dict[str, type[BaseModel] | None] = {
75+
self._schema_output_mapping: dict[str, type[BaseModel] | type[None]] = {
7376
WorkflowSchemaType.GENERATE: self._session_manager.workflow.single_output_schema,
7477
WorkflowSchemaType.CHAT: ChatResponse,
7578
WorkflowSchemaType.CHAT_STREAM: ChatResponseChunk,
@@ -114,55 +117,74 @@ async def run(self) -> None:
114117
pass
115118

116119
elif (isinstance(validated_message, WebSocketUserInteractionResponseMessage)):
117-
user_content = await self.process_user_message_content(validated_message)
120+
user_content = await self._process_websocket_user_interaction_response_message(validated_message)
121+
assert self._user_interaction_response is not None
118122
self._user_interaction_response.set_result(user_content)
119123
except (asyncio.CancelledError, WebSocketDisconnect):
120124
# TODO: Handle the disconnect
121125
break
122126

123-
async def process_user_message_content(
124-
self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
127+
def _extract_last_user_message_content(self, messages: list[UserMessages]) -> TextContent:
125128
"""
126-
Processes the contents of a user message.
129+
Extracts the last user's TextContent from a list of messages.
127130
128-
:param user_content: Incoming content data model.
129-
:return: A validated Pydantic user content model or None if not found.
130-
"""
131+
Args:
132+
messages: List of UserMessages.
131133
132-
for user_message in user_content.content.messages[::-1]:
133-
if (user_message.role == "user"):
134+
Returns:
135+
TextContent object from the last user message.
134136
137+
Raises:
138+
ValueError: If no user text content is found.
139+
"""
140+
for user_message in messages[::-1]:
141+
if user_message.role == UserMessageContentRoleType.USER:
135142
for attachment in user_message.content:
136-
137143
if isinstance(attachment, TextContent):
138144
return attachment
145+
raise ValueError("No user text content found in messages.")
146+
147+
async def _process_websocket_user_interaction_response_message(
148+
self, user_content: WebSocketUserInteractionResponseMessage) -> TextContent:
149+
"""
150+
Processes a WebSocketUserInteractionResponseMessage.
151+
"""
152+
return self._extract_last_user_message_content(user_content.content.messages)
139153

140-
return None
154+
async def _process_websocket_user_message(self, user_content: WebSocketUserMessage) -> ChatRequest | str:
155+
"""
156+
Processes a WebSocketUserMessage based on schema type.
157+
"""
158+
if self._workflow_schema_type in [WorkflowSchemaType.CHAT, WorkflowSchemaType.CHAT_STREAM]:
159+
return ChatRequest(**user_content.content.model_dump(include={"messages"}))
160+
161+
elif self._workflow_schema_type in [WorkflowSchemaType.GENERATE, WorkflowSchemaType.GENERATE_STREAM]:
162+
return self._extract_last_user_message_content(user_content.content.messages).text
163+
164+
raise ValueError("Unsupported workflow schema type for WebSocketUserMessage")
141165

142166
async def process_workflow_request(self, user_message_as_validated_type: WebSocketUserMessage) -> None:
143167
"""
144168
Process user messages and routes them appropriately.
145169
146-
:param user_message_as_validated_type: A WebSocketUserMessage Data Model instance.
170+
Args:
171+
user_message_as_validated_type (WebSocketUserMessage): The validated user message to process.
147172
"""
148173

149174
try:
150175
self._message_parent_id = user_message_as_validated_type.id
151176
self._workflow_schema_type = user_message_as_validated_type.schema_type
152177
self._conversation_id = user_message_as_validated_type.conversation_id
153178

154-
content: BaseModel | None = await self.process_user_message_content(user_message_as_validated_type)
155-
156-
if content is None:
157-
raise ValueError(f"User message content could not be found: {user_message_as_validated_type}")
179+
message_content: typing.Any = await self._process_websocket_user_message(user_message_as_validated_type)
158180

159-
if isinstance(content, TextContent) and (self._running_workflow_task is None):
181+
if (self._running_workflow_task is None):
160182

161-
def _done_callback(task: asyncio.Task):
183+
def _done_callback(_task: asyncio.Task):
162184
self._running_workflow_task = None
163185

164186
self._running_workflow_task = asyncio.create_task(
165-
self._run_workflow(payload=content.text,
187+
self._run_workflow(payload=message_content,
166188
user_message_id=self._message_parent_id,
167189
conversation_id=self._conversation_id,
168190
result_type=self._schema_output_mapping[self._workflow_schema_type],
@@ -180,13 +202,14 @@ def _done_callback(task: asyncio.Task):
180202
async def create_websocket_message(self,
181203
data_model: BaseModel,
182204
message_type: str | None = None,
183-
status: str = WebSocketMessageStatus.IN_PROGRESS) -> None:
205+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS) -> None:
184206
"""
185207
Creates a websocket message that will be ready for routing based on message type or data model.
186208
187-
:param data_model: Message content model.
188-
:param message_type: Message content model.
189-
:param status: Message content model.
209+
Args:
210+
data_model (BaseModel): Message content model.
211+
message_type (str | None): Message content model.
212+
status (WebSocketMessageStatus): Message content model.
190213
"""
191214
try:
192215
message: BaseModel | None = None
@@ -196,8 +219,8 @@ async def create_websocket_message(self,
196219

197220
message_schema: type[BaseModel] = await self._message_validator.get_message_schema_by_type(message_type)
198221

199-
if 'id' in data_model.model_fields:
200-
message_id: str = data_model.id
222+
if hasattr(data_model, 'id'):
223+
message_id: str = str(getattr(data_model, 'id'))
201224
else:
202225
message_id = str(uuid.uuid4())
203226

@@ -253,12 +276,15 @@ async def human_interaction_callback(self, prompt: InteractionPrompt) -> HumanRe
253276
Registered human interaction callback that processes human interactions and returns
254277
responses from websocket connection.
255278
256-
:param prompt: Incoming interaction content data model.
257-
:return: A Text Content Base Pydantic model.
279+
Args:
280+
prompt: Incoming interaction content data model.
281+
282+
Returns:
283+
A Text Content Base Pydantic model.
258284
"""
259285

260286
# First create a future from the loop for the human response
261-
human_response_future: asyncio.Future[HumanResponse] = asyncio.get_running_loop().create_future()
287+
human_response_future: asyncio.Future[TextContent] = asyncio.get_running_loop().create_future()
262288

263289
# Then add the future to the outstanding human prompts dictionary
264290
self._user_interaction_response = human_response_future
@@ -274,10 +300,10 @@ async def human_interaction_callback(self, prompt: InteractionPrompt) -> HumanRe
274300
return HumanResponseNotification()
275301

276302
# Wait for the human response future to complete
277-
interaction_response: HumanResponse = await human_response_future
303+
text_content: TextContent = await human_response_future
278304

279305
interaction_response: HumanResponse = await self._message_validator.convert_text_content_to_human_response(
280-
interaction_response, prompt.content)
306+
text_content, prompt.content)
281307

282308
return interaction_response
283309

@@ -293,13 +319,12 @@ async def _run_workflow(self,
293319
output_type: type | None = None) -> None:
294320

295321
try:
296-
async with self._session_manager.session(
297-
user_message_id=user_message_id,
298-
conversation_id=conversation_id,
299-
http_connection=self._socket,
300-
user_input_callback=self.human_interaction_callback,
301-
user_authentication_callback=(self._flow_handler.authenticate
302-
if self._flow_handler else None)) as session:
322+
auth_callback = self._flow_handler.authenticate if self._flow_handler else None
323+
async with self._session_manager.session(user_message_id=user_message_id,
324+
conversation_id=conversation_id,
325+
http_connection=self._socket,
326+
user_input_callback=self.human_interaction_callback,
327+
user_authentication_callback=auth_callback) as session:
303328

304329
async for value in generate_streaming_response(payload,
305330
session_manager=session,

src/nat/front_ends/fastapi/message_validator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,7 @@ async def create_system_response_token_message(
240240
thread_id: str = "default",
241241
parent_id: str = "default",
242242
conversation_id: str | None = None,
243-
content: SystemResponseContent
244-
| Error = SystemResponseContent(),
243+
content: SystemResponseContent | Error = SystemResponseContent(),
245244
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
246245
timestamp: str = str(datetime.datetime.now(datetime.UTC))
247246
) -> WebSocketSystemResponseTokenMessage | None:

0 commit comments

Comments
 (0)