feat: insert data de-weight;perf: input queue

This commit is contained in:
archer
2023-05-28 20:13:19 +08:00
parent 7e99f905bc
commit 516618b0cd
12 changed files with 187 additions and 105 deletions

View File

@@ -16,6 +16,10 @@ export type Props = {
prompt?: string;
};
export type Response = {
insertLen: number;
};
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, data, mode, prompt } = req.body as Props;
@@ -28,7 +32,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 凭证校验
const { userId } = await authUser({ req });
jsonRes(res, {
jsonRes<Response>(res, {
data: await pushDataToKb({
kbId,
data,
@@ -51,16 +55,12 @@ export async function pushDataToKb({
data,
mode,
prompt
}: { userId: string } & Props) {
}: { userId: string } & Props): Promise<Response> {
await authKb({
userId,
kbId
});
if (data.length === 0) {
return {};
}
// 过滤重复的 qa 内容
const set = new Set();
const filterData: {
@@ -75,41 +75,54 @@ export async function pushDataToKb({
set.add(text);
}
});
// 数据库去重
// const searchRes = await Promise.allSettled(
// data.map(async ({ q, a = '' }) => {
// if (!q) {
// return Promise.reject('q为空');
// }
const insertData = (
await Promise.allSettled(
filterData.map(async ({ q, a = '' }) => {
if (mode !== TrainingModeEnum.index) {
return Promise.resolve({
q,
a
});
}
// q = q.replace(/\\n/g, '\n');
// a = a.replace(/\\n/g, '\n');
if (!q) {
return Promise.reject('q为空');
}
// // Exactly the same data, not push
// try {
// const count = await PgClient.count('modelData', {
// where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
// });
q = q.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
a = a.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
// if (count > 0) {
// return Promise.reject('已经存在');
// }
// } catch (error) {
// error;
// }
// return Promise.resolve({
// q,
// a
// });
// })
// );
// const filterData = searchRes
// .filter((item) => item.status === 'fulfilled')
// .map<{ q: string; a: string }>((item: any) => item.value);
// Exactly the same data, not push
try {
const { rows } = await PgClient.query(`
SELECT COUNT(*) > 0 AS exists
FROM modelData
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
`);
const exists = rows[0]?.exists || false;
if (exists) {
return Promise.reject('已经存在');
}
} catch (error) {
console.log(error);
error;
}
return Promise.resolve({
q,
a
});
})
)
)
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
// 插入记录
await TrainingData.insertMany(
data.map((item) => ({
insertData.map((item) => ({
q: item.q,
a: item.a,
userId,
@@ -119,9 +132,11 @@ export async function pushDataToKb({
}))
);
startQueue();
insertData.length > 0 && startQueue();
return {};
return {
insertLen: insertData.length
};
}
export const config = {

View File

@@ -32,10 +32,10 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]],
values: [
{ key: 'a', value: a },
{ key: 'a', value: a.replace(/'/g, '"') },
...(q
? [
{ key: 'q', value: q },
{ key: 'q', value: q.replace(/'/g, '"') },
{ key: 'vector', value: `[${vector[0]}]` }
]
: [])