simple model config

This commit is contained in:
duanfuxiang
2025-02-17 13:06:22 +08:00
parent bf29a42baa
commit 025dc85c59
34 changed files with 12098 additions and 708 deletions

View File

@@ -2,7 +2,7 @@ import * as Handlebars from "handlebars";
import { Result, err, ok } from "neverthrow";
import { FewShotExample } from "../../settings/versions";
import { CustomLLMModel } from "../../types/llm/model";
import { LLMModel } from "../../types/llm/model";
import { RequestMessage } from '../../types/llm/request';
import { InfioSettings } from "../../types/settings";
import LLMManager from '../llm/manager';
@@ -25,9 +25,9 @@ import {
class LLMClient {
private llm: LLMManager;
private model: CustomLLMModel;
private model: LLMModel;
constructor(llm: LLMManager, model: CustomLLMModel) {
constructor(llm: LLMManager, model: LLMModel) {
this.llm = llm;
this.model = model;
}
@@ -100,17 +100,11 @@ class AutoComplete implements AutocompleteService {
postProcessors.push(new RemoveOverlap());
postProcessors.push(new RemoveWhitespace());
const llm_manager = new LLMManager({
deepseek: settings.deepseekApiKey,
openai: settings.openAIApiKey,
anthropic: settings.anthropicApiKey,
gemini: settings.geminiApiKey,
groq: settings.groqApiKey,
infio: settings.infioApiKey,
})
const model = settings.activeModels.find(
(option) => option.name === settings.chatModelId,
) as CustomLLMModel;
const llm_manager = new LLMManager(settings)
const model = {
provider: settings.applyModelProvider,
modelId: settings.applyModelId,
}
const llm = new LLMClient(llm_manager, model);
return new AutoComplete(

View File

@@ -6,7 +6,7 @@ import {
TextBlockParam,
} from '@anthropic-ai/sdk/resources/messages'
import { CustomLLMModel } from '../../types/llm/model'
import { LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
@@ -36,21 +36,14 @@ export class AnthropicProvider implements BaseLLMProvider {
}
async generateResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Anthropic API key is missing. Please set it in settings menu.',
)
}
this.client = new Anthropic({
baseURL: model.baseUrl,
apiKey: model.apiKey,
dangerouslyAllowBrowser: true
})
throw new LLMAPIKeyNotSetException(
'Anthropic API key is missing. Please set it in settings menu.',
)
}
const systemMessage = AnthropicProvider.validateSystemMessages(
@@ -89,21 +82,14 @@ export class AnthropicProvider implements BaseLLMProvider {
}
async streamResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Anthropic API key is missing. Please set it in settings menu.',
)
}
this.client = new Anthropic({
baseURL: model.baseUrl,
apiKey: model.apiKey,
dangerouslyAllowBrowser: true
})
throw new LLMAPIKeyNotSetException(
'Anthropic API key is missing. Please set it in settings menu.',
)
}
const systemMessage = AnthropicProvider.validateSystemMessages(

View File

@@ -1,4 +1,4 @@
import { CustomLLMModel } from '../../types/llm/model'
import { LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
@@ -11,12 +11,12 @@ import {
export type BaseLLMProvider = {
generateResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming>
streamResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>>

View File

@@ -7,7 +7,7 @@ import {
Part,
} from '@google/generative-ai'
import { CustomLLMModel } from '../../types/llm/model'
import { LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
@@ -43,18 +43,14 @@ export class GeminiProvider implements BaseLLMProvider {
}
async generateResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
this.apiKey = model.apiKey
this.client = new GoogleGenerativeAI(model.apiKey)
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
const systemMessages = request.messages.filter((m) => m.role === 'system')
@@ -110,18 +106,14 @@ export class GeminiProvider implements BaseLLMProvider {
}
async streamResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
this.apiKey = model.apiKey
this.client = new GoogleGenerativeAI(model.apiKey)
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
const systemMessages = request.messages.filter((m) => m.role === 'system')

View File

@@ -6,7 +6,7 @@ import {
ChatCompletionMessageParam,
} from 'groq-sdk/resources/chat/completions'
import { CustomLLMModel } from '../../types/llm/model'
import { LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
@@ -35,20 +35,14 @@ export class GroqProvider implements BaseLLMProvider {
}
async generateResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Groq API key is missing. Please set it in settings menu.',
)
}
this.client = new Groq({
apiKey: model.apiKey,
dangerouslyAllowBrowser: true,
})
throw new LLMAPIKeyNotSetException(
'Groq API key is missing. Please set it in settings menu.',
)
}
try {
@@ -78,20 +72,14 @@ export class GroqProvider implements BaseLLMProvider {
}
async streamResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Groq API key is missing. Please set it in settings menu.',
)
}
this.client = new Groq({
apiKey: model.apiKey,
dangerouslyAllowBrowser: true,
})
throw new LLMAPIKeyNotSetException(
'Groq API key is missing. Please set it in settings menu.',
)
}
try {

View File

@@ -4,12 +4,12 @@ import {
ChatCompletionChunk,
} from 'openai/resources/chat/completions'
import { CustomLLMModel } from '../../types/llm/model'
import { INFIO_BASE_URL } from '../../constants'
import { LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
RequestMessage
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
@@ -85,13 +85,13 @@ export class InfioProvider implements BaseLLMProvider {
// this.client = new OpenAI({ apiKey, dangerouslyAllowBrowser: true })
// this.adapter = new OpenAIMessageAdapter()
this.apiKey = apiKey
this.baseUrl = 'https://api.infio.com/api/raw_message'
this.baseUrl = INFIO_BASE_URL
}
async generateResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
// options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.apiKey) {
throw new LLMAPIKeyNotSetException(
@@ -107,7 +107,7 @@ export class InfioProvider implements BaseLLMProvider {
presence_penalty: request.presence_penalty,
max_tokens: request.max_tokens,
}
const options = {
const req_options = {
method: 'POST',
headers: {
Authorization: this.apiKey,
@@ -117,7 +117,7 @@ export class InfioProvider implements BaseLLMProvider {
body: JSON.stringify(req)
};
const response = await fetch(this.baseUrl, options);
const response = await fetch(this.baseUrl, req_options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
@@ -134,9 +134,8 @@ export class InfioProvider implements BaseLLMProvider {
}
async streamResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.apiKey) {
throw new LLMAPIKeyNotSetException(
@@ -154,7 +153,7 @@ export class InfioProvider implements BaseLLMProvider {
presence_penalty: request.presence_penalty,
max_tokens: request.max_tokens,
}
const options = {
const req_options = {
method: 'POST',
headers: {
Authorization: this.apiKey,
@@ -164,7 +163,7 @@ export class InfioProvider implements BaseLLMProvider {
body: JSON.stringify(req)
};
const response = await fetch(this.baseUrl, options);
const response = await fetch(this.baseUrl, req_options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}

View File

@@ -1,14 +1,15 @@
import { DEEPSEEK_BASE_URL } from '../../constants'
import { CustomLLMModel } from '../../types/llm/model'
import { ALIBABA_QWEN_BASE_URL, DEEPSEEK_BASE_URL, OPENROUTER_BASE_URL, SILICONFLOW_BASE_URL } from '../../constants'
import { ApiProvider, LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { InfioSettings } from '../../types/settings'
import { AnthropicProvider } from './anthropic'
import { GeminiProvider } from './gemini'
@@ -20,123 +21,147 @@ import { OpenAICompatibleProvider } from './openai-compatible-provider'
export type LLMManagerInterface = {
generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming>
streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>>
generateResponse(
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming>
streamResponse(
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>>
}
class LLMManager implements LLMManagerInterface {
private openaiProvider: OpenAIAuthenticatedProvider
private deepseekProvider: OpenAICompatibleProvider
private anthropicProvider: AnthropicProvider
private googleProvider: GeminiProvider
private groqProvider: GroqProvider
private infioProvider: InfioProvider
private ollamaProvider: OllamaProvider
private isInfioEnabled: boolean
private openaiProvider: OpenAIAuthenticatedProvider
private deepseekProvider: OpenAICompatibleProvider
private anthropicProvider: AnthropicProvider
private googleProvider: GeminiProvider
private groqProvider: GroqProvider
private infioProvider: InfioProvider
private openrouterProvider: OpenAICompatibleProvider
private siliconflowProvider: OpenAICompatibleProvider
private alibabaQwenProvider: OpenAICompatibleProvider
private ollamaProvider: OllamaProvider
private isInfioEnabled: boolean
constructor(apiKeys: {
deepseek?: string
openai?: string
anthropic?: string
gemini?: string
groq?: string
infio?: string
}) {
this.deepseekProvider = new OpenAICompatibleProvider(apiKeys.deepseek ?? '', DEEPSEEK_BASE_URL)
this.openaiProvider = new OpenAIAuthenticatedProvider(apiKeys.openai ?? '')
this.anthropicProvider = new AnthropicProvider(apiKeys.anthropic ?? '')
this.googleProvider = new GeminiProvider(apiKeys.gemini ?? '')
this.groqProvider = new GroqProvider(apiKeys.groq ?? '')
this.infioProvider = new InfioProvider(apiKeys.infio ?? '')
this.ollamaProvider = new OllamaProvider()
this.isInfioEnabled = !!apiKeys.infio
}
constructor(settings: InfioSettings) {
this.infioProvider = new InfioProvider(settings.infioProvider.apiKey)
this.openrouterProvider = new OpenAICompatibleProvider(settings.openrouterProvider.apiKey, OPENROUTER_BASE_URL)
this.siliconflowProvider = new OpenAICompatibleProvider(settings.siliconflowProvider.apiKey, SILICONFLOW_BASE_URL)
this.alibabaQwenProvider = new OpenAICompatibleProvider(settings.alibabaQwenProvider.apiKey, ALIBABA_QWEN_BASE_URL)
this.deepseekProvider = new OpenAICompatibleProvider(settings.deepseekProvider.apiKey, DEEPSEEK_BASE_URL)
this.openaiProvider = new OpenAIAuthenticatedProvider(settings.openaiProvider.apiKey)
this.anthropicProvider = new AnthropicProvider(settings.anthropicProvider.apiKey)
this.googleProvider = new GeminiProvider(settings.googleProvider.apiKey)
this.groqProvider = new GroqProvider(settings.groqProvider.apiKey)
this.ollamaProvider = new OllamaProvider(settings.groqProvider.baseUrl)
this.isInfioEnabled = !!settings.infioProvider.apiKey
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (this.isInfioEnabled) {
return await this.infioProvider.generateResponse(
model,
request,
options,
)
}
// use custom provider
switch (model.provider) {
case 'deepseek':
return await this.deepseekProvider.generateResponse(
model,
request,
options,
)
case 'openai':
return await this.openaiProvider.generateResponse(
model,
request,
options,
)
case 'anthropic':
return await this.anthropicProvider.generateResponse(
model,
request,
options,
)
case 'google':
return await this.googleProvider.generateResponse(
model,
request,
options,
)
case 'groq':
return await this.groqProvider.generateResponse(model, request, options)
case 'ollama':
return await this.ollamaProvider.generateResponse(
model,
request,
options,
)
}
}
async generateResponse(
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (this.isInfioEnabled) {
return await this.infioProvider.generateResponse(
model,
request,
)
}
// use custom provider
switch (model.provider) {
case ApiProvider.OpenRouter:
return await this.openrouterProvider.generateResponse(
model,
request,
options,
)
case ApiProvider.SiliconFlow:
return await this.siliconflowProvider.generateResponse(
model,
request,
options,
)
case ApiProvider.AlibabaQwen:
return await this.alibabaQwenProvider.generateResponse(
model,
request,
options,
)
case ApiProvider.Deepseek:
return await this.deepseekProvider.generateResponse(
model,
request,
options,
)
case ApiProvider.OpenAI:
return await this.openaiProvider.generateResponse(
model,
request,
options,
)
case ApiProvider.Anthropic:
return await this.anthropicProvider.generateResponse(
model,
request,
options,
)
case ApiProvider.Google:
return await this.googleProvider.generateResponse(
model,
request,
options,
)
case ApiProvider.Groq:
return await this.groqProvider.generateResponse(model, request, options)
case ApiProvider.Ollama:
return await this.ollamaProvider.generateResponse(
model,
request,
options,
)
default:
throw new Error(`Unsupported model provider: ${model.provider}`)
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (this.isInfioEnabled) {
return await this.infioProvider.streamResponse(model, request, options)
}
// use custom provider
switch (model.provider) {
case 'deepseek':
return await this.deepseekProvider.streamResponse(model, request, options)
case 'openai':
return await this.openaiProvider.streamResponse(model, request, options)
case 'anthropic':
return await this.anthropicProvider.streamResponse(
model,
request,
options,
)
case 'google':
return await this.googleProvider.streamResponse(model, request, options)
case 'groq':
return await this.groqProvider.streamResponse(model, request, options)
case 'ollama':
return await this.ollamaProvider.streamResponse(model, request, options)
}
}
async streamResponse(
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (this.isInfioEnabled) {
return await this.infioProvider.streamResponse(model, request)
}
// use custom provider
switch (model.provider) {
case ApiProvider.OpenRouter:
return await this.openrouterProvider.streamResponse(model, request, options)
case ApiProvider.SiliconFlow:
return await this.siliconflowProvider.streamResponse(model, request, options)
case ApiProvider.AlibabaQwen:
return await this.alibabaQwenProvider.streamResponse(model, request, options)
case ApiProvider.Deepseek:
return await this.deepseekProvider.streamResponse(model, request, options)
case ApiProvider.OpenAI:
return await this.openaiProvider.streamResponse(model, request, options)
case ApiProvider.Anthropic:
return await this.anthropicProvider.streamResponse(
model,
request,
options,
)
case ApiProvider.Google:
return await this.googleProvider.streamResponse(model, request, options)
case ApiProvider.Groq:
return await this.groqProvider.streamResponse(model, request, options)
case ApiProvider.Ollama:
return await this.ollamaProvider.streamResponse(model, request, options)
}
}
}
export default LLMManager

View File

@@ -7,7 +7,7 @@
import OpenAI from 'openai'
import { FinalRequestOptions } from 'openai/core'
import { CustomLLMModel } from '../../types/llm/model'
import { LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
@@ -19,7 +19,7 @@ import {
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import { LLMBaseUrlNotSetException, LLMModelNotSetException } from './exception'
import { LLMBaseUrlNotSetException } from './exception'
import { OpenAIMessageAdapter } from './openai-message-adapter'
export class NoStainlessOpenAI extends OpenAI {
@@ -35,7 +35,7 @@ export class NoStainlessOpenAI extends OpenAI {
{ retryCount = 0 }: { retryCount?: number } = {},
): { req: RequestInit; url: string; timeout: number } {
const req = super.buildRequest(options, { retryCount })
const headers = req.req.headers as Record<string, string>
const headers: Record<string, string> = req.req.headers
Object.keys(headers).forEach((k) => {
if (k.startsWith('x-stainless')) {
// eslint-disable-next-line @typescript-eslint/no-dynamic-delete
@@ -48,30 +48,26 @@ export class NoStainlessOpenAI extends OpenAI {
export class OllamaProvider implements BaseLLMProvider {
private adapter: OpenAIMessageAdapter
private baseUrl: string
constructor() {
constructor(baseUrl: string) {
this.adapter = new OpenAIMessageAdapter()
this.baseUrl = baseUrl
}
async generateResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!model.baseUrl) {
if (!this.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama base URL is missing. Please set it in settings menu.',
)
}
if (!model.name) {
throw new LLMModelNotSetException(
'Ollama model is missing. Please set it in settings menu.',
)
}
const client = new NoStainlessOpenAI({
baseURL: `${model.baseUrl}/v1`,
baseURL: `${this.baseUrl}/v1`,
apiKey: '',
dangerouslyAllowBrowser: true,
})
@@ -79,24 +75,18 @@ export class OllamaProvider implements BaseLLMProvider {
}
async streamResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!model.baseUrl) {
if (!this.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama base URL is missing. Please set it in settings menu.',
)
}
if (!model.name) {
throw new LLMModelNotSetException(
'Ollama model is missing. Please set it in settings menu.',
)
}
const client = new NoStainlessOpenAI({
baseURL: `${model.baseUrl}/v1`,
baseURL: `${this.baseUrl}/v1`,
apiKey: '',
dangerouslyAllowBrowser: true,
})

View File

@@ -1,6 +1,6 @@
import OpenAI from 'openai'
import { CustomLLMModel } from '../../types/llm/model'
import { LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
@@ -33,7 +33,7 @@ export class OpenAICompatibleProvider implements BaseLLMProvider {
}
async generateResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
@@ -47,7 +47,7 @@ export class OpenAICompatibleProvider implements BaseLLMProvider {
}
async streamResponse(
model: CustomLLMModel,
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {

View File

@@ -1,91 +1,77 @@
import OpenAI from 'openai'
import { CustomLLMModel } from '../../types/llm/model'
import { LLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
import { OpenAIMessageAdapter } from './openai-message-adapter'
export class OpenAIAuthenticatedProvider implements BaseLLMProvider {
private adapter: OpenAIMessageAdapter
private client: OpenAI
private adapter: OpenAIMessageAdapter
private client: OpenAI
constructor(apiKey: string) {
this.client = new OpenAI({
apiKey,
dangerouslyAllowBrowser: true,
})
this.adapter = new OpenAIMessageAdapter()
}
constructor(apiKey: string) {
this.client = new OpenAI({
apiKey,
dangerouslyAllowBrowser: true,
})
this.adapter = new OpenAIMessageAdapter()
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.baseUrl) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
this.client = new OpenAI({
apiKey: model.apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
})
}
try {
return this.adapter.generateResponse(this.client, request, options)
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async generateResponse(
model: LLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
try {
return this.adapter.generateResponse(this.client, request, options)
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.baseUrl) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
this.client = new OpenAI({
apiKey: model.apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
})
}
async streamResponse(
model: LLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
try {
return this.adapter.streamResponse(this.client, request, options)
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
try {
return this.adapter.streamResponse(this.client, request, options)
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
}

View File

@@ -1,7 +1,11 @@
import { GoogleGenerativeAI } from '@google/generative-ai'
import { OpenAI } from 'openai'
import { ALIBABA_QWEN_BASE_URL, OPENAI_BASE_URL, SILICONFLOW_BASE_URL } from "../../constants"
import { EmbeddingModel } from '../../types/embedding'
import { ApiProvider } from '../../types/llm/model'
import { InfioSettings } from '../../types/settings'
import { GetEmbeddingModelInfo } from '../../utils/api'
import {
LLMAPIKeyNotSetException,
LLMBaseUrlNotSetException,
@@ -10,22 +14,20 @@ import {
import { NoStainlessOpenAI } from '../llm/ollama'
export const getEmbeddingModel = (
embeddingModelId: string,
apiKeys: {
openAIApiKey: string
geminiApiKey: string
},
ollamaBaseUrl: string,
settings: InfioSettings,
): EmbeddingModel => {
switch (embeddingModelId) {
case 'text-embedding-3-small': {
switch (settings.embeddingModelProvider) {
case ApiProvider.OpenAI: {
const baseURL = settings.openaiProvider.useCustomUrl ? settings.openaiProvider.baseUrl : OPENAI_BASE_URL
const openai = new OpenAI({
apiKey: apiKeys.openAIApiKey,
apiKey: settings.openaiProvider.apiKey,
baseURL: baseURL,
dangerouslyAllowBrowser: true,
})
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
return {
id: 'text-embedding-3-small',
dimension: 1536,
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@@ -34,7 +36,7 @@ export const getEmbeddingModel = (
)
}
const embedding = await openai.embeddings.create({
model: 'text-embedding-3-small',
model: settings.embeddingModelId,
input: text,
})
return embedding.data[0].embedding
@@ -52,12 +54,87 @@ export const getEmbeddingModel = (
},
}
}
case 'text-embedding-004': {
const client = new GoogleGenerativeAI(apiKeys.geminiApiKey)
const model = client.getGenerativeModel({ model: 'text-embedding-004' })
case ApiProvider.SiliconFlow: {
const baseURL = settings.siliconflowProvider.useCustomUrl ? settings.siliconflowProvider.baseUrl : SILICONFLOW_BASE_URL
const openai = new OpenAI({
apiKey: settings.siliconflowProvider.apiKey,
baseURL: baseURL,
dangerouslyAllowBrowser: true,
})
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
return {
id: 'text-embedding-004',
dimension: 768,
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'SiliconFlow API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: text,
})
return embedding.data[0].embedding
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'SiliconFlow API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.AlibabaQwen: {
const baseURL = settings.alibabaQwenProvider.useCustomUrl ? settings.alibabaQwenProvider.baseUrl : ALIBABA_QWEN_BASE_URL
const openai = new OpenAI({
apiKey: settings.alibabaQwenProvider.apiKey,
baseURL: baseURL,
dangerouslyAllowBrowser: true,
})
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'Alibaba Qwen API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: text,
})
return embedding.data[0].embedding
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'Alibaba Qwen API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.Google: {
const client = new GoogleGenerativeAI(settings.googleProvider.apiKey)
const model = client.getGenerativeModel({ model: settings.embeddingModelId })
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
getEmbedding: async (text: string) => {
try {
const response = await model.embedContent(text)
@@ -76,69 +153,24 @@ export const getEmbeddingModel = (
},
}
}
case 'nomic-embed-text': {
case ApiProvider.Ollama: {
const openai = new NoStainlessOpenAI({
apiKey: '',
apiKey: settings.ollamaProvider.apiKey,
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
baseURL: `${settings.ollamaProvider.baseUrl}/v1`,
})
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
return {
id: 'nomic-embed-text',
dimension: 768,
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
if (!settings.ollamaProvider.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'nomic-embed-text',
input: text,
})
return embedding.data[0].embedding
},
}
}
case 'mxbai-embed-large': {
const openai = new NoStainlessOpenAI({
apiKey: '',
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
})
return {
id: 'mxbai-embed-large',
dimension: 1024,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'mxbai-embed-large',
input: text,
})
return embedding.data[0].embedding
},
}
}
case 'bge-m3': {
const openai = new NoStainlessOpenAI({
apiKey: '',
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
})
return {
id: 'bge-m3',
dimension: 1024,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'bge-m3',
model: settings.embeddingModelId,
input: text,
})
return embedding.data[0].embedding

View File

@@ -23,26 +23,12 @@ export class RAGEngine {
this.app = app
this.settings = settings
this.vectorManager = dbManager.getVectorManager()
this.embeddingModel = getEmbeddingModel(
settings.embeddingModelId,
{
openAIApiKey: settings.openAIApiKey,
geminiApiKey: settings.geminiApiKey,
},
settings.ollamaEmbeddingModel.baseUrl,
)
this.embeddingModel = getEmbeddingModel(settings)
}
setSettings(settings: InfioSettings) {
this.settings = settings
this.embeddingModel = getEmbeddingModel(
settings.embeddingModelId,
{
openAIApiKey: settings.openAIApiKey,
geminiApiKey: settings.geminiApiKey,
},
settings.ollamaEmbeddingModel.baseUrl,
)
this.embeddingModel = getEmbeddingModel(settings)
}
// TODO: Implement automatic vault re-indexing when settings are changed.