From 73c5345879c08c22ee9e7af5136397e6fc366f74 Mon Sep 17 00:00:00 2001 From: Rob Colbert Date: Sat, 9 May 2026 21:04:18 -0400 Subject: [PATCH] Re-build Agentic Workflow Loop The ridiculousness of trying to maintain the previous agent's work got out of hand, so we had this one re-build it - and got a better result. --- .../frontend/src/pages/ChatSessionView.tsx | 4 +- gadget-code/src/lib/code-session.ts | 44 +- gadget-code/src/services/chat-session.ts | 18 +- gadget-drone/src/gadget-drone.ts | 2 +- gadget-drone/src/services/agent.test.ts | 27 +- gadget-drone/src/services/agent.ts | 297 +++++----- gadget-drone/src/tools/toolbox.ts | 4 + packages/ai/src/api.ts | 122 ----- packages/ai/src/ollama.test.ts | 185 +------ packages/ai/src/ollama.ts | 280 +++------- packages/ai/src/openai.test.ts | 126 +++-- packages/ai/src/openai.ts | 513 +++++++++--------- 12 files changed, 654 insertions(+), 968 deletions(-) diff --git a/gadget-code/frontend/src/pages/ChatSessionView.tsx b/gadget-code/frontend/src/pages/ChatSessionView.tsx index 10026f7..68e212c 100644 --- a/gadget-code/frontend/src/pages/ChatSessionView.tsx +++ b/gadget-code/frontend/src/pages/ChatSessionView.tsx @@ -243,7 +243,9 @@ export default function ChatSessionView() { for (const updateBlock of turnUpdates.blocks) { let blockIndex = state?.currentBlockIndex ?? null; - if ( + if (updateBlock.mode === 'tool') { + blockIndex = null; + } else if ( blockIndex === null || updatedBlocks[blockIndex]?.mode !== updateBlock.mode ) { diff --git a/gadget-code/src/lib/code-session.ts b/gadget-code/src/lib/code-session.ts index 9439ca6..a2974a7 100644 --- a/gadget-code/src/lib/code-session.ts +++ b/gadget-code/src/lib/code-session.ts @@ -185,7 +185,9 @@ export class CodeSession extends SocketSession { try { const droneSession = SocketService.getDroneSession(this.selectedDrone); - const latestSession = await ChatSessionService.getById(this.chatSession._id); + const latestSession = await ChatSessionService.getById( + this.chatSession._id, + ); this.chatSession = latestSession; let turn: ChatTurnDocument = await ChatSessionService.createTurn( @@ -243,6 +245,27 @@ export class CodeSession extends SocketSession { turnId: turn._id, message, }); + + /* + * Auto-generate a session name from the first prompt. Only do this when + * the name is still the default (user hasn't set a custom name) and + * we're on the first turn (turnCount === 1 after the increment above). + */ + if ( + this.chatSession && + this.chatSession.name === "New Chat Session" && + this.chatSession.stats.turnCount === 1 + ) { + this.chatSession = + await ChatSessionService.generateSessionNameFromPrompt( + this.chatSession, + content, + ); + const update: Partial = { + name: this.chatSession.name, + }; + this.socket.emit("sessionUpdated", update); + } } else { this.log.error("work order rejected by drone", { turnId: turn._id, @@ -254,25 +277,6 @@ export class CodeSession extends SocketSession { } }, ); - - /* - * Auto-generate a session name from the first prompt. Only do this when - * the name is still the default (user hasn't set a custom name) and - * we're on the first turn (turnCount === 1 after the increment above). - */ - if ( - this.chatSession.name === "New Chat Session" && - this.chatSession.stats.turnCount === 1 - ) { - this.chatSession = - await ChatSessionService.generateSessionNameFromPrompt( - this.chatSession, - content, - ); - const update: Partial = { name: this.chatSession.name }; - this.log.debug("emitting sessionUpdated message", { update }); - this.socket.emit("sessionUpdated", update); - } } catch (error) { this.log.error("prompt rejected", { error }); cb(false, {}); diff --git a/gadget-code/src/services/chat-session.ts b/gadget-code/src/services/chat-session.ts index 2848d9f..2ffa3bc 100644 --- a/gadget-code/src/services/chat-session.ts +++ b/gadget-code/src/services/chat-session.ts @@ -404,6 +404,19 @@ class ChatSessionService extends DtpService { }); const api: AiApi = createAiApi(aiEnv, provider, this.log); + const systemPrompt = [ + "ROLE:", + "You are an assistant that creates titles for chat sessions by examining the first prompt submitted by the user.", + "", + "SCOPE:", + "You return just the title of the session, no explanation or justifications. You return one title, not multiple choices.", + "Select the best title for the chat session, and return that.", + "", + "CONSTRAINTS:", + "- Chat session titles shouldn't be longer than 60 characters.", + "- Chat session titles shouldn't contain vulgarity or degeneracy.", + ].join("\n"); + const response = await api.generate( { provider, @@ -416,9 +429,8 @@ class ChatSessionService extends DtpService { }, }, { - systemPrompt: - "You are an assistant that creates titles for chat sessions by examining the first prompt.", - prompt: `The first prompt submitted by the user: \n\n${prompt}`, + systemPrompt, + prompt: `Here is the first prompt submitted by the user:\n\n${prompt}`, }, ); diff --git a/gadget-drone/src/gadget-drone.ts b/gadget-drone/src/gadget-drone.ts index 265f92e..49deda8 100644 --- a/gadget-drone/src/gadget-drone.ts +++ b/gadget-drone/src/gadget-drone.ts @@ -584,7 +584,7 @@ class GadgetDrone extends GadgetProcess { "status", `failed to process work order: ${(error as Error).message}`, ); - // Leave cache in place for recovery + await WorkspaceService.removeWorkOrderCache(); } finally { process.chdir(workspaceDir); this.isProcessingWorkOrder = false; diff --git a/gadget-drone/src/services/agent.test.ts b/gadget-drone/src/services/agent.test.ts index bdaac66..c2a1458 100644 --- a/gadget-drone/src/services/agent.test.ts +++ b/gadget-drone/src/services/agent.test.ts @@ -1,11 +1,17 @@ -import { describe, expect, it } from "vitest"; +import { beforeEach, describe, expect, it } from "vitest"; import { ChatSessionMode, ChatTurnStatus } from "@gadget/api"; import { AgentService, type IAgentWorkOrder } from "./agent.ts"; describe("AgentService", () => { + let service: AgentService; + + beforeEach(async () => { + service = new AgentService(); + await service.start(); + }); + it("replays historical tool results as assistant-readable context, not raw tool-role messages", () => { - const service = new AgentService(); const user = { _id: "user-1", email: "user@example.com", @@ -110,4 +116,21 @@ describe("AgentService", () => { expect(messages[2]?.content).toContain("Historical tool result: file_read"); expect(messages[2]?.content).toContain("PATH: index.html"); }); + + it("does not expose mutating file tools in plan mode", () => { + const toolNames = service.getToolNamesForMode(ChatSessionMode.Plan); + + expect(toolNames).toContain("file_read"); + expect(toolNames).toContain("fetch_url"); + expect(toolNames).toContain("search_google"); + expect(toolNames).not.toContain("file_write"); + expect(toolNames).not.toContain("file_edit"); + }); + + it("exposes mutating file tools in build mode", () => { + const toolNames = service.getToolNamesForMode(ChatSessionMode.Build); + + expect(toolNames).toContain("file_write"); + expect(toolNames).toContain("file_edit"); + }); }); diff --git a/gadget-drone/src/services/agent.ts b/gadget-drone/src/services/agent.ts index 796ed6c..e45d620 100644 --- a/gadget-drone/src/services/agent.ts +++ b/gadget-drone/src/services/agent.ts @@ -3,12 +3,11 @@ // Licensed under the Apache License, Version 2.0 import env from "../config/env.ts"; -import assert from "node:assert"; import { Socket } from "socket.io-client"; import { IAiChatOptions, - IAiStreamChunk, + IAiResponseStreamFn, type IContextChatMessage, } from "@gadget/ai"; import { @@ -41,11 +40,6 @@ export interface IAgentWorkOrder { context: IChatTurn[]; } -interface IAgentWorkflow { - chatOptions: IAiChatOptions; - context: IContextChatMessage[]; -} - type DroneSocket = Socket; const toolboxEnv: DroneToolboxEnvironment = { @@ -71,25 +65,25 @@ class AgentService extends GadgetService { } async start(): Promise { - const googleSearchTool = new GoogleSearchTool(this.toolbox); - this.toolbox.register(googleSearchTool, [ - ChatSessionMode.Plan, - ChatSessionMode.Build, - ChatSessionMode.Test, - ChatSessionMode.Ship, - ChatSessionMode.Develop, - ]); - const modes = [ + const readOnlyModes = [ ChatSessionMode.Plan, ChatSessionMode.Build, ChatSessionMode.Test, ChatSessionMode.Ship, ChatSessionMode.Develop, ]; - this.toolbox.register(new FileReadTool(this.toolbox), modes); - this.toolbox.register(new FileWriteTool(this.toolbox), modes); - this.toolbox.register(new FileEditTool(this.toolbox), modes); - this.toolbox.register(new FetchUrlTool(this.toolbox), modes); + const writeModes = [ + ChatSessionMode.Build, + ChatSessionMode.Test, + ChatSessionMode.Ship, + ChatSessionMode.Develop, + ]; + + this.toolbox.register(new GoogleSearchTool(this.toolbox), readOnlyModes); + this.toolbox.register(new FileReadTool(this.toolbox), readOnlyModes); + this.toolbox.register(new FetchUrlTool(this.toolbox), readOnlyModes); + this.toolbox.register(new FileWriteTool(this.toolbox), writeModes); + this.toolbox.register(new FileEditTool(this.toolbox), writeModes); this.log.info("started"); } @@ -103,132 +97,130 @@ class AgentService extends GadgetService { socket: DroneSocket, ): Promise { const { turn } = workOrder; - const task: IAgentWorkflow = { - chatOptions: {}, - context: [], - }; - let streamedThinking = false; - let streamedResponse = false; - let streamedToolCall = false; + let toolCallCount = 0; + let inputTokens = 0; + let outputTokens = 0; - const onStreamChunk = async (chunk: IAiStreamChunk): Promise => { - // this.log.debug("stream chunk received", { chunk }); + // Build the full message array that grows with each iteration + const messages: IContextChatMessage[] = []; - switch (chunk.type) { - case "thinking": - streamedThinking = true; - socket.emit("thinking", chunk.data); - break; - case "response": - streamedResponse = true; - socket.emit("response", chunk.data); - break; - case "toolCall": - streamedToolCall = true; - socket.emit( - "toolCall", - chunk.toolCallId!, - chunk.toolName!, - chunk.params || "{}", - chunk.data, - ); - break; - } - }; + if (turn.prompts.system) { + messages.push({ + createdAt: turn.createdAt, + role: "system", + content: turn.prompts.system, + }); + } + + messages.push(...this.buildSessionContext(workOrder)); + + // Current turn's user prompt must be the last message before the AI call + messages.push({ + createdAt: turn.createdAt, + role: "user", + content: turn.prompts.user, + }); + + const reasoningEffort = turn.reasoningEffort || "off"; + const reasoning: boolean | "low" | "medium" | "high" = + reasoningEffort === "off" ? false : reasoningEffort; try { this.updateToolboxWorkspace(turn); - task.context = this.buildSessionContext(workOrder); - task.chatOptions = { - systemPrompt: turn.prompts.system, - context: task.context, - userPrompt: turn.prompts.user, - tools: this.getToolsForMode(turn.mode), - }; } catch (cause) { socket.emit( "workOrderComplete", turn._id, false, - `failed to build session context: ${(cause as Error).message}`, + `failed to update workspace: ${(cause as Error).message}`, ); - const error = new Error("failed to build session context", { cause }); - throw error; + throw new Error("failed to update workspace", { cause }); } + this.log.info("agent loop starting", { + turnId: turn._id, + messageCount: messages.length, + toolCount: this.toolbox.getToolNamesForMode(turn.mode).length, + }); + try { - const reasoningEffort = turn.reasoningEffort || "off"; - const reasoning: boolean | "low" | "medium" | "high" = - reasoningEffort === "off" ? false : reasoningEffort; + let continueLoop = true; + while (continueLoop) { + continueLoop = false; - const response = await AiService.chat( - turn.provider, - { - modelId: turn.llm, - params: { - reasoning, - temperature: 0.8, - topP: 0.9, - topK: 40, + this.log.info("agent loop iteration", { + messagesCount: messages.length, + toolsAvailable: this.toolbox.getToolNamesForMode(turn.mode), + }); + + const chatOptions: IAiChatOptions = { + context: messages, + tools: this.getToolsForMode(turn.mode), + }; + + const response = await AiService.chat( + turn.provider, + { + modelId: turn.llm, + params: { reasoning, temperature: 0.8, topP: 0.9, topK: 40 }, }, - }, - task.chatOptions, - onStreamChunk, - ); - - if (this.isEmptyAgentResponse(response)) { - throw new Error( - "AI provider returned an empty response: no thinking, response, tool calls, or tool results.", + chatOptions, + this.makeStreamHandler(socket), ); - } - // Check for model loading failure - if ( - response.doneReason === "load" && - !response.response && - !response.thinking && - (!response.toolCalls || response.toolCalls.length === 0) - ) { - throw new Error("Model failed to respond (still loading or error)"); - } + if (response.doneReason === "load" && !response.response && !response.thinking) { + throw new Error("Model failed to respond (still loading or error)"); + } - // Providers return accumulated final content; only emit it here when it - // was not already delivered through the stream callback. - if (response.thinking && !streamedThinking) { - socket.emit("thinking", response.thinking); - } + // Process tool calls if present + if (response.toolCalls && response.toolCalls.length > 0) { + continueLoop = true; + toolCallCount += response.toolCalls.length; - if (response.response && !streamedResponse) { - socket.emit("response", response.response); - } + messages.push({ + createdAt: turn.createdAt, + role: "assistant", + content: response.response, + }); - if (response.toolCalls && response.toolCalls.length > 0 && !streamedToolCall) { - for (const toolCall of response.toolCalls) { - socket.emit( - "toolCall", - toolCall.callId, - toolCall.function.name, - toolCall.function.arguments, - response.toolCallResults?.find((r) => r.callId === toolCall.callId) - ?.result || "", - ); + for (const toolCall of response.toolCalls) { + const result = await this.executeTool( + toolCall.function.name, + toolCall.function.arguments, + ); + + socket.emit( + "toolCall", + toolCall.callId, + toolCall.function.name, + toolCall.function.arguments, + result, + ); + + messages.push({ + createdAt: turn.createdAt, + role: "tool", + callId: toolCall.callId, + toolName: toolCall.function.name, + content: result, + }); + + inputTokens += Math.ceil(toolCall.function.arguments.length / 4); + outputTokens += Math.ceil(result.length / 4); + } } } - } catch (cause) { - socket.emit( - "workOrderComplete", - turn._id, - false, - `failed to process agentic workflow loop: ${(cause as Error).message}`, - ); - const error = new Error("failed to process agentic workflow loop", { - cause, - }); - throw error; - } - // Emit work order complete - socket.emit("workOrderComplete", turn._id, true); + socket.emit("workOrderComplete", turn._id, true); + } catch (cause) { + const msg = cause instanceof Error ? cause.message : String(cause); + this.log.error("agent loop failed, sending workOrderComplete(false)", { + turnId: turn._id, + error: msg, + }); + socket.emit("workOrderComplete", turn._id, false, msg); + throw cause; + } } buildSessionContext(workOrder: IAgentWorkOrder): IContextChatMessage[] { @@ -319,37 +311,62 @@ class AgentService extends GadgetService { return Array.from(this.toolbox.getModeSet(mode) || []); } + getToolNamesForMode(mode: ChatSessionMode): string[] { + return this.toolbox.getToolNamesForMode(mode); + } + private formatHistoricalToolResult(toolCall: { name: string; parameters?: string; response?: string; }): string { const response = toolCall.response || ""; - const maxLength = 8000; - const trimmedResponse = - response.length > maxLength - ? `${response.slice(0, maxLength)}\n\n[Tool result truncated from ${response.length} characters.]` - : response; return [ `Historical tool result: ${toolCall.name}`, `Parameters: ${toolCall.parameters || "{}"}`, "---", - trimmedResponse, + response.length > 8000 + ? `${response.slice(0, 8000)}\n\n[Tool result truncated from ${response.length} characters.]` + : response, ].join("\n"); } - private isEmptyAgentResponse(response: { - response?: string; - thinking?: string; - toolCalls?: unknown[]; - toolCallResults?: unknown[]; - }): boolean { - return ( - !(response.response && response.response.trim()) && - !(response.thinking && response.thinking.trim()) && - !(response.toolCalls && response.toolCalls.length) && - !(response.toolCallResults && response.toolCallResults.length) - ); + private makeStreamHandler(socket: DroneSocket): IAiResponseStreamFn { + return async (chunk) => { + switch (chunk.type) { + case "thinking": + socket.emit("thinking", chunk.data); + break; + case "response": + socket.emit("response", chunk.data); + break; + } + }; + } + + private async executeTool(name: string, argsJson: string): Promise { + const tool = this.toolbox.getTool(name); + if (!tool) { + const msg = `Unknown tool: ${name}`; + this.log.error("tool not found", { toolName: name }); + return JSON.stringify({ success: false, error: msg }); + } + + try { + const args = JSON.parse(argsJson); + this.log.info("executing tool", { name, params: argsJson }); + const result = await tool.execute(args, this.log); + this.log.info("tool result", { + name, + resultLength: result.length, + preview: result.length > 100 ? `${result.slice(0, 100)}...` : result, + }); + return result; + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + this.log.error("tool execution failed", { toolName: name, args: argsJson, error: msg }); + return JSON.stringify({ success: false, error: msg }); + } } private updateToolboxWorkspace(turn: IChatTurn): void { diff --git a/gadget-drone/src/tools/toolbox.ts b/gadget-drone/src/tools/toolbox.ts index 5905947..b4ad484 100644 --- a/gadget-drone/src/tools/toolbox.ts +++ b/gadget-drone/src/tools/toolbox.ts @@ -67,4 +67,8 @@ export class AiToolbox { getModeSet(mode: string): ToolSet | undefined { return this.modeSets.get(mode); } + + getToolNamesForMode(mode: string): string[] { + return Array.from(this.getModeSet(mode) || []).map((tool) => tool.name); + } } diff --git a/packages/ai/src/api.ts b/packages/ai/src/api.ts index c221ef5..df7718c 100644 --- a/packages/ai/src/api.ts +++ b/packages/ai/src/api.ts @@ -175,31 +175,6 @@ export abstract class AiApi { streamCallback?: IAiResponseStreamFn, ): Promise; - protected shouldContinueAfterNonToolResponse(response: string): boolean { - const normalized = response.trim().toLowerCase(); - if (!normalized) { - return false; - } - - const futureIntentPatterns = [ - /\bi\s*(?:will|'ll|am going to)\b/, - /\blet me\b/, - /\bi need to\b/, - /\bi should\b/, - /\bi(?:'m| am) going to\b/, - /\bnext,? i\b/, - /\bi(?:'ll| will) inspect\b/, - /\bi(?:'ll| will) read\b/, - /\bi(?:'ll| will) open\b/, - /\bi(?:'ll| will) check\b/, - /\bi(?:'ll| will) update\b/, - /\bi(?:'ll| will) modify\b/, - /\bi(?:'ll| will) fix\b/, - ]; - - return futureIntentPatterns.some((pattern) => pattern.test(normalized)); - } - protected assertNonEmptyChatResponse(response: IAiChatResponse): void { const hasResponse = response.response.trim().length > 0; const hasThinking = !!response.thinking?.trim(); @@ -213,101 +188,4 @@ export abstract class AiApi { } } - protected buildContinuationPrompt(): string { - return [ - "You stopped after describing future work instead of performing it.", - "Do not explain what you are about to do.", - "Call the appropriate tool now, or provide a final answer only if no tool use is needed.", - ].join(" "); - } - - protected shouldContinueForUserWorkRequest( - userPrompt: string | undefined, - response: string, - hasExecutedToolsThisTurn: boolean, - ): boolean { - if (hasExecutedToolsThisTurn) { - return false; - } - - const prompt = (userPrompt || "").toLowerCase(); - const answer = response.trim().toLowerCase(); - if (!prompt || !answer) { - return false; - } - - const workRequestPatterns = [ - /\bread\b/, - /\bfix\b/, - /\bchange\b/, - /\bupdate\b/, - /\bmodify\b/, - /\bedit\b/, - /\bwrite\b/, - /\bcreate\b/, - /\bimplement\b/, - /\bdebug\b/, - /\binspect\b/, - /\bopen\b/, - ]; - const directAnswerPatterns = [ - /\bcompleted\b/, - /\bfixed\b/, - /\bupdated\b/, - /\bchanged\b/, - /\bimplemented\b/, - /\bno tool use (?:is|was) needed\b/, - /\bdoes not require tool use\b/, - ]; - - return ( - workRequestPatterns.some((pattern) => pattern.test(prompt)) && - !directAnswerPatterns.some((pattern) => pattern.test(answer)) - ); - } - - protected async executeToolCalls( - toolCalls: IToolCall[], - tools: IAiTool[], - ): Promise { - const results: IToolCallResult[] = []; - - for (const toolCall of toolCalls) { - const tool = tools.find((t) => t.name === toolCall.function.name); - - if (!tool) { - this.log.warn(`tool not found: ${toolCall.function.name}`); - results.push({ - callId: toolCall.callId, - functionName: toolCall.function.name, - result: "", - error: `Tool '${toolCall.function.name}' not found`, - }); - continue; - } - - try { - const args = JSON.parse(toolCall.function.arguments); - const result = await tool.execute(args, this.log); - results.push({ - callId: toolCall.callId, - functionName: toolCall.function.name, - result, - }); - } catch (error) { - const errorMessage = (error as Error).message; - this.log.error(`tool execution failed: ${toolCall.function.name}`, { - error: errorMessage, - }); - results.push({ - callId: toolCall.callId, - functionName: toolCall.function.name, - result: "", - error: errorMessage, - }); - } - } - - return results; - } } diff --git a/packages/ai/src/ollama.test.ts b/packages/ai/src/ollama.test.ts index cafcdca..64d9859 100644 --- a/packages/ai/src/ollama.test.ts +++ b/packages/ai/src/ollama.test.ts @@ -118,61 +118,33 @@ describe('OllamaAiApi', () => { }); it('should handle tool calls', async () => { - // Mock streaming response with tool call - let callCount = 0; - mockOllamaClient.chat.mockImplementation(() => { - callCount++; - return (async function* () { - if (callCount === 1) { - yield { - message: { - content: '', - tool_calls: [ - { - function: { - name: 'search_google', - arguments: { query: 'test query' }, - }, - }, - ], + const mockStream = async function* () { + yield { + message: { + content: '', + tool_calls: [ + { + function: { + name: 'search_google', + arguments: { query: 'test query' }, + }, }, - done: false, - }; - yield { - message: { content: '' }, - done: true, - done_reason: 'stop', - total_duration: 100, - prompt_eval_count: 10, - eval_count: 1, - }; - } else { - yield { - message: { content: 'Done' }, - done: true, - done_reason: 'stop', - total_duration: 100, - prompt_eval_count: 10, - eval_count: 1, - }; - } - })(); - }); - - const mockTool = { - name: 'search_google', - category: 'search', - definition: { - type: 'function', - function: { - name: 'search_google', - description: 'Search Google', - parameters: { type: 'object', properties: {} }, + ], }, - }, - execute: vi.fn().mockResolvedValue('search results'), + done: false, + }; + yield { + message: { content: '' }, + done: true, + done_reason: 'stop', + total_duration: 100, + prompt_eval_count: 10, + eval_count: 1, + }; }; + mockOllamaClient.chat.mockResolvedValue(mockStream()); + const streamCallback = vi.fn(); const response = await api.chat( { @@ -183,28 +155,20 @@ describe('OllamaAiApi', () => { { userPrompt: 'Test prompt', context: [], - tools: [mockTool as any], }, streamCallback, ); - // Verify tool call was emitted via stream callback - expect(streamCallback).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'toolCall', - toolName: 'search_google', - }), - ); - - // Verify tool was executed - expect(mockTool.execute).toHaveBeenCalled(); - - // Verify response indicates tool calls were processed + // Verify tool calls are returned, not executed expect(response.toolCalls).toBeDefined(); + expect(response.toolCalls!.length).toBe(1); + expect(response.toolCalls![0].function.name).toBe('search_google'); + + // chat() should only be called once (no internal loop) + expect(mockOllamaClient.chat).toHaveBeenCalledTimes(1); }); it('should handle thinking content when reasoning is enabled', async () => { - // Mock streaming response with thinking const mockStream = async function* () { yield { message: { @@ -246,7 +210,6 @@ describe('OllamaAiApi', () => { streamCallback, ); - // Verify thinking was emitted expect(streamCallback).toHaveBeenCalledWith({ type: 'thinking', data: 'Let me think about this...', @@ -255,19 +218,14 @@ describe('OllamaAiApi', () => { type: 'thinking', data: ' The answer is', }); - - // Verify response was emitted expect(streamCallback).toHaveBeenCalledWith({ type: 'response', data: '42', }); - - // Verify final response includes thinking expect(response.thinking).toBe('Let me think about this... The answer is'); }); it('should reject empty response on load failure', async () => { - // Mock streaming response with load failure const mockStream = async function* () { yield { message: { content: '' }, @@ -294,91 +252,6 @@ describe('OllamaAiApi', () => { vi.fn(), )).rejects.toThrow('Provider returned an empty chat response'); }); - - it('should iterate tool calling loop when tools are present', async () => { - let callCount = 0; - - // Mock streaming response that requires tool call then returns - const mockStream = async function* () { - callCount++; - if (callCount === 1) { - // First call: return tool call - yield { - message: { - content: '', - tool_calls: [ - { - function: { - name: 'search_google', - arguments: { query: 'test' }, - }, - }, - ], - }, - done: false, - }; - yield { - message: { content: '' }, - done: true, - done_reason: 'stop', - total_duration: 100, - prompt_eval_count: 10, - eval_count: 1, - }; - } else { - // Second call: return final response - yield { - message: { content: 'Here are the results' }, - done: true, - done_reason: 'stop', - total_duration: 100, - prompt_eval_count: 15, - eval_count: 5, - }; - } - }; - - mockOllamaClient.chat.mockImplementation(() => mockStream()); - - const mockTool = { - name: 'search_google', - category: 'search', - definition: { - type: 'function', - function: { - name: 'search_google', - description: 'Search Google', - parameters: { type: 'object', properties: {} }, - }, - }, - execute: vi.fn().mockResolvedValue('search results'), - }; - - const streamCallback = vi.fn(); - const response = await api.chat( - { - provider: mockProvider as any, - modelId: 'test-model', - params: { reasoning: false, temperature: 0.8, topP: 0.9, topK: 40 }, - }, - { - userPrompt: 'Test prompt', - context: [], - tools: [mockTool as any], - }, - streamCallback, - ); - - // Verify chat was called twice (once for tool call, once for response) - expect(mockOllamaClient.chat).toHaveBeenCalledTimes(2); - - // Verify tool was executed - expect(mockTool.execute).toHaveBeenCalled(); - - // Verify final response - expect(response.done).toBe(true); - expect(response.response).toBe('Here are the results'); - }); }); describe('probeModel', () => { diff --git a/packages/ai/src/ollama.ts b/packages/ai/src/ollama.ts index 6a4962a..10e82ba 100644 --- a/packages/ai/src/ollama.ts +++ b/packages/ai/src/ollama.ts @@ -11,7 +11,6 @@ import { IAiChatOptions, IAiChatResponse, IToolCall, - IToolCallResult, IAiGenerateOptions, IAiGenerateResponse, IAiLogger, @@ -208,23 +207,12 @@ export class OllamaAiApi extends AiApi { modelId: model.modelId, }); - // VALIDATE: Ensure we have at least one message with content - if (!options.userPrompt || !options.userPrompt.trim()) { - throw new Error("userPrompt is required and cannot be empty"); - } - - // Build messages array like OpenAI does const messages: OllamaMessage[] = []; - // Add system prompt if present if (options.systemPrompt) { - messages.push({ - role: "system", - content: options.systemPrompt, - }); + messages.push({ role: "system", content: options.systemPrompt }); } - // Add context messages if (options.context) { for (const msg of options.context) { if (msg.content && msg.content.trim()) { @@ -244,217 +232,97 @@ export class OllamaAiApi extends AiApi { } } - // Add user prompt (required) - messages.push({ - role: "user", - content: options.userPrompt, - }); + if (options.userPrompt) { + messages.push({ role: "user", content: options.userPrompt }); + } - // VALIDATE: Ensure messages array is not empty before calling API if (messages.length === 0) { throw new Error( "Messages array is empty - cannot call Ollama API with no messages", ); } - // DEBUG: Log what we're sending to Ollama - await this.log.debug("Ollama chat request", { - messagesCount: messages.length, - messages: messages.map((m) => ({ - role: m.role, - contentLength: m.content?.length || 0, - })), - userPrompt: options.userPrompt?.slice(0, 100), - contextCount: options.context?.length || 0, + const ollamaTools = options.tools + ? options.tools.map((tool) => ({ + type: tool.definition.type, + function: { + name: tool.definition.function.name, + description: tool.definition.function.description, + parameters: tool.definition.function.parameters, + }, + })) + : undefined; + + const response = await this.client.chat({ + model: model.modelId, + messages, + stream: true, + think: model.params.reasoning, + tools: ollamaTools, }); - const allToolCallResults: IToolCallResult[] = []; - const allToolCalls: IToolCall[] = []; - let totalAccumulatedResponse = ""; - let totalAccumulatedThinking = ""; + let lastChunk; + let accumulatedThinking = ""; + let accumulatedResponse = ""; + const toolCalls: IToolCall[] = []; - /* - * Our agents do not have iteration count limits. We have seen an agent - * issue 100+ legitimate calls in a single turn. - */ - while (true) { - const ollamaTools = options.tools - ? options.tools.map((tool) => ({ - type: tool.definition.type, - function: { - name: tool.definition.function.name, - description: tool.definition.function.description, - parameters: tool.definition.function.parameters, - }, - })) - : undefined; + for await (const chunk of response) { + lastChunk = chunk; - const response = await this.client.chat({ - model: model.modelId, - messages, - stream: true, - think: model.params.reasoning, - tools: ollamaTools, - }); - - let lastChunk; - let accumulatedThinking = ""; - let accumulatedResponse = ""; - const streamedToolCalls: Array<{ - callId: string; - function: { name: string; arguments: any }; - }> = []; - - for await (const chunk of response) { - lastChunk = chunk; - - if (chunk.message.thinking) { - accumulatedThinking += chunk.message.thinking; - if (streamCallback) { - await streamCallback({ - type: "thinking", - data: chunk.message.thinking, - }); - } - } - if (chunk.message.content) { - accumulatedResponse += chunk.message.content; - if (streamCallback) { - await streamCallback({ - type: "response", - data: chunk.message.content, - }); - } - } - if (chunk.message.tool_calls) { - for (const [index, tc] of chunk.message.tool_calls.entries()) { - const params = JSON.stringify(tc.function.arguments); - const callId = `tool_${tc.function.name}_${Date.now()}_${index}`; - - const toolCall: IToolCall = { - callId, - function: { - name: tc.function.name, - arguments: params, - }, - }; - streamedToolCalls.push(toolCall); - allToolCalls.push(toolCall); - } - } - } - assert(lastChunk, "no response chunks received"); - - // Use accumulated thinking/response for final response - const finalThinking = accumulatedThinking || lastChunk.message.thinking; - const finalResponse = accumulatedResponse || lastChunk.message.content; - - // Accumulate across iterations - totalAccumulatedResponse += finalResponse || ""; - totalAccumulatedThinking += finalThinking || ""; - - // Use accumulated tool calls from stream - const toolCalls = streamedToolCalls; - - if (!toolCalls || toolCalls.length === 0) { - if ( - options.tools?.length && - (this.shouldContinueAfterNonToolResponse(finalResponse || "") || - this.shouldContinueForUserWorkRequest( - options.userPrompt, - finalResponse || "", - allToolCallResults.length > 0, - )) - ) { - await this.log.warn("model produced future-intent text without tool calls; continuing AWL", { - responseLength: (finalResponse || "").length, - }); - messages.push({ - role: "assistant", - content: finalResponse || "", - }); - messages.push({ - role: "user", - content: this.buildContinuationPrompt(), - }); - continue; - } - - const chatResponse: IAiChatResponse = { - done: lastChunk.done, - doneReason: lastChunk.done_reason, - response: totalAccumulatedResponse, - thinking: totalAccumulatedThinking, - toolCalls: allToolCalls.length > 0 ? allToolCalls : undefined, - toolCallResults: - allToolCallResults.length > 0 ? allToolCallResults : undefined, - stats: { - duration: { - seconds: lastChunk.total_duration, - text: numeral(lastChunk.total_duration).format("hh:mm:ss"), - }, - tokenCounts: { - input: lastChunk.prompt_eval_count, - response: lastChunk.eval_count, - thinking: 0, - }, - }, - }; - this.assertNonEmptyChatResponse(chatResponse); - return chatResponse; - } - - const toolCallResults = await this.executeToolCalls( - toolCalls, - options.tools || [], - ); - allToolCallResults.push(...toolCallResults); - - if (streamCallback) { - for (const result of toolCallResults) { - const toolCall = toolCalls.find((tc) => tc.callId === result.callId); + if (chunk.message.thinking) { + accumulatedThinking += chunk.message.thinking; + if (streamCallback) { await streamCallback({ - type: "toolCall", - data: result.error || result.result, - toolCallId: result.callId, - toolName: result.functionName, - params: toolCall?.function.arguments || "{}", + type: "thinking", + data: chunk.message.thinking, }); } } - - const assistantMsg: OllamaMessage = { - role: "assistant", - content: accumulatedResponse || lastChunk.message.content, - }; - if (lastChunk.message.thinking) { - assistantMsg.thinking = lastChunk.message.thinking; + if (chunk.message.content) { + accumulatedResponse += chunk.message.content; + if (streamCallback) { + await streamCallback({ + type: "response", + data: chunk.message.content, + }); + } } - if (lastChunk.message.tool_calls) { - assistantMsg.tool_calls = lastChunk.message.tool_calls; - } - messages.push(assistantMsg); + if (chunk.message.tool_calls) { + for (const [index, tc] of chunk.message.tool_calls.entries()) { + const params = JSON.stringify(tc.function.arguments); + const callId = `tool_${tc.function.name}_${Date.now()}_${index}`; - for (const result of toolCallResults) { - const toolContent = result.error - ? `Error executing ${result.functionName}: ${result.error}` - : result.result; - - const toolMsg = { - role: "tool" as const, - content: toolContent, - }; - messages.push(toolMsg); - } - - // VALIDATE: Ensure tool results are in messages - const toolMessages = messages.filter((m) => m.role === "tool"); - if (toolMessages.length === 0 && toolCallResults.length > 0) { - await this.log.error("CRITICAL: tool results NOT in messages array", { - toolCallResultsCount: toolCallResults.length, - messagesCount: messages.length, - }); + toolCalls.push({ + callId, + function: { + name: tc.function.name, + arguments: params, + }, + }); + } } } + assert(lastChunk, "no response chunks received"); + + const chatResponse: IAiChatResponse = { + done: lastChunk.done, + doneReason: lastChunk.done_reason, + response: accumulatedResponse || lastChunk.message.content, + thinking: accumulatedThinking || lastChunk.message.thinking, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + stats: { + duration: { + seconds: lastChunk.total_duration, + text: numeral(lastChunk.total_duration).format("hh:mm:ss"), + }, + tokenCounts: { + input: lastChunk.prompt_eval_count, + response: lastChunk.eval_count, + thinking: 0, + }, + }, + }; + this.assertNonEmptyChatResponse(chatResponse); + return chatResponse; } } diff --git a/packages/ai/src/openai.test.ts b/packages/ai/src/openai.test.ts index 310f090..80db455 100644 --- a/packages/ai/src/openai.test.ts +++ b/packages/ai/src/openai.test.ts @@ -72,73 +72,51 @@ describe("OpenAiApi", () => { )).rejects.toThrow("Provider returned an empty chat response"); }); - it("assembles streamed tool-call argument fragments before executing", async () => { - mockCreate - .mockResolvedValueOnce(streamChunks([ - { - choices: [{ - delta: { - tool_calls: [{ - index: 0, - id: "call_1", - type: "function", - function: { name: "file_read", arguments: '{"path"' }, - }], - }, - finish_reason: null, - }], - }, - { - choices: [{ - delta: { - tool_calls: [{ - index: 0, - function: { arguments: ':"index.html"}' }, - }], - }, - finish_reason: "tool_calls", - }], - }, - ])) - .mockResolvedValueOnce(streamChunks([ - { choices: [{ delta: { content: "Done" }, finish_reason: "stop" }] }, - ])); - - const tool = { - name: "file_read", - category: "file", - definition: { - type: "function" as const, - function: { - name: "file_read", - description: "Read file", - parameters: { type: "object", properties: {} }, - }, + it("assembles streamed tool-call argument fragments and returns them", async () => { + mockCreate.mockResolvedValueOnce(streamChunks([ + { + choices: [{ + delta: { + tool_calls: [{ + index: 0, + id: "call_1", + type: "function", + function: { name: "file_read", arguments: '{"path"' }, + }], + }, + finish_reason: null, + }], }, - execute: vi.fn().mockResolvedValue("PATH: index.html\n---\ncontent"), - }; - const streamCallback = vi.fn(); + { + choices: [{ + delta: { + tool_calls: [{ + index: 0, + function: { arguments: ':"index.html"}' }, + }], + }, + finish_reason: "tool_calls", + }], + }, + ])); + const streamCallback = vi.fn(); const response = await api.chat( { provider: mockProvider as any, modelId: "test-model", params: { reasoning: false, temperature: 0.8, topP: 0.9, topK: 40 }, }, - { userPrompt: "Read index.html", context: [], tools: [tool] }, + { userPrompt: "Read index.html", context: [], tools: [] }, streamCallback, ); - expect(tool.execute).toHaveBeenCalledWith({ path: "index.html" }, mockLogger); - expect(streamCallback).toHaveBeenCalledWith(expect.objectContaining({ - type: "toolCall", - toolCallId: "call_1", - toolName: "file_read", - data: "PATH: index.html\n---\ncontent", - params: '{"path":"index.html"}', - })); - expect(response.response).toBe("Done"); - expect(mockCreate).toHaveBeenCalledTimes(2); + // Tool calls are returned, not executed + expect(response.toolCalls).toBeDefined(); + expect(response.toolCalls!.length).toBe(1); + expect(response.toolCalls![0].function.name).toBe("file_read"); + expect(response.toolCalls![0].function.arguments).toBe('{"path":"index.html"}'); + expect(mockCreate).toHaveBeenCalledTimes(1); }); it("falls back to non-streaming response when stream has no deltas", async () => { @@ -164,4 +142,40 @@ describe("OpenAiApi", () => { expect(response.response).toBe("Fallback answer"); expect(streamCallback).toHaveBeenCalledWith({ type: "response", data: "Fallback answer" }); }); + + it("returns conversational responses without forcing another iteration", async () => { + mockCreate.mockResolvedValueOnce(streamChunks([ + { choices: [{ delta: { content: "I can talk this through." }, finish_reason: "stop" }] }, + ])); + const streamCallback = vi.fn(); + + const response = await api.chat( + { + provider: mockProvider as any, + modelId: "test-model", + params: { reasoning: false, temperature: 0.8, topP: 0.9, topK: 40 }, + }, + { + userPrompt: "Don't edit code. Just talk to me.", + context: [], + tools: [{ + name: "file_edit", + category: "file", + definition: { + type: "function" as const, + function: { + name: "file_edit", + description: "Edit file", + parameters: { type: "object", properties: {} }, + }, + }, + execute: vi.fn(), + }], + }, + streamCallback, + ); + + expect(response.response).toBe("I can talk this through."); + expect(mockCreate).toHaveBeenCalledTimes(1); + }); }); diff --git a/packages/ai/src/openai.ts b/packages/ai/src/openai.ts index 83ae059..d7b329f 100644 --- a/packages/ai/src/openai.ts +++ b/packages/ai/src/openai.ts @@ -9,7 +9,6 @@ import { IAiChatOptions, IAiChatResponse, IToolCall, - IToolCallResult, IAiGenerateOptions, IAiGenerateResponse, IAiLogger, @@ -20,11 +19,9 @@ import { IAiResponseStreamFn, } from "./api.js"; import { - ChatCompletionAssistantMessageParam, ChatCompletionFunctionTool, ChatCompletionMessageParam, ChatCompletionTool, - ChatCompletionToolMessageParam, } from "openai/resources"; import { IAiEnvironment } from "./config/env.ts"; @@ -67,6 +64,24 @@ interface StreamingToolCallAccumulator { }; } +interface OpenAiChatIterationResult { + response: string; + thinking?: string; + toolCalls: IToolCall[]; + assistantToolCalls: Array<{ + id: string; + type: "function"; + function: { + name: string; + arguments: string; + }; + }>; + finishReason?: string | null; + chunkCount: number; + contentDeltaCount: number; + toolDeltaCount: number; +} + export class OpenAiApi extends AiApi { protected client: OpenAI; @@ -256,277 +271,169 @@ export class OpenAiApi extends AiApi { }); const startTime = Date.now(); + const messages = this.buildMessages(options); + const tools = this.buildTools(options); + + let iteration = await this.readStreamingChatCompletion( + model, + messages, + tools, + streamCallback, + ); + + await this.log.debug("OpenAI chat stream iteration finished", { + chunkCount: iteration.chunkCount, + contentDeltaCount: iteration.contentDeltaCount, + toolDeltaCount: iteration.toolDeltaCount, + responseLength: iteration.response.length, + thinkingLength: iteration.thinking?.length || 0, + toolCallCount: iteration.toolCalls.length, + finishReason: iteration.finishReason, + }); + + if (this.isEmptyIteration(iteration)) { + iteration = await this.readNonStreamingChatCompletion(model, messages, tools); + if (streamCallback && iteration.response) { + await streamCallback({ type: "response", data: iteration.response }); + } + await this.log.warn("OpenAI stream was empty; used non-streaming fallback", { + responseLength: iteration.response.length, + thinkingLength: iteration.thinking?.length || 0, + toolCallCount: iteration.toolCalls.length, + finishReason: iteration.finishReason, + }); + } + + if (this.isEmptyIteration(iteration)) { + this.assertNonEmptyChatResponse({ + done: true, + response: "", + thinking: undefined, + toolCalls: undefined, + toolCallResults: undefined, + stats: this.buildStats(startTime), + }); + } + + const finalResponse: IAiChatResponse = { + done: true, + response: iteration.response, + thinking: iteration.thinking, + toolCalls: iteration.toolCalls.length > 0 ? iteration.toolCalls : undefined, + stats: this.buildStats(startTime), + }; + this.assertNonEmptyChatResponse(finalResponse); + return finalResponse; + } + + private buildMessages(options: IAiChatOptions): ChatCompletionMessageParam[] { const messages: ChatCompletionMessageParam[] = []; if (options.systemPrompt) { messages.push({ role: "system", content: options.systemPrompt }); } - if (options.context) { - for (const msg of options.context) { - if (msg.role === "tool") { - messages.push({ - role: "tool", - content: msg.content, - tool_call_id: msg.callId || "", - }); - } else { - messages.push({ - role: msg.role as "user" | "assistant" | "system", - content: msg.content, - }); - } + for (const msg of options.context || []) { + if (!msg.content?.trim()) continue; + if (msg.role === "tool") { + messages.push({ + role: "assistant", + content: `Historical tool result${msg.toolName ? ` from ${msg.toolName}` : ""}:\n${msg.content}`, + }); + continue; } + messages.push({ + role: msg.role as "user" | "assistant" | "system", + content: msg.content, + }); } - if (options.userPrompt) { + if (options.userPrompt?.trim()) { messages.push({ role: "user", content: options.userPrompt }); } - - const allToolCallResults: IToolCallResult[] = []; - const allToolCalls: IToolCall[] = []; - - while (true) { - const tools: ChatCompletionTool[] = options.tools - ? options.tools.map((tool) => { - const openaiTool: ChatCompletionFunctionTool = { - type: tool.definition.type, - function: { - name: tool.definition.function.name, - description: tool.definition.function.description, - parameters: tool.definition.function.parameters, - }, - }; - return openaiTool; - }) - : []; - - const response = await this.client.chat.completions.create({ - model: model.modelId, - messages, - tools, - stream: true, - ...(typeof model.params.reasoning === "string" - ? { - reasoning_effort: model.params.reasoning as - | "low" - | "medium" - | "high", - } - : {}), - }); - - let accumulatedResponse = ""; - let accumulatedThinking = ""; - let chunkCount = 0; - let contentDeltaCount = 0; - let toolDeltaCount = 0; - let finishReason: string | null | undefined; - const toolCallMap = new Map(); - let assistantToolCallsForMessage: Array<{ - id: string; - type: "function"; - function: { name: string; arguments: string }; - }> = []; - - for await (const chunk of response) { - chunkCount++; - finishReason = chunk.choices[0]?.finish_reason ?? finishReason; - const delta = chunk.choices[0]?.delta; - if (delta) { - if (delta.content) { - contentDeltaCount++; - accumulatedResponse += delta.content; - if (streamCallback) { - await streamCallback({ - type: "response", - data: delta.content, - }); - } - } - if ("reasoning" in delta && delta.reasoning) { - accumulatedThinking += delta.reasoning as string; - if (streamCallback) { - await streamCallback({ - type: "thinking", - data: delta.reasoning as string, - }); - } - } - if (delta.tool_calls) { - toolDeltaCount += delta.tool_calls.length; - for (const tc of delta.tool_calls) { - const index = tc.index; - let accumulated = toolCallMap.get(index); - if (!accumulated) { - accumulated = { - index, - id: tc.id || `tool_${Date.now()}_${index}`, - type: "function", - function: { - name: "", - arguments: "", - }, - }; - toolCallMap.set(index, accumulated); - } - - if (tc.id) { - accumulated.id = tc.id; - } - if (tc.function?.name) { - accumulated.function.name += tc.function.name; - } - if (tc.function?.arguments) { - accumulated.function.arguments += tc.function.arguments; - } - } - } - } - } - - const finalToolCalls = Array.from(toolCallMap.values()) - .sort((a, b) => a.index - b.index) - .filter((tc) => tc.function.name); - const toolCalls = finalToolCalls.map((tc) => ({ - callId: tc.id, - function: { - name: tc.function.name, - arguments: tc.function.arguments, - }, - })); - assistantToolCallsForMessage = finalToolCalls.map((tc) => ({ - id: tc.id, - type: "function" as const, - function: { - name: tc.function.name, - arguments: tc.function.arguments, - }, - })); - allToolCalls.push(...toolCalls); - - await this.log.debug("OpenAI chat stream iteration finished", { - chunkCount, - contentDeltaCount, - toolDeltaCount, - responseLength: accumulatedResponse.length, - thinkingLength: accumulatedThinking.length, - toolCallCount: toolCalls.length, - finishReason, - }); - - if (chunkCount > 0 && !accumulatedResponse && !accumulatedThinking && toolCalls.length === 0) { - const fallback = await this.chatOnceNonStreaming(model, messages, tools); - accumulatedResponse = fallback.response; - accumulatedThinking = fallback.thinking || ""; - toolCalls.push(...fallback.toolCalls); - allToolCalls.push(...fallback.toolCalls); - assistantToolCallsForMessage = fallback.assistantToolCalls; - if (streamCallback && fallback.response) { - await streamCallback({ type: "response", data: fallback.response }); - } - await this.log.warn("OpenAI stream was empty; used non-streaming fallback", { - responseLength: accumulatedResponse.length, - thinkingLength: accumulatedThinking.length, - toolCallCount: fallback.toolCalls.length, - finishReason, - }); - } - - if (!toolCalls || toolCalls.length === 0) { - if ( - options.tools?.length && - (this.shouldContinueAfterNonToolResponse(accumulatedResponse) || - this.shouldContinueForUserWorkRequest( - options.userPrompt, - accumulatedResponse, - allToolCallResults.length > 0, - )) - ) { - await this.log.warn("model produced future-intent text without tool calls; continuing AWL", { - responseLength: accumulatedResponse.length, - }); - messages.push({ role: "assistant", content: accumulatedResponse }); - messages.push({ role: "user", content: this.buildContinuationPrompt() }); - continue; - } - - const finalResponse: IAiChatResponse = { - done: true, - response: accumulatedResponse, - thinking: accumulatedThinking || undefined, - toolCalls: allToolCalls.length > 0 ? allToolCalls : undefined, - toolCallResults: - allToolCallResults.length > 0 ? allToolCallResults : undefined, - stats: { - duration: { - seconds: (Date.now() - startTime) / 1000, - text: numeral((Date.now() - startTime) / 1000).format("hh:mm:ss"), - }, - tokenCounts: { - input: 0, - response: 0, - thinking: 0, - }, - }, - }; - this.assertNonEmptyChatResponse(finalResponse); - return finalResponse; - } - - const toolCallResults = await this.executeToolCalls( - toolCalls, - options.tools || [], - ); - allToolCallResults.push(...toolCallResults); - - if (streamCallback) { - for (const result of toolCallResults) { - const toolCall = toolCalls.find((tc) => tc.callId === result.callId); - await streamCallback({ - type: "toolCall", - data: result.error || result.result, - toolCallId: result.callId, - toolName: result.functionName, - params: toolCall?.function.arguments || "{}", - }); - } - } - - const assistantMsg: ChatCompletionAssistantMessageParam = { - role: "assistant", - content: accumulatedResponse, - }; - if (assistantToolCallsForMessage.length) { - assistantMsg.tool_calls = assistantToolCallsForMessage; - } - messages.push(assistantMsg); - - for (const result of toolCallResults) { - const toolMsg: ChatCompletionToolMessageParam = { - role: "tool", - tool_call_id: result.callId, - content: result.error || result.result, - }; - messages.push(toolMsg); - } - } + return messages; } - private async chatOnceNonStreaming( + private buildTools(options: IAiChatOptions): ChatCompletionTool[] { + return (options.tools || []).map((tool): ChatCompletionFunctionTool => ({ + type: tool.definition.type, + function: { + name: tool.definition.function.name, + description: tool.definition.function.description, + parameters: tool.definition.function.parameters, + }, + })); + } + + private async readStreamingChatCompletion( model: IAiModelConfig, messages: ChatCompletionMessageParam[], tools: ChatCompletionTool[], - ): Promise<{ - response: string; - thinking?: string; - toolCalls: IToolCall[]; - assistantToolCalls: Array<{ - id: string; - type: "function"; - function: { - name: string; - arguments: string; - }; - }>; - }> { + streamCallback?: IAiResponseStreamFn, + ): Promise { + const response = await this.client.chat.completions.create({ + model: model.modelId, + messages, + tools, + stream: true, + ...(typeof model.params.reasoning === "string" + ? { + reasoning_effort: model.params.reasoning as + | "low" + | "medium" + | "high", + } + : {}), + }); + + let content = ""; + let thinking = ""; + let chunkCount = 0; + let contentDeltaCount = 0; + let toolDeltaCount = 0; + let finishReason: string | null | undefined; + const toolCallMap = new Map(); + + for await (const chunk of response) { + chunkCount++; + finishReason = chunk.choices[0]?.finish_reason ?? finishReason; + const delta = chunk.choices[0]?.delta; + if (!delta) continue; + + if (delta.content) { + contentDeltaCount++; + content += delta.content; + if (streamCallback) { + await streamCallback({ type: "response", data: delta.content }); + } + } + if ("reasoning" in delta && delta.reasoning) { + thinking += delta.reasoning as string; + if (streamCallback) { + await streamCallback({ type: "thinking", data: delta.reasoning as string }); + } + } + if (delta.tool_calls) { + toolDeltaCount += delta.tool_calls.length; + for (const toolCallDelta of delta.tool_calls) { + this.accumulateToolCallDelta(toolCallMap, toolCallDelta); + } + } + } + + return { + response: content, + thinking: thinking || undefined, + ...this.buildToolCallsFromMap(toolCallMap), + finishReason, + chunkCount, + contentDeltaCount, + toolDeltaCount, + }; + } + + private async readNonStreamingChatCompletion( + model: IAiModelConfig, + messages: ChatCompletionMessageParam[], + tools: ChatCompletionTool[], + ): Promise { const response = await this.client.chat.completions.create({ model: model.modelId, messages, @@ -566,6 +473,90 @@ export class OpenAiApi extends AiApi { response: content, toolCalls, assistantToolCalls, + finishReason: choice?.finish_reason, + chunkCount: 0, + contentDeltaCount: content ? 1 : 0, + toolDeltaCount: assistantToolCalls.length, + }; + } + + private accumulateToolCallDelta( + toolCallMap: Map, + delta: { + index: number; + id?: string; + type?: string; + function?: { + name?: string; + arguments?: string; + }; + }, + ): void { + let accumulated = toolCallMap.get(delta.index); + if (!accumulated) { + accumulated = { + index: delta.index, + id: delta.id || `tool_${Date.now()}_${delta.index}`, + type: "function", + function: { + name: "", + arguments: "", + }, + }; + toolCallMap.set(delta.index, accumulated); + } + + if (delta.id) accumulated.id = delta.id; + if (delta.function?.name) accumulated.function.name += delta.function.name; + if (delta.function?.arguments) { + accumulated.function.arguments += delta.function.arguments; + } + } + + private buildToolCallsFromMap( + toolCallMap: Map, + ): Pick { + const assistantToolCalls = Array.from(toolCallMap.values()) + .sort((a, b) => a.index - b.index) + .filter((toolCall) => toolCall.function.name) + .map((toolCall) => ({ + id: toolCall.id, + type: "function" as const, + function: { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }, + })); + const toolCalls: IToolCall[] = assistantToolCalls.map((toolCall) => ({ + callId: toolCall.id, + function: { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }, + })); + return { toolCalls, assistantToolCalls }; + } + + private isEmptyIteration(iteration: OpenAiChatIterationResult): boolean { + return ( + !iteration.response.trim() && + !iteration.thinking?.trim() && + iteration.toolCalls.length === 0 + ); + } + + private buildStats(startTime: number): IAiChatResponse["stats"] { + const seconds = (Date.now() - startTime) / 1000; + return { + duration: { + seconds, + text: numeral(seconds).format("hh:mm:ss"), + }, + tokenCounts: { + input: 0, + response: 0, + thinking: 0, + }, }; } }