mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-05-08 16:10:09 +00:00
更新嵌入管理器以支持 GPU 加速,调整批处理大小,优化内容处理逻辑,并添加获取数据库最大修改时间的功能以提高文件索引效率。同时修复了向量管理器中的类型问题,确保模型加载和嵌入过程的稳定性。
This commit is contained in:
@@ -27,7 +27,7 @@ export class VectorManager {
|
||||
constructor(app: App, dbManager: DBManager) {
|
||||
this.app = app
|
||||
this.dbManager = dbManager
|
||||
this.repository = new VectorRepository(app, dbManager.getPgClient())
|
||||
this.repository = new VectorRepository(app, dbManager.getPgClient() as any)
|
||||
}
|
||||
|
||||
async performSimilaritySearch(
|
||||
@@ -88,6 +88,7 @@ export class VectorManager {
|
||||
): Promise<void> {
|
||||
let filesToIndex: TFile[]
|
||||
if (options.reindexAll) {
|
||||
console.log("updateVaultIndex reindexAll")
|
||||
filesToIndex = await this.getFilesToIndex({
|
||||
embeddingModel: embeddingModel,
|
||||
excludePatterns: options.excludePatterns,
|
||||
@@ -96,17 +97,22 @@ export class VectorManager {
|
||||
})
|
||||
await this.repository.clearAllVectors(embeddingModel)
|
||||
} else {
|
||||
console.log("updateVaultIndex for update files")
|
||||
await this.cleanVectorsForDeletedFiles(embeddingModel)
|
||||
console.log("updateVaultIndex cleanVectorsForDeletedFiles")
|
||||
filesToIndex = await this.getFilesToIndex({
|
||||
embeddingModel: embeddingModel,
|
||||
excludePatterns: options.excludePatterns,
|
||||
includePatterns: options.includePatterns,
|
||||
})
|
||||
console.log("get files to index: ", filesToIndex.length)
|
||||
await this.repository.deleteVectorsForMultipleFiles(
|
||||
filesToIndex.map((file) => file.path),
|
||||
embeddingModel,
|
||||
)
|
||||
console.log("delete vectors for multiple files: ", filesToIndex.length)
|
||||
}
|
||||
console.log("get files to index: ", filesToIndex.length)
|
||||
|
||||
if (filesToIndex.length === 0) {
|
||||
return
|
||||
@@ -131,6 +137,7 @@ export class VectorManager {
|
||||
"",
|
||||
],
|
||||
});
|
||||
console.log("textSplitter chunkSize: ", options.chunkSize, "overlap: ", overlap)
|
||||
|
||||
const skippedFiles: string[] = []
|
||||
const contentChunks: InsertVector[] = (
|
||||
@@ -145,15 +152,16 @@ export class VectorManager {
|
||||
])
|
||||
return fileDocuments
|
||||
.map((chunk): InsertVector | null => {
|
||||
const content = removeMarkdown(chunk.pageContent).replace(/\0/g, '')
|
||||
if (!content || content.trim().length === 0) {
|
||||
// 保存原始内容,不在此处调用 removeMarkdown
|
||||
const rawContent = chunk.pageContent.replace(/\0/g, '')
|
||||
if (!rawContent || rawContent.trim().length === 0) {
|
||||
console.log("skipped chunk", chunk.pageContent)
|
||||
return null
|
||||
}
|
||||
return {
|
||||
path: file.path,
|
||||
mtime: file.stat.mtime,
|
||||
content,
|
||||
content: rawContent, // 保存原始内容
|
||||
embedding: [],
|
||||
metadata: {
|
||||
startLine: Number(chunk.metadata.loc.lines.from),
|
||||
@@ -171,6 +179,8 @@ export class VectorManager {
|
||||
)
|
||||
).flat()
|
||||
|
||||
console.log("contentChunks: ", contentChunks.length)
|
||||
|
||||
if (skippedFiles.length > 0) {
|
||||
console.warn(`跳过了 ${skippedFiles.length} 个有问题的文件:`, skippedFiles)
|
||||
new Notice(`跳过了 ${skippedFiles.length} 个有问题的文件`)
|
||||
@@ -186,31 +196,42 @@ export class VectorManager {
|
||||
// 减少批量大小以降低内存压力
|
||||
const insertBatchSize = 32
|
||||
let batchCount = 0
|
||||
|
||||
|
||||
try {
|
||||
if (embeddingModel.supportsBatch) {
|
||||
// 支持批量处理的提供商:使用流式处理逻辑
|
||||
const embeddingBatchSize = 32
|
||||
|
||||
const embeddingBatchSize = 32
|
||||
|
||||
for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) {
|
||||
batchCount++
|
||||
const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length))
|
||||
const batchTexts = batchChunks.map(chunk => chunk.content)
|
||||
|
||||
|
||||
const embeddedBatch: InsertVector[] = []
|
||||
|
||||
|
||||
await backOff(
|
||||
async () => {
|
||||
// 在嵌入之前处理 markdown,只处理一次
|
||||
const cleanedBatchData = batchChunks.map(chunk => {
|
||||
const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '')
|
||||
return { chunk, cleanContent }
|
||||
}).filter(({ cleanContent }) => cleanContent && cleanContent.trim().length > 0)
|
||||
|
||||
if (cleanedBatchData.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const batchTexts = cleanedBatchData.map(({ cleanContent }) => cleanContent)
|
||||
const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts)
|
||||
|
||||
|
||||
// 合并embedding结果到chunk数据
|
||||
for (let j = 0; j < batchChunks.length; j++) {
|
||||
for (let j = 0; j < cleanedBatchData.length; j++) {
|
||||
const { chunk, cleanContent } = cleanedBatchData[j]
|
||||
const embeddedChunk: InsertVector = {
|
||||
path: batchChunks[j].path,
|
||||
mtime: batchChunks[j].mtime,
|
||||
content: batchChunks[j].content,
|
||||
path: chunk.path,
|
||||
mtime: chunk.mtime,
|
||||
content: cleanContent, // 使用已经清理过的内容
|
||||
embedding: batchEmbeddings[j],
|
||||
metadata: batchChunks[j].metadata,
|
||||
metadata: chunk.metadata,
|
||||
}
|
||||
embeddedBatch.push(embeddedChunk)
|
||||
}
|
||||
@@ -229,7 +250,7 @@ export class VectorManager {
|
||||
// 清理批次数据
|
||||
embeddedBatch.length = 0
|
||||
}
|
||||
|
||||
|
||||
embeddingProgress.completed += batchChunks.length
|
||||
updateProgress?.({
|
||||
completedChunks: embeddingProgress.completed,
|
||||
@@ -244,17 +265,17 @@ export class VectorManager {
|
||||
// 不支持批量处理的提供商:使用流式处理逻辑
|
||||
const limit = pLimit(32) // 从50降低到10,减少并发压力
|
||||
const abortController = new AbortController()
|
||||
|
||||
|
||||
// 流式处理:分批处理并立即插入
|
||||
for (let i = 0; i < contentChunks.length; i += insertBatchSize) {
|
||||
if (abortController.signal.aborted) {
|
||||
throw new Error('Operation was aborted')
|
||||
}
|
||||
|
||||
|
||||
batchCount++
|
||||
const batchChunks = contentChunks.slice(i, Math.min(i + insertBatchSize, contentChunks.length))
|
||||
const embeddedBatch: InsertVector[] = []
|
||||
|
||||
|
||||
const tasks = batchChunks.map((chunk) =>
|
||||
limit(async () => {
|
||||
if (abortController.signal.aborted) {
|
||||
@@ -263,11 +284,18 @@ export class VectorManager {
|
||||
try {
|
||||
await backOff(
|
||||
async () => {
|
||||
const embedding = await embeddingModel.getEmbedding(chunk.content)
|
||||
// 在嵌入之前处理 markdown
|
||||
const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '')
|
||||
// 跳过清理后为空的内容
|
||||
if (!cleanContent || cleanContent.trim().length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const embedding = await embeddingModel.getEmbedding(cleanContent)
|
||||
const embeddedChunk = {
|
||||
path: chunk.path,
|
||||
mtime: chunk.mtime,
|
||||
content: chunk.content,
|
||||
content: cleanContent, // 使用清理后的内容
|
||||
embedding,
|
||||
metadata: chunk.metadata,
|
||||
}
|
||||
@@ -286,16 +314,16 @@ export class VectorManager {
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
await Promise.all(tasks)
|
||||
|
||||
|
||||
// 立即插入当前批次
|
||||
if (embeddedBatch.length > 0) {
|
||||
await this.repository.insertVectors(embeddedBatch, embeddingModel)
|
||||
// 清理批次数据
|
||||
embeddedBatch.length = 0
|
||||
}
|
||||
|
||||
|
||||
embeddingProgress.completed += batchChunks.length
|
||||
updateProgress?.({
|
||||
completedChunks: embeddingProgress.completed,
|
||||
@@ -339,9 +367,23 @@ export class VectorManager {
|
||||
)
|
||||
|
||||
// Embed the files
|
||||
const textSplitter = new MarkdownTextSplitter({
|
||||
const overlap = Math.floor(chunkSize * 0.15)
|
||||
const textSplitter = new RecursiveCharacterTextSplitter({
|
||||
chunkSize: chunkSize,
|
||||
chunkOverlap: Math.floor(chunkSize * 0.15)
|
||||
chunkOverlap: overlap,
|
||||
separators: [
|
||||
"\n\n",
|
||||
"\n",
|
||||
".",
|
||||
",",
|
||||
" ",
|
||||
"\u200b", // Zero-width space
|
||||
"\uff0c", // Fullwidth comma
|
||||
"\u3001", // Ideographic comma
|
||||
"\uff0e", // Fullwidth full stop
|
||||
"\u3002", // Ideographic full stop
|
||||
"",
|
||||
],
|
||||
});
|
||||
let fileContent = await this.app.vault.cachedRead(file)
|
||||
// 清理null字节,防止PostgreSQL UTF8编码错误
|
||||
@@ -352,14 +394,15 @@ export class VectorManager {
|
||||
|
||||
const contentChunks: InsertVector[] = fileDocuments
|
||||
.map((chunk): InsertVector | null => {
|
||||
const content = removeMarkdown(chunk.pageContent).replace(/\0/g, '')
|
||||
if (!content || content.trim().length === 0) {
|
||||
// 保存原始内容,不在此处调用 removeMarkdown
|
||||
const rawContent = String(chunk.pageContent || '').replace(/\0/g, '')
|
||||
if (!rawContent || rawContent.trim().length === 0) {
|
||||
return null
|
||||
}
|
||||
return {
|
||||
path: file.path,
|
||||
mtime: file.stat.mtime,
|
||||
content,
|
||||
content: rawContent, // 保存原始内容
|
||||
embedding: [],
|
||||
metadata: {
|
||||
startLine: Number(chunk.metadata.loc.lines.from),
|
||||
@@ -372,32 +415,43 @@ export class VectorManager {
|
||||
// 减少批量大小以降低内存压力
|
||||
const insertBatchSize = 16 // 从64降低到16
|
||||
let batchCount = 0
|
||||
|
||||
|
||||
try {
|
||||
if (embeddingModel.supportsBatch) {
|
||||
// 支持批量处理的提供商:使用流式处理逻辑
|
||||
const embeddingBatchSize = 16 // 从64降低到16
|
||||
|
||||
|
||||
for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) {
|
||||
batchCount++
|
||||
console.log(`Embedding batch ${batchCount} of ${Math.ceil(contentChunks.length / embeddingBatchSize)}`)
|
||||
const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length))
|
||||
const batchTexts = batchChunks.map(chunk => chunk.content)
|
||||
|
||||
|
||||
const embeddedBatch: InsertVector[] = []
|
||||
|
||||
|
||||
await backOff(
|
||||
async () => {
|
||||
// 在嵌入之前处理 markdown,只处理一次
|
||||
const cleanedBatchData = batchChunks.map(chunk => {
|
||||
const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '')
|
||||
return { chunk, cleanContent }
|
||||
}).filter(({ cleanContent }) => cleanContent && cleanContent.trim().length > 0)
|
||||
|
||||
if (cleanedBatchData.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const batchTexts = cleanedBatchData.map(({ cleanContent }) => cleanContent)
|
||||
const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts)
|
||||
|
||||
|
||||
// 合并embedding结果到chunk数据
|
||||
for (let j = 0; j < batchChunks.length; j++) {
|
||||
for (let j = 0; j < cleanedBatchData.length; j++) {
|
||||
const { chunk, cleanContent } = cleanedBatchData[j]
|
||||
const embeddedChunk: InsertVector = {
|
||||
path: batchChunks[j].path,
|
||||
mtime: batchChunks[j].mtime,
|
||||
content: batchChunks[j].content,
|
||||
path: chunk.path,
|
||||
mtime: chunk.mtime,
|
||||
content: cleanContent, // 使用已经清理过的内容
|
||||
embedding: batchEmbeddings[j],
|
||||
metadata: batchChunks[j].metadata,
|
||||
metadata: chunk.metadata,
|
||||
}
|
||||
embeddedBatch.push(embeddedChunk)
|
||||
}
|
||||
@@ -424,17 +478,17 @@ export class VectorManager {
|
||||
// 不支持批量处理的提供商:使用流式处理逻辑
|
||||
const limit = pLimit(10) // 从50降低到10
|
||||
const abortController = new AbortController()
|
||||
|
||||
|
||||
// 流式处理:分批处理并立即插入
|
||||
for (let i = 0; i < contentChunks.length; i += insertBatchSize) {
|
||||
if (abortController.signal.aborted) {
|
||||
throw new Error('Operation was aborted')
|
||||
}
|
||||
|
||||
|
||||
batchCount++
|
||||
const batchChunks = contentChunks.slice(i, Math.min(i + insertBatchSize, contentChunks.length))
|
||||
const embeddedBatch: InsertVector[] = []
|
||||
|
||||
|
||||
const tasks = batchChunks.map((chunk) =>
|
||||
limit(async () => {
|
||||
if (abortController.signal.aborted) {
|
||||
@@ -443,11 +497,18 @@ export class VectorManager {
|
||||
try {
|
||||
await backOff(
|
||||
async () => {
|
||||
const embedding = await embeddingModel.getEmbedding(chunk.content)
|
||||
// 在嵌入之前处理 markdown
|
||||
const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '')
|
||||
// 跳过清理后为空的内容
|
||||
if (!cleanContent || cleanContent.trim().length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const embedding = await embeddingModel.getEmbedding(cleanContent)
|
||||
const embeddedChunk = {
|
||||
path: chunk.path,
|
||||
mtime: chunk.mtime,
|
||||
content: chunk.content,
|
||||
content: cleanContent, // 使用清理后的内容
|
||||
embedding,
|
||||
metadata: chunk.metadata,
|
||||
}
|
||||
@@ -466,9 +527,9 @@ export class VectorManager {
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
await Promise.all(tasks)
|
||||
|
||||
|
||||
// 立即插入当前批次
|
||||
if (embeddedBatch.length > 0) {
|
||||
await this.repository.insertVectors(embeddedBatch, embeddingModel)
|
||||
@@ -522,8 +583,9 @@ export class VectorManager {
|
||||
excludePatterns: string[]
|
||||
includePatterns: string[]
|
||||
reindexAll?: boolean
|
||||
}): Promise<TFile[]> {
|
||||
}): Promise<TFile[]> {
|
||||
let filesToIndex = this.app.vault.getMarkdownFiles()
|
||||
console.log("get all vault files: ", filesToIndex.length)
|
||||
|
||||
filesToIndex = filesToIndex.filter((file) => {
|
||||
return !excludePatterns.some((pattern) => minimatch(file.path, pattern))
|
||||
@@ -538,39 +600,24 @@ export class VectorManager {
|
||||
if (reindexAll) {
|
||||
return filesToIndex
|
||||
}
|
||||
// Check for updated or new files
|
||||
filesToIndex = await Promise.all(
|
||||
filesToIndex.map(async (file) => {
|
||||
try {
|
||||
const fileChunks = await this.repository.getVectorsByFilePath(
|
||||
file.path,
|
||||
embeddingModel,
|
||||
)
|
||||
if (fileChunks.length === 0) {
|
||||
// File is not indexed, so we need to index it
|
||||
let fileContent = await this.app.vault.cachedRead(file)
|
||||
// 清理null字节,防止PostgreSQL UTF8编码错误
|
||||
fileContent = fileContent.replace(/\0/g, '')
|
||||
if (fileContent.length === 0) {
|
||||
// Ignore empty files
|
||||
return null
|
||||
}
|
||||
return file
|
||||
}
|
||||
const outOfDate = file.stat.mtime > fileChunks[0].mtime
|
||||
if (outOfDate) {
|
||||
// File has changed, so we need to re-index it
|
||||
console.log("File has changed, so we need to re-index it", file.path)
|
||||
return file
|
||||
}
|
||||
return null
|
||||
} catch (error) {
|
||||
console.warn(`跳过文件 ${file.path}:`, error.message)
|
||||
return null
|
||||
}
|
||||
}),
|
||||
).then((files) => files.filter(Boolean))
|
||||
|
||||
return filesToIndex
|
||||
// 优化流程:使用数据库最大mtime来过滤需要更新的文件
|
||||
try {
|
||||
const maxMtime = await this.repository.getMaxMtime(embeddingModel)
|
||||
console.log("Database max mtime:", maxMtime)
|
||||
|
||||
if (maxMtime === null) {
|
||||
// 数据库中没有任何向量,需要索引所有文件
|
||||
return filesToIndex
|
||||
}
|
||||
|
||||
// 筛选出在数据库最后更新时间之后修改的文件
|
||||
return filesToIndex.filter((file) => {
|
||||
return file.stat.mtime > maxMtime
|
||||
})
|
||||
} catch (error) {
|
||||
console.error("Error getting max mtime from database:", error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user