mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-05-16 04:31:38 +00:00
update LLM models
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import {
|
||||
Content,
|
||||
EnhancedGenerateContentResponse,
|
||||
GenerateContentResult,
|
||||
GenerateContentStreamResult,
|
||||
GoogleGenerativeAI,
|
||||
GoogleGenAI,
|
||||
Part,
|
||||
} from '@google/generative-ai'
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
type GenerateContentResponse,
|
||||
} from "@google/genai"
|
||||
|
||||
import { LLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
@@ -18,6 +18,12 @@ import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
import {
|
||||
GeminiModelId,
|
||||
ModelInfo,
|
||||
geminiDefaultModelId,
|
||||
geminiModels
|
||||
} from "../../utils/api"
|
||||
import { parseImageDataUrl } from '../../utils/image'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
@@ -34,12 +40,41 @@ import {
|
||||
* issues are resolved.
|
||||
*/
|
||||
export class GeminiProvider implements BaseLLMProvider {
|
||||
private client: GoogleGenerativeAI
|
||||
private client: GoogleGenAI
|
||||
private apiKey: string
|
||||
private baseUrl: string
|
||||
|
||||
constructor(apiKey: string) {
|
||||
constructor(apiKey: string, baseUrl?: string) {
|
||||
this.apiKey = apiKey
|
||||
this.client = new GoogleGenerativeAI(apiKey)
|
||||
this.baseUrl = baseUrl
|
||||
this.client = new GoogleGenAI({ apiKey })
|
||||
}
|
||||
|
||||
getModel(modelId: string) {
|
||||
let id = modelId
|
||||
let info: ModelInfo = geminiModels[id as GeminiModelId]
|
||||
|
||||
if (id?.endsWith(":thinking")) {
|
||||
id = id.slice(0, -":thinking".length)
|
||||
|
||||
if (geminiModels[id as GeminiModelId]) {
|
||||
info = geminiModels[id as GeminiModelId]
|
||||
|
||||
return {
|
||||
id,
|
||||
info,
|
||||
thinkingConfig: undefined,
|
||||
maxOutputTokens: info.maxTokens ?? undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!info) {
|
||||
id = geminiDefaultModelId
|
||||
info = geminiModels[geminiDefaultModelId]
|
||||
}
|
||||
|
||||
return { id, info }
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
@@ -53,6 +88,8 @@ export class GeminiProvider implements BaseLLMProvider {
|
||||
)
|
||||
}
|
||||
|
||||
const { id: modelName, thinkingConfig, maxOutputTokens, info } = this.getModel(model.modelId)
|
||||
|
||||
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
||||
const systemInstruction: string | undefined =
|
||||
systemMessages.length > 0
|
||||
@@ -60,30 +97,26 @@ export class GeminiProvider implements BaseLLMProvider {
|
||||
: undefined
|
||||
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: request.model,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
topP: request.top_p,
|
||||
presencePenalty: request.presence_penalty,
|
||||
frequencyPenalty: request.frequency_penalty,
|
||||
},
|
||||
systemInstruction: systemInstruction,
|
||||
})
|
||||
|
||||
const result = await model.generateContent(
|
||||
{
|
||||
systemInstruction: systemInstruction,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
const config: GenerateContentConfig = {
|
||||
systemInstruction,
|
||||
httpOptions: this.baseUrl ? { baseUrl: this.baseUrl } : undefined,
|
||||
thinkingConfig,
|
||||
maxOutputTokens: maxOutputTokens ?? request.max_tokens,
|
||||
temperature: request.temperature ?? 0,
|
||||
topP: request.top_p ?? 1,
|
||||
presencePenalty: request.presence_penalty ?? 0,
|
||||
frequencyPenalty: request.frequency_penalty ?? 0,
|
||||
}
|
||||
const params: GenerateContentParameters = {
|
||||
model: modelName,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
config,
|
||||
}
|
||||
|
||||
const result = await this.client.models.generateContent(params)
|
||||
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
||||
return GeminiProvider.parseNonStreamingResponse(
|
||||
result,
|
||||
@@ -115,6 +148,7 @@ export class GeminiProvider implements BaseLLMProvider {
|
||||
`Gemini API key is missing. Please set it in settings menu.`,
|
||||
)
|
||||
}
|
||||
const { id: modelName, thinkingConfig, maxOutputTokens, info } = this.getModel(model.modelId)
|
||||
|
||||
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
||||
const systemInstruction: string | undefined =
|
||||
@@ -123,30 +157,25 @@ export class GeminiProvider implements BaseLLMProvider {
|
||||
: undefined
|
||||
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: request.model,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
topP: request.top_p,
|
||||
presencePenalty: request.presence_penalty,
|
||||
frequencyPenalty: request.frequency_penalty,
|
||||
},
|
||||
systemInstruction: systemInstruction,
|
||||
})
|
||||
|
||||
const stream = await model.generateContentStream(
|
||||
{
|
||||
systemInstruction: systemInstruction,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
const config: GenerateContentConfig = {
|
||||
systemInstruction,
|
||||
httpOptions: this.baseUrl ? { baseUrl: this.baseUrl } : undefined,
|
||||
thinkingConfig,
|
||||
maxOutputTokens: maxOutputTokens ?? request.max_tokens,
|
||||
temperature: request.temperature ?? 0,
|
||||
topP: request.top_p ?? 1,
|
||||
presencePenalty: request.presence_penalty ?? 0,
|
||||
frequencyPenalty: request.frequency_penalty ?? 0,
|
||||
}
|
||||
const params: GenerateContentParameters = {
|
||||
model: modelName,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
config,
|
||||
}
|
||||
|
||||
const stream = await this.client.models.generateContentStream(params)
|
||||
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
||||
return this.streamResponseGenerator(stream, request.model, messageId)
|
||||
} catch (error) {
|
||||
@@ -165,11 +194,11 @@ export class GeminiProvider implements BaseLLMProvider {
|
||||
}
|
||||
|
||||
private async *streamResponseGenerator(
|
||||
stream: GenerateContentStreamResult,
|
||||
stream: AsyncGenerator<GenerateContentResponse>,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream.stream) {
|
||||
for await (const chunk of stream) {
|
||||
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
|
||||
}
|
||||
}
|
||||
@@ -215,7 +244,7 @@ export class GeminiProvider implements BaseLLMProvider {
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: GenerateContentResult,
|
||||
response: GenerateContentResponse,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): LLMResponseNonStreaming {
|
||||
@@ -224,9 +253,9 @@ export class GeminiProvider implements BaseLLMProvider {
|
||||
choices: [
|
||||
{
|
||||
finish_reason:
|
||||
response.response.candidates?.[0]?.finishReason ?? null,
|
||||
response.candidates?.[0]?.finishReason ?? null,
|
||||
message: {
|
||||
content: response.response.text(),
|
||||
content: response.candidates?.[0]?.content?.parts?.[0]?.text ?? '',
|
||||
role: 'assistant',
|
||||
},
|
||||
},
|
||||
@@ -234,29 +263,32 @@ export class GeminiProvider implements BaseLLMProvider {
|
||||
created: Date.now(),
|
||||
model: model,
|
||||
object: 'chat.completion',
|
||||
usage: response.response.usageMetadata
|
||||
usage: response.usageMetadata
|
||||
? {
|
||||
prompt_tokens: response.response.usageMetadata.promptTokenCount,
|
||||
prompt_tokens: response.usageMetadata.promptTokenCount,
|
||||
completion_tokens:
|
||||
response.response.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: response.response.usageMetadata.totalTokenCount,
|
||||
response.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: response.usageMetadata.totalTokenCount,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: EnhancedGenerateContentResponse,
|
||||
chunk: GenerateContentResponse,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): LLMResponseStreaming {
|
||||
const firstCandidate = chunk.candidates?.[0]
|
||||
const textContent = firstCandidate?.content?.parts?.[0]?.text || ''
|
||||
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
|
||||
finish_reason: firstCandidate?.finishReason ?? null,
|
||||
delta: {
|
||||
content: chunk.text(),
|
||||
content: textContent,
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user