training queue
This commit is contained in:
@@ -5,7 +5,7 @@ import { PgClient } from '@/service/pg';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
import type { ChatItemSimpleType } from '@/types/chat';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { ModelVectorSearchModeEnum } from '@/constants/model';
|
||||
import { appVectorSearchModeEnum } from '@/constants/model';
|
||||
import { authModel } from '@/service/utils/auth';
|
||||
import { ChatModelMap } from '@/constants/model';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
@@ -92,7 +92,8 @@ export async function appKbSearch({
|
||||
// get vector
|
||||
const promptVectors = await openaiEmbedding({
|
||||
userId,
|
||||
input
|
||||
input,
|
||||
type: 'chat'
|
||||
});
|
||||
|
||||
// search kb
|
||||
@@ -138,7 +139,7 @@ export async function appKbSearch({
|
||||
obj: ChatRoleEnum.System,
|
||||
value: model.chat.systemPrompt
|
||||
}
|
||||
: model.chat.searchMode === ModelVectorSearchModeEnum.noContext
|
||||
: model.chat.searchMode === appVectorSearchModeEnum.noContext
|
||||
? {
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
|
||||
@@ -176,7 +177,7 @@ export async function appKbSearch({
|
||||
const systemPrompt = sliceResult.flat().join('\n').trim();
|
||||
|
||||
/* 高相似度+不回复 */
|
||||
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) {
|
||||
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
|
||||
return {
|
||||
code: 201,
|
||||
rawSearch: [],
|
||||
@@ -190,7 +191,7 @@ export async function appKbSearch({
|
||||
};
|
||||
}
|
||||
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
|
||||
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
|
||||
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.noContext) {
|
||||
return {
|
||||
code: 200,
|
||||
rawSearch: [],
|
||||
|
||||
@@ -1,84 +1,36 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import type { KbDataItemType } from '@/types/plugin';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { connectToDatabase, TrainingData } from '@/service/mongo';
|
||||
import { authUser } from '@/service/utils/auth';
|
||||
import { generateVector } from '@/service/events/generateVector';
|
||||
import { PgClient, insertKbItem } from '@/service/pg';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { authKb } from '@/service/utils/auth';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
|
||||
interface Props {
|
||||
kbId: string;
|
||||
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
|
||||
}
|
||||
|
||||
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const {
|
||||
kbId,
|
||||
data,
|
||||
formatLineBreak = true
|
||||
} = req.body as {
|
||||
kbId: string;
|
||||
formatLineBreak?: boolean;
|
||||
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
|
||||
};
|
||||
const { kbId, data } = req.body as Props;
|
||||
|
||||
if (!kbId || !Array.isArray(data)) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 凭证校验
|
||||
const { userId } = await authUser({ req });
|
||||
|
||||
await authKb({
|
||||
userId,
|
||||
kbId
|
||||
});
|
||||
|
||||
// 过滤重复的内容
|
||||
const searchRes = await Promise.allSettled(
|
||||
data.map(async ({ q, a = '' }) => {
|
||||
if (!q) {
|
||||
return Promise.reject('q为空');
|
||||
}
|
||||
|
||||
if (formatLineBreak) {
|
||||
q = q.replace(/\\n/g, '\n');
|
||||
a = a.replace(/\\n/g, '\n');
|
||||
}
|
||||
|
||||
// Exactly the same data, not push
|
||||
try {
|
||||
const count = await PgClient.count('modelData', {
|
||||
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
|
||||
});
|
||||
if (count > 0) {
|
||||
return Promise.reject('已经存在');
|
||||
}
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
return Promise.resolve({
|
||||
q,
|
||||
a
|
||||
});
|
||||
})
|
||||
);
|
||||
const filterData = searchRes
|
||||
.filter((item) => item.status === 'fulfilled')
|
||||
.map<{ q: string; a: string }>((item: any) => item.value);
|
||||
|
||||
// 插入记录
|
||||
const insertRes = await insertKbItem({
|
||||
userId,
|
||||
kbId,
|
||||
data: filterData
|
||||
});
|
||||
|
||||
generateVector();
|
||||
|
||||
jsonRes(res, {
|
||||
message: `共插入 ${insertRes.rowCount} 条数据`,
|
||||
data: insertRes.rowCount
|
||||
data: await pushDataToKb({
|
||||
kbId,
|
||||
data,
|
||||
userId
|
||||
})
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
@@ -88,6 +40,32 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
}
|
||||
});
|
||||
|
||||
export async function pushDataToKb({ userId, kbId, data }: { userId: string } & Props) {
|
||||
await authKb({
|
||||
userId,
|
||||
kbId
|
||||
});
|
||||
|
||||
if (data.length === 0) {
|
||||
return {
|
||||
trainingId: ''
|
||||
};
|
||||
}
|
||||
|
||||
// 插入记录
|
||||
const { _id } = await TrainingData.create({
|
||||
userId,
|
||||
kbId,
|
||||
vectorList: data
|
||||
});
|
||||
|
||||
generateVector(_id);
|
||||
|
||||
return {
|
||||
trainingId: _id
|
||||
};
|
||||
}
|
||||
|
||||
export const config = {
|
||||
api: {
|
||||
bodyParser: {
|
||||
|
||||
@@ -5,10 +5,11 @@ import { ModelDataStatusEnum } from '@/constants/model';
|
||||
import { generateVector } from '@/service/events/generateVector';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
import { openaiEmbedding } from '../plugin/openaiEmbedding';
|
||||
|
||||
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string };
|
||||
const { dataId, a = '', q = '' } = req.body as { dataId: string; a?: string; q?: string };
|
||||
|
||||
if (!dataId) {
|
||||
throw new Error('缺少参数');
|
||||
@@ -17,22 +18,24 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
// 凭证校验
|
||||
const { userId } = await authUser({ req });
|
||||
|
||||
// get vector
|
||||
const vector = await (async () => {
|
||||
if (q) {
|
||||
return openaiEmbedding({
|
||||
userId,
|
||||
input: [q],
|
||||
type: 'chat'
|
||||
});
|
||||
}
|
||||
return [];
|
||||
})();
|
||||
|
||||
// 更新 pg 内容.仅修改a,不需要更新向量。
|
||||
await PgClient.update('modelData', {
|
||||
where: [['id', dataId], 'AND', ['user_id', userId]],
|
||||
values: [
|
||||
{ key: 'a', value: a },
|
||||
...(q
|
||||
? [
|
||||
{ key: 'q', value: q },
|
||||
{ key: 'status', value: ModelDataStatusEnum.waiting }
|
||||
]
|
||||
: [])
|
||||
]
|
||||
values: [{ key: 'a', value: a }, ...(q ? [{ key: 'q', value: `${vector[0]}` }] : [])]
|
||||
});
|
||||
|
||||
q && generateVector();
|
||||
|
||||
jsonRes(res);
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
|
||||
Reference in New Issue
Block a user