perf: open push data api

This commit is contained in:
archer
2023-08-29 10:40:48 +08:00
parent 19d7edb585
commit e0de04dddb
6 changed files with 51 additions and 77 deletions

View File

@@ -8,6 +8,7 @@ import { PgTrainingTableName, TrainingModeEnum } from '@/constants/plugin';
import { startQueue } from '@/service/utils/tools';
import { PgClient } from '@/service/pg';
import { modelToolMap } from '@/utils/plugin';
import { getVectorModel } from '@/service/utils/data';
export type DateItemType = { a: string; q: string; source?: string };
@@ -22,17 +23,25 @@ export type Response = {
insertLen: number;
};
const modeMaxToken = {
[TrainingModeEnum.index]: 6000,
[TrainingModeEnum.qa]: 12000
const modeMap = {
[TrainingModeEnum.index]: true,
[TrainingModeEnum.qa]: true
};
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, data, mode, prompt } = req.body as Props;
const { kbId, data, mode = TrainingModeEnum.index, prompt } = req.body as Props;
if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数');
throw new Error('KbId or data is empty');
}
if (modeMap[mode] === undefined) {
throw new Error('Mode is error');
}
if (data.length > 500) {
throw new Error('Data is too long, max 500');
}
await connectToDatabase();
@@ -64,25 +73,42 @@ export async function pushDataToKb({
mode,
prompt
}: { userId: string } & Props): Promise<Response> {
await authKb({
userId,
kbId
});
const [kb, vectorModel] = await Promise.all([
authKb({
userId,
kbId
}),
(async () => {
if (mode === TrainingModeEnum.index) {
const vectorModel = (await KB.findById(kbId, 'vectorModel'))?.vectorModel;
return getVectorModel(vectorModel || global.vectorModels[0].model);
}
return global.vectorModels[0];
})()
]);
const modeMaxToken = {
[TrainingModeEnum.index]: vectorModel.maxToken,
[TrainingModeEnum.qa]: global.qaModel.maxToken * 0.8
};
// 过滤重复的 qa 内容
const set = new Set();
const filterData: DateItemType[] = [];
data.forEach((item) => {
if (!item.q) return;
const text = item.q + item.a;
// count token
// count q token
const token = modelToolMap.countTokens({
model: 'gpt-3.5-turbo',
messages: [{ obj: 'System', value: item.q }]
});
if (token > modeMaxToken[TrainingModeEnum.qa]) {
if (token > modeMaxToken[mode]) {
return;
}
@@ -138,15 +164,8 @@ export async function pushDataToKb({
.filter((item) => item.status === 'fulfilled')
.map<DateItemType>((item: any) => item.value);
const vectorModel = await (async () => {
if (mode === TrainingModeEnum.index) {
return (await KB.findById(kbId, 'vectorModel'))?.vectorModel || global.vectorModels[0].model;
}
return global.vectorModels[0].model;
})();
// 插入记录
await TrainingData.insertMany(
const insertRes = await TrainingData.insertMany(
insertData.map((item) => ({
q: item.q,
a: item.a,
@@ -155,21 +174,21 @@ export async function pushDataToKb({
kbId,
mode,
prompt,
vectorModel
vectorModel: vectorModel.model
}))
);
insertData.length > 0 && startQueue();
insertRes.length > 0 && startQueue();
return {
insertLen: insertData.length
insertLen: insertRes.length
};
}
export const config = {
api: {
bodyParser: {
sizeLimit: '20mb'
sizeLimit: '12mb'
}
}
};