feat: 模型数据管理
feat: 模型数据导入 feat: redis 向量入库 feat: 向量索引 feat: 文件导入模型 perf: 交互 perf: prompt
This commit is contained in:
@@ -1,29 +1,26 @@
|
||||
import { DataItem } from '@/service/mongo';
|
||||
import { SplitData, ModelData } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
|
||||
import type { ChatCompletionRequestMessage } from 'openai';
|
||||
import { DataItemSchema } from '@/types/mongoSchema';
|
||||
import { ChatModelNameEnum } from '@/constants/model';
|
||||
import { pushSplitDataBill } from '@/service/events/pushBill';
|
||||
import { generateVector } from './generateVector';
|
||||
import { customAlphabet } from 'nanoid';
|
||||
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
|
||||
|
||||
export async function generateQA(next = false): Promise<any> {
|
||||
if (process.env.NODE_ENV === 'development') return;
|
||||
|
||||
if (global.generatingQA && !next) return;
|
||||
global.generatingQA = true;
|
||||
|
||||
const systemPrompt: ChatCompletionRequestMessage = {
|
||||
role: 'system',
|
||||
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,请按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
|
||||
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,并按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
|
||||
};
|
||||
let dataItem: DataItemSchema | null = null;
|
||||
|
||||
try {
|
||||
// 找出一个需要生成的 dataItem
|
||||
dataItem = await DataItem.findOne({
|
||||
status: { $ne: 0 },
|
||||
times: { $gt: 0 },
|
||||
type: 'QA'
|
||||
const dataItem = await SplitData.findOne({
|
||||
textList: { $exists: true, $ne: [] }
|
||||
});
|
||||
|
||||
if (!dataItem) {
|
||||
@@ -32,10 +29,13 @@ export async function generateQA(next = false): Promise<any> {
|
||||
return;
|
||||
}
|
||||
|
||||
// 更新状态为生成中
|
||||
await DataItem.findByIdAndUpdate(dataItem._id, {
|
||||
status: 2
|
||||
});
|
||||
// 弹出文本
|
||||
await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } });
|
||||
|
||||
const text = dataItem.textList[dataItem.textList.length - 1];
|
||||
if (!text) {
|
||||
throw new Error('无文本');
|
||||
}
|
||||
|
||||
// 获取 openapi Key
|
||||
let userApiKey, systemKey;
|
||||
@@ -44,10 +44,10 @@ export async function generateQA(next = false): Promise<any> {
|
||||
userApiKey = key.userApiKey;
|
||||
systemKey = key.systemKey;
|
||||
} catch (error) {
|
||||
// 余额不够了, 把用户所有记录改成闲置
|
||||
await DataItem.updateMany({
|
||||
userId: dataItem.userId,
|
||||
status: 0
|
||||
// 余额不够了, 清空该记录
|
||||
await SplitData.findByIdAndUpdate(dataItem._id, {
|
||||
textList: [],
|
||||
errorText: '余额不足,生成数据集任务终止'
|
||||
});
|
||||
throw new Error('获取 openai key 失败');
|
||||
}
|
||||
@@ -59,84 +59,71 @@ export async function generateQA(next = false): Promise<any> {
|
||||
// 获取 openai 请求实例
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
||||
// 请求 chatgpt 获取回答
|
||||
const response = await Promise.allSettled(
|
||||
[0.2, 0.8].map(
|
||||
(temperature) =>
|
||||
chatAPI
|
||||
.createChatCompletion(
|
||||
{
|
||||
model: ChatModelNameEnum.GPT35,
|
||||
temperature: temperature,
|
||||
n: 1,
|
||||
messages: [
|
||||
systemPrompt,
|
||||
{
|
||||
role: 'user',
|
||||
content: dataItem?.text || ''
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
timeout: 120000,
|
||||
httpsAgent
|
||||
}
|
||||
)
|
||||
.then((res) => ({
|
||||
rawContent: res?.data.choices[0].message?.content || '',
|
||||
result: splitText(res?.data.choices[0].message?.content || '')
|
||||
})) // 从 content 中提取 QA
|
||||
)
|
||||
);
|
||||
// 过滤出成功的响应
|
||||
const successResponse: {
|
||||
rawContent: string;
|
||||
result: { q: string; a: string }[];
|
||||
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
|
||||
|
||||
const rawContents = successResponse.map((item) => item.rawContent);
|
||||
const results = successResponse.map((item) => item.result).flat();
|
||||
|
||||
// 插入数据库,并修改状态
|
||||
await DataItem.findByIdAndUpdate(dataItem._id, {
|
||||
status: 0,
|
||||
$push: {
|
||||
rawResponse: {
|
||||
$each: successResponse.map((item) => item.rawContent)
|
||||
const response = await chatAPI
|
||||
.createChatCompletion(
|
||||
{
|
||||
model: ChatModelNameEnum.GPT35,
|
||||
temperature: 0.2,
|
||||
n: 1,
|
||||
messages: [
|
||||
systemPrompt,
|
||||
{
|
||||
role: 'user',
|
||||
content: text
|
||||
}
|
||||
]
|
||||
},
|
||||
result: {
|
||||
$each: results
|
||||
{
|
||||
timeout: 120000,
|
||||
httpsAgent
|
||||
}
|
||||
}
|
||||
});
|
||||
)
|
||||
.then((res) => ({
|
||||
rawContent: res?.data.choices[0].message?.content || '',
|
||||
result: splitText(res?.data.choices[0].message?.content || '')
|
||||
})); // 从 content 中提取 QA
|
||||
|
||||
// 插入 modelData 表,生成向量
|
||||
await ModelData.insertMany(
|
||||
response.result.map((item) => ({
|
||||
modelId: dataItem.modelId,
|
||||
userId: dataItem.userId,
|
||||
text: item.a,
|
||||
q: [
|
||||
{
|
||||
id: nanoid(),
|
||||
text: item.q
|
||||
}
|
||||
],
|
||||
status: 1
|
||||
}))
|
||||
);
|
||||
|
||||
console.log(
|
||||
'生成QA成功,time:',
|
||||
`${(Date.now() - startTime) / 1000}s`,
|
||||
'QA数量:',
|
||||
results.length
|
||||
response.result.length
|
||||
);
|
||||
|
||||
// 计费
|
||||
pushSplitDataBill({
|
||||
isPay: !userApiKey && results.length > 0,
|
||||
isPay: !userApiKey && response.result.length > 0,
|
||||
userId: dataItem.userId,
|
||||
type: 'QA',
|
||||
text: systemPrompt.content + dataItem.text + rawContents.join('')
|
||||
text: systemPrompt.content + text + response.rawContent
|
||||
});
|
||||
} catch (error: any) {
|
||||
console.log('error: 生成QA错误', dataItem?._id);
|
||||
console.log('response:', error?.response);
|
||||
if (dataItem?._id) {
|
||||
await DataItem.findByIdAndUpdate(dataItem._id, {
|
||||
status: dataItem.times > 0 ? 1 : 0, // 还有重试次数则可以继续进行
|
||||
$inc: {
|
||||
// 剩余尝试次数-1
|
||||
times: -1
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
generateQA(true);
|
||||
generateQA(true);
|
||||
generateVector(true);
|
||||
} catch (error: any) {
|
||||
console.log(error);
|
||||
console.log('生成QA错误:', error?.response);
|
||||
|
||||
setTimeout(() => {
|
||||
generateQA(true);
|
||||
}, 10000);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
88
src/service/events/generateVector.ts
Normal file
88
src/service/events/generateVector.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
import { httpsAgent } from '@/service/utils/tools';
|
||||
import { ModelData } from '../models/modelData';
|
||||
import { connectRedis } from '../redis';
|
||||
import { VecModelDataIndex } from '@/constants/redis';
|
||||
|
||||
export async function generateVector(next = false): Promise<any> {
|
||||
if (global.generatingVector && !next) return;
|
||||
global.generatingVector = true;
|
||||
|
||||
try {
|
||||
const redis = await connectRedis();
|
||||
|
||||
// 找出一个需要生成的 dataItem
|
||||
const dataItem = await ModelData.findOne({
|
||||
status: { $ne: 0 }
|
||||
});
|
||||
|
||||
if (!dataItem) {
|
||||
console.log('没有需要生成 【向量】 的数据');
|
||||
global.generatingVector = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取 openapi Key
|
||||
const openAiKey = process.env.OPENAIKEY as string;
|
||||
|
||||
// 获取 openai 请求实例
|
||||
const chatAPI = getOpenAIApi(openAiKey);
|
||||
|
||||
const dataId = String(dataItem._id);
|
||||
|
||||
// 生成词向量
|
||||
const response = await Promise.allSettled(
|
||||
dataItem.q.map((item, i) =>
|
||||
chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: 'text-embedding-ada-002',
|
||||
input: item.text
|
||||
},
|
||||
{
|
||||
timeout: 120000,
|
||||
httpsAgent
|
||||
}
|
||||
)
|
||||
.then((res) => res?.data?.data?.[0]?.embedding || [])
|
||||
.then((vector) =>
|
||||
redis.sendCommand([
|
||||
'JSON.SET',
|
||||
`${VecModelDataIndex}:${dataId}:${i}`,
|
||||
'$',
|
||||
JSON.stringify({
|
||||
dataId,
|
||||
modelId: String(dataItem.modelId),
|
||||
vector
|
||||
})
|
||||
])
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
if (response.filter((item) => item.status === 'fulfilled').length === 0) {
|
||||
throw new Error(JSON.stringify(response));
|
||||
}
|
||||
// 修改该数据状态
|
||||
await ModelData.findByIdAndUpdate(dataItem._id, {
|
||||
status: 0
|
||||
});
|
||||
|
||||
console.log(`生成向量成功: ${dataItem._id}`);
|
||||
|
||||
setTimeout(() => {
|
||||
generateVector(true);
|
||||
}, 3000);
|
||||
} catch (error: any) {
|
||||
console.log(error);
|
||||
console.log('error: 生成向量错误', error?.response?.data);
|
||||
|
||||
if (error?.response?.statusText === 'Too Many Requests') {
|
||||
console.log('次数限制,1分钟后尝试');
|
||||
// 限制次数,1分钟后再试
|
||||
setTimeout(() => {
|
||||
generateVector(true);
|
||||
}, 60000);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -34,7 +34,7 @@ export const pushChatBill = async ({
|
||||
// 计算价格
|
||||
const unitPrice = modelItem?.price || 5;
|
||||
const price = unitPrice * tokens.length;
|
||||
console.log(`chat bill, price: ${formatPrice(price)}元`);
|
||||
console.log(`chat bill, unit price: ${unitPrice}, price: ${formatPrice(price)}元`);
|
||||
|
||||
try {
|
||||
// 插入 Bill 记录
|
||||
|
||||
@@ -13,22 +13,23 @@ const ModelDataSchema = new Schema({
|
||||
ref: 'user',
|
||||
required: true
|
||||
},
|
||||
q: {
|
||||
text: {
|
||||
type: String,
|
||||
required: true
|
||||
},
|
||||
a: {
|
||||
type: String,
|
||||
default: ''
|
||||
q: {
|
||||
type: [
|
||||
{
|
||||
id: String, // 对应redis的key
|
||||
text: String
|
||||
}
|
||||
],
|
||||
default: []
|
||||
},
|
||||
status: {
|
||||
type: Number,
|
||||
enum: [0, 1, 2],
|
||||
enum: [0, 1], // 1 训练ing
|
||||
default: 1
|
||||
},
|
||||
createTime: {
|
||||
type: Date,
|
||||
default: () => new Date()
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
31
src/service/models/splitData.ts
Normal file
31
src/service/models/splitData.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
/* 模型的知识库 */
|
||||
import { Schema, model, models, Model as MongoModel } from 'mongoose';
|
||||
import { ModelSplitDataSchema as SplitDataType } from '@/types/mongoSchema';
|
||||
|
||||
const SplitDataSchema = new Schema({
|
||||
userId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'user',
|
||||
required: true
|
||||
},
|
||||
modelId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'model',
|
||||
required: true
|
||||
},
|
||||
rawText: {
|
||||
type: String,
|
||||
required: true
|
||||
},
|
||||
textList: {
|
||||
type: [String],
|
||||
default: []
|
||||
},
|
||||
errorText: {
|
||||
type: String,
|
||||
default: ''
|
||||
}
|
||||
});
|
||||
|
||||
export const SplitData: MongoModel<SplitDataType> =
|
||||
models['splitData'] || model('splitData', SplitDataSchema);
|
||||
@@ -1,6 +1,7 @@
|
||||
import mongoose from 'mongoose';
|
||||
import { generateQA } from './events/generateQA';
|
||||
import { generateAbstract } from './events/generateAbstract';
|
||||
import { generateVector } from './events/generateVector';
|
||||
|
||||
/**
|
||||
* 连接 MongoDB 数据库
|
||||
@@ -27,7 +28,8 @@ export async function connectToDatabase(): Promise<void> {
|
||||
}
|
||||
|
||||
generateQA();
|
||||
generateAbstract();
|
||||
// generateAbstract();
|
||||
generateVector();
|
||||
}
|
||||
|
||||
export * from './models/authCode';
|
||||
@@ -40,3 +42,4 @@ export * from './models/bill';
|
||||
export * from './models/pay';
|
||||
export * from './models/data';
|
||||
export * from './models/dataItem';
|
||||
export * from './models/splitData';
|
||||
|
||||
45
src/service/redis.ts
Normal file
45
src/service/redis.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { createClient } from 'redis';
|
||||
import { customAlphabet } from 'nanoid';
|
||||
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 10);
|
||||
|
||||
export const connectRedis = async () => {
|
||||
// 断开了,重连
|
||||
if (global.redisClient && !global.redisClient.isOpen) {
|
||||
await global.redisClient.disconnect();
|
||||
} else if (global.redisClient) {
|
||||
// 没断开,不再连接
|
||||
return global.redisClient;
|
||||
}
|
||||
|
||||
try {
|
||||
global.redisClient = createClient({
|
||||
url: process.env.REDIS_URL
|
||||
});
|
||||
|
||||
global.redisClient.on('error', (err) => {
|
||||
console.log('Redis Client Error', err);
|
||||
global.redisClient = null;
|
||||
});
|
||||
global.redisClient.on('end', () => {
|
||||
global.redisClient = null;
|
||||
});
|
||||
global.redisClient.on('ready', () => {
|
||||
console.log('redis connected');
|
||||
});
|
||||
|
||||
await global.redisClient.connect();
|
||||
|
||||
// 0 - 测试库,1 - 正式
|
||||
await global.redisClient.select(0);
|
||||
|
||||
return global.redisClient;
|
||||
} catch (error) {
|
||||
console.log(error, '==');
|
||||
global.redisClient = null;
|
||||
return Promise.reject('redis 连接失败');
|
||||
}
|
||||
};
|
||||
|
||||
export const getKey = (prefix = '') => {
|
||||
return `${prefix}:${nanoid()}`;
|
||||
};
|
||||
@@ -119,3 +119,21 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
|
||||
|
||||
return systemPrompt ? [systemPrompt, ...res] : res;
|
||||
};
|
||||
|
||||
/* system 内容截断 */
|
||||
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
|
||||
let splitText = '';
|
||||
|
||||
// 从前往前截取
|
||||
for (let i = 0; i < prompts.length; i++) {
|
||||
const prompt = prompts[i];
|
||||
|
||||
splitText += `${prompt}\n`;
|
||||
const tokens = encode(splitText).length;
|
||||
if (tokens >= maxTokens) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return splitText;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user