init
This commit is contained in:
151
src/core/rag/embedding.ts
Normal file
151
src/core/rag/embedding.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
import { EmbeddingModel } from '../../types/embedding'
|
||||
import {
|
||||
LLMAPIKeyNotSetException,
|
||||
LLMBaseUrlNotSetException,
|
||||
LLMRateLimitExceededException,
|
||||
} from '../llm/exception'
|
||||
import { NoStainlessOpenAI } from '../llm/ollama'
|
||||
|
||||
export const getEmbeddingModel = (
|
||||
embeddingModelId: string,
|
||||
apiKeys: {
|
||||
openAIApiKey: string
|
||||
geminiApiKey: string
|
||||
},
|
||||
ollamaBaseUrl: string,
|
||||
): EmbeddingModel => {
|
||||
switch (embeddingModelId) {
|
||||
case 'text-embedding-3-small': {
|
||||
const openai = new OpenAI({
|
||||
apiKey: apiKeys.openAIApiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return {
|
||||
id: 'text-embedding-3-small',
|
||||
dimension: 1536,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
if (!openai.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'text-embedding-3-small',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
} catch (error) {
|
||||
if (
|
||||
error.status === 429 &&
|
||||
error.message.toLowerCase().includes('rate limit')
|
||||
) {
|
||||
throw new LLMRateLimitExceededException(
|
||||
'OpenAI API rate limit exceeded. Please try again later.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'text-embedding-004': {
|
||||
const client = new GoogleGenerativeAI(apiKeys.geminiApiKey)
|
||||
const model = client.getGenerativeModel({ model: 'text-embedding-004' })
|
||||
return {
|
||||
id: 'text-embedding-004',
|
||||
dimension: 768,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
const response = await model.embedContent(text)
|
||||
return response.embedding.values
|
||||
} catch (error) {
|
||||
if (
|
||||
error.status === 429 &&
|
||||
error.message.includes('RATE_LIMIT_EXCEEDED')
|
||||
) {
|
||||
throw new LLMRateLimitExceededException(
|
||||
'Gemini API rate limit exceeded. Please try again later.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'nomic-embed-text': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'nomic-embed-text',
|
||||
dimension: 768,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'nomic-embed-text',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'mxbai-embed-large': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'mxbai-embed-large',
|
||||
dimension: 1024,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'mxbai-embed-large',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'bge-m3': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'bge-m3',
|
||||
dimension: 1024,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'bge-m3',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
default:
|
||||
throw new Error('Invalid embedding model')
|
||||
}
|
||||
}
|
||||
124
src/core/rag/rag-engine.ts
Normal file
124
src/core/rag/rag-engine.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
import { App } from 'obsidian'
|
||||
|
||||
import { QueryProgressState } from '../../components/chat-view/QueryProgress'
|
||||
import { DBManager } from '../../database/database-manager'
|
||||
import { VectorManager } from '../../database/modules/vector/vector-manager'
|
||||
import { SelectVector } from '../../database/schema'
|
||||
import { EmbeddingModel } from '../../types/embedding'
|
||||
import { InfioSettings } from '../../types/settings'
|
||||
|
||||
import { getEmbeddingModel } from './embedding'
|
||||
|
||||
export class RAGEngine {
|
||||
private app: App
|
||||
private settings: InfioSettings
|
||||
private vectorManager: VectorManager
|
||||
private embeddingModel: EmbeddingModel | null = null
|
||||
|
||||
constructor(
|
||||
app: App,
|
||||
settings: InfioSettings,
|
||||
dbManager: DBManager,
|
||||
) {
|
||||
this.app = app
|
||||
this.settings = settings
|
||||
this.vectorManager = dbManager.getVectorManager()
|
||||
this.embeddingModel = getEmbeddingModel(
|
||||
settings.embeddingModelId,
|
||||
{
|
||||
openAIApiKey: settings.openAIApiKey,
|
||||
geminiApiKey: settings.geminiApiKey,
|
||||
},
|
||||
settings.ollamaEmbeddingModel.baseUrl,
|
||||
)
|
||||
}
|
||||
|
||||
setSettings(settings: InfioSettings) {
|
||||
this.settings = settings
|
||||
this.embeddingModel = getEmbeddingModel(
|
||||
settings.embeddingModelId,
|
||||
{
|
||||
openAIApiKey: settings.openAIApiKey,
|
||||
geminiApiKey: settings.geminiApiKey,
|
||||
},
|
||||
settings.ollamaEmbeddingModel.baseUrl,
|
||||
)
|
||||
}
|
||||
|
||||
// TODO: Implement automatic vault re-indexing when settings are changed.
|
||||
// Currently, users must manually re-index the vault.
|
||||
async updateVaultIndex(
|
||||
options: { reindexAll: boolean } = {
|
||||
reindexAll: false,
|
||||
},
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void,
|
||||
): Promise<void> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
await this.vectorManager.updateVaultIndex(
|
||||
this.embeddingModel,
|
||||
{
|
||||
chunkSize: this.settings.ragOptions.chunkSize,
|
||||
excludePatterns: this.settings.ragOptions.excludePatterns,
|
||||
includePatterns: this.settings.ragOptions.includePatterns,
|
||||
reindexAll: options.reindexAll,
|
||||
},
|
||||
(indexProgress) => {
|
||||
onQueryProgressChange?.({
|
||||
type: 'indexing',
|
||||
indexProgress,
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
async processQuery({
|
||||
query,
|
||||
scope,
|
||||
onQueryProgressChange,
|
||||
}: {
|
||||
query: string
|
||||
scope?: {
|
||||
files: string[]
|
||||
folders: string[]
|
||||
}
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
||||
}): Promise<
|
||||
(Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
})[]
|
||||
> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
// TODO: Decide the vault index update strategy.
|
||||
// Current approach: Update on every query.
|
||||
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
||||
const queryEmbedding = await this.getQueryEmbedding(query)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying',
|
||||
})
|
||||
const queryResult = await this.vectorManager.performSimilaritySearch(
|
||||
queryEmbedding,
|
||||
this.embeddingModel,
|
||||
{
|
||||
minSimilarity: this.settings.ragOptions.minSimilarity,
|
||||
limit: this.settings.ragOptions.limit,
|
||||
scope,
|
||||
},
|
||||
)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying-done',
|
||||
queryResult,
|
||||
})
|
||||
return queryResult
|
||||
}
|
||||
|
||||
private async getQueryEmbedding(query: string): Promise<number[]> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
return this.embeddingModel.getEmbedding(query)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user