add local embed

This commit is contained in:
duanfuxiang
2025-07-04 09:28:12 +08:00
parent cd65d6b3de
commit 65c5df3d22
22 changed files with 2156 additions and 195 deletions

View File

@@ -0,0 +1,274 @@
// 导入完整的嵌入 Worker
// @ts-nocheck
import EmbedWorker from './embed.worker';
// 类型定义
export interface EmbedResult {
vec: number[];
tokens: number;
embed_input?: string;
}
export interface ModelLoadResult {
model_loaded: boolean;
}
export interface ModelUnloadResult {
model_unloaded: boolean;
}
export interface TokenCountResult {
tokens: number;
}
export class EmbeddingManager {
private worker: Worker;
private requests = new Map<number, { resolve: (value: any) => void; reject: (reason?: any) => void }>();
private nextRequestId = 0;
private isModelLoaded = false;
private currentModelId: string | null = null;
constructor() {
// 创建 Worker使用与 pgworker 相同的模式
this.worker = new EmbedWorker();
// 统一监听来自 Worker 的所有消息
this.worker.onmessage = (event) => {
const { id, result, error } = event.data;
// 根据返回的 id 找到对应的 Promise 回调
const request = this.requests.get(id);
if (request) {
if (error) {
request.reject(new Error(error));
} else {
request.resolve(result);
}
// 完成后从 Map 中删除
this.requests.delete(id);
}
};
this.worker.onerror = (error) => {
console.error("EmbeddingWorker error:", error);
// 拒绝所有待处理的请求
this.requests.forEach(request => {
request.reject(error);
});
this.requests.clear();
};
}
/**
* 向 Worker 发送一个请求,并返回一个 Promise该 Promise 将在收到响应时解析。
* @param method 要调用的方法 (e.g., 'load', 'embed_batch')
* @param params 方法所需的参数
*/
private postRequest<T>(method: string, params: any): Promise<T> {
return new Promise<T>((resolve, reject) => {
const id = this.nextRequestId++;
this.requests.set(id, { resolve, reject });
this.worker.postMessage({ method, params, id });
});
}
/**
* 加载指定的嵌入模型到 Worker 中。
* @param modelId 模型ID, 例如 'TaylorAI/bge-micro-v2'
* @param useGpu 是否使用GPU加速默认为false
*/
public async loadModel(modelId: string, useGpu: boolean = false): Promise<ModelLoadResult> {
console.log(`Loading embedding model: ${modelId}, GPU: ${useGpu}`);
try {
// 如果已经加载了相同的模型,直接返回
if (this.isModelLoaded && this.currentModelId === modelId) {
console.log(`Model ${modelId} already loaded`);
return { model_loaded: true };
}
// 如果加载了不同的模型,先卸载
if (this.isModelLoaded && this.currentModelId !== modelId) {
console.log(`Unloading previous model: ${this.currentModelId}`);
await this.unloadModel();
}
const result = await this.postRequest<ModelLoadResult>('load', {
model_key: modelId,
use_gpu: useGpu
});
this.isModelLoaded = result.model_loaded;
this.currentModelId = result.model_loaded ? modelId : null;
if (result.model_loaded) {
console.log(`Model ${modelId} loaded successfully`);
}
return result;
} catch (error) {
console.error(`Failed to load model ${modelId}:`, error);
this.isModelLoaded = false;
this.currentModelId = null;
throw error;
}
}
/**
* 为一批文本生成嵌入向量。
* @param texts 要处理的文本数组
* @returns 返回一个包含向量和 token 信息的对象数组
*/
public async embedBatch(texts: string[]): Promise<EmbedResult[]> {
if (!this.isModelLoaded) {
throw new Error('Model not loaded. Please call loadModel() first.');
}
if (!texts || texts.length === 0) {
return [];
}
console.log(`Generating embeddings for ${texts.length} texts`);
try {
const inputs = texts.map(text => ({ embed_input: text }));
const results = await this.postRequest<EmbedResult[]>('embed_batch', { inputs });
console.log(`Generated ${results.length} embeddings`);
return results;
} catch (error) {
console.error('Failed to generate embeddings:', error);
throw error;
}
}
/**
* 为单个文本生成嵌入向量。
* @param text 要处理的文本
* @returns 返回包含向量和 token 信息的对象
*/
public async embed(text: string): Promise<EmbedResult> {
if (!text || text.trim().length === 0) {
throw new Error('Text cannot be empty');
}
const results = await this.embedBatch([text]);
if (results.length === 0) {
throw new Error('Failed to generate embedding');
}
return results[0];
}
/**
* 计算文本的 token 数量。
* @param text 要计算的文本
*/
public async countTokens(text: string): Promise<TokenCountResult> {
if (!this.isModelLoaded) {
throw new Error('Model not loaded. Please call loadModel() first.');
}
if (!text) {
return { tokens: 0 };
}
try {
return await this.postRequest<TokenCountResult>('count_tokens', text);
} catch (error) {
console.error('Failed to count tokens:', error);
throw error;
}
}
/**
* 卸载模型,释放内存。
*/
public async unloadModel(): Promise<ModelUnloadResult> {
if (!this.isModelLoaded) {
console.log('No model to unload');
return { model_unloaded: true };
}
try {
console.log(`Unloading model: ${this.currentModelId}`);
const result = await this.postRequest<ModelUnloadResult>('unload', {});
this.isModelLoaded = false;
this.currentModelId = null;
console.log('Model unloaded successfully');
return result;
} catch (error) {
console.error('Failed to unload model:', error);
// 即使卸载失败,也重置状态
this.isModelLoaded = false;
this.currentModelId = null;
throw error;
}
}
/**
* 检查模型是否已加载。
*/
public get modelLoaded(): boolean {
return this.isModelLoaded;
}
/**
* 获取当前加载的模型ID。
*/
public get currentModel(): string | null {
return this.currentModelId;
}
/**
* 获取支持的模型列表。
*/
public getSupportedModels(): string[] {
return [
'Xenova/all-MiniLM-L6-v2',
'Xenova/bge-small-en-v1.5',
'Xenova/bge-base-en-v1.5',
'Xenova/jina-embeddings-v2-base-zh',
'Xenova/jina-embeddings-v2-small-en',
'Xenova/multilingual-e5-small',
'Xenova/multilingual-e5-base',
'Xenova/gte-small',
'Xenova/e5-small-v2',
'Xenova/e5-base-v2'
];
}
/**
* 获取模型信息。
*/
public getModelInfo(modelId: string): { dims: number; maxTokens: number; description: string } | null {
const modelInfoMap: Record<string, { dims: number; maxTokens: number; description: string }> = {
'Xenova/all-MiniLM-L6-v2': { dims: 384, maxTokens: 512, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' },
'Xenova/bge-small-en-v1.5': { dims: 384, maxTokens: 512, description: 'BGE-small-en-v1.5' },
'Xenova/bge-base-en-v1.5': { dims: 768, maxTokens: 512, description: 'BGE-base-en-v1.5 (更高质量)' },
'Xenova/jina-embeddings-v2-base-zh': { dims: 768, maxTokens: 8192, description: 'Jina-v2-base-zh (中英双语)' },
'Xenova/jina-embeddings-v2-small-en': { dims: 512, maxTokens: 8192, description: 'Jina-v2-small-en' },
'Xenova/multilingual-e5-small': { dims: 384, maxTokens: 512, description: 'E5-small (多语言)' },
'Xenova/multilingual-e5-base': { dims: 768, maxTokens: 512, description: 'E5-base (多语言,更高质量)' },
'Xenova/gte-small': { dims: 384, maxTokens: 512, description: 'GTE-small' },
'Xenova/e5-small-v2': { dims: 384, maxTokens: 512, description: 'E5-small-v2' },
'Xenova/e5-base-v2': { dims: 768, maxTokens: 512, description: 'E5-base-v2 (更高质量)' }
};
return modelInfoMap[modelId] || null;
}
/**
* 终止 Worker释放资源。
*/
public terminate() {
this.worker.terminate();
this.requests.clear();
this.isModelLoaded = false;
}
}

171
src/embedworker/README.md Normal file
View File

@@ -0,0 +1,171 @@
# 本地嵌入功能
这个模块提供了在 Web Worker 中运行的本地嵌入功能,使用 Transformers.js 库来生成文本的向量表示。
## 功能特性
- 🚀 **高性能**: 在 Web Worker 中运行,不阻塞主线程
- 🔒 **隐私保护**: 完全本地运行,数据不离开设备
- 🎯 **多模型支持**: 支持多种预训练的嵌入模型
- 💾 **内存管理**: 自动管理模型加载和卸载
- 🔧 **类型安全**: 完整的 TypeScript 类型支持
## 快速开始
### 基本使用
```typescript
import { embeddingManager } from './embedworker';
// 加载模型
await embeddingManager.loadModel('Xenova/all-MiniLM-L6-v2');
// 生成单个文本的嵌入向量
const result = await embeddingManager.embed('Hello, world!');
console.log(result.vec); // [0.1234, -0.5678, ...]
console.log(result.tokens); // 3
// 批量生成嵌入向量
const texts = ['Hello', 'World', 'AI is amazing'];
const results = await embeddingManager.embedBatch(texts);
// 计算 token 数量
const tokenCount = await embeddingManager.countTokens('How many tokens?');
console.log(tokenCount.tokens); // 4
```
### 高级使用
```typescript
import { EmbeddingManager } from './embedworker';
// 创建自定义实例
const customEmbedding = new EmbeddingManager();
// 使用 GPU 加速(如果支持)
await customEmbedding.loadModel('Xenova/all-MiniLM-L6-v2', true);
// 检查模型状态
console.log(customEmbedding.modelLoaded); // true
console.log(customEmbedding.currentModel); // 'TaylorAI/bge-micro-v2'
// 获取支持的模型列表
const models = customEmbedding.getSupportedModels();
console.log(models);
// 获取模型信息
const modelInfo = customEmbedding.getModelInfo('Xenova/all-MiniLM-L6-v2');
console.log(modelInfo); // { dims: 384, maxTokens: 512, description: '...' }
// 切换模型
await customEmbedding.loadModel('Snowflake/snowflake-arctic-embed-xs');
// 清理资源
await customEmbedding.unloadModel();
customEmbedding.terminate();
```
## 支持的模型
| 模型 | 维度 | 最大Token | 描述 |
|------|------|-----------|------|
| Xenova/all-MiniLM-L6-v2 | 384 | 512 | All-MiniLM-L6-v2 (推荐,轻量级) |
| Xenova/bge-small-en-v1.5 | 384 | 512 | BGE-small-en-v1.5 |
| Xenova/bge-base-en-v1.5 | 768 | 512 | BGE-base-en-v1.5 (更高质量) |
| Xenova/jina-embeddings-v2-base-zh | 768 | 8192 | Jina-v2-base-zh (中英双语) |
| Xenova/jina-embeddings-v2-small-en | 512 | 8192 | Jina-v2-small-en |
| Xenova/multilingual-e5-small | 384 | 512 | E5-small (多语言) |
| Xenova/multilingual-e5-base | 768 | 512 | E5-base (多语言,更高质量) |
| Xenova/gte-small | 384 | 512 | GTE-small |
| Xenova/e5-small-v2 | 384 | 512 | E5-small-v2 |
| Xenova/e5-base-v2 | 768 | 512 | E5-base-v2 (更高质量) |
## API 参考
### EmbeddingManager
#### 方法
- `loadModel(modelId: string, useGpu?: boolean): Promise<ModelLoadResult>`
- 加载指定的嵌入模型
- `modelId`: 模型标识符
- `useGpu`: 是否使用 GPU 加速(默认 false
- `embed(text: string): Promise<EmbedResult>`
- 为单个文本生成嵌入向量
- 返回包含向量和 token 数量的结果
- `embedBatch(texts: string[]): Promise<EmbedResult[]>`
- 为多个文本批量生成嵌入向量
- 更高效的批处理方式
- `countTokens(text: string): Promise<TokenCountResult>`
- 计算文本的 token 数量
- `unloadModel(): Promise<ModelUnloadResult>`
- 卸载当前模型,释放内存
- `terminate(): void`
- 终止 Worker释放所有资源
#### 属性
- `modelLoaded: boolean` - 模型是否已加载
- `currentModel: string | null` - 当前加载的模型ID
#### 工具方法
- `getSupportedModels(): string[]` - 获取支持的模型列表
- `getModelInfo(modelId: string)` - 获取模型详细信息
### 类型定义
```typescript
interface EmbedResult {
vec: number[]; // 嵌入向量
tokens: number; // token 数量
embed_input?: string; // 原始输入文本
}
interface ModelLoadResult {
model_loaded: boolean; // 是否加载成功
}
interface ModelUnloadResult {
model_unloaded: boolean; // 是否卸载成功
}
interface TokenCountResult {
tokens: number; // token 数量
}
```
## 错误处理
```typescript
try {
await embeddingManager.loadModel('invalid-model');
} catch (error) {
console.error('加载模型失败:', error.message);
}
try {
const result = await embeddingManager.embed('');
} catch (error) {
console.error('文本不能为空:', error.message);
}
```
## 性能考虑
1. **模型加载**: 首次加载模型需要下载和初始化,可能需要几秒到几分钟
2. **批处理**: 使用 `embedBatch` 比多次调用 `embed` 更高效
3. **内存使用**: 大模型需要更多内存,注意设备限制
4. **GPU 加速**: 在支持 WebGPU 的浏览器中可以启用 GPU 加速
## 注意事项
- 首次使用某个模型时需要从 Hugging Face 下载,请确保网络连接正常
- 模型文件会被浏览器缓存,后续使用会更快
- 在移动设备上使用大模型可能会遇到内存限制
- Worker 在后台运行,不会阻塞 UI 线程

View File

@@ -0,0 +1,353 @@
// 完整的嵌入 Worker使用 Transformers.js
console.log('Embedding worker loaded');
// 类型定义
interface EmbedInput {
embed_input: string;
}
interface EmbedResult {
vec: number[];
tokens: number;
embed_input?: string;
}
interface WorkerMessage {
method: string;
params: any;
id: number;
worker_id?: string;
}
interface WorkerResponse {
id: number;
result?: any;
error?: string;
worker_id?: string;
}
// 全局变量
let model: any = null;
let pipeline: any = null;
let tokenizer: any = null;
let processing_message = false;
let transformersLoaded = false;
// 动态导入 Transformers.js
async function loadTransformers() {
if (transformersLoaded) return;
try {
console.log('Loading Transformers.js...');
// 尝试使用旧版本的 Transformers.js它在 Worker 中更稳定
const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers');
// 配置环境以适应浏览器 Worker
env.allowLocalModels = false;
env.allowRemoteModels = true;
// 配置 WASM 后端
env.backends.onnx.wasm.numThreads = 2; // 在 Worker 中使用单线程
env.backends.onnx.wasm.simd = true;
// 禁用 Node.js 特定功能
env.useFS = false;
env.useBrowserCache = true;
// 存储导入的函数
(globalThis as any).pipelineFactory = pipelineFactory;
(globalThis as any).AutoTokenizer = AutoTokenizer;
(globalThis as any).env = env;
transformersLoaded = true;
console.log('Transformers.js loaded successfully');
} catch (error) {
console.error('Failed to load Transformers.js:', error);
throw new Error(`Failed to load Transformers.js: ${error}`);
}
}
// 加载模型
async function loadModel(modelKey: string, useGpu: boolean = false) {
try {
console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`);
// 确保 Transformers.js 已加载
await loadTransformers();
const pipelineFactory = (globalThis as any).pipelineFactory;
const AutoTokenizer = (globalThis as any).AutoTokenizer;
const env = (globalThis as any).env;
// 配置管道选项
const pipelineOpts: any = {
quantized: true,
progress_callback: (progress: any) => {
console.log('Model loading progress:', progress);
}
};
if (useGpu && typeof navigator !== 'undefined' && 'gpu' in navigator) {
console.log('[Transformers] Attempting to use GPU');
try {
pipelineOpts.device = 'webgpu';
pipelineOpts.dtype = 'fp32';
} catch (error) {
console.warn('[Transformers] GPU not available, falling back to CPU');
}
} else {
console.log('[Transformers] Using CPU');
}
// 创建嵌入管道
pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts);
// 创建分词器
tokenizer = await AutoTokenizer.from_pretrained(modelKey);
model = {
loaded: true,
model_key: modelKey,
use_gpu: useGpu
};
console.log(`Model ${modelKey} loaded successfully`);
return { model_loaded: true };
} catch (error) {
console.error('Error loading model:', error);
throw new Error(`Failed to load model: ${error}`);
}
}
// 卸载模型
async function unloadModel() {
try {
console.log('Unloading model...');
if (pipeline) {
if (pipeline.destroy) {
pipeline.destroy();
}
pipeline = null;
}
if (tokenizer) {
tokenizer = null;
}
model = null;
console.log('Model unloaded successfully');
return { model_unloaded: true };
} catch (error) {
console.error('Error unloading model:', error);
throw new Error(`Failed to unload model: ${error}`);
}
}
// 计算 token 数量
async function countTokens(input: string) {
try {
if (!tokenizer) {
throw new Error('Tokenizer not loaded');
}
const { input_ids } = await tokenizer(input);
return { tokens: input_ids.data.length };
} catch (error) {
console.error('Error counting tokens:', error);
throw new Error(`Failed to count tokens: ${error}`);
}
}
// 生成嵌入向量
async function embedBatch(inputs: EmbedInput[]): Promise<EmbedResult[]> {
try {
if (!pipeline || !tokenizer) {
throw new Error('Model not loaded');
}
console.log(`Processing ${inputs.length} inputs`);
// 过滤空输入
const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0);
if (filteredInputs.length === 0) {
return [];
}
// 批处理大小(可以根据需要调整)
const batchSize = 1;
if (filteredInputs.length > batchSize) {
console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`);
const results: EmbedResult[] = [];
for (let i = 0; i < filteredInputs.length; i += batchSize) {
const batch = filteredInputs.slice(i, i + batchSize);
const batchResults = await processBatch(batch);
results.push(...batchResults);
}
return results;
}
return await processBatch(filteredInputs);
} catch (error) {
console.error('Error in embed batch:', error);
throw new Error(`Failed to generate embeddings: ${error}`);
}
}
// 处理单个批次
async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> {
try {
// 计算每个输入的 token 数量
const tokens = await Promise.all(
batchInputs.map(item => countTokens(item.embed_input))
);
// 准备嵌入输入(处理超长文本)
const maxTokens = 512; // 大多数模型的最大 token 限制
const embedInputs = await Promise.all(
batchInputs.map(async (item, i) => {
if (tokens[i].tokens < maxTokens) {
return item.embed_input;
}
// 截断超长文本
let tokenCt = tokens[i].tokens;
let truncatedInput = item.embed_input;
while (tokenCt > maxTokens) {
const pct = maxTokens / tokenCt;
const maxChars = Math.floor(truncatedInput.length * pct * 0.9);
truncatedInput = truncatedInput.substring(0, maxChars) + '...';
tokenCt = (await countTokens(truncatedInput)).tokens;
}
tokens[i].tokens = tokenCt;
return truncatedInput;
})
);
// 生成嵌入向量
const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true });
// 处理结果
return batchInputs.map((item, i) => ({
vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8),
tokens: tokens[i].tokens,
embed_input: item.embed_input
}));
} catch (error) {
console.error('Error processing batch:', error);
// 如果批处理失败,尝试逐个处理
return Promise.all(
batchInputs.map(async (item) => {
try {
const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true });
const tokenCount = await countTokens(item.embed_input);
return {
vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8),
tokens: tokenCount.tokens,
embed_input: item.embed_input
};
} catch (singleError) {
console.error('Error processing single item:', singleError);
return {
vec: [],
tokens: 0,
embed_input: item.embed_input,
error: (singleError as Error).message
} as any;
}
})
);
}
}
// 处理消息
async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
const { method, params, id, worker_id } = data;
try {
let result: any;
switch (method) {
case 'load':
console.log('Load method called with params:', params);
result = await loadModel(params.model_key, params.use_gpu || false);
break;
case 'unload':
console.log('Unload method called');
result = await unloadModel();
break;
case 'embed_batch':
console.log('Embed batch method called');
if (!model) {
throw new Error('Model not loaded');
}
// 等待之前的处理完成
if (processing_message) {
while (processing_message) {
await new Promise(resolve => setTimeout(resolve, 100));
}
}
processing_message = true;
result = await embedBatch(params.inputs);
processing_message = false;
break;
case 'count_tokens':
console.log('Count tokens method called');
if (!model) {
throw new Error('Model not loaded');
}
// 等待之前的处理完成
if (processing_message) {
while (processing_message) {
await new Promise(resolve => setTimeout(resolve, 100));
}
}
processing_message = true;
result = await countTokens(params);
processing_message = false;
break;
default:
throw new Error(`Unknown method: ${method}`);
}
return { id, result, worker_id };
} catch (error) {
console.error('Error processing message:', error);
processing_message = false;
return { id, error: (error as Error).message, worker_id };
}
}
// 监听消息
self.addEventListener('message', async (event) => {
console.log('Worker received message:', event.data);
const response = await processMessage(event.data);
console.log('Worker sending response:', response);
self.postMessage(response);
});
console.log('Embedding worker ready');

15
src/embedworker/index.ts Normal file
View File

@@ -0,0 +1,15 @@
import { EmbeddingManager } from "./EmbeddingManager";
// 创建一个单例的 Manager以便在整个应用中共享同一个 Worker
export const embeddingManager = new EmbeddingManager();
// 导出 EmbeddingManager 类以便其他地方使用
export { EmbeddingManager };
// 导出类型定义
export type {
EmbedResult,
ModelLoadResult,
ModelUnloadResult,
TokenCountResult
} from './EmbeddingManager';