mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-05-13 19:11:02 +00:00
325 lines
8.2 KiB
TypeScript
325 lines
8.2 KiB
TypeScript
import {
|
|
Content,
|
|
GoogleGenAI,
|
|
Part,
|
|
type GenerateContentConfig,
|
|
type GenerateContentParameters,
|
|
type GenerateContentResponse,
|
|
} from "@google/genai"
|
|
|
|
import { LLMModel } from '../../types/llm/model'
|
|
import {
|
|
LLMOptions,
|
|
LLMRequestNonStreaming,
|
|
LLMRequestStreaming,
|
|
RequestMessage,
|
|
} from '../../types/llm/request'
|
|
import {
|
|
LLMResponseNonStreaming,
|
|
LLMResponseStreaming,
|
|
} from '../../types/llm/response'
|
|
import {
|
|
GeminiModelId,
|
|
ModelInfo,
|
|
geminiDefaultModelId,
|
|
geminiModels
|
|
} from "../../utils/api"
|
|
import { parseImageDataUrl } from '../../utils/image'
|
|
|
|
import { BaseLLMProvider } from './base'
|
|
import {
|
|
LLMAPIKeyInvalidException,
|
|
LLMAPIKeyNotSetException,
|
|
} from './exception'
|
|
|
|
/**
|
|
* Note on OpenAI Compatibility API:
|
|
* Gemini provides an OpenAI-compatible endpoint (https://ai.google.dev/gemini-api/docs/openai)
|
|
* which allows using the OpenAI SDK with Gemini models. However, there are currently CORS issues
|
|
* preventing its use in Obsidian. Consider switching to this endpoint in the future once these
|
|
* issues are resolved.
|
|
*/
|
|
export class GeminiProvider implements BaseLLMProvider {
|
|
private client: GoogleGenAI
|
|
private apiKey: string
|
|
private baseUrl: string
|
|
|
|
constructor(apiKey: string, baseUrl?: string) {
|
|
this.apiKey = apiKey
|
|
this.baseUrl = baseUrl
|
|
this.client = new GoogleGenAI({ apiKey })
|
|
}
|
|
|
|
getModel(modelId: string) {
|
|
let id = modelId
|
|
let info: ModelInfo = geminiModels[id as GeminiModelId]
|
|
|
|
if (id?.endsWith(":thinking")) {
|
|
id = id.slice(0, -":thinking".length)
|
|
|
|
if (geminiModels[id as GeminiModelId]) {
|
|
info = geminiModels[id as GeminiModelId]
|
|
|
|
return {
|
|
id,
|
|
info,
|
|
thinkingConfig: undefined,
|
|
maxOutputTokens: info.maxTokens ?? undefined,
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!info) {
|
|
id = geminiDefaultModelId
|
|
info = geminiModels[geminiDefaultModelId]
|
|
}
|
|
|
|
return { id, info }
|
|
}
|
|
|
|
async generateResponse(
|
|
model: LLMModel,
|
|
request: LLMRequestNonStreaming,
|
|
options?: LLMOptions,
|
|
): Promise<LLMResponseNonStreaming> {
|
|
if (!this.apiKey) {
|
|
throw new LLMAPIKeyNotSetException(
|
|
`Gemini API key is missing. Please set it in settings menu.`,
|
|
)
|
|
}
|
|
|
|
const { id: modelName, thinkingConfig, maxOutputTokens, info } = this.getModel(model.modelId)
|
|
|
|
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
|
const systemInstruction: string | undefined =
|
|
systemMessages.length > 0
|
|
? systemMessages.map((m) => m.content).join('\n')
|
|
: undefined
|
|
|
|
try {
|
|
|
|
const config: GenerateContentConfig = {
|
|
systemInstruction,
|
|
httpOptions: this.baseUrl ? { baseUrl: this.baseUrl } : undefined,
|
|
thinkingConfig,
|
|
maxOutputTokens: maxOutputTokens ?? request.max_tokens,
|
|
temperature: request.temperature ?? 0,
|
|
topP: request.top_p ?? 1,
|
|
presencePenalty: request.presence_penalty ?? 0,
|
|
frequencyPenalty: request.frequency_penalty ?? 0,
|
|
}
|
|
const params: GenerateContentParameters = {
|
|
model: modelName,
|
|
contents: request.messages
|
|
.map((message) => GeminiProvider.parseRequestMessage(message))
|
|
.filter((m): m is Content => m !== null),
|
|
config,
|
|
}
|
|
|
|
const result = await this.client.models.generateContent(params)
|
|
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
|
return GeminiProvider.parseNonStreamingResponse(
|
|
result,
|
|
request.model,
|
|
messageId,
|
|
)
|
|
} catch (error) {
|
|
const isInvalidApiKey =
|
|
error.message?.includes('API_KEY_INVALID') ||
|
|
error.message?.includes('API key not valid')
|
|
|
|
if (isInvalidApiKey) {
|
|
throw new LLMAPIKeyInvalidException(
|
|
`Gemini API key is invalid. Please update it in settings menu.`,
|
|
)
|
|
}
|
|
|
|
throw error
|
|
}
|
|
}
|
|
|
|
async streamResponse(
|
|
model: LLMModel,
|
|
request: LLMRequestStreaming,
|
|
options?: LLMOptions,
|
|
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
|
if (!this.apiKey) {
|
|
throw new LLMAPIKeyNotSetException(
|
|
`Gemini API key is missing. Please set it in settings menu.`,
|
|
)
|
|
}
|
|
const { id: modelName, thinkingConfig, maxOutputTokens, info } = this.getModel(model.modelId)
|
|
|
|
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
|
const systemInstruction: string | undefined =
|
|
systemMessages.length > 0
|
|
? systemMessages.map((m) => m.content).join('\n')
|
|
: undefined
|
|
|
|
try {
|
|
const config: GenerateContentConfig = {
|
|
systemInstruction,
|
|
httpOptions: this.baseUrl ? { baseUrl: this.baseUrl } : undefined,
|
|
thinkingConfig,
|
|
maxOutputTokens: maxOutputTokens ?? request.max_tokens,
|
|
temperature: request.temperature ?? 0,
|
|
topP: request.top_p ?? 1,
|
|
presencePenalty: request.presence_penalty ?? 0,
|
|
frequencyPenalty: request.frequency_penalty ?? 0,
|
|
}
|
|
const params: GenerateContentParameters = {
|
|
model: modelName,
|
|
contents: request.messages
|
|
.map((message) => GeminiProvider.parseRequestMessage(message))
|
|
.filter((m): m is Content => m !== null),
|
|
config,
|
|
}
|
|
|
|
const stream = await this.client.models.generateContentStream(params)
|
|
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
|
return this.streamResponseGenerator(stream, request.model, messageId)
|
|
} catch (error) {
|
|
const isInvalidApiKey =
|
|
error.message?.includes('API_KEY_INVALID') ||
|
|
error.message?.includes('API key not valid')
|
|
|
|
if (isInvalidApiKey) {
|
|
throw new LLMAPIKeyInvalidException(
|
|
`Gemini API key is invalid. Please update it in settings menu.`,
|
|
)
|
|
}
|
|
|
|
throw error
|
|
}
|
|
}
|
|
|
|
private async *streamResponseGenerator(
|
|
stream: AsyncGenerator<GenerateContentResponse>,
|
|
model: string,
|
|
messageId: string,
|
|
): AsyncIterable<LLMResponseStreaming> {
|
|
for await (const chunk of stream) {
|
|
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
|
|
}
|
|
}
|
|
|
|
static parseRequestMessage(message: RequestMessage): Content | null {
|
|
if (message.role === 'system') {
|
|
return null
|
|
}
|
|
|
|
if (Array.isArray(message.content)) {
|
|
return {
|
|
role: message.role === 'user' ? 'user' : 'model',
|
|
parts: message.content.map((part) => {
|
|
switch (part.type) {
|
|
case 'text':
|
|
return { text: part.text }
|
|
case 'image_url': {
|
|
const { mimeType, base64Data } = parseImageDataUrl(
|
|
part.image_url.url,
|
|
)
|
|
GeminiProvider.validateImageType(mimeType)
|
|
|
|
return {
|
|
inlineData: {
|
|
data: base64Data,
|
|
mimeType,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}) as Part[],
|
|
}
|
|
}
|
|
|
|
return {
|
|
role: message.role === 'user' ? 'user' : 'model',
|
|
parts: [
|
|
{
|
|
text: message.content,
|
|
},
|
|
],
|
|
}
|
|
}
|
|
|
|
static parseNonStreamingResponse(
|
|
response: GenerateContentResponse,
|
|
model: string,
|
|
messageId: string,
|
|
): LLMResponseNonStreaming {
|
|
return {
|
|
id: messageId,
|
|
choices: [
|
|
{
|
|
finish_reason:
|
|
response.candidates?.[0]?.finishReason ?? null,
|
|
message: {
|
|
content: response.candidates?.[0]?.content?.parts?.[0]?.text ?? '',
|
|
role: 'assistant',
|
|
},
|
|
},
|
|
],
|
|
created: Date.now(),
|
|
model: model,
|
|
object: 'chat.completion',
|
|
usage: response.usageMetadata
|
|
? {
|
|
prompt_tokens: response.usageMetadata.promptTokenCount,
|
|
completion_tokens:
|
|
response.usageMetadata.candidatesTokenCount,
|
|
total_tokens: response.usageMetadata.totalTokenCount,
|
|
}
|
|
: undefined,
|
|
}
|
|
}
|
|
|
|
static parseStreamingResponseChunk(
|
|
chunk: GenerateContentResponse,
|
|
model: string,
|
|
messageId: string,
|
|
): LLMResponseStreaming {
|
|
const firstCandidate = chunk.candidates?.[0]
|
|
const textContent = firstCandidate?.content?.parts?.[0]?.text || ''
|
|
|
|
return {
|
|
id: messageId,
|
|
choices: [
|
|
{
|
|
finish_reason: firstCandidate?.finishReason ?? null,
|
|
delta: {
|
|
content: textContent,
|
|
},
|
|
},
|
|
],
|
|
created: Date.now(),
|
|
model: model,
|
|
object: 'chat.completion.chunk',
|
|
usage: chunk.usageMetadata
|
|
? {
|
|
prompt_tokens: chunk.usageMetadata.promptTokenCount,
|
|
completion_tokens: chunk.usageMetadata.candidatesTokenCount,
|
|
total_tokens: chunk.usageMetadata.totalTokenCount,
|
|
}
|
|
: undefined,
|
|
}
|
|
}
|
|
|
|
private static validateImageType(mimeType: string) {
|
|
const SUPPORTED_IMAGE_TYPES = [
|
|
'image/png',
|
|
'image/jpeg',
|
|
'image/webp',
|
|
'image/heic',
|
|
'image/heif',
|
|
]
|
|
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
|
|
throw new Error(
|
|
`Gemini does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
|
|
', ',
|
|
)}`,
|
|
)
|
|
}
|
|
}
|
|
}
|