training queue

This commit is contained in:
archer
2023-05-26 23:08:25 +08:00
parent 69f32a0861
commit dc1c1d1355
32 changed files with 528 additions and 493 deletions

View File

@@ -5,7 +5,7 @@ import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools';
import type { ChatItemSimpleType } from '@/types/chat';
import type { ModelSchema } from '@/types/mongoSchema';
import { ModelVectorSearchModeEnum } from '@/constants/model';
import { appVectorSearchModeEnum } from '@/constants/model';
import { authModel } from '@/service/utils/auth';
import { ChatModelMap } from '@/constants/model';
import { ChatRoleEnum } from '@/constants/chat';
@@ -92,7 +92,8 @@ export async function appKbSearch({
// get vector
const promptVectors = await openaiEmbedding({
userId,
input
input,
type: 'chat'
});
// search kb
@@ -138,7 +139,7 @@ export async function appKbSearch({
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
: model.chat.searchMode === ModelVectorSearchModeEnum.noContext
: model.chat.searchMode === appVectorSearchModeEnum.noContext
? {
obj: ChatRoleEnum.System,
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
@@ -176,7 +177,7 @@ export async function appKbSearch({
const systemPrompt = sliceResult.flat().join('\n').trim();
/* 高相似度+不回复 */
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) {
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
return {
code: 201,
rawSearch: [],
@@ -190,7 +191,7 @@ export async function appKbSearch({
};
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.noContext) {
return {
code: 200,
rawSearch: [],

View File

@@ -1,84 +1,36 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import type { KbDataItemType } from '@/types/plugin';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { generateVector } from '@/service/events/generateVector';
import { PgClient, insertKbItem } from '@/service/pg';
import { PgClient } from '@/service/pg';
import { authKb } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools';
interface Props {
kbId: string;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
}
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const {
kbId,
data,
formatLineBreak = true
} = req.body as {
kbId: string;
formatLineBreak?: boolean;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
};
const { kbId, data } = req.body as Props;
if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数');
}
await connectToDatabase();
// 凭证校验
const { userId } = await authUser({ req });
await authKb({
userId,
kbId
});
// 过滤重复的内容
const searchRes = await Promise.allSettled(
data.map(async ({ q, a = '' }) => {
if (!q) {
return Promise.reject('q为空');
}
if (formatLineBreak) {
q = q.replace(/\\n/g, '\n');
a = a.replace(/\\n/g, '\n');
}
// Exactly the same data, not push
try {
const count = await PgClient.count('modelData', {
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
});
if (count > 0) {
return Promise.reject('已经存在');
}
} catch (error) {
error;
}
return Promise.resolve({
q,
a
});
})
);
const filterData = searchRes
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
// 插入记录
const insertRes = await insertKbItem({
userId,
kbId,
data: filterData
});
generateVector();
jsonRes(res, {
message: `共插入 ${insertRes.rowCount} 条数据`,
data: insertRes.rowCount
data: await pushDataToKb({
kbId,
data,
userId
})
});
} catch (err) {
jsonRes(res, {
@@ -88,6 +40,32 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
}
});
export async function pushDataToKb({ userId, kbId, data }: { userId: string } & Props) {
await authKb({
userId,
kbId
});
if (data.length === 0) {
return {
trainingId: ''
};
}
// 插入记录
const { _id } = await TrainingData.create({
userId,
kbId,
vectorList: data
});
generateVector(_id);
return {
trainingId: _id
};
}
export const config = {
api: {
bodyParser: {

View File

@@ -5,10 +5,11 @@ import { ModelDataStatusEnum } from '@/constants/model';
import { generateVector } from '@/service/events/generateVector';
import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools';
import { openaiEmbedding } from '../plugin/openaiEmbedding';
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string };
const { dataId, a = '', q = '' } = req.body as { dataId: string; a?: string; q?: string };
if (!dataId) {
throw new Error('缺少参数');
@@ -17,22 +18,24 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 凭证校验
const { userId } = await authUser({ req });
// get vector
const vector = await (async () => {
if (q) {
return openaiEmbedding({
userId,
input: [q],
type: 'chat'
});
}
return [];
})();
// 更新 pg 内容.仅修改a不需要更新向量。
await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]],
values: [
{ key: 'a', value: a },
...(q
? [
{ key: 'q', value: q },
{ key: 'status', value: ModelDataStatusEnum.waiting }
]
: [])
]
values: [{ key: 'a', value: a }, ...(q ? [{ key: 'q', value: `${vector[0]}` }] : [])]
});
q && generateVector();
jsonRes(res);
} catch (err) {
jsonRes(res, {

View File

@@ -1,30 +1,31 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools';
import { getApiKey } from '@/service/utils/auth';
import { getOpenAIApi } from '@/service/utils/chat/openai';
import { embeddingModel } from '@/constants/model';
import { axiosConfig } from '@/service/utils/tools';
import { pushGenerateVectorBill } from '@/service/events/pushBill';
import { ApiKeyType } from '@/service/utils/auth';
type Props = {
input: string[];
type?: ApiKeyType;
};
type Response = number[][];
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { userId } = await authUser({ req });
let { input } = req.query as Props;
let { input, type } = req.query as Props;
if (!Array.isArray(input)) {
throw new Error('缺少参数');
}
jsonRes<Response>(res, {
data: await openaiEmbedding({ userId, input, mustPay: true })
data: await openaiEmbedding({ userId, input, mustPay: true, type })
});
} catch (err) {
console.log(err);
@@ -38,12 +39,14 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
export async function openaiEmbedding({
userId,
input,
mustPay = false
mustPay = false,
type = 'chat'
}: { userId: string; mustPay?: boolean } & Props) {
const { userOpenAiKey, systemAuthKey } = await getApiKey({
model: 'gpt-3.5-turbo',
userId,
mustPay
mustPay,
type
});
// 获取 chatAPI

View File

@@ -1,19 +0,0 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { generateQA } from '@/service/events/generateQA';
import { generateVector } from '@/service/events/generateVector';
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
generateQA();
generateVector();
jsonRes(res);
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -17,7 +17,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { input } = req.body as TextPluginRequestParams;
const response = await axios({
...axiosConfig(getSystemOpenAiKey()),
...axiosConfig(getSystemOpenAiKey('chat')),
method: 'POST',
url: `/moderations`,
data: {

View File

@@ -1,12 +1,11 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData } from '@/service/mongo';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authKb, authUser } from '@/service/utils/auth';
import { generateVector } from '@/service/events/generateVector';
import { generateQA } from '@/service/events/generateQA';
import { insertKbItem } from '@/service/pg';
import { SplitTextTypEnum } from '@/constants/plugin';
import { TrainingTypeEnum } from '@/constants/plugin';
import { withNextCors } from '@/service/utils/tools';
import { pushDataToKb } from '../kb/pushData';
/* split text */
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -15,7 +14,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
kbId: string;
chunks: string[];
prompt: string;
mode: `${SplitTextTypEnum}`;
mode: `${TrainingTypeEnum}`;
};
if (!chunks || !kbId || !prompt) {
throw new Error('参数错误');
@@ -30,29 +29,26 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
userId
});
if (mode === SplitTextTypEnum.qa) {
if (mode === TrainingTypeEnum.qa) {
// 批量QA拆分插入数据
await SplitData.create({
const { _id } = await TrainingData.create({
userId,
kbId,
textList: chunks,
qaList: chunks,
prompt
});
generateQA();
} else if (mode === SplitTextTypEnum.subsection) {
// 待优化,直接调用另一个接口
// 插入记录
await insertKbItem({
userId,
generateQA(_id);
} else if (mode === TrainingTypeEnum.subsection) {
// 分段导入,直接插入向量队列
const response = await pushDataToKb({
kbId,
data: chunks.map((item) => ({
q: item,
a: ''
}))
data: chunks.map((item) => ({ q: item, a: '' })),
userId
});
generateVector();
return jsonRes(res, {
data: response
});
}
jsonRes(res);

View File

@@ -1,14 +1,15 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData, Model } from '@/service/mongo';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { ModelDataStatusEnum } from '@/constants/model';
import { PgClient } from '@/service/pg';
import { Types } from 'mongoose';
import { generateQA } from '@/service/events/generateQA';
import { generateVector } from '@/service/events/generateVector';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { kbId } = req.query as { kbId: string };
const { kbId, init = false } = req.body as { kbId: string; init: boolean };
if (!kbId) {
throw new Error('参数错误');
}
@@ -17,29 +18,43 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { userId } = await authUser({ req, authToken: true });
// split queue data
const data = await SplitData.find({
userId,
kbId,
textList: { $exists: true, $not: { $size: 0 } }
});
// embedding queue data
const embeddingData = await PgClient.count('modelData', {
where: [
['user_id', userId],
'AND',
['kb_id', kbId],
'AND',
['status', ModelDataStatusEnum.waiting]
]
});
const result = await TrainingData.aggregate([
{ $match: { userId: new Types.ObjectId(userId), kbId: new Types.ObjectId(kbId) } },
{
$project: {
qaListLength: { $size: { $ifNull: ['$qaList', []] } },
vectorListLength: { $size: { $ifNull: ['$vectorList', []] } }
}
},
{
$group: {
_id: null,
totalQaListLength: { $sum: '$qaListLength' },
totalVectorListLength: { $sum: '$vectorListLength' }
}
}
]);
jsonRes(res, {
data: {
splitDataQueue: data.map((item) => item.textList).flat().length,
embeddingQueue: embeddingData
qaListLen: result[0]?.totalQaListLength || 0,
vectorListLen: result[0]?.totalVectorListLength || 0
}
});
if (init) {
const list = await TrainingData.find(
{
userId,
kbId
},
'_id'
);
list.forEach((item) => {
generateQA(item._id);
generateVector(item._id);
});
}
} catch (err) {
jsonRes(res, {
code: 500,

View File

@@ -91,9 +91,9 @@ const DataCard = ({ kbId }: { kbId: string }) => {
onClose: onCloseSelectCsvModal
} = useDisclosure();
const { data: { splitDataQueue = 0, embeddingQueue = 0 } = {}, refetch } = useQuery(
const { data: { qaListLen = 0, vectorListLen = 0 } = {}, refetch } = useQuery(
['getModelSplitDataList'],
() => getTrainingData(kbId),
() => getTrainingData({ kbId, init: false }),
{
onError(err) {
console.log(err);
@@ -113,7 +113,7 @@ const DataCard = ({ kbId }: { kbId: string }) => {
// interval get data
useQuery(['refetchData'], () => refetchData(pageNum), {
refetchInterval: 5000,
enabled: splitDataQueue > 0 || embeddingQueue > 0
enabled: qaListLen > 0 || vectorListLen > 0
});
// get al data and export csv
@@ -161,7 +161,10 @@ const DataCard = ({ kbId }: { kbId: string }) => {
variant={'outline'}
mr={[2, 4]}
size={'sm'}
onClick={() => refetchData(pageNum)}
onClick={() => {
refetchData(pageNum);
getTrainingData({ kbId, init: true });
}}
/>
<Button
variant={'outline'}
@@ -194,10 +197,10 @@ const DataCard = ({ kbId }: { kbId: string }) => {
</Menu>
</Flex>
<Flex mt={4}>
{(splitDataQueue > 0 || embeddingQueue > 0) && (
{(qaListLen > 0 || vectorListLen > 0) && (
<Box fontSize={'xs'}>
{splitDataQueue > 0 ? `${splitDataQueue}条数据正在拆分,` : ''}
{embeddingQueue > 0 ? `${embeddingQueue}条数据正在生成索引,` : ''}
{qaListLen > 0 ? `${qaListLen}条数据正在拆分,` : ''}
{vectorListLen > 0 ? `${vectorListLen}条数据正在生成索引,` : ''}
...
</Box>
)}

View File

@@ -20,7 +20,8 @@ import { useMutation } from '@tanstack/react-query';
import { postSplitData } from '@/api/plugins/kb';
import Radio from '@/components/Radio';
import { splitText_token } from '@/utils/file';
import { SplitTextTypEnum } from '@/constants/plugin';
import { TrainingTypeEnum } from '@/constants/plugin';
import { getErrText } from '@/utils/tools';
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
@@ -52,7 +53,7 @@ const SelectFileModal = ({
const { toast } = useToast();
const [prompt, setPrompt] = useState('');
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [mode, setMode] = useState<`${SplitTextTypEnum}`>(SplitTextTypEnum.subsection);
const [mode, setMode] = useState<`${TrainingTypeEnum}`>(TrainingTypeEnum.subsection);
const [fileTextArr, setFileTextArr] = useState<string[]>(['']);
const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({
tokens: 0,
@@ -113,8 +114,9 @@ const SelectFileModal = ({
prompt: `下面是"${prompt || '一段长文本'}"`,
mode
});
toast({
title: '导入数据成功,需要一段拆解和训练',
title: '导入数据成功,需要一段拆解和训练. 重复数据会自动删除',
status: 'success'
});
onClose();
@@ -130,27 +132,35 @@ const SelectFileModal = ({
const onclickImport = useCallback(async () => {
setBtnLoading(true);
let promise = Promise.resolve();
try {
let promise = Promise.resolve();
const splitRes = fileTextArr
.filter((item) => item)
.map((item) =>
splitText_token({
text: item,
...modeMap[mode]
})
const splitRes = await Promise.all(
fileTextArr
.filter((item) => item)
.map((item) =>
splitText_token({
text: item,
...modeMap[mode]
})
)
);
setSplitRes({
tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0),
chunks: splitRes.map((item) => item.chunks).flat()
});
setSplitRes({
tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0),
chunks: splitRes.map((item) => item.chunks).flat()
});
await promise;
openConfirm(mutate)();
} catch (error) {
toast({
status: 'warning',
title: getErrText(error, '拆分文本异常')
});
}
setBtnLoading(false);
await promise;
openConfirm(mutate)();
}, [fileTextArr, mode, mutate, openConfirm]);
}, [fileTextArr, mode, mutate, openConfirm, toast]);
return (
<Modal isOpen={true} onClose={onClose} isCentered>