update local pdf tools

This commit is contained in:
duanfuxiang
2025-06-07 17:13:02 +08:00
parent 8e5a1c75f6
commit 2b571f67a7
8 changed files with 668 additions and 36 deletions

View File

@@ -576,7 +576,7 @@ const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
if (!opFile) {
throw new Error(`File not found: ${toolArgs.filepath}`)
}
const fileContent = await readTFileContent(opFile, app.vault)
const fileContent = await readTFileContent(opFile, app.vault, app)
const formattedContent = `[read_file for '${toolArgs.filepath}'] Result:\n${addLineNumbers(fileContent)}\n`;
return {
type: 'read_file',

View File

@@ -1,22 +1,62 @@
import * as path from 'path'
import { App, Editor, MarkdownView, TFile, TFolder, Vault, WorkspaceLeaf } from 'obsidian'
import { App, Editor, MarkdownView, TFile, TFolder, Vault, WorkspaceLeaf, loadPdfJs } from 'obsidian'
import { MentionableBlockData } from '../types/mentionable'
export async function parsePdfContent(file: TFile, app: App): Promise<string> {
try {
// 使用 Obsidian 内置的 PDF.js
const pdfjsLib = await loadPdfJs()
// Read PDF file as binary buffer
const pdfBuffer = await app.vault.readBinary(file)
// 使用 Obsidian 内置的 PDF.js 处理 PDF
const loadingTask = pdfjsLib.getDocument({ data: pdfBuffer })
const doc = await loadingTask.promise
let fullText = ''
for (let pageNum = 1; pageNum <= doc.numPages; pageNum++) {
const page = await doc.getPage(pageNum)
const textContent = await page.getTextContent()
const pageText = textContent.items
.map((item: any) => item.str)
.join(' ')
fullText += pageText + '\n\n'
}
return fullText || '(Empty PDF content)'
} catch (error: any) {
console.error('Error parsing PDF:', error)
return `(Error reading PDF file: ${error?.message || 'Unknown error'})`
}
}
export async function readTFileContent(
file: TFile,
vault: Vault,
app?: App,
): Promise<string> {
if (file.extension === 'pdf') {
if (app) {
return await parsePdfContent(file, app)
}
return "(PDF file, app context required for processing)"
}
if (file.extension != 'md') {
return "(Binary file, unable to display content)"
}
return await vault.cachedRead(file)
}
export async function readMultipleTFiles(
files: TFile[],
vault: Vault,
app?: App,
): Promise<string[]> {
// Read files in parallel
const readPromises = files.map((file) => readTFileContent(file, vault))
const readPromises = files.map((file) => readTFileContent(file, vault, app))
return await Promise.all(readPromises)
}

View File

@@ -7,7 +7,7 @@ import { McpHub } from '../core/mcp/McpHub'
import { SystemPrompt } from '../core/prompts/system'
import { RAGEngine } from '../core/rag/rag-engine'
import { SelectVector } from '../database/schema'
import { ChatAssistantMessage, ChatMessage, ChatUserMessage } from '../types/chat'
import { ChatMessage, ChatUserMessage } from '../types/chat'
import { ContentPart, RequestMessage } from '../types/llm/request'
import {
MentionableBlock,
@@ -21,10 +21,14 @@ import { InfioSettings } from '../types/settings'
import { CustomModePrompts, Mode, ModeConfig, getFullModeDetails } from "../utils/modes"
import {
readTFileContent
readTFileContent,
readMultipleTFiles,
getNestedFiles,
parsePdfContent
} from './obsidian'
import { tokenCount } from './token'
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript'
import { isVideoUrl, isYoutubeUrl } from './video-detector'
import { YoutubeTranscript } from './youtube-transcript'
export function addLineNumbers(content: string, startLine: number = 1): string {
const lines = content.split("\n")
@@ -66,13 +70,20 @@ async function getFolderTreeContent(path: TFolder): Promise<string> {
}
}
async function getFileOrFolderContent(path: TAbstractFile, vault: Vault): Promise<string> {
async function getFileOrFolderContent(path: TAbstractFile, vault: Vault, app?: App): Promise<string> {
try {
if (path instanceof TFile) {
if (path.extension === 'pdf') {
// Handle PDF files without line numbers
if (app) {
return await parsePdfContent(path, app)
}
return "(PDF file, app context required for processing)"
}
if (path.extension != 'md') {
return "(Binary file, unable to display content)"
}
return addLineNumbers(await readTFileContent(path, vault))
return addLineNumbers(await readTFileContent(path, vault, app))
} else if (path instanceof TFolder) {
const entries = path.children
let folderContent = ""
@@ -85,10 +96,18 @@ async function getFileOrFolderContent(path: TAbstractFile, vault: Vault): Promis
fileContentPromises.push(
(async () => {
try {
if (entry.extension === 'pdf') {
// Handle PDF files in folders
if (app) {
const content = await parsePdfContent(entry, app)
return `<file_content path="${entry.path}">\n${content}\n</file_content>`
}
return `<file_content path="${entry.path}">\n(PDF file, app context required for processing)\n</file_content>`
}
if (entry.extension != 'md') {
return undefined
}
const content = addLineNumbers(await readTFileContent(entry, vault))
const content = addLineNumbers(await readTFileContent(entry, vault, app))
return `<file_content path="${entry.path}">\n${content}\n</file_content>`
} catch (error) {
return undefined
@@ -196,18 +215,18 @@ export class PromptGenerator {
...compiledMessages.slice(-19)
.filter((message) => !(message.role === 'assistant' && message.isToolResult))
.map((message): RequestMessage => {
if (message.role === 'user') {
return {
role: 'user',
content: message.promptContent ?? '',
if (message.role === 'user') {
return {
role: 'user',
content: message.promptContent ?? '',
}
} else {
return {
role: 'assistant',
content: message.content,
}
}
} else {
return {
role: 'assistant',
content: message.content,
}
}
}),
}),
]
return {
@@ -336,7 +355,7 @@ export class PromptGenerator {
.map((m) => m.file)
let fileContentsPrompts = files.length > 0
? (await Promise.all(files.map(async (file) => {
const content = await getFileOrFolderContent(file, this.app.vault)
const content = await getFileOrFolderContent(file, this.app.vault, this.app)
return `<file_content path="${file.path}">\n${content}\n</file_content>`
}))).join('\n')
: undefined
@@ -347,7 +366,7 @@ export class PromptGenerator {
.map((m) => m.folder)
let folderContentsPrompts = folders.length > 0
? (await Promise.all(folders.map(async (folder) => {
const content = await getFileOrFolderContent(folder, this.app.vault)
const content = await getFileOrFolderContent(folder, this.app.vault, this.app)
return `<folder_content path="${folder.path}">\n${content}\n</folder_content>`
}))).join('\n')
: undefined
@@ -387,7 +406,7 @@ export class PromptGenerator {
.filter((m): m is MentionableFile => m.type === 'current-file')
.first()
const currentFileContent = currentFile && currentFile.file != null
? await getFileOrFolderContent(currentFile.file, this.app.vault)
? await getFileOrFolderContent(currentFile.file, this.app.vault, this.app)
: undefined
// Check if current file content should be included
@@ -647,7 +666,7 @@ ${customInstruction}
private async getCurrentFileMessage(
currentFile: TFile,
): Promise<RequestMessage> {
const fileContent = await readTFileContent(currentFile, this.app.vault)
const fileContent = await readTFileContent(currentFile, this.app.vault, this.app)
return {
role: 'user',
content: `# Inputs
@@ -669,7 +688,7 @@ ${fileContent}
return null;
}
const fileContent = await readTFileContent(currentFile, this.app.vault);
const fileContent = await readTFileContent(currentFile, this.app.vault, this.app);
const lines = fileContent.split('\n');
// 计算上下文范围,并处理边界情况
@@ -743,6 +762,12 @@ When writing out new markdown blocks, remember not to include "line_number|" at
return linesWithNumbers.join('\n')
}
private async getPdfContent(file: TFile): Promise<string> {
return await parsePdfContent(file, this.app)
}
/**
* TODO: Improve markdown conversion logic
* - filter visually hidden elements
@@ -763,4 +788,26 @@ ${transcript.map((t) => `${t.offset}: ${t.text}`).join('\n')}`
return htmlToMarkdown(response.text)
}
private async callMcpToolGetWebsiteContent(url: string, mcpHub: McpHub | null): Promise<string> {
if (isVideoUrl(url)) {
return this.callMcpToolConvertVideo(url, mcpHub)
}
return this.callMcpToolFetchUrlContent(url, mcpHub)
}
private async callMcpToolConvertVideo(url: string, mcpHub: McpHub | null): Promise<string> {
// TODO: implement
return ''
}
private async callMcpToolFetchUrlContent(url: string, mcpHub: McpHub | null): Promise<string> {
// TODO: implement
return ''
}
private async callMcpToolConvertDocument(file: TFile, mcpHub: McpHub | null): Promise<string> {
// TODO: implement
return ''
}
}

View File

@@ -0,0 +1,107 @@
import {
extractVideoId,
getSupportedVideoProviders,
getVideoProvider,
isBilibiliUrl,
isTikTokUrl,
isVideoUrl,
isVimeoUrl
} from './video-detector'
describe('video-detector', () => {
describe('isVideoUrl', () => {
it('should correctly identify YouTube URLs', () => {
expect(isVideoUrl('https://www.youtube.com/watch?v=dQw4w9WgXcQ')).toBe(true)
expect(isVideoUrl('https://youtu.be/dQw4w9WgXcQ')).toBe(true)
})
it('should correctly identify Bilibili URLs', () => {
expect(isVideoUrl('https://www.bilibili.com/video/BV1GJ411x7h7')).toBe(true)
expect(isVideoUrl('https://b23.tv/BV1GJ411x7h7')).toBe(true)
})
it('should correctly identify Vimeo URLs', () => {
expect(isVideoUrl('https://vimeo.com/123456789')).toBe(true)
})
it('should correctly identify TikTok URLs', () => {
expect(isVideoUrl('https://www.tiktok.com/@username/video/1234567890')).toBe(true)
expect(isVideoUrl('https://vm.tiktok.com/ZMeABCDEF/')).toBe(true)
})
it('should correctly identify video file URLs', () => {
expect(isVideoUrl('https://example.com/video.mp4')).toBe(true)
expect(isVideoUrl('https://example.com/movie.avi?t=123')).toBe(true)
expect(isVideoUrl('https://example.com/clip.webm')).toBe(true)
})
it('should correctly reject non-video URLs', () => {
expect(isVideoUrl('https://www.google.com')).toBe(false)
expect(isVideoUrl('https://github.com/user/repo')).toBe(false)
expect(isVideoUrl('https://docs.google.com/document/123')).toBe(false)
})
})
describe('getVideoProvider', () => {
it('should correctly identify YouTube provider', () => {
expect(getVideoProvider('https://www.youtube.com/watch?v=dQw4w9WgXcQ')).toBe('youtube')
expect(getVideoProvider('https://youtu.be/dQw4w9WgXcQ')).toBe('youtube')
})
it('should correctly identify Bilibili provider', () => {
expect(getVideoProvider('https://www.bilibili.com/video/BV1GJ411x7h7')).toBe('bilibili')
})
it('should correctly identify Vimeo provider', () => {
expect(getVideoProvider('https://vimeo.com/123456789')).toBe('vimeo')
})
it('should return null for non-video URLs', () => {
expect(getVideoProvider('https://www.google.com')).toBeNull()
expect(getVideoProvider('https://github.com/user/repo')).toBeNull()
})
})
describe('extractVideoId', () => {
it('should extract YouTube video IDs', () => {
expect(extractVideoId('https://www.youtube.com/watch?v=dQw4w9WgXcQ')).toBe('dQw4w9WgXcQ')
expect(extractVideoId('https://youtu.be/dQw4w9WgXcQ')).toBe('dQw4w9WgXcQ')
})
it('should extract Vimeo video IDs', () => {
expect(extractVideoId('https://vimeo.com/123456789')).toBe('123456789')
})
it('should return null for non-video URLs', () => {
expect(extractVideoId('https://www.google.com')).toBeNull()
})
})
describe('platform-specific detectors', () => {
it('should correctly detect Bilibili URLs', () => {
expect(isBilibiliUrl('https://www.bilibili.com/video/BV1GJ411x7h7')).toBe(true)
expect(isBilibiliUrl('https://www.youtube.com/watch?v=123')).toBe(false)
})
it('should correctly detect Vimeo URLs', () => {
expect(isVimeoUrl('https://vimeo.com/123456789')).toBe(true)
expect(isVimeoUrl('https://www.youtube.com/watch?v=123')).toBe(false)
})
it('should correctly detect TikTok URLs', () => {
expect(isTikTokUrl('https://www.tiktok.com/@user/video/123')).toBe(true)
expect(isTikTokUrl('https://www.youtube.com/watch?v=123')).toBe(false)
})
})
describe('getSupportedVideoProviders', () => {
it('should return an array of supported providers', () => {
const providers = getSupportedVideoProviders()
expect(Array.isArray(providers)).toBe(true)
expect(providers.length).toBeGreaterThan(0)
expect(providers).toContain('youtube')
expect(providers).toContain('bilibili')
expect(providers).toContain('vimeo')
})
})
})

142
src/utils/video-detector.ts Normal file
View File

@@ -0,0 +1,142 @@
/**
* 视频平台URL检测工具
* 支持多种主流视频平台的URL识别
*/
// 各种视频平台的正则表达式
const VIDEO_PATTERNS = {
// YouTube
youtube: /(?:youtube\.com\/(?:[^/]+\/.+\/|(?:v|e(?:mbed)?)\/|.*[?&]v=)|youtu\.be\/)([^"&?/\s]{11})/i,
// Bilibili
bilibili: /(?:bilibili\.com\/video\/|b23\.tv\/)[A-Za-z0-9]+/i,
// Vimeo
vimeo: /(?:vimeo\.com\/)([0-9]+)/i,
// Dailymotion
dailymotion: /(?:dailymotion\.com\/video\/|dai\.ly\/)([A-Za-z0-9]+)/i,
// TikTok
tiktok: /(?:tiktok\.com\/@[^/]+\/video\/|vm\.tiktok\.com\/)[A-Za-z0-9]+/i,
// Twitch
twitch: /(?:twitch\.tv\/videos\/|clips\.twitch\.tv\/)[A-Za-z0-9]+/i,
// 腾讯视频
tencent: /(?:v\.qq\.com\/x\/cover\/|v\.qq\.com\/x\/page\/)[A-Za-z0-9]+/i,
// 爱奇艺
iqiyi: /(?:iqiyi\.com\/v_)[A-Za-z0-9]+/i,
// 优酷
youku: /(?:youku\.com\/v_show\/id_)[A-Za-z0-9]+/i,
// Facebook/Meta
facebook: /(?:facebook\.com\/watch\/|fb\.watch\/)[A-Za-z0-9]+/i,
// Instagram
instagram: /(?:instagram\.com\/(?:p|reel)\/)[A-Za-z0-9_-]+/i,
// Twitter/X
twitter: /(?:twitter\.com\/[^/]+\/status\/|x\.com\/[^/]+\/status\/)[0-9]+/i,
// 抖音
douyin: /(?:douyin\.com\/video\/)[0-9]+/i,
// 快手
kuaishou: /(?:kuaishou\.com\/short-video\/)[A-Za-z0-9]+/i,
// 小红书
xiaohongshu: /(?:xiaohongshu\.com\/explore\/)[A-Za-z0-9]+/i,
// 微博视频
weibo: /(?:weibo\.com\/[^/]+\/[A-Za-z0-9]+|weibo\.cn\/sinaurl)/i,
// Rumble
rumble: /(?:rumble\.com\/)[A-Za-z0-9_-]+/i,
// Odysee
odysee: /(?:odysee\.com\/@[^/]+\/)[A-Za-z0-9_-]+/i,
// JW Player (通用嵌入式播放器)
jwplayer: /(?:jwplayer\.com\/players\/)[A-Za-z0-9_-]+/i,
// 通用视频文件扩展名
videoFile: /\.(mp4|avi|mov|wmv|flv|webm|mkv|m4v|3gp|ogv)(\?.*)?$/i,
// 通用视频流媒体
streaming: /(?:stream|live|video|watch|play).*\.(m3u8|mpd|f4m)(\?.*)?$/i
}
export type VideoProvider = keyof typeof VIDEO_PATTERNS
/**
* 检测URL是否为视频内容
* @param url 要检测的URL
* @returns 是否为视频URL
*/
export function isVideoUrl(url: string): boolean {
return Object.values(VIDEO_PATTERNS).some(pattern => pattern.test(url))
}
/**
* 检测URL属于哪个视频平台
* @param url 要检测的URL
* @returns 视频平台名称如果不是视频URL则返回null
*/
export function getVideoProvider(url: string): VideoProvider | null {
for (const [provider, pattern] of Object.entries(VIDEO_PATTERNS)) {
if (pattern.test(url)) {
return provider as VideoProvider
}
}
return null
}
/**
* 检测特定平台的视频URL
* @param url 要检测的URL
* @param provider 视频平台
* @returns 是否为指定平台的视频URL
*/
export function isVideoUrlFromProvider(url: string, provider: VideoProvider): boolean {
const pattern = VIDEO_PATTERNS[provider]
return pattern ? pattern.test(url) : false
}
/**
* 从URL中提取视频ID如果可能
* @param url 视频URL
* @returns 视频ID或null
*/
export function extractVideoId(url: string): string | null {
const provider = getVideoProvider(url)
if (!provider) return null
const pattern = VIDEO_PATTERNS[provider]
const match = url.match(pattern)
// 返回第一个捕获组(如果存在)
return match && match[1] ? match[1] : null
}
/**
* 获取支持的视频平台列表
* @returns 支持的视频平台名称数组
*/
export function getSupportedVideoProviders(): VideoProvider[] {
return Object.keys(VIDEO_PATTERNS) as VideoProvider[]
}
// 为了向后兼容保留原有的YouTube检测函数
export function isYoutubeUrl(url: string): boolean {
return isVideoUrlFromProvider(url, 'youtube')
}
// 导出常用的视频平台检测函数
export const isYouTubeUrl = isYoutubeUrl // 别名
export const isBilibiliUrl = (url: string) => isVideoUrlFromProvider(url, 'bilibili')
export const isVimeoUrl = (url: string) => isVideoUrlFromProvider(url, 'vimeo')
export const isTikTokUrl = (url: string) => isVideoUrlFromProvider(url, 'tiktok')
export const isTwitchUrl = (url: string) => isVideoUrlFromProvider(url, 'twitch')

View File

@@ -5,6 +5,7 @@ import { htmlToMarkdown, requestUrl } from 'obsidian';
import { JINA_BASE_URL, SERPER_BASE_URL } from '../constants';
import { RAGEngine } from '../core/rag/rag-engine';
import { isVideoUrl, getVideoProvider } from './video-detector';
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript';
@@ -172,18 +173,37 @@ async function filterByEmbedding(query: string, results: SearchResult[], ragEngi
}
async function fetchByLocalTool(url: string): Promise<string> {
if (isYoutubeUrl(url)) {
// TODO: pass language based on user preferences
const { title, transcript } =
await YoutubeTranscript.fetchTranscriptAndMetadata(url)
// 检查是否为视频内容
if (isVideoUrl(url)) {
const provider = getVideoProvider(url)
// 对于YouTube使用现有的转录功能
if (provider === 'youtube') {
try {
// TODO: pass language based on user preferences
const { title, transcript } =
await YoutubeTranscript.fetchTranscriptAndMetadata(url)
return `Title: ${title}
return `Title: ${title}
Video Transcript:
${transcript.map((t) => `${t.offset}: ${t.text}`).join('\n')}`
} catch (error) {
console.warn('Failed to extract YouTube transcript:', error)
// 如果转录失败,返回视频信息提示
return `Video Content Detected: ${url}
Platform: YouTube
Note: This is a video content. Transcript extraction failed. Please use specialized video processing tools for content analysis.`
}
}
// 对于其他视频平台,返回视频信息提示
return `Video Content Detected: ${url}
Platform: ${provider || 'Unknown'}
Note: This is a video content. Please use specialized video processing tools for content analysis.`
}
// 非视频内容,使用常规方式获取网页内容
const response = await requestUrl({ url })
return htmlToMarkdown(response.text)
}
@@ -236,7 +256,8 @@ async function fetchByJina(url: string, apiKey: string): Promise<string> {
export async function fetchUrlContent(url: string, apiKey: string): Promise<string | null> {
try {
if (isYoutubeUrl(url)) {
// 如果是视频内容,直接使用本地工具处理
if (isVideoUrl(url)) {
return await fetchByLocalTool(url);
}
let content: string | null = null;