gadget/packages/ai/src/openai.ts
2026-05-07 00:10:57 -04:00

369 lines
10 KiB
TypeScript

// src/openai.ts
// Copyright (C) 2026 Rob Colbert <rob.colbert@openplatform.us>
// Licensed under the Apache License, Version 2.0
import OpenAI from "openai";
import numeral from "numeral";
import {
AiApi,
IAiChatOptions,
IAiChatResponse,
IToolCallResult,
IAiGenerateOptions,
IAiGenerateResponse,
IAiLogger,
IAiModelConfig,
IAiModelListResult,
IAiModelProbeResult,
IAiProvider,
IAiResponseStreamFn,
} from "./api.js";
import {
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionTool,
ChatCompletionMessageParam,
ChatCompletionTool,
ChatCompletionToolMessageParam,
} from "openai/resources";
import { IAiEnvironment } from "./config/env.ts";
interface GabAiCapabilities {
text?: boolean;
images?: boolean;
video?: boolean;
audio?: boolean;
streaming?: boolean;
thinking?: boolean;
web_search?: boolean;
function_calling?: boolean;
embeddings?: boolean;
image_input?: boolean;
file_input?: boolean;
audio_input?: boolean;
video_input?: boolean;
}
interface OpenAIModelInfo {
id: string;
created: number;
object: "model";
owned_by: string;
supported_methods?: string[];
groups?: string[];
features?: string[];
max_tokens?: number;
capabilities?: GabAiCapabilities;
context_window?: number;
}
export class OpenAiApi extends AiApi {
protected client: OpenAI;
constructor(env: IAiEnvironment, provider: IAiProvider, logger?: IAiLogger) {
super(env, provider, logger);
this.client = new OpenAI({
baseURL: provider.baseUrl,
apiKey: provider.apiKey,
});
}
async listModels(): Promise<IAiModelListResult> {
await this.log.debug("OpenAiApi.listModels called");
const response = await this.client.models.list();
await this.log.debug("OpenAI models list response", {
data: response.data,
});
const models = response.data.map((model) => {
const modelInfo = model as unknown as OpenAIModelInfo;
const maxTokens = modelInfo.max_tokens || modelInfo.context_window;
return {
id: model.id,
name: model.id,
parameterLabel: undefined,
parameterCount: undefined,
contextWindow: maxTokens,
};
});
return { models };
}
async probeModel(modelId: string): Promise<IAiModelProbeResult> {
await this.log.debug("OpenAiApi.probeModel called", { modelId });
try {
const response = await this.client.models.retrieve(modelId);
const modelInfo = response as unknown as OpenAIModelInfo;
await this.log.debug("OpenAI model retrieve response", {
modelId,
features: modelInfo.features,
supported_methods: modelInfo.supported_methods,
capabilities: modelInfo.capabilities,
});
const capabilities = this.analyzeCapabilities(modelInfo);
return {
capabilities,
settings: undefined,
};
} catch (error) {
await this.log.debug("Failed to retrieve model details, using fallback", {
modelId,
error: (error as Error).message,
});
const listResponse = await this.client.models.list();
const modelFromList = listResponse.data.find((m) => m.id === modelId);
if (modelFromList) {
const modelInfo = modelFromList as unknown as OpenAIModelInfo;
if (modelInfo.capabilities) {
await this.log.debug("Using capabilities from list endpoint", {
modelId,
});
return {
capabilities: this.analyzeCapabilities(modelInfo),
settings: undefined,
};
}
}
return {
capabilities: {
canCallTools: modelId.toLowerCase().includes("gpt"),
hasVision:
modelId.toLowerCase().includes("vision") ||
modelId.toLowerCase().includes("4o") ||
modelId.toLowerCase().includes("image"),
hasEmbedding:
modelId.toLowerCase().includes("embedding") ||
modelId.toLowerCase().includes("embed"),
hasThinking:
modelId.toLowerCase().includes("o1") ||
modelId.toLowerCase().includes("o3") ||
modelId.toLowerCase().includes("reasoning"),
isInstructTuned: true,
},
settings: undefined,
};
}
}
private analyzeCapabilities(
modelInfo: OpenAIModelInfo,
): IAiModelProbeResult["capabilities"] {
const features = modelInfo.features || [];
const supportedMethods = modelInfo.supported_methods || [];
const caps = modelInfo.capabilities;
if (caps) {
return {
canCallTools: !!caps.function_calling,
hasVision: !!caps.images || !!caps.image_input,
hasEmbedding: !!caps.embeddings,
hasThinking: !!caps.thinking,
isInstructTuned: !!caps.text,
};
}
return {
canCallTools:
features.includes("function_calling") ||
features.includes("parallel_tool_calls"),
hasVision: features.includes("image_content"),
hasEmbedding: supportedMethods.includes("embedding"),
hasThinking: features.includes("reasoning_effort"),
isInstructTuned: supportedMethods.includes("chat.completions"),
};
}
async generate(
model: IAiModelConfig,
options: IAiGenerateOptions,
streamCallback?: IAiResponseStreamFn,
): Promise<IAiGenerateResponse> {
await this.log.debug("OpenAiApi.generate called", {
provider: model.provider.name,
modelId: model.modelId,
});
const startTime = Date.now();
const response = await this.client.chat.completions.create({
model: model.modelId,
messages: [
...(options.systemPrompt
? [{ role: "system" as const, content: options.systemPrompt }]
: []),
{ role: "user" as const, content: options.prompt },
],
stream: false,
});
const choice = response.choices[0];
const endTime = Date.now();
const durationMs = endTime - startTime;
return {
done: true,
response: choice.message.content || "",
thinking: undefined,
stats: {
duration: {
seconds: durationMs / 1000,
text: numeral(durationMs / 1000).format("hh:mm:ss"),
},
tokenCounts: {
input: response.usage?.prompt_tokens || 0,
response: response.usage?.completion_tokens || 0,
thinking: 0,
},
},
};
}
async chat(
model: IAiModelConfig,
options: IAiChatOptions,
streamCallback?: IAiResponseStreamFn,
): Promise<IAiChatResponse> {
await this.log.debug("OpenAiApi.chat called", {
provider: model.provider.name,
modelId: model.modelId,
});
const startTime = Date.now();
const maxIterations = options.maxToolIterations ?? 5;
let iteration = 0;
const messages: ChatCompletionMessageParam[] = [];
if (options.systemPrompt) {
messages.push({ role: "system", content: options.systemPrompt });
}
if (options.context) {
for (const msg of options.context) {
messages.push({
role: msg.role as "user" | "assistant" | "system",
content: msg.content,
});
}
}
if (options.userPrompt) {
messages.push({ role: "user", content: options.userPrompt });
}
const allToolCallResults: IToolCallResult[] = [];
while (iteration < maxIterations) {
iteration++;
const tools: ChatCompletionTool[] = options.tools
? options.tools.map((tool) => {
const openaiTool: ChatCompletionFunctionTool = {
type: tool.definition.type,
function: {
name: tool.definition.function.name,
description: tool.definition.function.description,
parameters: tool.definition.function.parameters,
},
};
return openaiTool;
})
: [];
const response = await this.client.chat.completions.create({
model: model.modelId,
messages,
tools,
stream: false,
});
const choice = response.choices[0];
const endTime = Date.now();
const durationMs = endTime - startTime;
const toolCalls = choice.message.tool_calls
?.filter((tc) => tc.type === "function")
.map((tc) => ({
callId: tc.id,
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
}));
if (!toolCalls || toolCalls.length === 0) {
return {
done: true,
response: choice.message.content || "",
thinking: undefined,
toolCalls: undefined,
toolCallResults: allToolCallResults.length > 0 ? allToolCallResults : undefined,
stats: {
duration: {
seconds: durationMs / 1000,
text: numeral(durationMs / 1000).format("hh:mm:ss"),
},
tokenCounts: {
input: response.usage?.prompt_tokens || 0,
response: response.usage?.completion_tokens || 0,
thinking: 0,
},
},
};
}
const toolCallResults = await this.executeToolCalls(
toolCalls,
options.tools || [],
);
allToolCallResults.push(...toolCallResults);
const assistantMsg: ChatCompletionAssistantMessageParam = {
role: "assistant",
content: choice.message.content,
};
if (choice.message.tool_calls) {
assistantMsg.tool_calls = choice.message.tool_calls;
}
messages.push(assistantMsg);
for (const result of toolCallResults) {
const toolMsg: ChatCompletionToolMessageParam = {
role: "tool",
tool_call_id: result.callId,
content: result.error || result.result,
};
messages.push(toolMsg);
}
}
const endTime = Date.now();
const durationMs = endTime - startTime;
return {
done: false,
doneReason: "max_tool_iterations_reached",
response: "",
thinking: undefined,
toolCalls: undefined,
toolCallResults: allToolCallResults,
stats: {
duration: {
seconds: durationMs / 1000,
text: numeral(durationMs / 1000).format("hh:mm:ss"),
},
tokenCounts: {
input: 0,
response: 0,
thinking: 0,
},
},
};
}
}