mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-05-13 19:11:02 +00:00
add local embed
This commit is contained in:
274
src/embedworker/EmbeddingManager.ts
Normal file
274
src/embedworker/EmbeddingManager.ts
Normal file
@@ -0,0 +1,274 @@
|
||||
// 导入完整的嵌入 Worker
|
||||
// @ts-nocheck
|
||||
import EmbedWorker from './embed.worker';
|
||||
|
||||
// 类型定义
|
||||
export interface EmbedResult {
|
||||
vec: number[];
|
||||
tokens: number;
|
||||
embed_input?: string;
|
||||
}
|
||||
|
||||
export interface ModelLoadResult {
|
||||
model_loaded: boolean;
|
||||
}
|
||||
|
||||
export interface ModelUnloadResult {
|
||||
model_unloaded: boolean;
|
||||
}
|
||||
|
||||
export interface TokenCountResult {
|
||||
tokens: number;
|
||||
}
|
||||
|
||||
export class EmbeddingManager {
|
||||
private worker: Worker;
|
||||
private requests = new Map<number, { resolve: (value: any) => void; reject: (reason?: any) => void }>();
|
||||
private nextRequestId = 0;
|
||||
private isModelLoaded = false;
|
||||
private currentModelId: string | null = null;
|
||||
|
||||
constructor() {
|
||||
// 创建 Worker,使用与 pgworker 相同的模式
|
||||
this.worker = new EmbedWorker();
|
||||
|
||||
// 统一监听来自 Worker 的所有消息
|
||||
this.worker.onmessage = (event) => {
|
||||
const { id, result, error } = event.data;
|
||||
|
||||
// 根据返回的 id 找到对应的 Promise 回调
|
||||
const request = this.requests.get(id);
|
||||
|
||||
if (request) {
|
||||
if (error) {
|
||||
request.reject(new Error(error));
|
||||
} else {
|
||||
request.resolve(result);
|
||||
}
|
||||
// 完成后从 Map 中删除
|
||||
this.requests.delete(id);
|
||||
}
|
||||
};
|
||||
|
||||
this.worker.onerror = (error) => {
|
||||
console.error("EmbeddingWorker error:", error);
|
||||
// 拒绝所有待处理的请求
|
||||
this.requests.forEach(request => {
|
||||
request.reject(error);
|
||||
});
|
||||
this.requests.clear();
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 向 Worker 发送一个请求,并返回一个 Promise,该 Promise 将在收到响应时解析。
|
||||
* @param method 要调用的方法 (e.g., 'load', 'embed_batch')
|
||||
* @param params 方法所需的参数
|
||||
*/
|
||||
private postRequest<T>(method: string, params: any): Promise<T> {
|
||||
return new Promise<T>((resolve, reject) => {
|
||||
const id = this.nextRequestId++;
|
||||
this.requests.set(id, { resolve, reject });
|
||||
this.worker.postMessage({ method, params, id });
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 加载指定的嵌入模型到 Worker 中。
|
||||
* @param modelId 模型ID, 例如 'TaylorAI/bge-micro-v2'
|
||||
* @param useGpu 是否使用GPU加速,默认为false
|
||||
*/
|
||||
public async loadModel(modelId: string, useGpu: boolean = false): Promise<ModelLoadResult> {
|
||||
console.log(`Loading embedding model: ${modelId}, GPU: ${useGpu}`);
|
||||
|
||||
try {
|
||||
// 如果已经加载了相同的模型,直接返回
|
||||
if (this.isModelLoaded && this.currentModelId === modelId) {
|
||||
console.log(`Model ${modelId} already loaded`);
|
||||
return { model_loaded: true };
|
||||
}
|
||||
|
||||
// 如果加载了不同的模型,先卸载
|
||||
if (this.isModelLoaded && this.currentModelId !== modelId) {
|
||||
console.log(`Unloading previous model: ${this.currentModelId}`);
|
||||
await this.unloadModel();
|
||||
}
|
||||
|
||||
const result = await this.postRequest<ModelLoadResult>('load', {
|
||||
model_key: modelId,
|
||||
use_gpu: useGpu
|
||||
});
|
||||
|
||||
this.isModelLoaded = result.model_loaded;
|
||||
this.currentModelId = result.model_loaded ? modelId : null;
|
||||
|
||||
if (result.model_loaded) {
|
||||
console.log(`Model ${modelId} loaded successfully`);
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
console.error(`Failed to load model ${modelId}:`, error);
|
||||
this.isModelLoaded = false;
|
||||
this.currentModelId = null;
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 为一批文本生成嵌入向量。
|
||||
* @param texts 要处理的文本数组
|
||||
* @returns 返回一个包含向量和 token 信息的对象数组
|
||||
*/
|
||||
public async embedBatch(texts: string[]): Promise<EmbedResult[]> {
|
||||
if (!this.isModelLoaded) {
|
||||
throw new Error('Model not loaded. Please call loadModel() first.');
|
||||
}
|
||||
|
||||
if (!texts || texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
console.log(`Generating embeddings for ${texts.length} texts`);
|
||||
|
||||
try {
|
||||
const inputs = texts.map(text => ({ embed_input: text }));
|
||||
const results = await this.postRequest<EmbedResult[]>('embed_batch', { inputs });
|
||||
|
||||
console.log(`Generated ${results.length} embeddings`);
|
||||
return results;
|
||||
} catch (error) {
|
||||
console.error('Failed to generate embeddings:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 为单个文本生成嵌入向量。
|
||||
* @param text 要处理的文本
|
||||
* @returns 返回包含向量和 token 信息的对象
|
||||
*/
|
||||
public async embed(text: string): Promise<EmbedResult> {
|
||||
if (!text || text.trim().length === 0) {
|
||||
throw new Error('Text cannot be empty');
|
||||
}
|
||||
|
||||
const results = await this.embedBatch([text]);
|
||||
if (results.length === 0) {
|
||||
throw new Error('Failed to generate embedding');
|
||||
}
|
||||
|
||||
return results[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算文本的 token 数量。
|
||||
* @param text 要计算的文本
|
||||
*/
|
||||
public async countTokens(text: string): Promise<TokenCountResult> {
|
||||
if (!this.isModelLoaded) {
|
||||
throw new Error('Model not loaded. Please call loadModel() first.');
|
||||
}
|
||||
|
||||
if (!text) {
|
||||
return { tokens: 0 };
|
||||
}
|
||||
|
||||
try {
|
||||
return await this.postRequest<TokenCountResult>('count_tokens', text);
|
||||
} catch (error) {
|
||||
console.error('Failed to count tokens:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 卸载模型,释放内存。
|
||||
*/
|
||||
public async unloadModel(): Promise<ModelUnloadResult> {
|
||||
if (!this.isModelLoaded) {
|
||||
console.log('No model to unload');
|
||||
return { model_unloaded: true };
|
||||
}
|
||||
|
||||
try {
|
||||
console.log(`Unloading model: ${this.currentModelId}`);
|
||||
const result = await this.postRequest<ModelUnloadResult>('unload', {});
|
||||
|
||||
this.isModelLoaded = false;
|
||||
this.currentModelId = null;
|
||||
|
||||
console.log('Model unloaded successfully');
|
||||
return result;
|
||||
} catch (error) {
|
||||
console.error('Failed to unload model:', error);
|
||||
// 即使卸载失败,也重置状态
|
||||
this.isModelLoaded = false;
|
||||
this.currentModelId = null;
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查模型是否已加载。
|
||||
*/
|
||||
public get modelLoaded(): boolean {
|
||||
return this.isModelLoaded;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前加载的模型ID。
|
||||
*/
|
||||
public get currentModel(): string | null {
|
||||
return this.currentModelId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取支持的模型列表。
|
||||
*/
|
||||
public getSupportedModels(): string[] {
|
||||
return [
|
||||
'Xenova/all-MiniLM-L6-v2',
|
||||
'Xenova/bge-small-en-v1.5',
|
||||
'Xenova/bge-base-en-v1.5',
|
||||
'Xenova/jina-embeddings-v2-base-zh',
|
||||
'Xenova/jina-embeddings-v2-small-en',
|
||||
'Xenova/multilingual-e5-small',
|
||||
'Xenova/multilingual-e5-base',
|
||||
'Xenova/gte-small',
|
||||
'Xenova/e5-small-v2',
|
||||
'Xenova/e5-base-v2'
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型信息。
|
||||
*/
|
||||
public getModelInfo(modelId: string): { dims: number; maxTokens: number; description: string } | null {
|
||||
const modelInfoMap: Record<string, { dims: number; maxTokens: number; description: string }> = {
|
||||
'Xenova/all-MiniLM-L6-v2': { dims: 384, maxTokens: 512, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' },
|
||||
'Xenova/bge-small-en-v1.5': { dims: 384, maxTokens: 512, description: 'BGE-small-en-v1.5' },
|
||||
'Xenova/bge-base-en-v1.5': { dims: 768, maxTokens: 512, description: 'BGE-base-en-v1.5 (更高质量)' },
|
||||
'Xenova/jina-embeddings-v2-base-zh': { dims: 768, maxTokens: 8192, description: 'Jina-v2-base-zh (中英双语)' },
|
||||
'Xenova/jina-embeddings-v2-small-en': { dims: 512, maxTokens: 8192, description: 'Jina-v2-small-en' },
|
||||
'Xenova/multilingual-e5-small': { dims: 384, maxTokens: 512, description: 'E5-small (多语言)' },
|
||||
'Xenova/multilingual-e5-base': { dims: 768, maxTokens: 512, description: 'E5-base (多语言,更高质量)' },
|
||||
'Xenova/gte-small': { dims: 384, maxTokens: 512, description: 'GTE-small' },
|
||||
'Xenova/e5-small-v2': { dims: 384, maxTokens: 512, description: 'E5-small-v2' },
|
||||
'Xenova/e5-base-v2': { dims: 768, maxTokens: 512, description: 'E5-base-v2 (更高质量)' }
|
||||
};
|
||||
|
||||
return modelInfoMap[modelId] || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 终止 Worker,释放资源。
|
||||
*/
|
||||
public terminate() {
|
||||
this.worker.terminate();
|
||||
this.requests.clear();
|
||||
this.isModelLoaded = false;
|
||||
}
|
||||
}
|
||||
171
src/embedworker/README.md
Normal file
171
src/embedworker/README.md
Normal file
@@ -0,0 +1,171 @@
|
||||
# 本地嵌入功能
|
||||
|
||||
这个模块提供了在 Web Worker 中运行的本地嵌入功能,使用 Transformers.js 库来生成文本的向量表示。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 🚀 **高性能**: 在 Web Worker 中运行,不阻塞主线程
|
||||
- 🔒 **隐私保护**: 完全本地运行,数据不离开设备
|
||||
- 🎯 **多模型支持**: 支持多种预训练的嵌入模型
|
||||
- 💾 **内存管理**: 自动管理模型加载和卸载
|
||||
- 🔧 **类型安全**: 完整的 TypeScript 类型支持
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 基本使用
|
||||
|
||||
```typescript
|
||||
import { embeddingManager } from './embedworker';
|
||||
|
||||
// 加载模型
|
||||
await embeddingManager.loadModel('Xenova/all-MiniLM-L6-v2');
|
||||
|
||||
// 生成单个文本的嵌入向量
|
||||
const result = await embeddingManager.embed('Hello, world!');
|
||||
console.log(result.vec); // [0.1234, -0.5678, ...]
|
||||
console.log(result.tokens); // 3
|
||||
|
||||
// 批量生成嵌入向量
|
||||
const texts = ['Hello', 'World', 'AI is amazing'];
|
||||
const results = await embeddingManager.embedBatch(texts);
|
||||
|
||||
// 计算 token 数量
|
||||
const tokenCount = await embeddingManager.countTokens('How many tokens?');
|
||||
console.log(tokenCount.tokens); // 4
|
||||
```
|
||||
|
||||
### 高级使用
|
||||
|
||||
```typescript
|
||||
import { EmbeddingManager } from './embedworker';
|
||||
|
||||
// 创建自定义实例
|
||||
const customEmbedding = new EmbeddingManager();
|
||||
|
||||
// 使用 GPU 加速(如果支持)
|
||||
await customEmbedding.loadModel('Xenova/all-MiniLM-L6-v2', true);
|
||||
|
||||
// 检查模型状态
|
||||
console.log(customEmbedding.modelLoaded); // true
|
||||
console.log(customEmbedding.currentModel); // 'TaylorAI/bge-micro-v2'
|
||||
|
||||
// 获取支持的模型列表
|
||||
const models = customEmbedding.getSupportedModels();
|
||||
console.log(models);
|
||||
|
||||
// 获取模型信息
|
||||
const modelInfo = customEmbedding.getModelInfo('Xenova/all-MiniLM-L6-v2');
|
||||
console.log(modelInfo); // { dims: 384, maxTokens: 512, description: '...' }
|
||||
|
||||
// 切换模型
|
||||
await customEmbedding.loadModel('Snowflake/snowflake-arctic-embed-xs');
|
||||
|
||||
// 清理资源
|
||||
await customEmbedding.unloadModel();
|
||||
customEmbedding.terminate();
|
||||
```
|
||||
|
||||
## 支持的模型
|
||||
|
||||
| 模型 | 维度 | 最大Token | 描述 |
|
||||
|------|------|-----------|------|
|
||||
| Xenova/all-MiniLM-L6-v2 | 384 | 512 | All-MiniLM-L6-v2 (推荐,轻量级) |
|
||||
| Xenova/bge-small-en-v1.5 | 384 | 512 | BGE-small-en-v1.5 |
|
||||
| Xenova/bge-base-en-v1.5 | 768 | 512 | BGE-base-en-v1.5 (更高质量) |
|
||||
| Xenova/jina-embeddings-v2-base-zh | 768 | 8192 | Jina-v2-base-zh (中英双语) |
|
||||
| Xenova/jina-embeddings-v2-small-en | 512 | 8192 | Jina-v2-small-en |
|
||||
| Xenova/multilingual-e5-small | 384 | 512 | E5-small (多语言) |
|
||||
| Xenova/multilingual-e5-base | 768 | 512 | E5-base (多语言,更高质量) |
|
||||
| Xenova/gte-small | 384 | 512 | GTE-small |
|
||||
| Xenova/e5-small-v2 | 384 | 512 | E5-small-v2 |
|
||||
| Xenova/e5-base-v2 | 768 | 512 | E5-base-v2 (更高质量) |
|
||||
|
||||
## API 参考
|
||||
|
||||
### EmbeddingManager
|
||||
|
||||
#### 方法
|
||||
|
||||
- `loadModel(modelId: string, useGpu?: boolean): Promise<ModelLoadResult>`
|
||||
- 加载指定的嵌入模型
|
||||
- `modelId`: 模型标识符
|
||||
- `useGpu`: 是否使用 GPU 加速(默认 false)
|
||||
|
||||
- `embed(text: string): Promise<EmbedResult>`
|
||||
- 为单个文本生成嵌入向量
|
||||
- 返回包含向量和 token 数量的结果
|
||||
|
||||
- `embedBatch(texts: string[]): Promise<EmbedResult[]>`
|
||||
- 为多个文本批量生成嵌入向量
|
||||
- 更高效的批处理方式
|
||||
|
||||
- `countTokens(text: string): Promise<TokenCountResult>`
|
||||
- 计算文本的 token 数量
|
||||
|
||||
- `unloadModel(): Promise<ModelUnloadResult>`
|
||||
- 卸载当前模型,释放内存
|
||||
|
||||
- `terminate(): void`
|
||||
- 终止 Worker,释放所有资源
|
||||
|
||||
#### 属性
|
||||
|
||||
- `modelLoaded: boolean` - 模型是否已加载
|
||||
- `currentModel: string | null` - 当前加载的模型ID
|
||||
|
||||
#### 工具方法
|
||||
|
||||
- `getSupportedModels(): string[]` - 获取支持的模型列表
|
||||
- `getModelInfo(modelId: string)` - 获取模型详细信息
|
||||
|
||||
### 类型定义
|
||||
|
||||
```typescript
|
||||
interface EmbedResult {
|
||||
vec: number[]; // 嵌入向量
|
||||
tokens: number; // token 数量
|
||||
embed_input?: string; // 原始输入文本
|
||||
}
|
||||
|
||||
interface ModelLoadResult {
|
||||
model_loaded: boolean; // 是否加载成功
|
||||
}
|
||||
|
||||
interface ModelUnloadResult {
|
||||
model_unloaded: boolean; // 是否卸载成功
|
||||
}
|
||||
|
||||
interface TokenCountResult {
|
||||
tokens: number; // token 数量
|
||||
}
|
||||
```
|
||||
|
||||
## 错误处理
|
||||
|
||||
```typescript
|
||||
try {
|
||||
await embeddingManager.loadModel('invalid-model');
|
||||
} catch (error) {
|
||||
console.error('加载模型失败:', error.message);
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await embeddingManager.embed('');
|
||||
} catch (error) {
|
||||
console.error('文本不能为空:', error.message);
|
||||
}
|
||||
```
|
||||
|
||||
## 性能考虑
|
||||
|
||||
1. **模型加载**: 首次加载模型需要下载和初始化,可能需要几秒到几分钟
|
||||
2. **批处理**: 使用 `embedBatch` 比多次调用 `embed` 更高效
|
||||
3. **内存使用**: 大模型需要更多内存,注意设备限制
|
||||
4. **GPU 加速**: 在支持 WebGPU 的浏览器中可以启用 GPU 加速
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 首次使用某个模型时需要从 Hugging Face 下载,请确保网络连接正常
|
||||
- 模型文件会被浏览器缓存,后续使用会更快
|
||||
- 在移动设备上使用大模型可能会遇到内存限制
|
||||
- Worker 在后台运行,不会阻塞 UI 线程
|
||||
353
src/embedworker/embed.worker.ts
Normal file
353
src/embedworker/embed.worker.ts
Normal file
@@ -0,0 +1,353 @@
|
||||
// 完整的嵌入 Worker,使用 Transformers.js
|
||||
console.log('Embedding worker loaded');
|
||||
|
||||
// 类型定义
|
||||
interface EmbedInput {
|
||||
embed_input: string;
|
||||
}
|
||||
|
||||
interface EmbedResult {
|
||||
vec: number[];
|
||||
tokens: number;
|
||||
embed_input?: string;
|
||||
}
|
||||
|
||||
interface WorkerMessage {
|
||||
method: string;
|
||||
params: any;
|
||||
id: number;
|
||||
worker_id?: string;
|
||||
}
|
||||
|
||||
interface WorkerResponse {
|
||||
id: number;
|
||||
result?: any;
|
||||
error?: string;
|
||||
worker_id?: string;
|
||||
}
|
||||
|
||||
// 全局变量
|
||||
let model: any = null;
|
||||
let pipeline: any = null;
|
||||
let tokenizer: any = null;
|
||||
let processing_message = false;
|
||||
let transformersLoaded = false;
|
||||
|
||||
// 动态导入 Transformers.js
|
||||
async function loadTransformers() {
|
||||
if (transformersLoaded) return;
|
||||
|
||||
try {
|
||||
console.log('Loading Transformers.js...');
|
||||
|
||||
// 尝试使用旧版本的 Transformers.js,它在 Worker 中更稳定
|
||||
const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers');
|
||||
|
||||
// 配置环境以适应浏览器 Worker
|
||||
env.allowLocalModels = false;
|
||||
env.allowRemoteModels = true;
|
||||
|
||||
// 配置 WASM 后端
|
||||
env.backends.onnx.wasm.numThreads = 2; // 在 Worker 中使用单线程
|
||||
env.backends.onnx.wasm.simd = true;
|
||||
|
||||
// 禁用 Node.js 特定功能
|
||||
env.useFS = false;
|
||||
env.useBrowserCache = true;
|
||||
|
||||
// 存储导入的函数
|
||||
(globalThis as any).pipelineFactory = pipelineFactory;
|
||||
(globalThis as any).AutoTokenizer = AutoTokenizer;
|
||||
(globalThis as any).env = env;
|
||||
|
||||
transformersLoaded = true;
|
||||
console.log('Transformers.js loaded successfully');
|
||||
} catch (error) {
|
||||
console.error('Failed to load Transformers.js:', error);
|
||||
throw new Error(`Failed to load Transformers.js: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 加载模型
|
||||
async function loadModel(modelKey: string, useGpu: boolean = false) {
|
||||
try {
|
||||
console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`);
|
||||
|
||||
// 确保 Transformers.js 已加载
|
||||
await loadTransformers();
|
||||
|
||||
const pipelineFactory = (globalThis as any).pipelineFactory;
|
||||
const AutoTokenizer = (globalThis as any).AutoTokenizer;
|
||||
const env = (globalThis as any).env;
|
||||
|
||||
// 配置管道选项
|
||||
const pipelineOpts: any = {
|
||||
quantized: true,
|
||||
progress_callback: (progress: any) => {
|
||||
console.log('Model loading progress:', progress);
|
||||
}
|
||||
};
|
||||
|
||||
if (useGpu && typeof navigator !== 'undefined' && 'gpu' in navigator) {
|
||||
console.log('[Transformers] Attempting to use GPU');
|
||||
try {
|
||||
pipelineOpts.device = 'webgpu';
|
||||
pipelineOpts.dtype = 'fp32';
|
||||
} catch (error) {
|
||||
console.warn('[Transformers] GPU not available, falling back to CPU');
|
||||
}
|
||||
} else {
|
||||
console.log('[Transformers] Using CPU');
|
||||
}
|
||||
|
||||
// 创建嵌入管道
|
||||
pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts);
|
||||
|
||||
// 创建分词器
|
||||
tokenizer = await AutoTokenizer.from_pretrained(modelKey);
|
||||
|
||||
model = {
|
||||
loaded: true,
|
||||
model_key: modelKey,
|
||||
use_gpu: useGpu
|
||||
};
|
||||
|
||||
console.log(`Model ${modelKey} loaded successfully`);
|
||||
return { model_loaded: true };
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error loading model:', error);
|
||||
throw new Error(`Failed to load model: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 卸载模型
|
||||
async function unloadModel() {
|
||||
try {
|
||||
console.log('Unloading model...');
|
||||
|
||||
if (pipeline) {
|
||||
if (pipeline.destroy) {
|
||||
pipeline.destroy();
|
||||
}
|
||||
pipeline = null;
|
||||
}
|
||||
|
||||
if (tokenizer) {
|
||||
tokenizer = null;
|
||||
}
|
||||
|
||||
model = null;
|
||||
|
||||
console.log('Model unloaded successfully');
|
||||
return { model_unloaded: true };
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error unloading model:', error);
|
||||
throw new Error(`Failed to unload model: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 计算 token 数量
|
||||
async function countTokens(input: string) {
|
||||
try {
|
||||
if (!tokenizer) {
|
||||
throw new Error('Tokenizer not loaded');
|
||||
}
|
||||
|
||||
const { input_ids } = await tokenizer(input);
|
||||
return { tokens: input_ids.data.length };
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error counting tokens:', error);
|
||||
throw new Error(`Failed to count tokens: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 生成嵌入向量
|
||||
async function embedBatch(inputs: EmbedInput[]): Promise<EmbedResult[]> {
|
||||
try {
|
||||
if (!pipeline || !tokenizer) {
|
||||
throw new Error('Model not loaded');
|
||||
}
|
||||
|
||||
console.log(`Processing ${inputs.length} inputs`);
|
||||
|
||||
// 过滤空输入
|
||||
const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0);
|
||||
|
||||
if (filteredInputs.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
// 批处理大小(可以根据需要调整)
|
||||
const batchSize = 1;
|
||||
|
||||
if (filteredInputs.length > batchSize) {
|
||||
console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`);
|
||||
const results: EmbedResult[] = [];
|
||||
|
||||
for (let i = 0; i < filteredInputs.length; i += batchSize) {
|
||||
const batch = filteredInputs.slice(i, i + batchSize);
|
||||
const batchResults = await processBatch(batch);
|
||||
results.push(...batchResults);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
return await processBatch(filteredInputs);
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error in embed batch:', error);
|
||||
throw new Error(`Failed to generate embeddings: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 处理单个批次
|
||||
async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> {
|
||||
try {
|
||||
// 计算每个输入的 token 数量
|
||||
const tokens = await Promise.all(
|
||||
batchInputs.map(item => countTokens(item.embed_input))
|
||||
);
|
||||
|
||||
// 准备嵌入输入(处理超长文本)
|
||||
const maxTokens = 512; // 大多数模型的最大 token 限制
|
||||
const embedInputs = await Promise.all(
|
||||
batchInputs.map(async (item, i) => {
|
||||
if (tokens[i].tokens < maxTokens) {
|
||||
return item.embed_input;
|
||||
}
|
||||
|
||||
// 截断超长文本
|
||||
let tokenCt = tokens[i].tokens;
|
||||
let truncatedInput = item.embed_input;
|
||||
|
||||
while (tokenCt > maxTokens) {
|
||||
const pct = maxTokens / tokenCt;
|
||||
const maxChars = Math.floor(truncatedInput.length * pct * 0.9);
|
||||
truncatedInput = truncatedInput.substring(0, maxChars) + '...';
|
||||
tokenCt = (await countTokens(truncatedInput)).tokens;
|
||||
}
|
||||
|
||||
tokens[i].tokens = tokenCt;
|
||||
return truncatedInput;
|
||||
})
|
||||
);
|
||||
|
||||
// 生成嵌入向量
|
||||
const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true });
|
||||
|
||||
// 处理结果
|
||||
return batchInputs.map((item, i) => ({
|
||||
vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8),
|
||||
tokens: tokens[i].tokens,
|
||||
embed_input: item.embed_input
|
||||
}));
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error processing batch:', error);
|
||||
|
||||
// 如果批处理失败,尝试逐个处理
|
||||
return Promise.all(
|
||||
batchInputs.map(async (item) => {
|
||||
try {
|
||||
const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true });
|
||||
const tokenCount = await countTokens(item.embed_input);
|
||||
|
||||
return {
|
||||
vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8),
|
||||
tokens: tokenCount.tokens,
|
||||
embed_input: item.embed_input
|
||||
};
|
||||
} catch (singleError) {
|
||||
console.error('Error processing single item:', singleError);
|
||||
return {
|
||||
vec: [],
|
||||
tokens: 0,
|
||||
embed_input: item.embed_input,
|
||||
error: (singleError as Error).message
|
||||
} as any;
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
|
||||
const { method, params, id, worker_id } = data;
|
||||
|
||||
try {
|
||||
let result: any;
|
||||
|
||||
switch (method) {
|
||||
case 'load':
|
||||
console.log('Load method called with params:', params);
|
||||
result = await loadModel(params.model_key, params.use_gpu || false);
|
||||
break;
|
||||
|
||||
case 'unload':
|
||||
console.log('Unload method called');
|
||||
result = await unloadModel();
|
||||
break;
|
||||
|
||||
case 'embed_batch':
|
||||
console.log('Embed batch method called');
|
||||
if (!model) {
|
||||
throw new Error('Model not loaded');
|
||||
}
|
||||
|
||||
// 等待之前的处理完成
|
||||
if (processing_message) {
|
||||
while (processing_message) {
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
}
|
||||
}
|
||||
|
||||
processing_message = true;
|
||||
result = await embedBatch(params.inputs);
|
||||
processing_message = false;
|
||||
break;
|
||||
|
||||
case 'count_tokens':
|
||||
console.log('Count tokens method called');
|
||||
if (!model) {
|
||||
throw new Error('Model not loaded');
|
||||
}
|
||||
|
||||
// 等待之前的处理完成
|
||||
if (processing_message) {
|
||||
while (processing_message) {
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
}
|
||||
}
|
||||
|
||||
processing_message = true;
|
||||
result = await countTokens(params);
|
||||
processing_message = false;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new Error(`Unknown method: ${method}`);
|
||||
}
|
||||
|
||||
return { id, result, worker_id };
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error processing message:', error);
|
||||
processing_message = false;
|
||||
return { id, error: (error as Error).message, worker_id };
|
||||
}
|
||||
}
|
||||
|
||||
// 监听消息
|
||||
self.addEventListener('message', async (event) => {
|
||||
console.log('Worker received message:', event.data);
|
||||
const response = await processMessage(event.data);
|
||||
console.log('Worker sending response:', response);
|
||||
self.postMessage(response);
|
||||
});
|
||||
|
||||
console.log('Embedding worker ready');
|
||||
15
src/embedworker/index.ts
Normal file
15
src/embedworker/index.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import { EmbeddingManager } from "./EmbeddingManager";
|
||||
|
||||
// 创建一个单例的 Manager,以便在整个应用中共享同一个 Worker
|
||||
export const embeddingManager = new EmbeddingManager();
|
||||
|
||||
// 导出 EmbeddingManager 类以便其他地方使用
|
||||
export { EmbeddingManager };
|
||||
|
||||
// 导出类型定义
|
||||
export type {
|
||||
EmbedResult,
|
||||
ModelLoadResult,
|
||||
ModelUnloadResult,
|
||||
TokenCountResult
|
||||
} from './EmbeddingManager';
|
||||
Reference in New Issue
Block a user