new framwork

This commit is contained in:
archer
2023-06-09 12:57:42 +08:00
parent d9450bd7ee
commit ba9d9c3d5f
263 changed files with 12269 additions and 11599 deletions

View File

@@ -0,0 +1,229 @@
import { TrainingData } from '@/service/mongo';
import { getApiKey } from '../utils/auth';
import { OpenAiChatEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
import { openaiAccountError } from '../errorCode';
import { modelServiceToolMap } from '../utils/chat';
import { ChatRoleEnum } from '@/constants/chat';
import { BillTypeEnum } from '@/constants/user';
import { pushDataToKb } from '@/pages/api/openapi/kb/pushData';
import { TrainingModeEnum } from '@/constants/plugin';
import { ERROR_ENUM } from '../errorCode';
import { sendInform } from '@/pages/api/user/inform/send';
const reduceQueue = () => {
global.qaQueueLen = global.qaQueueLen > 0 ? global.qaQueueLen - 1 : 0;
};
export async function generateQA(): Promise<any> {
const maxProcess = Number(process.env.QA_MAX_PROCESS || 10);
if (global.qaQueueLen >= maxProcess) return;
global.qaQueueLen++;
let trainingId = '';
let userId = '';
try {
const match = {
mode: TrainingModeEnum.qa,
lockTime: { $lte: new Date(Date.now() - 4 * 60 * 1000) }
};
// random get task
const agree = await TrainingData.aggregate([
{
$match: match
},
{ $sample: { size: 1 } },
{
$project: {
_id: 1
}
}
]);
// no task
if (agree.length === 0) {
reduceQueue();
global.qaQueueLen <= 0 && console.log(`没有需要【QA】的数据, ${global.qaQueueLen}`);
return;
}
const data = await TrainingData.findOneAndUpdate(
{
_id: agree[0]._id,
...match
},
{
lockTime: new Date()
}
).select({
_id: 1,
userId: 1,
kbId: 1,
prompt: 1,
q: 1,
source: 1
});
// task preemption
if (!data) {
reduceQueue();
return generateQA();
}
trainingId = data._id;
userId = String(data.userId);
const kbId = String(data.kbId);
// 余额校验并获取 openapi Key
const { systemAuthKey } = await getApiKey({
model: OpenAiChatEnum.GPT35,
userId,
type: 'training',
mustPay: true
});
const startTime = Date.now();
// 请求 chatgpt 获取回答
const response = await Promise.all(
[data.q].map((text) =>
modelServiceToolMap[OpenAiChatEnum.GPT35]
.chatCompletion({
apiKey: systemAuthKey,
temperature: 0.8,
messages: [
{
obj: ChatRoleEnum.System,
value: `你是出题人
${data.prompt || '下面是"一段长文本"'}
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
A1:
Q2:
A2:
...`
},
{
obj: 'Human',
value: text
}
],
stream: false
})
.then(({ totalTokens, responseText, responseMessages }) => {
const result = formatSplitText(responseText); // 格式化后的QA对
console.log(`split result length: `, result.length);
// 计费
pushSplitDataBill({
isPay: result.length > 0,
userId: data.userId,
type: BillTypeEnum.QA,
textLen: responseMessages.map((item) => item.value).join('').length,
totalTokens
});
return {
rawContent: responseText,
result
};
})
.catch((err) => {
console.log('QA拆分错误');
console.log(err.response?.status, err.response?.statusText, err.response?.data);
return Promise.reject(err);
})
)
);
const responseList = response.map((item) => item.result).flat();
// 创建 向量生成 队列
await pushDataToKb({
kbId,
data: responseList.map((item) => ({
...item,
source: data.source
})),
userId,
mode: TrainingModeEnum.index
});
// delete data from training
await TrainingData.findByIdAndDelete(data._id);
console.log('生成QA成功time:', `${(Date.now() - startTime) / 1000}s`);
reduceQueue();
generateQA();
} catch (err: any) {
reduceQueue();
// log
if (err?.response) {
console.log('openai error: 生成QA错误');
console.log(err.response?.status, err.response?.statusText, err.response?.data);
} else {
console.log('生成QA错误:', err);
}
// message error or openai account error
if (
err?.message === 'invalid message format' ||
err.response?.statusText === 'Unauthorized' ||
openaiAccountError[err?.response?.data?.error?.code || err?.response?.data?.error?.type]
) {
await TrainingData.findByIdAndRemove(trainingId);
}
// 账号余额不足,删除任务
if (userId && err === ERROR_ENUM.insufficientQuota) {
sendInform({
type: 'system',
title: 'QA 任务中止',
content: '由于账号余额不足QA 任务中止,重新充值后将会继续。',
userId
});
console.log('余额不足,暂停向量生成任务');
await TrainingData.updateMany(
{
userId
},
{
lockTime: new Date('2999/5/5')
}
);
return generateQA();
}
// unlock
await TrainingData.findByIdAndUpdate(trainingId, {
lockTime: new Date('2000/1/1')
});
setTimeout(() => {
generateQA();
}, 1000);
}
}
/**
* 检查文本是否按格式返回
*/
function formatSplitText(text: string) {
const regex = /Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)([\s\S]*?)(?=Q|$)/g; // 匹配Q和A的正则表达式
const matches = text.matchAll(regex); // 获取所有匹配到的结果
const result = []; // 存储最终的结果
for (const match of matches) {
const q = match[2];
const a = match[5];
if (q && a) {
// 如果Q和A都存在就将其添加到结果中
result.push({
q,
a: a.trim().replace(/\n\s*/g, '\n')
});
}
}
return result;
}

