Skip to content

Commit 6c2f4bb

Browse files
committed
fix: add finish_reason processing to xai.ts provider
Same fix as roo.ts and openrouter.ts - process finish_reason during streaming to emit tool_call_end events immediately, ensuring tool calls are finalized without waiting for [DONE] stream termination signal.
1 parent 88a0bed commit 6c2f4bb

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

src/api/providers/__tests__/xai.spec.ts

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,5 +495,87 @@ describe("XAIHandler", () => {
495495
}),
496496
)
497497
})
498+
499+
it("should yield tool_call_end events when finish_reason is tool_calls", async () => {
500+
// Import NativeToolCallParser to set up state
501+
const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser")
502+
503+
// Clear any previous state
504+
NativeToolCallParser.clearRawChunkState()
505+
506+
const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
507+
508+
mockCreate.mockImplementationOnce(() => {
509+
return {
510+
[Symbol.asyncIterator]: () => ({
511+
next: vi
512+
.fn()
513+
.mockResolvedValueOnce({
514+
done: false,
515+
value: {
516+
choices: [
517+
{
518+
delta: {
519+
tool_calls: [
520+
{
521+
index: 0,
522+
id: "call_xai_test",
523+
function: {
524+
name: "test_tool",
525+
arguments: '{"arg1":"value"}',
526+
},
527+
},
528+
],
529+
},
530+
},
531+
],
532+
},
533+
})
534+
.mockResolvedValueOnce({
535+
done: false,
536+
value: {
537+
choices: [
538+
{
539+
delta: {},
540+
finish_reason: "tool_calls",
541+
},
542+
],
543+
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
544+
},
545+
})
546+
.mockResolvedValueOnce({ done: true }),
547+
}),
548+
}
549+
})
550+
551+
const stream = handlerWithTools.createMessage("test prompt", [], {
552+
taskId: "test-task-id",
553+
tools: testTools,
554+
toolProtocol: "native",
555+
})
556+
557+
const chunks = []
558+
for await (const chunk of stream) {
559+
// Simulate what Task.ts does: when we receive tool_call_partial,
560+
// process it through NativeToolCallParser to populate rawChunkTracker
561+
if (chunk.type === "tool_call_partial") {
562+
NativeToolCallParser.processRawChunk({
563+
index: chunk.index,
564+
id: chunk.id,
565+
name: chunk.name,
566+
arguments: chunk.arguments,
567+
})
568+
}
569+
chunks.push(chunk)
570+
}
571+
572+
// Should have tool_call_partial and tool_call_end
573+
const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
574+
const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
575+
576+
expect(partialChunks).toHaveLength(1)
577+
expect(endChunks).toHaveLength(1)
578+
expect(endChunks[0].id).toBe("call_xai_test")
579+
})
498580
})
499581
})

src/api/providers/xai.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import OpenAI from "openai"
33

44
import { type XAIModelId, xaiDefaultModelId, xaiModels } from "@roo-code/types"
55

6+
import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser"
67
import type { ApiHandlerOptions } from "../../shared/api"
78

89
import { ApiStream } from "../transform/stream"
@@ -83,6 +84,7 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
8384

8485
for await (const chunk of stream) {
8586
const delta = chunk.choices[0]?.delta
87+
const finishReason = chunk.choices[0]?.finish_reason
8688

8789
if (delta?.content) {
8890
yield {
@@ -111,6 +113,15 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
111113
}
112114
}
113115

116+
// Process finish_reason to emit tool_call_end events
117+
// This ensures tool calls are finalized even if the stream doesn't properly close
118+
if (finishReason) {
119+
const endEvents = NativeToolCallParser.processFinishReason(finishReason)
120+
for (const event of endEvents) {
121+
yield event
122+
}
123+
}
124+
114125
if (chunk.usage) {
115126
// Extract detailed token information if available
116127
// First check for prompt_tokens_details structure (real API response)

0 commit comments

Comments
 (0)