simple model config
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
import { ALIBABA_QWEN_BASE_URL, OPENAI_BASE_URL, SILICONFLOW_BASE_URL } from "../../constants"
|
||||
import { EmbeddingModel } from '../../types/embedding'
|
||||
import { ApiProvider } from '../../types/llm/model'
|
||||
import { InfioSettings } from '../../types/settings'
|
||||
import { GetEmbeddingModelInfo } from '../../utils/api'
|
||||
import {
|
||||
LLMAPIKeyNotSetException,
|
||||
LLMBaseUrlNotSetException,
|
||||
@@ -10,22 +14,20 @@ import {
|
||||
import { NoStainlessOpenAI } from '../llm/ollama'
|
||||
|
||||
export const getEmbeddingModel = (
|
||||
embeddingModelId: string,
|
||||
apiKeys: {
|
||||
openAIApiKey: string
|
||||
geminiApiKey: string
|
||||
},
|
||||
ollamaBaseUrl: string,
|
||||
settings: InfioSettings,
|
||||
): EmbeddingModel => {
|
||||
switch (embeddingModelId) {
|
||||
case 'text-embedding-3-small': {
|
||||
switch (settings.embeddingModelProvider) {
|
||||
case ApiProvider.OpenAI: {
|
||||
const baseURL = settings.openaiProvider.useCustomUrl ? settings.openaiProvider.baseUrl : OPENAI_BASE_URL
|
||||
const openai = new OpenAI({
|
||||
apiKey: apiKeys.openAIApiKey,
|
||||
apiKey: settings.openaiProvider.apiKey,
|
||||
baseURL: baseURL,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||
return {
|
||||
id: 'text-embedding-3-small',
|
||||
dimension: 1536,
|
||||
id: settings.embeddingModelId,
|
||||
dimension: modelInfo.dimensions,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
if (!openai.apiKey) {
|
||||
@@ -34,7 +36,7 @@ export const getEmbeddingModel = (
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'text-embedding-3-small',
|
||||
model: settings.embeddingModelId,
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
@@ -52,12 +54,87 @@ export const getEmbeddingModel = (
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'text-embedding-004': {
|
||||
const client = new GoogleGenerativeAI(apiKeys.geminiApiKey)
|
||||
const model = client.getGenerativeModel({ model: 'text-embedding-004' })
|
||||
case ApiProvider.SiliconFlow: {
|
||||
const baseURL = settings.siliconflowProvider.useCustomUrl ? settings.siliconflowProvider.baseUrl : SILICONFLOW_BASE_URL
|
||||
const openai = new OpenAI({
|
||||
apiKey: settings.siliconflowProvider.apiKey,
|
||||
baseURL: baseURL,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||
return {
|
||||
id: 'text-embedding-004',
|
||||
dimension: 768,
|
||||
id: settings.embeddingModelId,
|
||||
dimension: modelInfo.dimensions,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
if (!openai.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'SiliconFlow API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: settings.embeddingModelId,
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
} catch (error) {
|
||||
if (
|
||||
error.status === 429 &&
|
||||
error.message.toLowerCase().includes('rate limit')
|
||||
) {
|
||||
throw new LLMRateLimitExceededException(
|
||||
'SiliconFlow API rate limit exceeded. Please try again later.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case ApiProvider.AlibabaQwen: {
|
||||
const baseURL = settings.alibabaQwenProvider.useCustomUrl ? settings.alibabaQwenProvider.baseUrl : ALIBABA_QWEN_BASE_URL
|
||||
const openai = new OpenAI({
|
||||
apiKey: settings.alibabaQwenProvider.apiKey,
|
||||
baseURL: baseURL,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||
return {
|
||||
id: settings.embeddingModelId,
|
||||
dimension: modelInfo.dimensions,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
if (!openai.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Alibaba Qwen API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: settings.embeddingModelId,
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
} catch (error) {
|
||||
if (
|
||||
error.status === 429 &&
|
||||
error.message.toLowerCase().includes('rate limit')
|
||||
) {
|
||||
throw new LLMRateLimitExceededException(
|
||||
'Alibaba Qwen API rate limit exceeded. Please try again later.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case ApiProvider.Google: {
|
||||
const client = new GoogleGenerativeAI(settings.googleProvider.apiKey)
|
||||
const model = client.getGenerativeModel({ model: settings.embeddingModelId })
|
||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||
return {
|
||||
id: settings.embeddingModelId,
|
||||
dimension: modelInfo.dimensions,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
const response = await model.embedContent(text)
|
||||
@@ -76,69 +153,24 @@ export const getEmbeddingModel = (
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'nomic-embed-text': {
|
||||
case ApiProvider.Ollama: {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
apiKey: settings.ollamaProvider.apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
baseURL: `${settings.ollamaProvider.baseUrl}/v1`,
|
||||
})
|
||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||
return {
|
||||
id: 'nomic-embed-text',
|
||||
dimension: 768,
|
||||
id: settings.embeddingModelId,
|
||||
dimension: modelInfo.dimensions,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
if (!settings.ollamaProvider.baseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'nomic-embed-text',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'mxbai-embed-large': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'mxbai-embed-large',
|
||||
dimension: 1024,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'mxbai-embed-large',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'bge-m3': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'bge-m3',
|
||||
dimension: 1024,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'bge-m3',
|
||||
model: settings.embeddingModelId,
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
|
||||
Reference in New Issue
Block a user