update LLM models

This commit is contained in:
duanfuxiang
2025-05-29 22:40:20 +08:00
parent 48b95ea416
commit 120c442274
9 changed files with 1014 additions and 536 deletions

View File

@@ -1,11 +1,11 @@
import {
Content,
EnhancedGenerateContentResponse,
GenerateContentResult,
GenerateContentStreamResult,
GoogleGenerativeAI,
GoogleGenAI,
Part,
} from '@google/generative-ai'
type GenerateContentConfig,
type GenerateContentParameters,
type GenerateContentResponse,
} from "@google/genai"
import { LLMModel } from '../../types/llm/model'
import {
@@ -18,6 +18,12 @@ import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import {
GeminiModelId,
ModelInfo,
geminiDefaultModelId,
geminiModels
} from "../../utils/api"
import { parseImageDataUrl } from '../../utils/image'
import { BaseLLMProvider } from './base'
@@ -34,12 +40,41 @@ import {
* issues are resolved.
*/
export class GeminiProvider implements BaseLLMProvider {
private client: GoogleGenerativeAI
private client: GoogleGenAI
private apiKey: string
private baseUrl: string
constructor(apiKey: string) {
constructor(apiKey: string, baseUrl?: string) {
this.apiKey = apiKey
this.client = new GoogleGenerativeAI(apiKey)
this.baseUrl = baseUrl
this.client = new GoogleGenAI({ apiKey })
}
getModel(modelId: string) {
let id = modelId
let info: ModelInfo = geminiModels[id as GeminiModelId]
if (id?.endsWith(":thinking")) {
id = id.slice(0, -":thinking".length)
if (geminiModels[id as GeminiModelId]) {
info = geminiModels[id as GeminiModelId]
return {
id,
info,
thinkingConfig: undefined,
maxOutputTokens: info.maxTokens ?? undefined,
}
}
}
if (!info) {
id = geminiDefaultModelId
info = geminiModels[geminiDefaultModelId]
}
return { id, info }
}
async generateResponse(
@@ -53,6 +88,8 @@ export class GeminiProvider implements BaseLLMProvider {
)
}
const { id: modelName, thinkingConfig, maxOutputTokens, info } = this.getModel(model.modelId)
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
systemMessages.length > 0
@@ -60,30 +97,26 @@ export class GeminiProvider implements BaseLLMProvider {
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const result = await model.generateContent(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const config: GenerateContentConfig = {
systemInstruction,
httpOptions: this.baseUrl ? { baseUrl: this.baseUrl } : undefined,
thinkingConfig,
maxOutputTokens: maxOutputTokens ?? request.max_tokens,
temperature: request.temperature ?? 0,
topP: request.top_p ?? 1,
presencePenalty: request.presence_penalty ?? 0,
frequencyPenalty: request.frequency_penalty ?? 0,
}
const params: GenerateContentParameters = {
model: modelName,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
config,
}
const result = await this.client.models.generateContent(params)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return GeminiProvider.parseNonStreamingResponse(
result,
@@ -115,6 +148,7 @@ export class GeminiProvider implements BaseLLMProvider {
`Gemini API key is missing. Please set it in settings menu.`,
)
}
const { id: modelName, thinkingConfig, maxOutputTokens, info } = this.getModel(model.modelId)
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
@@ -123,30 +157,25 @@ export class GeminiProvider implements BaseLLMProvider {
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const stream = await model.generateContentStream(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const config: GenerateContentConfig = {
systemInstruction,
httpOptions: this.baseUrl ? { baseUrl: this.baseUrl } : undefined,
thinkingConfig,
maxOutputTokens: maxOutputTokens ?? request.max_tokens,
temperature: request.temperature ?? 0,
topP: request.top_p ?? 1,
presencePenalty: request.presence_penalty ?? 0,
frequencyPenalty: request.frequency_penalty ?? 0,
}
const params: GenerateContentParameters = {
model: modelName,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
config,
}
const stream = await this.client.models.generateContentStream(params)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return this.streamResponseGenerator(stream, request.model, messageId)
} catch (error) {
@@ -165,11 +194,11 @@ export class GeminiProvider implements BaseLLMProvider {
}
private async *streamResponseGenerator(
stream: GenerateContentStreamResult,
stream: AsyncGenerator<GenerateContentResponse>,
model: string,
messageId: string,
): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream.stream) {
for await (const chunk of stream) {
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
}
}
@@ -215,7 +244,7 @@ export class GeminiProvider implements BaseLLMProvider {
}
static parseNonStreamingResponse(
response: GenerateContentResult,
response: GenerateContentResponse,
model: string,
messageId: string,
): LLMResponseNonStreaming {
@@ -224,9 +253,9 @@ export class GeminiProvider implements BaseLLMProvider {
choices: [
{
finish_reason:
response.response.candidates?.[0]?.finishReason ?? null,
response.candidates?.[0]?.finishReason ?? null,
message: {
content: response.response.text(),
content: response.candidates?.[0]?.content?.parts?.[0]?.text ?? '',
role: 'assistant',
},
},
@@ -234,29 +263,32 @@ export class GeminiProvider implements BaseLLMProvider {
created: Date.now(),
model: model,
object: 'chat.completion',
usage: response.response.usageMetadata
usage: response.usageMetadata
? {
prompt_tokens: response.response.usageMetadata.promptTokenCount,
prompt_tokens: response.usageMetadata.promptTokenCount,
completion_tokens:
response.response.usageMetadata.candidatesTokenCount,
total_tokens: response.response.usageMetadata.totalTokenCount,
response.usageMetadata.candidatesTokenCount,
total_tokens: response.usageMetadata.totalTokenCount,
}
: undefined,
}
}
static parseStreamingResponseChunk(
chunk: EnhancedGenerateContentResponse,
chunk: GenerateContentResponse,
model: string,
messageId: string,
): LLMResponseStreaming {
const firstCandidate = chunk.candidates?.[0]
const textContent = firstCandidate?.content?.parts?.[0]?.text || ''
return {
id: messageId,
choices: [
{
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
finish_reason: firstCandidate?.finishReason ?? null,
delta: {
content: chunk.text(),
content: textContent,
},
},
],

View File

@@ -1,251 +0,0 @@
import OpenAI from 'openai'
import {
ChatCompletion,
ChatCompletionChunk,
} from 'openai/resources/chat/completions'
import { INFIO_BASE_URL } from '../../constants'
import { LLMModel } from '../../types/llm/model'
import {
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
export type RangeFilter = {
gte?: number;
lte?: number;
}
export type ChunkFilter = {
field: string;
match_all?: string[];
range?: RangeFilter;
}
/**
* Interface for making requests to the Infio API
*/
export type InfioRequest = {
/** Required: The content of the user message to attach to the topic and then generate an assistant message in response to */
messages: RequestMessage[];
// /** Required: The ID of the topic to attach the message to */
// topic_id: string;
/** Optional: URLs to include */
links?: string[];
/** Optional: Files to include */
files?: string[];
/** Optional: Whether to highlight results in chunk_html. Default is true */
highlight_results?: boolean;
/** Optional: Delimiters for highlighting citations. Default is [".", "!", "?", "\n", "\t", ","] */
highlight_delimiters?: string[];
/** Optional: Search type - "semantic", "fulltext", or "hybrid". Default is "hybrid" */
search_type?: string;
/** Optional: Filters for chunk filtering */
filters?: ChunkFilter;
/** Optional: Whether to use web search API. Default is false */
use_web_search?: boolean;
/** Optional: LLM model to use */
llm_model?: string;
/** Optional: Force source */
force_source?: string;
/** Optional: Whether completion should come before chunks in stream. Default is false */
completion_first?: boolean;
/** Optional: Whether to stream the response. Default is true */
stream_response?: boolean;
/** Optional: Sampling temperature between 0 and 2. Default is 0.5 */
temperature?: number;
/** Optional: Frequency penalty between -2.0 and 2.0. Default is 0.7 */
frequency_penalty?: number;
/** Optional: Presence penalty between -2.0 and 2.0. Default is 0.7 */
presence_penalty?: number;
/** Optional: Maximum tokens to generate */
max_tokens?: number;
/** Optional: Stop tokens (up to 4 sequences) */
stop_tokens?: string[];
}
export class InfioProvider implements BaseLLMProvider {
// private adapter: OpenAIMessageAdapter
// private client: OpenAI
private apiKey: string
private baseUrl: string
constructor(apiKey: string) {
// this.client = new OpenAI({ apiKey, dangerouslyAllowBrowser: true })
// this.adapter = new OpenAIMessageAdapter()
this.apiKey = apiKey
this.baseUrl = INFIO_BASE_URL
}
async generateResponse(
model: LLMModel,
request: LLMRequestNonStreaming,
// options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
try {
const req: InfioRequest = {
messages: request.messages,
stream_response: false,
temperature: request.temperature,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
max_tokens: request.max_tokens,
}
const req_options = {
method: 'POST',
headers: {
Authorization: this.apiKey,
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
'Content-Type': 'application/json'
},
body: JSON.stringify(req)
};
const response = await fetch(this.baseUrl, req_options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json() as ChatCompletion;
return InfioProvider.parseNonStreamingResponse(data);
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: LLMModel,
request: LLMRequestStreaming,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
try {
const req: InfioRequest = {
llm_model: request.model,
messages: request.messages,
stream_response: true,
temperature: request.temperature,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
max_tokens: request.max_tokens,
}
const req_options = {
method: 'POST',
headers: {
Authorization: this.apiKey,
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
"Content-Type": "application/json"
},
body: JSON.stringify(req)
};
const response = await fetch(this.baseUrl, req_options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
return {
[Symbol.asyncIterator]: async function* () {
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n').filter(line => line.trim());
for (const line of lines) {
if (line.startsWith('data: ')) {
const jsonData = JSON.parse(line.slice(6)) as ChatCompletionChunk;
if (!jsonData || typeof jsonData !== 'object' || !('choices' in jsonData)) {
throw new Error('Invalid chunk format received');
}
yield InfioProvider.parseStreamingResponseChunk(jsonData);
}
}
}
} finally {
reader.releaseLock();
}
}
};
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
static parseNonStreamingResponse(
response: ChatCompletion,
): LLMResponseNonStreaming {
return {
id: response.id,
choices: response.choices.map((choice) => ({
finish_reason: choice.finish_reason,
message: {
content: choice.message.content,
role: choice.message.role,
},
})),
created: response.created,
model: response.model,
object: 'chat.completion',
system_fingerprint: response.system_fingerprint,
usage: response.usage,
}
}
static parseStreamingResponseChunk(
chunk: ChatCompletionChunk,
): LLMResponseStreaming {
return {
id: chunk.id,
choices: chunk.choices.map((choice) => ({
finish_reason: choice.finish_reason ?? null,
delta: {
content: choice.delta.content ?? null,
role: choice.delta.role,
},
})),
created: chunk.created,
model: chunk.model,
object: 'chat.completion.chunk',
system_fingerprint: chunk.system_fingerprint,
usage: chunk.usage ?? undefined,
}
}
}

View File

@@ -1,4 +1,4 @@
import { ALIBABA_QWEN_BASE_URL, DEEPSEEK_BASE_URL, GROK_BASE_URL, OPENROUTER_BASE_URL, SILICONFLOW_BASE_URL } from '../../constants'
import { ALIBABA_QWEN_BASE_URL, DEEPSEEK_BASE_URL, GROK_BASE_URL, INFIO_BASE_URL, OPENROUTER_BASE_URL, SILICONFLOW_BASE_URL } from '../../constants'
import { ApiProvider, LLMModel } from '../../types/llm/model'
import {
LLMOptions,
@@ -14,7 +14,6 @@ import { InfioSettings } from '../../types/settings'
import { AnthropicProvider } from './anthropic'
import { GeminiProvider } from './gemini'
import { GroqProvider } from './groq'
import { InfioProvider } from './infio'
import { OllamaProvider } from './ollama'
import { OpenAIAuthenticatedProvider } from './openai'
import { OpenAICompatibleProvider } from './openai-compatible'
@@ -40,7 +39,7 @@ class LLMManager implements LLMManagerInterface {
private googleProvider: GeminiProvider
private groqProvider: GroqProvider
private grokProvider: OpenAICompatibleProvider
private infioProvider: InfioProvider
private infioProvider: OpenAICompatibleProvider
private openrouterProvider: OpenAICompatibleProvider
private siliconflowProvider: OpenAICompatibleProvider
private alibabaQwenProvider: OpenAICompatibleProvider
@@ -49,7 +48,10 @@ class LLMManager implements LLMManagerInterface {
private isInfioEnabled: boolean
constructor(settings: InfioSettings) {
this.infioProvider = new InfioProvider(settings.infioProvider.apiKey)
this.infioProvider = new OpenAICompatibleProvider(
settings.infioProvider.apiKey,
INFIO_BASE_URL
)
this.openrouterProvider = new OpenAICompatibleProvider(
settings.openrouterProvider.apiKey,
settings.openrouterProvider.baseUrl && settings.openrouterProvider.useCustomUrl ?
@@ -93,14 +95,14 @@ class LLMManager implements LLMManagerInterface {
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (this.isInfioEnabled) {
return await this.infioProvider.generateResponse(
model,
request,
)
}
// use custom provider
console.log("model", model)
switch (model.provider) {
case ApiProvider.Infio:
return await this.infioProvider.generateResponse(
model,
request,
options,
)
case ApiProvider.OpenRouter:
return await this.openrouterProvider.generateResponse(
model,
@@ -169,11 +171,9 @@ class LLMManager implements LLMManagerInterface {
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (this.isInfioEnabled) {
return await this.infioProvider.streamResponse(model, request)
}
// use custom provider
switch (model.provider) {
case ApiProvider.Infio:
return await this.infioProvider.streamResponse(model, request, options)
case ApiProvider.OpenRouter:
return await this.openrouterProvider.streamResponse(model, request, options)
case ApiProvider.SiliconFlow: