training queue

This commit is contained in:
archer
2023-05-26 23:08:25 +08:00
parent 69f32a0861
commit dc1c1d1355
32 changed files with 528 additions and 493 deletions

View File

@@ -1,84 +1,36 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import type { KbDataItemType } from '@/types/plugin';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { generateVector } from '@/service/events/generateVector';
import { PgClient, insertKbItem } from '@/service/pg';
import { PgClient } from '@/service/pg';
import { authKb } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools';
interface Props {
kbId: string;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
}
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const {
kbId,
data,
formatLineBreak = true
} = req.body as {
kbId: string;
formatLineBreak?: boolean;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
};
const { kbId, data } = req.body as Props;
if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数');
}
await connectToDatabase();
// 凭证校验
const { userId } = await authUser({ req });
await authKb({
userId,
kbId
});
// 过滤重复的内容
const searchRes = await Promise.allSettled(
data.map(async ({ q, a = '' }) => {
if (!q) {
return Promise.reject('q为空');
}
if (formatLineBreak) {
q = q.replace(/\\n/g, '\n');
a = a.replace(/\\n/g, '\n');
}
// Exactly the same data, not push
try {
const count = await PgClient.count('modelData', {
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
});
if (count > 0) {
return Promise.reject('已经存在');
}
} catch (error) {
error;
}
return Promise.resolve({
q,
a
});
})
);
const filterData = searchRes
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
// 插入记录
const insertRes = await insertKbItem({
userId,
kbId,
data: filterData
});
generateVector();
jsonRes(res, {
message: `共插入 ${insertRes.rowCount} 条数据`,
data: insertRes.rowCount
data: await pushDataToKb({
kbId,
data,
userId
})
});
} catch (err) {
jsonRes(res, {
@@ -88,6 +40,32 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
}
});
export async function pushDataToKb({ userId, kbId, data }: { userId: string } & Props) {
await authKb({
userId,
kbId
});
if (data.length === 0) {
return {
trainingId: ''
};
}
// 插入记录
const { _id } = await TrainingData.create({
userId,
kbId,
vectorList: data
});
generateVector(_id);
return {
trainingId: _id
};
}
export const config = {
api: {
bodyParser: {