feat: 模型数据管理

feat: 模型数据导入

feat: redis 向量入库

feat: 向量索引

feat: 文件导入模型

perf: 交互

perf: prompt
This commit is contained in:
archer
2023-03-29 00:22:48 +08:00
parent 713332522f
commit 2099a87908
45 changed files with 1522 additions and 284 deletions

View File

@@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelStatusEnum, modelList, ChatModelNameEnum } from '@/constants/model';
import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model';
import { Model } from '@/service/models/model';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
@@ -33,15 +33,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
// 重名校验
const authRepeatName = await Model.findOne({
name,
userId
});
if (authRepeatName) {
throw new Error('模型名重复');
}
// 上限校验
const authCount = await Model.countDocuments({
userId
@@ -57,9 +48,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
status: ModelStatusEnum.running,
service: {
company: modelItem.serviceCompany,
trainId: modelItem.trainName,
chatModel: modelItem.model,
modelName: modelItem.model
trainId: '',
chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
}
});

View File

@@ -5,8 +5,8 @@ import { authToken } from '@/service/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
let { modelId } = req.query as {
modelId: string;
let { dataId } = req.query as {
dataId: string;
};
const { authorization } = req.headers;
@@ -14,7 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作');
}
if (!modelId) {
if (!dataId) {
throw new Error('缺少参数');
}
@@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
await ModelData.deleteOne({
modelId,
_id: dataId,
userId
});

View File

@@ -14,6 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
pageNum: string;
pageSize: string;
};
const { authorization } = req.headers;
pageNum = +pageNum;
@@ -41,7 +42,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
.limit(pageSize);
jsonRes(res, {
data
data: {
pageNum,
pageSize,
data,
total: await ModelData.countDocuments({
modelId,
userId
})
}
});
} catch (err) {
jsonRes(res, {

View File

@@ -2,12 +2,14 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelDataSchema } from '@/types/mongoSchema';
import { generateVector } from '@/service/events/generateVector';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { modelId, data } = req.body as {
modelId: string;
data: { q: string; a: string }[];
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
};
const { authorization } = req.headers;
@@ -43,6 +45,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}))
);
generateVector(true);
jsonRes(res, {
data: model
});

View File

@@ -0,0 +1,57 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, DataItem, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string };
if (!dataIds) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
const dataItems = (
await Promise.all(
dataIds.map((dataId) =>
DataItem.find<{ _id: string; result: { q: string }[]; text: string }>(
{
userId,
dataId
},
'result text'
)
)
)
).flat();
// push data
await ModelData.insertMany(
dataItems.map((item) => ({
modelId: modelId,
userId,
text: item.text,
q: item.result.map((item) => ({
id: nanoid(),
text: item.q
}))
}))
);
jsonRes(res, {
data: dataItems
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -5,9 +5,9 @@ import { authToken } from '@/service/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
let { modelId, answer } = req.body as {
modelId: string;
answer: string;
let { dataId, text } = req.body as {
dataId: string;
text: string;
};
const { authorization } = req.headers;
@@ -15,7 +15,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作');
}
if (!modelId) {
if (!dataId) {
throw new Error('缺少参数');
}
@@ -26,11 +26,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await ModelData.updateOne(
{
modelId,
_id: dataId,
userId
},
{
a: answer
text
}
);

View File

@@ -0,0 +1,67 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateQA } from '@/service/events/generateQA';
import { encode } from 'gpt-token-utils';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { text, modelId } = req.body as { text: string; modelId: string };
if (!text || !modelId) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
const replaceText = text.replace(/(\\n|\n)+/g, ' ');
// 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
const textList: string[] = [];
let splitText = '';
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
if (tokens >= 980) {
textList.push(splitText);
splitText = '';
}
});
// 批量插入数据
await SplitData.create({
userId,
modelId,
rawText: text,
textList
});
// generateQA();
jsonRes(res, {
data: { chunks, replaceText }
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { Chat, Model, Training, connectToDatabase } from '@/service/mongo';
import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo';
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
import { TrainingStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat';
@@ -26,16 +26,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
// 删除模型
await Model.deleteOne({
_id: modelId,
userId
});
let requestQueue: any[] = [];
// 删除对应的聊天
await Chat.deleteMany({
modelId
});
requestQueue.push(
Chat.deleteMany({
modelId
})
);
// 删除数据集
requestQueue.push(
ModelData.deleteMany({
modelId
})
);
// 查看是否正在训练
const training: TrainingItemType | null = await Training.findOne({
@@ -56,9 +60,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}
// 删除对应训练记录
await Training.deleteMany({
modelId
});
requestQueue.push(
Training.deleteMany({
modelId
})
);
// 删除模型
requestQueue.push(
Model.deleteOne({
_id: modelId,
userId
})
);
await requestQueue;
jsonRes(res);
} catch (err) {

View File

@@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
systemPrompt,
intro,
temperature,
service,
// service,
security
}
);