2525from starlette .websockets import WebSocketDisconnect
2626
2727from nat .authentication .interfaces import FlowHandlerBase
28+ from nat .data_models .api_server import ChatRequest
2829from nat .data_models .api_server import ChatResponse
2930from nat .data_models .api_server import ChatResponseChunk
3031from nat .data_models .api_server import Error
3334from nat .data_models .api_server import ResponseSerializable
3435from nat .data_models .api_server import SystemResponseContent
3536from nat .data_models .api_server import TextContent
37+ from nat .data_models .api_server import UserMessageContentRoleType
38+ from nat .data_models .api_server import UserMessages
3639from nat .data_models .api_server import WebSocketMessageStatus
3740from nat .data_models .api_server import WebSocketMessageType
3841from nat .data_models .api_server import WebSocketSystemInteractionMessage
@@ -64,12 +67,12 @@ def __init__(self, socket: WebSocket, session_manager: SessionManager, step_adap
6467 self ._running_workflow_task : asyncio .Task | None = None
6568 self ._message_parent_id : str = "default_id"
6669 self ._conversation_id : str | None = None
67- self ._workflow_schema_type : str = None
68- self ._user_interaction_response : asyncio .Future [HumanResponse ] | None = None
70+ self ._workflow_schema_type : str | None = None
71+ self ._user_interaction_response : asyncio .Future [TextContent ] | None = None
6972
7073 self ._flow_handler : FlowHandlerBase | None = None
7174
72- self ._schema_output_mapping : dict [str , type [BaseModel ] | None ] = {
75+ self ._schema_output_mapping : dict [str , type [BaseModel ] | type [ None ] ] = {
7376 WorkflowSchemaType .GENERATE : self ._session_manager .workflow .single_output_schema ,
7477 WorkflowSchemaType .CHAT : ChatResponse ,
7578 WorkflowSchemaType .CHAT_STREAM : ChatResponseChunk ,
@@ -114,55 +117,74 @@ async def run(self) -> None:
114117 pass
115118
116119 elif (isinstance (validated_message , WebSocketUserInteractionResponseMessage )):
117- user_content = await self .process_user_message_content (validated_message )
120+ user_content = await self ._process_websocket_user_interaction_response_message (validated_message )
121+ assert self ._user_interaction_response is not None
118122 self ._user_interaction_response .set_result (user_content )
119123 except (asyncio .CancelledError , WebSocketDisconnect ):
120124 # TODO: Handle the disconnect
121125 break
122126
123- async def process_user_message_content (
124- self , user_content : WebSocketUserMessage | WebSocketUserInteractionResponseMessage ) -> BaseModel | None :
127+ def _extract_last_user_message_content (self , messages : list [UserMessages ]) -> TextContent :
125128 """
126- Processes the contents of a user message .
129+ Extracts the last user's TextContent from a list of messages .
127130
128- :param user_content: Incoming content data model.
129- :return: A validated Pydantic user content model or None if not found.
130- """
131+ Args:
132+ messages: List of UserMessages.
131133
132- for user_message in user_content . content . messages [:: - 1 ] :
133- if ( user_message . role == " user" ):
134+ Returns :
135+ TextContent object from the last user message.
134136
137+ Raises:
138+ ValueError: If no user text content is found.
139+ """
140+ for user_message in messages [::- 1 ]:
141+ if user_message .role == UserMessageContentRoleType .USER :
135142 for attachment in user_message .content :
136-
137143 if isinstance (attachment , TextContent ):
138144 return attachment
145+ raise ValueError ("No user text content found in messages." )
146+
147+ async def _process_websocket_user_interaction_response_message (
148+ self , user_content : WebSocketUserInteractionResponseMessage ) -> TextContent :
149+ """
150+ Processes a WebSocketUserInteractionResponseMessage.
151+ """
152+ return self ._extract_last_user_message_content (user_content .content .messages )
139153
140- return None
154+ async def _process_websocket_user_message (self , user_content : WebSocketUserMessage ) -> ChatRequest | str :
155+ """
156+ Processes a WebSocketUserMessage based on schema type.
157+ """
158+ if self ._workflow_schema_type in [WorkflowSchemaType .CHAT , WorkflowSchemaType .CHAT_STREAM ]:
159+ return ChatRequest (** user_content .content .model_dump (include = {"messages" }))
160+
161+ elif self ._workflow_schema_type in [WorkflowSchemaType .GENERATE , WorkflowSchemaType .GENERATE_STREAM ]:
162+ return self ._extract_last_user_message_content (user_content .content .messages ).text
163+
164+ raise ValueError ("Unsupported workflow schema type for WebSocketUserMessage" )
141165
142166 async def process_workflow_request (self , user_message_as_validated_type : WebSocketUserMessage ) -> None :
143167 """
144168 Process user messages and routes them appropriately.
145169
146- :param user_message_as_validated_type: A WebSocketUserMessage Data Model instance.
170+ Args:
171+ user_message_as_validated_type (WebSocketUserMessage): The validated user message to process.
147172 """
148173
149174 try :
150175 self ._message_parent_id = user_message_as_validated_type .id
151176 self ._workflow_schema_type = user_message_as_validated_type .schema_type
152177 self ._conversation_id = user_message_as_validated_type .conversation_id
153178
154- content : BaseModel | None = await self .process_user_message_content (user_message_as_validated_type )
155-
156- if content is None :
157- raise ValueError (f"User message content could not be found: { user_message_as_validated_type } " )
179+ message_content : typing .Any = await self ._process_websocket_user_message (user_message_as_validated_type )
158180
159- if isinstance ( content , TextContent ) and (self ._running_workflow_task is None ):
181+ if (self ._running_workflow_task is None ):
160182
161- def _done_callback (task : asyncio .Task ):
183+ def _done_callback (_task : asyncio .Task ):
162184 self ._running_workflow_task = None
163185
164186 self ._running_workflow_task = asyncio .create_task (
165- self ._run_workflow (payload = content . text ,
187+ self ._run_workflow (payload = message_content ,
166188 user_message_id = self ._message_parent_id ,
167189 conversation_id = self ._conversation_id ,
168190 result_type = self ._schema_output_mapping [self ._workflow_schema_type ],
@@ -180,13 +202,14 @@ def _done_callback(task: asyncio.Task):
180202 async def create_websocket_message (self ,
181203 data_model : BaseModel ,
182204 message_type : str | None = None ,
183- status : str = WebSocketMessageStatus .IN_PROGRESS ) -> None :
205+ status : WebSocketMessageStatus = WebSocketMessageStatus .IN_PROGRESS ) -> None :
184206 """
185207 Creates a websocket message that will be ready for routing based on message type or data model.
186208
187- :param data_model: Message content model.
188- :param message_type: Message content model.
189- :param status: Message content model.
209+ Args:
210+ data_model (BaseModel): Message content model.
211+ message_type (str | None): Message content model.
212+ status (WebSocketMessageStatus): Message content model.
190213 """
191214 try :
192215 message : BaseModel | None = None
@@ -196,8 +219,8 @@ async def create_websocket_message(self,
196219
197220 message_schema : type [BaseModel ] = await self ._message_validator .get_message_schema_by_type (message_type )
198221
199- if 'id' in data_model . model_fields :
200- message_id : str = data_model . id
222+ if hasattr ( data_model , 'id' ) :
223+ message_id : str = str ( getattr ( data_model , 'id' ))
201224 else :
202225 message_id = str (uuid .uuid4 ())
203226
@@ -253,12 +276,15 @@ async def human_interaction_callback(self, prompt: InteractionPrompt) -> HumanRe
253276 Registered human interaction callback that processes human interactions and returns
254277 responses from websocket connection.
255278
256- :param prompt: Incoming interaction content data model.
257- :return: A Text Content Base Pydantic model.
279+ Args:
280+ prompt: Incoming interaction content data model.
281+
282+ Returns:
283+ A Text Content Base Pydantic model.
258284 """
259285
260286 # First create a future from the loop for the human response
261- human_response_future : asyncio .Future [HumanResponse ] = asyncio .get_running_loop ().create_future ()
287+ human_response_future : asyncio .Future [TextContent ] = asyncio .get_running_loop ().create_future ()
262288
263289 # Then add the future to the outstanding human prompts dictionary
264290 self ._user_interaction_response = human_response_future
@@ -274,10 +300,10 @@ async def human_interaction_callback(self, prompt: InteractionPrompt) -> HumanRe
274300 return HumanResponseNotification ()
275301
276302 # Wait for the human response future to complete
277- interaction_response : HumanResponse = await human_response_future
303+ text_content : TextContent = await human_response_future
278304
279305 interaction_response : HumanResponse = await self ._message_validator .convert_text_content_to_human_response (
280- interaction_response , prompt .content )
306+ text_content , prompt .content )
281307
282308 return interaction_response
283309
@@ -293,13 +319,12 @@ async def _run_workflow(self,
293319 output_type : type | None = None ) -> None :
294320
295321 try :
296- async with self ._session_manager .session (
297- user_message_id = user_message_id ,
298- conversation_id = conversation_id ,
299- http_connection = self ._socket ,
300- user_input_callback = self .human_interaction_callback ,
301- user_authentication_callback = (self ._flow_handler .authenticate
302- if self ._flow_handler else None )) as session :
322+ auth_callback = self ._flow_handler .authenticate if self ._flow_handler else None
323+ async with self ._session_manager .session (user_message_id = user_message_id ,
324+ conversation_id = conversation_id ,
325+ http_connection = self ._socket ,
326+ user_input_callback = self .human_interaction_callback ,
327+ user_authentication_callback = auth_callback ) as session :
303328
304329 async for value in generate_streaming_response (payload ,
305330 session_manager = session ,
0 commit comments