mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-05-09 08:30:09 +00:00
更新 RAGEngine 和嵌入管理器以支持嵌入管理器的传递,添加本地提供者的嵌入模型加载逻辑,优化错误处理和消息处理机制。
This commit is contained in:
@@ -16,10 +16,67 @@ import {
|
||||
} from '../llm/exception'
|
||||
import { NoStainlessOpenAI } from '../llm/ollama'
|
||||
|
||||
// EmbeddingManager 类型定义
|
||||
type EmbeddingManager = {
|
||||
modelLoaded: boolean
|
||||
currentModel: string | null
|
||||
loadModel(modelId: string, useGpu: boolean): Promise<any>
|
||||
embed(text: string): Promise<{ vec: number[] }>
|
||||
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
|
||||
}
|
||||
|
||||
export const getEmbeddingModel = (
|
||||
settings: InfioSettings,
|
||||
embeddingManager?: EmbeddingManager,
|
||||
): EmbeddingModel => {
|
||||
switch (settings.embeddingModelProvider) {
|
||||
case ApiProvider.LocalProvider: {
|
||||
if (!embeddingManager) {
|
||||
throw new Error('EmbeddingManager is required for LocalProvider')
|
||||
}
|
||||
|
||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||
if (!modelInfo) {
|
||||
throw new Error(`Embedding model ${settings.embeddingModelId} not found for provider ${settings.embeddingModelProvider}`)
|
||||
}
|
||||
|
||||
return {
|
||||
id: settings.embeddingModelId,
|
||||
dimension: modelInfo.dimensions,
|
||||
supportsBatch: true,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
// 确保模型已加载
|
||||
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
|
||||
console.log(`Loading model: ${settings.embeddingModelId}`)
|
||||
await embeddingManager.loadModel(settings.embeddingModelId, true)
|
||||
}
|
||||
|
||||
const result = await embeddingManager.embed(text)
|
||||
return result.vec
|
||||
} catch (error) {
|
||||
console.error('LocalProvider embedding error:', error)
|
||||
throw new Error(`LocalProvider embedding failed: ${error.message}`)
|
||||
}
|
||||
},
|
||||
getBatchEmbeddings: async (texts: string[]) => {
|
||||
try {
|
||||
// 确保模型已加载
|
||||
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
|
||||
console.log(`Loading model: ${settings.embeddingModelId}`)
|
||||
await embeddingManager.loadModel(settings.embeddingModelId, false)
|
||||
}
|
||||
|
||||
const results = await embeddingManager.embedBatch(texts)
|
||||
console.log('results', results)
|
||||
return results.map(result => result.vec)
|
||||
} catch (error) {
|
||||
console.error('LocalProvider batch embedding error:', error)
|
||||
throw new Error(`LocalProvider batch embedding failed: ${error.message}`)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case ApiProvider.Infio: {
|
||||
const openai = new OpenAI({
|
||||
apiKey: settings.infioProvider.apiKey,
|
||||
|
||||
@@ -10,9 +10,19 @@ import { InfioSettings } from '../../types/settings'
|
||||
|
||||
import { getEmbeddingModel } from './embedding'
|
||||
|
||||
// EmbeddingManager 类型定义
|
||||
type EmbeddingManager = {
|
||||
modelLoaded: boolean
|
||||
currentModel: string | null
|
||||
loadModel(modelId: string, useGpu: boolean): Promise<any>
|
||||
embed(text: string): Promise<{ vec: number[] }>
|
||||
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
|
||||
}
|
||||
|
||||
export class RAGEngine {
|
||||
private app: App
|
||||
private settings: InfioSettings
|
||||
private embeddingManager?: EmbeddingManager
|
||||
private vectorManager: VectorManager | null = null
|
||||
private embeddingModel: EmbeddingModel | null = null
|
||||
private initialized = false
|
||||
@@ -21,13 +31,15 @@ export class RAGEngine {
|
||||
app: App,
|
||||
settings: InfioSettings,
|
||||
dbManager: DBManager,
|
||||
embeddingManager?: EmbeddingManager,
|
||||
) {
|
||||
this.app = app
|
||||
this.settings = settings
|
||||
this.embeddingManager = embeddingManager
|
||||
this.vectorManager = dbManager.getVectorManager()
|
||||
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
|
||||
try {
|
||||
this.embeddingModel = getEmbeddingModel(settings)
|
||||
this.embeddingModel = getEmbeddingModel(settings, embeddingManager)
|
||||
} catch (error) {
|
||||
console.warn('Failed to initialize embedding model:', error)
|
||||
this.embeddingModel = null
|
||||
@@ -46,7 +58,7 @@ export class RAGEngine {
|
||||
this.settings = settings
|
||||
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
|
||||
try {
|
||||
this.embeddingModel = getEmbeddingModel(settings)
|
||||
this.embeddingModel = getEmbeddingModel(settings, this.embeddingManager)
|
||||
} catch (error) {
|
||||
console.warn('Failed to initialize embedding model:', error)
|
||||
this.embeddingModel = null
|
||||
|
||||
Reference in New Issue
Block a user