View File

@@ -0,0 +1,158 @@
import { openaiAccountError } from '../errorCode';
import { insertKbItem } from '@/service/pg';
import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding';
import { TrainingData } from '../models/trainingData';
import { ERROR_ENUM } from '../errorCode';
import { TrainingModeEnum } from '@/constants/plugin';
import { sendInform } from '@/pages/api/user/inform/send';
const reduceQueue = () => {
global.vectorQueueLen = global.vectorQueueLen > 0 ? global.vectorQueueLen - 1 : 0;
};
/* 索引生成队列。每导入一次,就是一个单独的线程 */
export async function generateVector(): Promise<any> {
const maxProcess = Number(process.env.VECTOR_MAX_PROCESS || 10);
if (global.vectorQueueLen >= maxProcess) return;
global.vectorQueueLen++;
let trainingId = '';
let userId = '';
try {
const match = {
mode: TrainingModeEnum.index,
lockTime: { $lte: new Date(Date.now() - 2 * 60 * 1000) }
};
// random get task
const agree = await TrainingData.aggregate([
{
$match: match
},
{ $sample: { size: 1 } },
{
$project: {
_id: 1
}
}
]);
// no task
if (agree.length === 0) {
reduceQueue();
global.vectorQueueLen <= 0 && console.log(`没有需要【索引】的数据, ${global.vectorQueueLen}`);
return;
}
const data = await TrainingData.findOneAndUpdate(
{
_id: agree[0]._id,
...match
},
{
lockTime: new Date()
}
).select({
_id: 1,
userId: 1,
kbId: 1,
q: 1,
a: 1,
source: 1
});
// task preemption
if (!data) {
reduceQueue();
return generateVector();
}
trainingId = data._id;
userId = String(data.userId);
const kbId = String(data.kbId);
const dataItems = [
{
q: data.q,
a: data.a
}
];
// 生成词向量
const vectors = await openaiEmbedding({
input: dataItems.map((item) => item.q),
userId,
type: 'training',
mustPay: true
});
// 生成结果插入到 pg
await insertKbItem({
userId,
kbId,
data: vectors.map((vector, i) => ({
q: dataItems[i].q,
a: dataItems[i].a,
source: data.source,
vector
}))
});
// delete data from training
await TrainingData.findByIdAndDelete(data._id);
console.log(`生成向量成功: ${data._id}`);
reduceQueue();
generateVector();
} catch (err: any) {
reduceQueue();
// log
if (err?.response) {
console.log('openai error: 生成向量错误');
console.log(err.response?.status, err.response?.statusText, err.response?.data);
} else {
console.log('生成向量错误:', err);
}
// message error or openai account error
if (
err?.message === 'invalid message format' ||
err.response?.statusText === 'Unauthorized' ||
openaiAccountError[err?.response?.data?.error?.code || err?.response?.data?.error?.type]
) {
console.log('删除一个任务');
await TrainingData.findByIdAndRemove(trainingId);
}
// 账号余额不足,删除任务
if (userId && err === ERROR_ENUM.insufficientQuota) {
sendInform({
type: 'system',
title: '索引生成任务中止',
content: '由于账号余额不足,索引生成任务中止,重新充值后将会继续。',
userId
});
console.log('余额不足,暂停向量生成任务');
await TrainingData.updateMany(
{
userId
},
{
lockTime: new Date('2999/5/5')
}
);
return generateVector();
}
// unlock
err.response?.statusText !== 'Too Many Requests' &&
(await TrainingData.findByIdAndUpdate(trainingId, {
lockTime: new Date('2000/1/1')
}));
setTimeout(() => {
generateVector();
}, 1000);
}
}

