update LLM models

This commit is contained in:
duanfuxiang
2025-05-29 22:40:20 +08:00
parent 48b95ea416
commit 120c442274
9 changed files with 1014 additions and 536 deletions

View File

@@ -1,11 +1,11 @@
import {
Content,
EnhancedGenerateContentResponse,
GenerateContentResult,
GenerateContentStreamResult,
GoogleGenerativeAI,
GoogleGenAI,
Part,
} from '@google/generative-ai'
type GenerateContentConfig,
type GenerateContentParameters,
type GenerateContentResponse,
} from "@google/genai"
import { LLMModel } from '../../types/llm/model'
import {
@@ -18,6 +18,12 @@ import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import {
GeminiModelId,
ModelInfo,
geminiDefaultModelId,
geminiModels
} from "../../utils/api"
import { parseImageDataUrl } from '../../utils/image'
import { BaseLLMProvider } from './base'
@@ -34,12 +40,41 @@ import {
* issues are resolved.
*/
export class GeminiProvider implements BaseLLMProvider {
private client: GoogleGenerativeAI
private client: GoogleGenAI
private apiKey: string
private baseUrl: string
constructor(apiKey: string) {
constructor(apiKey: string, baseUrl?: string) {
this.apiKey = apiKey
this.client = new GoogleGenerativeAI(apiKey)
this.baseUrl = baseUrl
this.client = new GoogleGenAI({ apiKey })
}
getModel(modelId: string) {
let id = modelId
let info: ModelInfo = geminiModels[id as GeminiModelId]
if (id?.endsWith(":thinking")) {
id = id.slice(0, -":thinking".length)
if (geminiModels[id as GeminiModelId]) {
info = geminiModels[id as GeminiModelId]
return {
id,
info,
thinkingConfig: undefined,
maxOutputTokens: info.maxTokens ?? undefined,
}
}
}
if (!info) {
id = geminiDefaultModelId
info = geminiModels[geminiDefaultModelId]
}
return { id, info }
}
async generateResponse(
@@ -53,6 +88,8 @@ export class GeminiProvider implements BaseLLMProvider {
)
}
const { id: modelName, thinkingConfig, maxOutputTokens, info } = this.getModel(model.modelId)
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
systemMessages.length > 0
@@ -60,30 +97,26 @@ export class GeminiProvider implements BaseLLMProvider {
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const result = await model.generateContent(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const config: GenerateContentConfig = {
systemInstruction,
httpOptions: this.baseUrl ? { baseUrl: this.baseUrl } : undefined,
thinkingConfig,
maxOutputTokens: maxOutputTokens ?? request.max_tokens,
temperature: request.temperature ?? 0,
topP: request.top_p ?? 1,
presencePenalty: request.presence_penalty ?? 0,
frequencyPenalty: request.frequency_penalty ?? 0,
}
const params: GenerateContentParameters = {
model: modelName,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
config,
}
const result = await this.client.models.generateContent(params)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return GeminiProvider.parseNonStreamingResponse(
result,
@@ -115,6 +148,7 @@ export class GeminiProvider implements BaseLLMProvider {
`Gemini API key is missing. Please set it in settings menu.`,
)
}
const { id: modelName, thinkingConfig, maxOutputTokens, info } = this.getModel(model.modelId)
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
@@ -123,30 +157,25 @@ export class GeminiProvider implements BaseLLMProvider {
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const stream = await model.generateContentStream(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const config: GenerateContentConfig = {
systemInstruction,
httpOptions: this.baseUrl ? { baseUrl: this.baseUrl } : undefined,
thinkingConfig,
maxOutputTokens: maxOutputTokens ?? request.max_tokens,
temperature: request.temperature ?? 0,
topP: request.top_p ?? 1,
presencePenalty: request.presence_penalty ?? 0,
frequencyPenalty: request.frequency_penalty ?? 0,
}
const params: GenerateContentParameters = {
model: modelName,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
config,
}
const stream = await this.client.models.generateContentStream(params)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return this.streamResponseGenerator(stream, request.model, messageId)
} catch (error) {
@@ -165,11 +194,11 @@ export class GeminiProvider implements BaseLLMProvider {
}
private async *streamResponseGenerator(
stream: GenerateContentStreamResult,
stream: AsyncGenerator<GenerateContentResponse>,
model: string,
messageId: string,
): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream.stream) {
for await (const chunk of stream) {
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
}
}
@@ -215,7 +244,7 @@ export class GeminiProvider implements BaseLLMProvider {
}
static parseNonStreamingResponse(
response: GenerateContentResult,
response: GenerateContentResponse,
model: string,
messageId: string,
): LLMResponseNonStreaming {
@@ -224,9 +253,9 @@ export class GeminiProvider implements BaseLLMProvider {
choices: [
{
finish_reason:
response.response.candidates?.[0]?.finishReason ?? null,
response.candidates?.[0]?.finishReason ?? null,
message: {
content: response.response.text(),
content: response.candidates?.[0]?.content?.parts?.[0]?.text ?? '',
role: 'assistant',
},
},
@@ -234,29 +263,32 @@ export class GeminiProvider implements BaseLLMProvider {
created: Date.now(),
model: model,
object: 'chat.completion',
usage: response.response.usageMetadata
usage: response.usageMetadata
? {
prompt_tokens: response.response.usageMetadata.promptTokenCount,
prompt_tokens: response.usageMetadata.promptTokenCount,
completion_tokens:
response.response.usageMetadata.candidatesTokenCount,
total_tokens: response.response.usageMetadata.totalTokenCount,
response.usageMetadata.candidatesTokenCount,
total_tokens: response.usageMetadata.totalTokenCount,
}
: undefined,
}
}
static parseStreamingResponseChunk(
chunk: EnhancedGenerateContentResponse,
chunk: GenerateContentResponse,
model: string,
messageId: string,
): LLMResponseStreaming {
const firstCandidate = chunk.candidates?.[0]
const textContent = firstCandidate?.content?.parts?.[0]?.text || ''
return {
id: messageId,
choices: [
{
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
finish_reason: firstCandidate?.finishReason ?? null,
delta: {
content: chunk.text(),
content: textContent,
},
},
],