feat: config vector model and qa model

This commit is contained in:
archer
2023-08-25 15:00:51 +08:00
parent a9970dd694
commit 6d93059e25
35 changed files with 337 additions and 196 deletions

View File

@@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { connectToDatabase, TrainingData, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { authKb } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools';
@@ -14,7 +14,6 @@ export type DateItemType = { a: string; q: string; source?: string };
export type Props = {
kbId: string;
data: DateItemType[];
model: string;
mode: `${TrainingModeEnum}`;
prompt?: string;
};
@@ -30,23 +29,12 @@ const modeMaxToken = {
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, data, mode, prompt, model } = req.body as Props;
const { kbId, data, mode, prompt } = req.body as Props;
if (!kbId || !Array.isArray(data) || !model) {
if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数');
}
// auth model
if (mode === TrainingModeEnum.qa && !global.qaModels.find((item) => item.model === model)) {
throw new Error('不支持的 QA 拆分模型');
}
if (
mode === TrainingModeEnum.index &&
!global.vectorModels.find((item) => item.model === model)
) {
throw new Error('不支持的向量生成模型');
}
await connectToDatabase();
// 凭证校验
@@ -58,8 +46,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
data,
userId,
mode,
prompt,
model
prompt
})
});
} catch (err) {
@@ -75,8 +62,7 @@ export async function pushDataToKb({
kbId,
data,
mode,
prompt,
model
prompt
}: { userId: string } & Props): Promise<Response> {
await authKb({
userId,
@@ -152,17 +138,24 @@ export async function pushDataToKb({
.filter((item) => item.status === 'fulfilled')
.map<DateItemType>((item: any) => item.value);
const vectorModel = await (async () => {
if (mode === TrainingModeEnum.index) {
return (await KB.findById(kbId, 'vectorModel'))?.vectorModel || global.vectorModels[0].model;
}
return global.vectorModels[0].model;
})();
// 插入记录
await TrainingData.insertMany(
insertData.map((item) => ({
q: item.q,
a: item.a,
model,
source: item.source,
userId,
kbId,
mode,
prompt
prompt,
vectorModel
}))
);