更新 RAGEngine 和嵌入管理器以支持嵌入管理器的传递,添加本地提供者的嵌入模型加载逻辑,优化错误处理和消息处理机制。

This commit is contained in:
duanfuxiang
2025-07-04 15:52:00 +08:00
parent bed96a5233
commit 4e139ecc4f
7 changed files with 453 additions and 312 deletions

View File

@@ -16,10 +16,67 @@ import {
} from '../llm/exception'
import { NoStainlessOpenAI } from '../llm/ollama'
// EmbeddingManager 类型定义
type EmbeddingManager = {
modelLoaded: boolean
currentModel: string | null
loadModel(modelId: string, useGpu: boolean): Promise<any>
embed(text: string): Promise<{ vec: number[] }>
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
}
export const getEmbeddingModel = (
settings: InfioSettings,
embeddingManager?: EmbeddingManager,
): EmbeddingModel => {
switch (settings.embeddingModelProvider) {
case ApiProvider.LocalProvider: {
if (!embeddingManager) {
throw new Error('EmbeddingManager is required for LocalProvider')
}
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
if (!modelInfo) {
throw new Error(`Embedding model ${settings.embeddingModelId} not found for provider ${settings.embeddingModelProvider}`)
}
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => {
try {
// 确保模型已加载
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
console.log(`Loading model: ${settings.embeddingModelId}`)
await embeddingManager.loadModel(settings.embeddingModelId, true)
}
const result = await embeddingManager.embed(text)
return result.vec
} catch (error) {
console.error('LocalProvider embedding error:', error)
throw new Error(`LocalProvider embedding failed: ${error.message}`)
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
// 确保模型已加载
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
console.log(`Loading model: ${settings.embeddingModelId}`)
await embeddingManager.loadModel(settings.embeddingModelId, false)
}
const results = await embeddingManager.embedBatch(texts)
console.log('results', results)
return results.map(result => result.vec)
} catch (error) {
console.error('LocalProvider batch embedding error:', error)
throw new Error(`LocalProvider batch embedding failed: ${error.message}`)
}
},
}
}
case ApiProvider.Infio: {
const openai = new OpenAI({
apiKey: settings.infioProvider.apiKey,

View File

@@ -10,9 +10,19 @@ import { InfioSettings } from '../../types/settings'
import { getEmbeddingModel } from './embedding'
// EmbeddingManager 类型定义
type EmbeddingManager = {
modelLoaded: boolean
currentModel: string | null
loadModel(modelId: string, useGpu: boolean): Promise<any>
embed(text: string): Promise<{ vec: number[] }>
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
}
export class RAGEngine {
private app: App
private settings: InfioSettings
private embeddingManager?: EmbeddingManager
private vectorManager: VectorManager | null = null
private embeddingModel: EmbeddingModel | null = null
private initialized = false
@@ -21,13 +31,15 @@ export class RAGEngine {
app: App,
settings: InfioSettings,
dbManager: DBManager,
embeddingManager?: EmbeddingManager,
) {
this.app = app
this.settings = settings
this.embeddingManager = embeddingManager
this.vectorManager = dbManager.getVectorManager()
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
try {
this.embeddingModel = getEmbeddingModel(settings)
this.embeddingModel = getEmbeddingModel(settings, embeddingManager)
} catch (error) {
console.warn('Failed to initialize embedding model:', error)
this.embeddingModel = null
@@ -46,7 +58,7 @@ export class RAGEngine {
this.settings = settings
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
try {
this.embeddingModel = getEmbeddingModel(settings)
this.embeddingModel = getEmbeddingModel(settings, this.embeddingManager)
} catch (error) {
console.warn('Failed to initialize embedding model:', error)
this.embeddingModel = null