gadget/packages/ai/src/openai.ts

610 lines
17 KiB
TypeScript

// src/openai.ts
// Copyright (C) 2026 Rob Colbert <rob.colbert@openplatform.us>
// Licensed under the Apache License, Version 2.0
import OpenAI from "openai";
import numeral from "numeral";
import {
AiApi,
IAiChatOptions,
IAiChatResponse,
IToolCall,
IAiGenerateOptions,
IAiGenerateResponse,
IAiLogger,
IAiModelConfig,
IAiModelListResult,
IAiModelProbeResult,
IAiProvider,
IAiResponseStreamFn,
} from "./api.js";
import {
ChatCompletionFunctionTool,
ChatCompletionMessageParam,
ChatCompletionTool,
} from "openai/resources";
import { IAiEnvironment } from "./config/env.ts";
interface GabAiCapabilities {
text?: boolean;
images?: boolean;
video?: boolean;
audio?: boolean;
streaming?: boolean;
thinking?: boolean;
web_search?: boolean;
function_calling?: boolean;
embeddings?: boolean;
image_input?: boolean;
file_input?: boolean;
audio_input?: boolean;
video_input?: boolean;
}
interface OpenAIModelInfo {
id: string;
created: number;
object: "model";
owned_by: string;
supported_methods?: string[];
groups?: string[];
features?: string[];
max_tokens?: number;
capabilities?: GabAiCapabilities;
context_window?: number;
}
interface StreamingToolCallAccumulator {
index: number;
id: string;
type: "function";
function: {
name: string;
arguments: string;
};
}
interface OpenAiChatIterationResult {
response: string;
thinking?: string;
toolCalls: IToolCall[];
assistantToolCalls: Array<{
id: string;
type: "function";
function: {
name: string;
arguments: string;
};
}>;
finishReason?: string | null;
chunkCount: number;
contentDeltaCount: number;
toolDeltaCount: number;
}
export class OpenAiApi extends AiApi {
protected client: OpenAI;
constructor(env: IAiEnvironment, provider: IAiProvider, logger?: IAiLogger) {
super(env, provider, logger);
this.client = new OpenAI({
baseURL: provider.baseUrl,
apiKey: provider.apiKey,
});
}
async listModels(): Promise<IAiModelListResult> {
const response = await this.client.models.list();
const models = response.data.map((model) => {
const modelInfo = model as unknown as OpenAIModelInfo;
const maxTokens = modelInfo.max_tokens || modelInfo.context_window;
return {
id: model.id,
name: model.id,
parameterLabel: undefined,
parameterCount: undefined,
contextWindow: maxTokens,
};
});
return { models };
}
async probeModel(modelId: string): Promise<IAiModelProbeResult> {
try {
const response = await this.client.models.retrieve(modelId);
const modelInfo = response as unknown as OpenAIModelInfo;
const capabilities = this.analyzeCapabilities(modelInfo);
return {
capabilities,
settings: undefined,
};
} catch (error) {
const listResponse = await this.client.models.list();
const modelFromList = listResponse.data.find((m) => m.id === modelId);
if (modelFromList) {
const modelInfo = modelFromList as unknown as OpenAIModelInfo;
if (modelInfo.capabilities) {
return {
capabilities: this.analyzeCapabilities(modelInfo),
settings: undefined,
};
}
}
return {
capabilities: {
canCallTools: modelId.toLowerCase().includes("gpt"),
hasVision:
modelId.toLowerCase().includes("vision") ||
modelId.toLowerCase().includes("4o") ||
modelId.toLowerCase().includes("image"),
hasEmbedding:
modelId.toLowerCase().includes("embedding") ||
modelId.toLowerCase().includes("embed"),
hasThinking:
modelId.toLowerCase().includes("o1") ||
modelId.toLowerCase().includes("o3") ||
modelId.toLowerCase().includes("reasoning"),
isInstructTuned: true,
},
settings: undefined,
};
}
}
private analyzeCapabilities(
modelInfo: OpenAIModelInfo,
): IAiModelProbeResult["capabilities"] {
const features = modelInfo.features || [];
const supportedMethods = modelInfo.supported_methods || [];
const caps = modelInfo.capabilities;
if (caps) {
return {
canCallTools: !!caps.function_calling,
hasVision: !!caps.images || !!caps.image_input,
hasEmbedding: !!caps.embeddings,
hasThinking: !!caps.thinking,
isInstructTuned: !!caps.text,
};
}
return {
canCallTools:
features.includes("function_calling") ||
features.includes("parallel_tool_calls"),
hasVision: features.includes("image_content"),
hasEmbedding: supportedMethods.includes("embedding"),
hasThinking: features.includes("reasoning_effort"),
isInstructTuned: supportedMethods.includes("chat.completions"),
};
}
async generate(
model: IAiModelConfig,
options: IAiGenerateOptions,
streamCallback?: IAiResponseStreamFn,
): Promise<IAiGenerateResponse> {
await this.log.debug("OpenAiApi.generate called", {
provider: model.provider.name,
modelId: model.modelId,
});
if (options.signal?.aborted) {
throw new DOMException("The operation was aborted", "AbortError");
}
const startTime = Date.now();
const response = await this.client.chat.completions.create({
model: model.modelId,
messages: [
...(options.systemPrompt
? [{ role: "system" as const, content: options.systemPrompt }]
: []),
{ role: "user" as const, content: options.prompt },
],
stream: true,
...(model.params.maxCompletionTokens
? { max_completion_tokens: model.params.maxCompletionTokens }
: {}),
temperature: model.params.temperature,
top_p: model.params.topP,
...(typeof model.params.reasoning === "string"
? {
reasoning_effort: model.params.reasoning as
| "low"
| "medium"
| "high",
}
: {}),
}, options.signal ? { signal: options.signal } : undefined);
let accumulatedResponse = "";
let accumulatedThinking = "";
for await (const chunk of response) {
if (options.signal?.aborted) {
throw new DOMException("The operation was aborted", "AbortError");
}
const delta = chunk.choices[0]?.delta;
if (delta) {
if (delta.content) {
accumulatedResponse += delta.content;
if (streamCallback) {
await streamCallback({
type: "response",
data: delta.content,
});
}
}
if ("reasoning" in delta && delta.reasoning) {
accumulatedThinking += delta.reasoning as string;
if (streamCallback) {
await streamCallback({
type: "thinking",
data: delta.reasoning as string,
});
}
}
}
}
const endTime = Date.now();
const durationMs = endTime - startTime;
return {
done: true,
response: accumulatedResponse,
thinking: accumulatedThinking || undefined,
stats: {
duration: {
seconds: durationMs / 1000,
text: numeral(durationMs / 1000).format("hh:mm:ss"),
},
tokenCounts: {
input: 0,
response: 0,
thinking: 0,
},
},
};
}
async chat(
model: IAiModelConfig,
options: IAiChatOptions,
streamCallback?: IAiResponseStreamFn,
): Promise<IAiChatResponse> {
await this.log.debug("OpenAiApi.chat called", {
provider: model.provider.name,
modelId: model.modelId,
});
const startTime = Date.now();
const messages = this.buildMessages(options);
const tools = this.buildTools(options);
let iteration = await this.readStreamingChatCompletion(
model,
messages,
tools,
streamCallback,
options.signal,
);
await this.log.debug("OpenAI chat stream iteration finished", {
chunkCount: iteration.chunkCount,
contentDeltaCount: iteration.contentDeltaCount,
toolDeltaCount: iteration.toolDeltaCount,
responseLength: iteration.response.length,
thinkingLength: iteration.thinking?.length || 0,
toolCallCount: iteration.toolCalls.length,
finishReason: iteration.finishReason,
});
if (this.isEmptyIteration(iteration)) {
iteration = await this.readNonStreamingChatCompletion(model, messages, tools, options.signal);
if (streamCallback && iteration.response) {
await streamCallback({ type: "response", data: iteration.response });
}
await this.log.warn("OpenAI stream was empty; used non-streaming fallback", {
responseLength: iteration.response.length,
thinkingLength: iteration.thinking?.length || 0,
toolCallCount: iteration.toolCalls.length,
finishReason: iteration.finishReason,
});
}
if (this.isEmptyIteration(iteration)) {
this.assertNonEmptyChatResponse({
done: true,
response: "",
thinking: undefined,
toolCalls: undefined,
toolCallResults: undefined,
stats: this.buildStats(startTime),
});
}
const finalResponse: IAiChatResponse = {
done: true,
response: iteration.response,
thinking: iteration.thinking,
toolCalls: iteration.toolCalls.length > 0 ? iteration.toolCalls : undefined,
stats: this.buildStats(startTime),
};
this.assertNonEmptyChatResponse(finalResponse);
return finalResponse;
}
private buildMessages(options: IAiChatOptions): ChatCompletionMessageParam[] {
const messages: ChatCompletionMessageParam[] = [];
if (options.systemPrompt) {
messages.push({ role: "system", content: options.systemPrompt });
}
for (const msg of options.context || []) {
if (!msg.content?.trim()) continue;
if (msg.role === "tool") {
messages.push({
role: "tool",
tool_call_id: msg.toolCallId || msg.callId || "",
content: msg.content,
});
continue;
}
if (msg.role === "assistant" && msg.toolCalls && msg.toolCalls.length > 0) {
messages.push({
role: "assistant",
content: msg.content,
tool_calls: msg.toolCalls,
});
continue;
}
messages.push({
role: msg.role as "user" | "assistant" | "system",
content: msg.content,
});
}
if (options.userPrompt?.trim()) {
messages.push({ role: "user", content: options.userPrompt });
}
return messages;
}
private buildTools(options: IAiChatOptions): ChatCompletionTool[] {
return (options.tools || []).map((tool): ChatCompletionFunctionTool => ({
type: tool.definition.type,
function: {
name: tool.definition.function.name,
description: tool.definition.function.description,
parameters: tool.definition.function.parameters,
},
}));
}
private async readStreamingChatCompletion(
model: IAiModelConfig,
messages: ChatCompletionMessageParam[],
tools: ChatCompletionTool[],
streamCallback?: IAiResponseStreamFn,
signal?: AbortSignal,
): Promise<OpenAiChatIterationResult> {
if (signal?.aborted) {
throw new DOMException("The operation was aborted", "AbortError");
}
const response = await this.client.chat.completions.create({
model: model.modelId,
messages,
tools,
stream: true,
...(model.params.maxCompletionTokens
? { max_completion_tokens: model.params.maxCompletionTokens }
: {}),
temperature: model.params.temperature,
top_p: model.params.topP,
...(typeof model.params.reasoning === "string"
? {
reasoning_effort: model.params.reasoning as
| "low"
| "medium"
| "high",
}
: {}),
}, signal ? { signal } : undefined);
let content = "";
let thinking = "";
let chunkCount = 0;
let contentDeltaCount = 0;
let toolDeltaCount = 0;
let finishReason: string | null | undefined;
const toolCallMap = new Map<number, StreamingToolCallAccumulator>();
for await (const chunk of response) {
if (signal?.aborted) {
throw new DOMException("The operation was aborted", "AbortError");
}
chunkCount++;
finishReason = chunk.choices[0]?.finish_reason ?? finishReason;
const delta = chunk.choices[0]?.delta;
if (!delta) continue;
if (delta.content) {
contentDeltaCount++;
content += delta.content;
if (streamCallback) {
await streamCallback({ type: "response", data: delta.content });
}
}
if ("reasoning" in delta && delta.reasoning) {
thinking += delta.reasoning as string;
if (streamCallback) {
await streamCallback({ type: "thinking", data: delta.reasoning as string });
}
}
if (delta.tool_calls) {
toolDeltaCount += delta.tool_calls.length;
for (const toolCallDelta of delta.tool_calls) {
this.accumulateToolCallDelta(toolCallMap, toolCallDelta);
}
}
}
return {
response: content,
thinking: thinking || undefined,
...this.buildToolCallsFromMap(toolCallMap),
finishReason,
chunkCount,
contentDeltaCount,
toolDeltaCount,
};
}
private async readNonStreamingChatCompletion(
model: IAiModelConfig,
messages: ChatCompletionMessageParam[],
tools: ChatCompletionTool[],
signal?: AbortSignal,
): Promise<OpenAiChatIterationResult> {
if (signal?.aborted) {
throw new DOMException("The operation was aborted", "AbortError");
}
const response = await this.client.chat.completions.create({
model: model.modelId,
messages,
tools,
stream: false,
...(model.params.maxCompletionTokens
? { max_completion_tokens: model.params.maxCompletionTokens }
: {}),
temperature: model.params.temperature,
top_p: model.params.topP,
...(typeof model.params.reasoning === "string"
? {
reasoning_effort: model.params.reasoning as
| "low"
| "medium"
| "high",
}
: {}),
}, signal ? { signal } : undefined);
const choice = response.choices[0];
const message = choice?.message;
const content = typeof message?.content === "string" ? message.content : "";
const assistantToolCalls = (message?.tool_calls || [])
.filter((tc) => tc.type === "function")
.map((tc) => ({
id: tc.id,
type: "function" as const,
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
}));
const toolCalls: IToolCall[] = assistantToolCalls.map((tc) => ({
callId: tc.id,
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
}));
return {
response: content,
toolCalls,
assistantToolCalls,
finishReason: choice?.finish_reason,
chunkCount: 0,
contentDeltaCount: content ? 1 : 0,
toolDeltaCount: assistantToolCalls.length,
};
}
private accumulateToolCallDelta(
toolCallMap: Map<number, StreamingToolCallAccumulator>,
delta: {
index: number;
id?: string;
type?: string;
function?: {
name?: string;
arguments?: string;
};
},
): void {
let accumulated = toolCallMap.get(delta.index);
if (!accumulated) {
accumulated = {
index: delta.index,
id: delta.id || `tool_${Date.now()}_${delta.index}`,
type: "function",
function: {
name: "",
arguments: "",
},
};
toolCallMap.set(delta.index, accumulated);
}
if (delta.id) accumulated.id = delta.id;
if (delta.function?.name) accumulated.function.name += delta.function.name;
if (delta.function?.arguments) {
accumulated.function.arguments += delta.function.arguments;
}
}
private buildToolCallsFromMap(
toolCallMap: Map<number, StreamingToolCallAccumulator>,
): Pick<OpenAiChatIterationResult, "toolCalls" | "assistantToolCalls"> {
const assistantToolCalls = Array.from(toolCallMap.values())
.sort((a, b) => a.index - b.index)
.filter((toolCall) => toolCall.function.name)
.map((toolCall) => ({
id: toolCall.id,
type: "function" as const,
function: {
name: toolCall.function.name,
arguments: toolCall.function.arguments,
},
}));
const toolCalls: IToolCall[] = assistantToolCalls.map((toolCall) => ({
callId: toolCall.id,
function: {
name: toolCall.function.name,
arguments: toolCall.function.arguments,
},
}));
return { toolCalls, assistantToolCalls };
}
private isEmptyIteration(iteration: OpenAiChatIterationResult): boolean {
return (
!iteration.response.trim() &&
!iteration.thinking?.trim() &&
iteration.toolCalls.length === 0
);
}
private buildStats(startTime: number): IAiChatResponse["stats"] {
const seconds = (Date.now() - startTime) / 1000;
return {
duration: {
seconds,
text: numeral(seconds).format("hh:mm:ss"),
},
tokenCounts: {
input: 0,
response: 0,
thinking: 0,
},
};
}
}