View File

@@ -0,0 +1,180 @@
import { connectToDatabase, Bill, User, ShareChat } from '../mongo';
import {
ChatModelMap,
OpenAiChatEnum,
ChatModelType,
embeddingModel,
embeddingPrice
} from '@/constants/model';
import { BillTypeEnum } from '@/constants/user';
export const pushChatBill = async ({
isPay,
chatModel,
userId,
chatId,
textLen,
tokens,
type
}: {
isPay: boolean;
chatModel: ChatModelType;
userId: string;
chatId?: '' | string;
textLen: number;
tokens: number;
type: BillTypeEnum.chat | BillTypeEnum.openapiChat;
}) => {
console.log(`chat generate success. text len: ${textLen}. token len: ${tokens}. pay:${isPay}`);
if (!isPay) return;
let billId = '';
try {
await connectToDatabase();
// 计算价格
const unitPrice = ChatModelMap[chatModel]?.price || 3;
const price = unitPrice * tokens;
try {
// 插入 Bill 记录
const res = await Bill.create({
userId,
type,
modelName: chatModel,
chatId: chatId ? chatId : undefined,
textLen,
tokenLen: tokens,
price
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
} catch (error) {
console.log(error);
}
};
export const updateShareChatBill = async ({
shareId,
tokens
}: {
shareId: string;
tokens: number;
}) => {
try {
await ShareChat.findByIdAndUpdate(shareId, {
$inc: { tokens },
lastTime: new Date()
});
} catch (error) {
console.log('update shareChat error', error);
}
};
export const pushSplitDataBill = async ({
isPay,
userId,
totalTokens,
textLen,
type
}: {
isPay: boolean;
userId: string;
totalTokens: number;
textLen: number;
type: BillTypeEnum.QA;
}) => {
console.log(
`splitData generate success. text len: ${textLen}. token len: ${totalTokens}. pay:${isPay}`
);
if (!isPay) return;
let billId;
try {
await connectToDatabase();
// 获取模型单价格, 都是用 gpt35 拆分
const unitPrice = ChatModelMap[OpenAiChatEnum.GPT35].price || 3;
// 计算价格
const price = unitPrice * totalTokens;
// 插入 Bill 记录
const res = await Bill.create({
userId,
type,
modelName: OpenAiChatEnum.GPT35,
textLen,
tokenLen: totalTokens,
price
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
};
export const pushGenerateVectorBill = async ({
isPay,
userId,
text,
tokenLen
}: {
isPay: boolean;
userId: string;
text: string;
tokenLen: number;
}) => {
// console.log(
// `vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
// );
if (!isPay) return;
let billId;
try {
await connectToDatabase();
try {
// 计算价格. 至少为1
let price = embeddingPrice * tokenLen;
price = price > 1 ? price : 1;
// 插入 Bill 记录
const res = await Bill.create({
userId,
type: BillTypeEnum.vector,
modelName: embeddingModel,
textLen: text.length,
tokenLen,
price
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
} catch (error) {
console.log(error);
}
};