This commit is contained in:
duanfuxiang
2025-06-12 13:35:00 +08:00
parent 3ce55899df
commit b20b4f9e19
3 changed files with 296 additions and 53 deletions

View File

@@ -30,6 +30,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@@ -54,6 +55,31 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
console.log("use getBatchEmbeddings", texts.length)
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.OpenAI: {
@@ -67,6 +93,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@@ -91,6 +118,30 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.SiliconFlow: {
@@ -104,6 +155,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@@ -128,6 +180,30 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'SiliconFlow API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'SiliconFlow API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.AlibabaQwen: {
@@ -141,6 +217,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: false,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@@ -165,6 +242,30 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'Alibaba Qwen API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'Alibaba Qwen API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.Google: {
@@ -174,6 +275,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: false,
getEmbedding: async (text: string) => {
try {
const response = await model.embedContent(text)
@@ -190,6 +292,27 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
const embeddings = await Promise.all(
texts.map(async (text) => {
const response = await model.embedContent(text)
return response.embedding.values
})
)
return embeddings
} catch (error) {
if (
error.status === 429 &&
error.message.includes('RATE_LIMIT_EXCEEDED')
) {
throw new LLMRateLimitExceededException(
'Gemini API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.Ollama: {
@@ -201,6 +324,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: 0,
supportsBatch: false,
getEmbedding: async (text: string) => {
if (!settings.ollamaProvider.baseUrl) {
throw new LLMBaseUrlNotSetException(
@@ -213,6 +337,18 @@ export const getEmbeddingModel = (
})
return embedding.data[0].embedding
},
getBatchEmbeddings: async (texts: string[]) => {
if (!settings.ollamaProvider.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
},
}
}
case ApiProvider.OpenAICompatible: {
@@ -224,6 +360,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: 0,
supportsBatch: false,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@@ -249,6 +386,31 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI Compatible API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
encoding_format: "float",
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI Compatible API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
default: