Skip to content

Commit 6747a4e

Browse files
committed
feat(ai): 添加工具调用的流式处理支持并限制输出大小
添加流式工具调用接口以实时显示AI思考过程 限制命令输出大小防止内存溢出 处理流式JSON参数不完整的情况
1 parent 1f625ba commit 6747a4e

File tree

3 files changed

+203
-32
lines changed

3 files changed

+203
-32
lines changed

app_ai_chat.go

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,16 @@ func (a *App) runAgentLoop(conversationId string, messages []AIChatMessage, prom
249249
return nil
250250
}
251251

252-
resp, err := client.ChatWithTools(ctx, aiMessages, toolDefs)
252+
// Use streaming tool call so users can see AI thinking in real-time
253+
resp, err := client.ChatWithToolsStream(ctx, aiMessages, toolDefs, func(chunk string) error {
254+
a.emitEvent("ai-chat-chunk", map[string]string{
255+
"conversationId": conversationId,
256+
"chunk": chunk,
257+
})
258+
return nil
259+
})
253260
if err != nil {
254261
if ctx.Err() != nil {
255-
// Cancelled by user or timeout
256262
a.emitEvent("ai-chat-chunk", map[string]string{
257263
"conversationId": conversationId,
258264
"chunk": "\n\n⏹️ 操作已被用户停止。",
@@ -263,30 +269,17 @@ func (a *App) runAgentLoop(conversationId string, messages []AIChatMessage, prom
263269
})
264270
return nil
265271
}
266-
a.emitEvent( "ai-chat-complete", map[string]interface{}{
272+
a.emitEvent("ai-chat-complete", map[string]interface{}{
267273
"conversationId": conversationId,
268274
"success": false,
269275
})
270276
return fmt.Errorf(i18n.Tf("app_ai_analysis_failed", err))
271277
}
272278

273-
// No tool calls → stream the final answer
279+
// No tool calls → final answer (already streamed via callback)
274280
if len(resp.ToolCalls) == 0 {
275281
aiMessages = append(aiMessages, ai.Message{Role: "assistant", Content: resp.Content})
276-
277-
words := strings.Split(resp.Content, "")
278-
chunkSize := 8
279-
for i := 0; i < len(words); i += chunkSize {
280-
end := i + chunkSize
281-
if end > len(words) {
282-
end = len(words)
283-
}
284-
a.emitEvent( "ai-chat-chunk", map[string]string{
285-
"conversationId": conversationId,
286-
"chunk": strings.Join(words[i:end], ""),
287-
})
288-
}
289-
a.emitEvent( "ai-chat-complete", map[string]interface{}{
282+
a.emitEvent("ai-chat-complete", map[string]interface{}{
290283
"conversationId": conversationId,
291284
"success": true,
292285
})
@@ -303,30 +296,46 @@ func (a *App) runAgentLoop(conversationId string, messages []AIChatMessage, prom
303296
// Execute each tool call
304297
for _, tc := range resp.ToolCalls {
305298
var args map[string]interface{}
299+
var jsonParseErr error
306300
if tc.Function.Arguments != "" {
307301
if jsonErr := json.Unmarshal([]byte(tc.Function.Arguments), &args); jsonErr != nil {
308-
args = map[string]interface{}{}
302+
// Streaming may produce incomplete JSON for no-arg tools; treat as empty
303+
trimmed := strings.TrimSpace(tc.Function.Arguments)
304+
if trimmed == "{" || trimmed == "" {
305+
args = map[string]interface{}{}
306+
} else {
307+
jsonParseErr = jsonErr
308+
args = map[string]interface{}{}
309+
}
309310
}
310311
}
311312

312-
a.emitEvent( "ai-agent-tool-call", map[string]interface{}{
313+
a.emitEvent("ai-agent-tool-call", map[string]interface{}{
313314
"conversationId": conversationId,
314315
"toolCallId": tc.ID,
315316
"toolName": tc.Function.Name,
316317
"toolArgs": args,
317318
})
318319

319-
result, execErr := mcpServer.ExecuteTool(tc.Function.Name, args)
320320
var resultContent string
321-
success := execErr == nil
322-
if execErr != nil {
323-
resultContent = fmt.Sprintf("工具执行失败: %v", execErr)
324-
} else if len(result.Content) > 0 {
325-
var parts []string
326-
for _, item := range result.Content {
327-
parts = append(parts, item.Text)
321+
var success bool
322+
323+
if jsonParseErr != nil {
324+
// Report JSON parse failure as tool result so AI knows the root cause
325+
resultContent = fmt.Sprintf("工具参数 JSON 解析失败: %v\n原始参数: %s", jsonParseErr, tc.Function.Arguments)
326+
success = false
327+
} else {
328+
result, execErr := mcpServer.ExecuteTool(tc.Function.Name, args)
329+
success = execErr == nil
330+
if execErr != nil {
331+
resultContent = fmt.Sprintf("工具执行失败: %v", execErr)
332+
} else if len(result.Content) > 0 {
333+
var parts []string
334+
for _, item := range result.Content {
335+
parts = append(parts, item.Text)
336+
}
337+
resultContent = strings.Join(parts, "\n")
328338
}
329-
resultContent = strings.Join(parts, "\n")
330339
}
331340

332341
// Truncate large tool results to prevent context window overflow
@@ -335,7 +344,7 @@ func (a *App) runAgentLoop(conversationId string, messages []AIChatMessage, prom
335344
resultContent = resultContent[:maxToolResultLen] + "\n\n... (output truncated, total " + fmt.Sprintf("%d", len(resultContent)) + " bytes)"
336345
}
337346

338-
a.emitEvent( "ai-agent-tool-result", map[string]interface{}{
347+
a.emitEvent("ai-agent-tool-result", map[string]interface{}{
339348
"conversationId": conversationId,
340349
"toolCallId": tc.ID,
341350
"toolName": tc.Function.Name,

mod/ai/client.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,137 @@ type ToolCallResponse struct {
114114
// StreamCallback is called for each chunk of the stream
115115
type StreamCallback func(chunk string) error
116116

117+
// ToolStreamCallback is called for each content chunk during streaming tool calls
118+
type ToolStreamCallback func(chunk string) error
119+
120+
// ChatWithToolsStream sends a streaming chat request with tool definitions.
121+
// Content chunks are sent to the callback in real-time. Returns the full response (content + tool_calls).
122+
func (c *Client) ChatWithToolsStream(ctx context.Context, messages []Message, tools []ToolDefinition, callback ToolStreamCallback) (*ToolCallResponse, error) {
123+
reqBody := map[string]interface{}{
124+
"model": c.Model,
125+
"messages": messages,
126+
"tools": tools,
127+
"tool_choice": "auto",
128+
"stream": true,
129+
}
130+
131+
bodyBytes, err := json.Marshal(reqBody)
132+
if err != nil {
133+
return nil, fmt.Errorf("failed to marshal request: %w", err)
134+
}
135+
136+
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewReader(bodyBytes))
137+
if err != nil {
138+
return nil, fmt.Errorf("failed to create request: %w", err)
139+
}
140+
req.Header.Set("Content-Type", "application/json")
141+
req.Header.Set("Authorization", "Bearer "+c.APIKey)
142+
143+
resp, err := c.client.Do(req)
144+
if err != nil {
145+
return nil, fmt.Errorf("failed to send request: %w", err)
146+
}
147+
defer resp.Body.Close()
148+
149+
if resp.StatusCode != http.StatusOK {
150+
body, _ := io.ReadAll(resp.Body)
151+
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
152+
}
153+
154+
var contentBuilder strings.Builder
155+
toolCallMap := make(map[int]*ToolCall) // index → accumulated tool call
156+
157+
reader := bufio.NewReader(resp.Body)
158+
for {
159+
line, err := reader.ReadBytes('\n')
160+
if err != nil {
161+
if err == io.EOF {
162+
break
163+
}
164+
return nil, fmt.Errorf("failed to read stream: %w", err)
165+
}
166+
167+
line = bytes.TrimSpace(line)
168+
if len(line) == 0 {
169+
continue
170+
}
171+
if !bytes.HasPrefix(line, []byte("data: ")) {
172+
continue
173+
}
174+
data := bytes.TrimPrefix(line, []byte("data: "))
175+
if bytes.Equal(data, []byte("[DONE]")) {
176+
break
177+
}
178+
179+
var chunk struct {
180+
Choices []struct {
181+
Delta struct {
182+
Content string `json:"content"`
183+
ToolCalls []struct {
184+
Index int `json:"index"`
185+
ID string `json:"id"`
186+
Type string `json:"type"`
187+
Function struct {
188+
Name string `json:"name"`
189+
Arguments string `json:"arguments"`
190+
} `json:"function"`
191+
} `json:"tool_calls"`
192+
} `json:"delta"`
193+
} `json:"choices"`
194+
}
195+
if err := json.Unmarshal(data, &chunk); err != nil {
196+
continue
197+
}
198+
if len(chunk.Choices) == 0 {
199+
continue
200+
}
201+
202+
delta := chunk.Choices[0].Delta
203+
204+
// Stream content chunks to callback
205+
if delta.Content != "" {
206+
contentBuilder.WriteString(delta.Content)
207+
if callback != nil {
208+
_ = callback(delta.Content)
209+
}
210+
}
211+
212+
// Accumulate tool calls
213+
for _, dtc := range delta.ToolCalls {
214+
tc, ok := toolCallMap[dtc.Index]
215+
if !ok {
216+
tc = &ToolCall{Type: "function"}
217+
toolCallMap[dtc.Index] = tc
218+
}
219+
if dtc.ID != "" {
220+
tc.ID = dtc.ID
221+
}
222+
if dtc.Function.Name != "" {
223+
tc.Function.Name = dtc.Function.Name
224+
}
225+
tc.Function.Arguments += dtc.Function.Arguments
226+
}
227+
}
228+
229+
// Build sorted tool calls list
230+
var toolCalls []ToolCall
231+
for i := 0; i < len(toolCallMap); i++ {
232+
if tc, ok := toolCallMap[i]; ok {
233+
// Fix incomplete arguments from streaming (e.g., "{" without closing "}")
234+
args := strings.TrimSpace(tc.Function.Arguments)
235+
if args == "" || args == "{" {
236+
tc.Function.Arguments = "{}"
237+
}
238+
toolCalls = append(toolCalls, *tc)
239+
}
240+
}
241+
242+
return &ToolCallResponse{
243+
Content: contentBuilder.String(),
244+
ToolCalls: toolCalls,
245+
}, nil
246+
}
247+
117248
// ChatWithTools sends a non-streaming chat request with tool definitions, returns full response
118249
func (c *Client) ChatWithTools(ctx context.Context, messages []Message, tools []ToolDefinition) (*ToolCallResponse, error) {
119250
reqBody := map[string]interface{}{

mod/mcp/mcp.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,8 +1265,11 @@ func (s *MCPServer) toolExecCommand(caseID string, command string) (ToolResult,
12651265
}
12661266
defer session.Close()
12671267

1268-
session.Stdout = &outputBuf
1269-
session.Stderr = &outputBuf
1268+
// Limit captured output to prevent memory/context explosion
1269+
const maxOutputBytes = 32 * 1024 // 32KB
1270+
lw := &limitedWriter{w: &outputBuf, limit: maxOutputBytes}
1271+
session.Stdout = lw
1272+
session.Stderr = lw
12701273

12711274
// Run with timeout to prevent hanging on blocking commands
12721275
const execTimeout = 120 * time.Second
@@ -1275,8 +1278,10 @@ func (s *MCPServer) toolExecCommand(caseID string, command string) (ToolResult,
12751278
done <- session.Run(command)
12761279
}()
12771280

1281+
var truncated bool
12781282
select {
12791283
case err := <-done:
1284+
truncated = lw.truncated
12801285
if err != nil {
12811286
return ToolResult{}, fmt.Errorf("command failed: %v\nOutput: %s", err, outputBuf.String())
12821287
}
@@ -1287,6 +1292,9 @@ func (s *MCPServer) toolExecCommand(caseID string, command string) (ToolResult,
12871292

12881293
output := fmt.Sprintf("Command executed on case '%s' (%s):\n", c.Name, c.GetId())
12891294
output += fmt.Sprintf("\nOutput:\n%s", outputBuf.String())
1295+
if truncated {
1296+
output += fmt.Sprintf("\n\n... (output truncated at %d bytes, total output exceeded limit)", maxOutputBytes)
1297+
}
12901298

12911299
return ToolResult{
12921300
Content: []ContentItem{{
@@ -1971,3 +1979,26 @@ func (m *MCPServerManager) runSSEServer(ctx context.Context, addr string) error
19711979

19721980
return nil
19731981
}
1982+
1983+
// limitedWriter wraps an io.Writer and stops writing after a byte limit
1984+
type limitedWriter struct {
1985+
w io.Writer
1986+
limit int
1987+
written int
1988+
truncated bool
1989+
}
1990+
1991+
func (lw *limitedWriter) Write(p []byte) (int, error) {
1992+
if lw.written >= lw.limit {
1993+
lw.truncated = true
1994+
return len(p), nil // discard silently
1995+
}
1996+
remaining := lw.limit - lw.written
1997+
if len(p) > remaining {
1998+
p = p[:remaining]
1999+
lw.truncated = true
2000+
}
2001+
n, err := lw.w.Write(p)
2002+
lw.written += n
2003+
return n, err
2004+
}

0 commit comments

Comments
 (0)