Optimize the search view component, add model selection functionality, support multiple search modes (notes, insights, all), update internationalization support, improve user interaction prompts, enhance log output, and ensure better user experience and code readability.

This commit is contained in:
duanfuxiang
2025-07-07 16:56:12 +08:00
parent 3db334c6e8
commit c89186a40d
11 changed files with 1393 additions and 593 deletions

View File

@@ -8,36 +8,153 @@ interface EmbedResult {
vec: number[];
tokens: number;
embed_input?: string;
error?: string;
}
// 定义工作器消息的参数类型
interface LoadParams {
model_key: string;
use_gpu?: boolean;
}
interface EmbedBatchParams {
inputs: EmbedInput[];
}
type WorkerParams = LoadParams | EmbedBatchParams | string | undefined;
interface WorkerMessage {
method: string;
params: any;
params: WorkerParams;
id: number;
worker_id?: string;
}
interface WorkerResponse {
id: number;
result?: any;
result?: unknown;
error?: string;
worker_id?: string;
}
// 定义 Transformers.js 相关类型
interface TransformersEnv {
allowLocalModels: boolean;
allowRemoteModels: boolean;
backends: {
onnx: {
wasm: {
numThreads: number;
simd: boolean;
};
};
};
useFS: boolean;
useBrowserCache: boolean;
remoteHost?: string;
}
interface PipelineOptions {
quantized?: boolean;
progress_callback?: (progress: unknown) => void;
device?: string;
dtype?: string;
}
interface ModelInfo {
loaded: boolean;
model_key: string;
use_gpu: boolean;
}
interface TokenizerResult {
input_ids: {
data: number[];
};
}
interface GlobalTransformers {
pipelineFactory: (task: string, model: string, options?: PipelineOptions) => Promise<unknown>;
AutoTokenizer: {
from_pretrained: (model: string) => Promise<unknown>;
};
env: TransformersEnv;
}
// 全局变量
let model: any = null;
let pipeline: any = null;
let tokenizer: any = null;
let model: ModelInfo | null = null;
let pipeline: unknown = null;
let tokenizer: unknown = null;
let processing_message = false;
let transformersLoaded = false;
/**
* 测试一个网络端点是否可访问
* @param {string} url 要测试的 URL
* @param {number} timeout 超时时间 (毫秒)
* @returns {Promise<boolean>} 如果可访问则返回 true否则返回 false
*/
async function testEndpoint(url: string, timeout = 3000): Promise<boolean> {
// AbortController 用于在超时后取消 fetch 请求
const controller = new AbortController();
const signal = controller.signal;
const timeoutId = setTimeout(() => {
console.log(`请求 ${url} 超时。`);
controller.abort();
}, timeout);
try {
console.log(`正在测试端点: ${url}`);
// 我们使用 'HEAD' 方法,因为它只请求头部信息,非常快速,适合做存活检测。
// 'no-cors' 模式允许我们在浏览器环境中进行跨域请求以进行简单的可达性测试,
// 即使我们不能读取响应内容,请求成功也意味着网络是通的。
await fetch(url, { method: 'HEAD', mode: 'no-cors', signal });
// 如果 fetch 成功,清除超时定时器并返回 true
clearTimeout(timeoutId);
console.log(`端点 ${url} 可访问。`);
return true;
} catch (error) {
// 如果发生网络错误或请求被中止 (超时),则进入 catch 块
clearTimeout(timeoutId); // 同样需要清除定时器
console.warn(`无法访问端点 ${url}:`, error instanceof Error && error.name === 'AbortError' ? '超时' : (error as Error).message);
return false;
}
}
/**
* 初始化 Hugging Face 端点,如果默认的不可用,则自动切换到备用镜像。
*/
async function initializeEndpoint(): Promise<void> {
const defaultEndpoint = 'https://huggingface.co';
const fallbackEndpoint = 'https://hf-mirror.com';
const isDefaultReachable = await testEndpoint(defaultEndpoint);
const globalTransformers = globalThis as unknown as { transformers?: GlobalTransformers };
if (!isDefaultReachable) {
console.log(`默认端点不可达,将切换到备用镜像: ${fallbackEndpoint}`);
// 这是关键步骤:在代码中设置 endpoint
if (globalTransformers.transformers?.env) {
globalTransformers.transformers.env.remoteHost = fallbackEndpoint;
}
} else {
console.log(`将使用默认端点: ${defaultEndpoint}`);
}
}
// 动态导入 Transformers.js
async function loadTransformers() {
async function loadTransformers(): Promise<void> {
if (transformersLoaded) return;
try {
console.log('Loading Transformers.js...');
// 首先初始化端点
await initializeEndpoint();
// 尝试使用旧版本的 Transformers.js它在 Worker 中更稳定
const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers');
@@ -53,9 +170,12 @@ async function loadTransformers() {
env.useFS = false;
env.useBrowserCache = true;
(globalThis as any).pipelineFactory = pipelineFactory;
(globalThis as any).AutoTokenizer = AutoTokenizer;
(globalThis as any).env = env;
const globalTransformers = globalThis as unknown as { transformers?: GlobalTransformers };
globalTransformers.transformers = {
pipelineFactory,
AutoTokenizer,
env
};
transformersLoaded = true;
console.log('Transformers.js loaded successfully');
@@ -65,22 +185,27 @@ async function loadTransformers() {
}
}
async function loadModel(modelKey: string, useGpu: boolean = false) {
async function loadModel(modelKey: string, useGpu: boolean = false): Promise<{ model_loaded: boolean }> {
try {
console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`);
// 确保 Transformers.js 已加载
await loadTransformers();
const pipelineFactory = (globalThis as any).pipelineFactory;
const AutoTokenizer = (globalThis as any).AutoTokenizer;
const env = (globalThis as any).env;
const globalTransformers = globalThis as unknown as { transformers?: GlobalTransformers };
const transformers = globalTransformers.transformers;
if (!transformers) {
throw new Error('Transformers.js not loaded');
}
const { pipelineFactory, AutoTokenizer } = transformers;
// 配置管道选项
const pipelineOpts: any = {
const pipelineOpts: PipelineOptions = {
quantized: true,
// 修复进度回调,添加错误处理
progress_callback: (progress: any) => {
progress_callback: (progress: unknown) => {
try {
if (progress && typeof progress === 'object') {
// console.log('Model loading progress:', progress);
@@ -96,9 +221,9 @@ async function loadModel(modelKey: string, useGpu: boolean = false) {
if (useGpu) {
try {
// 检查 WebGPU 支持
console.log("useGpu", useGpu)
console.log("useGpu", useGpu);
if (typeof navigator !== 'undefined' && 'gpu' in navigator) {
const gpu = (navigator as any).gpu;
const gpu = (navigator as { gpu?: { requestAdapter?: () => unknown } }).gpu;
if (gpu && typeof gpu.requestAdapter === 'function') {
console.log('[Transformers] Attempting to use GPU');
pipelineOpts.device = 'webgpu';
@@ -137,21 +262,17 @@ async function loadModel(modelKey: string, useGpu: boolean = false) {
}
}
async function unloadModel() {
async function unloadModel(): Promise<{ model_unloaded: boolean }> {
try {
console.log('Unloading model...');
if (pipeline) {
if (pipeline.destroy) {
pipeline.destroy();
}
pipeline = null;
}
if (tokenizer) {
tokenizer = null;
if (pipeline && typeof pipeline === 'object' && 'destroy' in pipeline) {
const pipelineWithDestroy = pipeline as { destroy: () => void };
pipelineWithDestroy.destroy();
}
pipeline = null;
tokenizer = null;
model = null;
console.log('Model unloaded successfully');
@@ -163,13 +284,14 @@ async function unloadModel() {
}
}
async function countTokens(input: string) {
async function countTokens(input: string): Promise<{ tokens: number }> {
try {
if (!tokenizer) {
throw new Error('Tokenizer not loaded');
}
const { input_ids } = await tokenizer(input);
const tokenizerWithCall = tokenizer as (input: string) => Promise<TokenizerResult>;
const { input_ids } = await tokenizerWithCall(input);
return { tokens: input_ids.data.length };
} catch (error) {
@@ -249,7 +371,8 @@ async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> {
);
// 生成嵌入向量
const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true });
const pipelineCall = pipeline as (inputs: string[], options: { pooling: string; normalize: boolean }) => Promise<{ data: number[] }[]>;
const resp = await pipelineCall(embedInputs, { pooling: 'mean', normalize: true });
// 处理结果
return batchInputs.map((item, i) => ({
@@ -262,10 +385,11 @@ async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> {
console.error('Error processing batch:', error);
// 如果批处理失败,尝试逐个处理
return Promise.all(
batchInputs.map(async (item) => {
const results = await Promise.all(
batchInputs.map(async (item): Promise<EmbedResult> => {
try {
const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true });
const pipelineCall = pipeline as (input: string, options: { pooling: string; normalize: boolean }) => Promise<{ data: number[] }[]>;
const result = await pipelineCall(item.embed_input, { pooling: 'mean', normalize: true });
const tokenCount = await countTokens(item.embed_input);
return {
@@ -279,11 +403,13 @@ async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> {
vec: [],
tokens: 0,
embed_input: item.embed_input,
error: (singleError as Error).message
} as any;
error: singleError instanceof Error ? singleError.message : 'Unknown error'
};
}
})
);
return results;
}
}
@@ -291,12 +417,13 @@ async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
const { method, params, id, worker_id } = data;
try {
let result: any;
let result: unknown;
switch (method) {
case 'load':
console.log('Load method called with params:', params);
result = await loadModel(params.model_key, params.use_gpu || false);
const loadParams = params as LoadParams;
result = await loadModel(loadParams.model_key, loadParams.use_gpu || false);
break;
case 'unload':
@@ -318,7 +445,8 @@ async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
}
processing_message = true;
result = await embedBatch(params.inputs);
const embedParams = params as EmbedBatchParams;
result = await embedBatch(embedParams.inputs);
processing_message = false;
break;
@@ -336,7 +464,8 @@ async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
}
processing_message = true;
result = await countTokens(params);
const tokenParams = params as string;
result = await countTokens(tokenParams);
processing_message = false;
break;
@@ -349,7 +478,7 @@ async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
} catch (error) {
console.error('Error processing message:', error);
processing_message = false;
return { id, error: (error as Error).message, worker_id };
return { id, error: error instanceof Error ? error.message : 'Unknown error', worker_id };
}
}
@@ -367,14 +496,14 @@ self.addEventListener('message', async (event) => {
return;
}
const response = await processMessage(event.data);
const response = await processMessage(event.data as WorkerMessage);
console.log('Worker sending response:', response);
self.postMessage(response);
} catch (error) {
console.error('Unhandled error in worker message handler:', error);
self.postMessage({
id: event.data?.id || -1,
error: `Worker error: ${error.message || 'Unknown error'}`
id: (event.data as { id?: number })?.id || -1,
error: `Worker error: ${error instanceof Error ? error.message : 'Unknown error'}`
});
}
});