doc gpt V0.2
This commit is contained in:
75
src/pages/api/model/create.ts
Normal file
75
src/pages/api/model/create.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { ModelStatusEnum, OpenAiList } from '@/constants/model';
|
||||
import { Model } from '@/service/models/model';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { name, serviceModelName, serviceModelCompany = 'openai' } = req.body;
|
||||
const { authorization } = req.headers;
|
||||
|
||||
if (!authorization) {
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
if (!name || !serviceModelName || !serviceModelCompany) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
// 凭证校验
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
const modelItem = OpenAiList.find((item) => item.model === serviceModelName);
|
||||
|
||||
if (!modelItem) {
|
||||
throw new Error('模型错误');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 重名校验
|
||||
const authRepeatName = await Model.findOne({
|
||||
name,
|
||||
userId
|
||||
});
|
||||
if (authRepeatName) {
|
||||
throw new Error('模型名重复');
|
||||
}
|
||||
|
||||
// 上限校验
|
||||
const authCount = await Model.countDocuments({
|
||||
userId
|
||||
});
|
||||
if (authCount >= 5) {
|
||||
throw new Error('上限5个模型');
|
||||
}
|
||||
|
||||
// 创建模型
|
||||
const response = await Model.create({
|
||||
name,
|
||||
userId,
|
||||
status: ModelStatusEnum.running,
|
||||
service: {
|
||||
company: serviceModelCompany,
|
||||
trainId: modelItem.trainName,
|
||||
chatModel: modelItem.model,
|
||||
modelName: modelItem.model
|
||||
}
|
||||
});
|
||||
|
||||
// 根据 id 获取模型信息
|
||||
const model = await Model.findById(response._id);
|
||||
|
||||
jsonRes(res, {
|
||||
data: model
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
70
src/pages/api/model/del.ts
Normal file
70
src/pages/api/model/del.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { Chat, Model, Training, connectToDatabase } from '@/service/mongo';
|
||||
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
|
||||
import { TrainingStatusEnum } from '@/constants/model';
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
import { TrainingItemType } from '@/types/training';
|
||||
import { openaiProxy } from '@/service/utils/tools';
|
||||
|
||||
/* 获取我的模型 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { modelId } = req.query;
|
||||
const { authorization } = req.headers;
|
||||
|
||||
if (!authorization) {
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
if (!modelId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
|
||||
// 凭证校验
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 删除模型
|
||||
await Model.deleteOne({
|
||||
_id: modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
// 删除对应的聊天
|
||||
await Chat.deleteMany({
|
||||
modelId
|
||||
});
|
||||
|
||||
// 查看是否正在训练
|
||||
const training: TrainingItemType | null = await Training.findOne({
|
||||
modelId,
|
||||
status: TrainingStatusEnum.pending
|
||||
});
|
||||
|
||||
// 如果正在训练,需要删除openai上的相关信息
|
||||
if (training) {
|
||||
const openai = getOpenAIApi(await getUserOpenaiKey(userId));
|
||||
// 获取训练记录
|
||||
const tuneRecord = await openai.retrieveFineTune(training.tuneId, openaiProxy);
|
||||
|
||||
// 删除训练文件
|
||||
openai.deleteFile(tuneRecord.data.training_files[0].id, openaiProxy);
|
||||
// 取消训练
|
||||
openai.cancelFineTune(training.tuneId, openaiProxy);
|
||||
}
|
||||
|
||||
// 删除对应训练记录
|
||||
await Training.deleteMany({
|
||||
modelId
|
||||
});
|
||||
|
||||
jsonRes(res);
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
47
src/pages/api/model/detail.tsx
Normal file
47
src/pages/api/model/detail.tsx
Normal file
@@ -0,0 +1,47 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { Model } from '@/service/models/model';
|
||||
import { ModelType } from '@/types/model';
|
||||
|
||||
/* 获取我的模型 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { authorization } = req.headers;
|
||||
|
||||
if (!authorization) {
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
const { modelId } = req.query;
|
||||
|
||||
if (!modelId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
|
||||
// 凭证校验
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 根据 userId 获取模型信息
|
||||
const model: ModelType | null = await Model.findOne({
|
||||
userId,
|
||||
_id: modelId
|
||||
});
|
||||
|
||||
if (!model) {
|
||||
throw new Error('模型不存在');
|
||||
}
|
||||
|
||||
jsonRes(res, {
|
||||
data: model
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
60
src/pages/api/model/getTrainings.ts
Normal file
60
src/pages/api/model/getTrainings.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, Model, Training } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
import formidable from 'formidable';
|
||||
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
|
||||
import { join } from 'path';
|
||||
import fs from 'fs';
|
||||
import type { ModelType } from '@/types/model';
|
||||
import type { OpenAIApi } from 'openai';
|
||||
import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model';
|
||||
import { openaiProxy } from '@/service/utils/tools';
|
||||
|
||||
// 关闭next默认的bodyParser处理方式
|
||||
export const config = {
|
||||
api: {
|
||||
bodyParser: false
|
||||
}
|
||||
};
|
||||
|
||||
/* 上传文件,开始微调 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
try {
|
||||
const { authorization } = req.headers;
|
||||
|
||||
if (!authorization) {
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
const { modelId } = req.query;
|
||||
if (!modelId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
/* 获取 modelId 下的 training 记录 */
|
||||
const records = await Training.find({
|
||||
modelId
|
||||
});
|
||||
|
||||
jsonRes(res, {
|
||||
data: records
|
||||
});
|
||||
} catch (err: any) {
|
||||
/* 清除上传的文件,关闭训练记录 */
|
||||
// @ts-ignore
|
||||
if (openai) {
|
||||
// @ts-ignore
|
||||
uploadFileId && openai.deleteFile(uploadFileId);
|
||||
// @ts-ignore
|
||||
trainId && openai.cancelFineTune(trainId);
|
||||
}
|
||||
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
35
src/pages/api/model/list.ts
Normal file
35
src/pages/api/model/list.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { Model } from '@/service/models/model';
|
||||
|
||||
/* 获取我的模型 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { authorization } = req.headers;
|
||||
|
||||
if (!authorization) {
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
// 凭证校验
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 根据 userId 获取模型信息
|
||||
const models = await Model.find({
|
||||
userId
|
||||
});
|
||||
|
||||
jsonRes(res, {
|
||||
data: models
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
101
src/pages/api/model/putTrainStatus.ts
Normal file
101
src/pages/api/model/putTrainStatus.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, Model, Training } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
|
||||
import type { ModelType } from '@/types/model';
|
||||
import { TrainingItemType } from '@/types/training';
|
||||
import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model';
|
||||
import { OpenAiTuneStatusEnum } from '@/service/constants/training';
|
||||
import { openaiProxy } from '@/service/utils/tools';
|
||||
|
||||
/* 更新训练状态 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
try {
|
||||
const { authorization } = req.headers;
|
||||
|
||||
if (!authorization) {
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
const { modelId } = req.query as { modelId: string };
|
||||
if (!modelId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 获取模型
|
||||
const model: ModelType | null = await Model.findById(modelId);
|
||||
|
||||
if (!model || model.status !== 'training') {
|
||||
throw new Error('模型不在训练中');
|
||||
}
|
||||
|
||||
// 查询正在训练中的训练记录
|
||||
const training: TrainingItemType | null = await Training.findOne({
|
||||
modelId,
|
||||
status: 'pending'
|
||||
});
|
||||
|
||||
if (!training) {
|
||||
throw new Error('找不到训练记录');
|
||||
}
|
||||
|
||||
// 用户的 openai 实例
|
||||
const openai = getOpenAIApi(await getUserOpenaiKey(userId));
|
||||
|
||||
// 获取 openai 的训练情况
|
||||
const { data } = await openai.retrieveFineTune(training.tuneId, openaiProxy);
|
||||
|
||||
if (data.status === OpenAiTuneStatusEnum.succeeded) {
|
||||
// 删除训练文件
|
||||
openai.deleteFile(data.training_files[0].id, openaiProxy);
|
||||
|
||||
// 更新模型
|
||||
await Model.findByIdAndUpdate(modelId, {
|
||||
status: ModelStatusEnum.running,
|
||||
updateTime: new Date(),
|
||||
service: {
|
||||
...model.service,
|
||||
trainId: data.fine_tuned_model, // 训练完后,再次训练和对话使用的 model 是一样的
|
||||
chatModel: data.fine_tuned_model
|
||||
}
|
||||
});
|
||||
// 更新训练数据
|
||||
await Training.findByIdAndUpdate(training._id, {
|
||||
status: TrainingStatusEnum.succeed
|
||||
});
|
||||
|
||||
return jsonRes(res, {
|
||||
data: '模型微调完成'
|
||||
});
|
||||
}
|
||||
|
||||
if (data.status === OpenAiTuneStatusEnum.cancelled) {
|
||||
// 删除训练文件
|
||||
openai.deleteFile(data.training_files[0].id, openaiProxy);
|
||||
|
||||
// 更新模型
|
||||
await Model.findByIdAndUpdate(modelId, {
|
||||
status: ModelStatusEnum.running,
|
||||
updateTime: new Date()
|
||||
});
|
||||
// 更新训练数据
|
||||
await Training.findByIdAndUpdate(training._id, {
|
||||
status: TrainingStatusEnum.canceled
|
||||
});
|
||||
|
||||
return jsonRes(res, {
|
||||
data: '模型微调取消'
|
||||
});
|
||||
}
|
||||
|
||||
throw new Error('模型还在训练中');
|
||||
} catch (err: any) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
127
src/pages/api/model/train.ts
Normal file
127
src/pages/api/model/train.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, Model, Training } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
import formidable from 'formidable';
|
||||
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
|
||||
import { join } from 'path';
|
||||
import fs from 'fs';
|
||||
import type { ModelType } from '@/types/model';
|
||||
import type { OpenAIApi } from 'openai';
|
||||
import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model';
|
||||
import { openaiProxy } from '@/service/utils/tools';
|
||||
|
||||
// 关闭next默认的bodyParser处理方式
|
||||
export const config = {
|
||||
api: {
|
||||
bodyParser: false
|
||||
}
|
||||
};
|
||||
|
||||
/* 上传文件,开始微调 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
let openai: OpenAIApi, trainId: string, uploadFileId: string;
|
||||
|
||||
try {
|
||||
const { authorization } = req.headers;
|
||||
|
||||
if (!authorization) {
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
const { modelId } = req.query;
|
||||
if (!modelId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 获取模型的状态
|
||||
const model: ModelType | null = await Model.findById(modelId);
|
||||
|
||||
if (!model || model.status !== 'running') {
|
||||
throw new Error('模型正忙');
|
||||
}
|
||||
|
||||
// const trainingType = model.service.modelType
|
||||
const trainingType = model.service.trainId; // 目前都默认是 openai text-davinci-03
|
||||
|
||||
// 获取用户的 API Key 实例化后的对象
|
||||
openai = getOpenAIApi(await getUserOpenaiKey(userId));
|
||||
|
||||
// 接收文件并保存
|
||||
const form = formidable({
|
||||
uploadDir: join(process.cwd(), 'public/trainData'),
|
||||
keepExtensions: true
|
||||
});
|
||||
|
||||
const { files } = await new Promise<{
|
||||
fields: formidable.Fields;
|
||||
files: formidable.Files;
|
||||
}>((resolve, reject) => {
|
||||
form.parse(req, (err, fields, files) => {
|
||||
if (err) return reject(err);
|
||||
resolve({ fields, files });
|
||||
});
|
||||
});
|
||||
const file = files.file;
|
||||
|
||||
// 上传文件
|
||||
// @ts-ignore
|
||||
const uploadRes = await openai.createFile(
|
||||
// @ts-ignore
|
||||
fs.createReadStream(file.filepath),
|
||||
'fine-tune',
|
||||
openaiProxy
|
||||
);
|
||||
uploadFileId = uploadRes.data.id; // 记录上传文件的 ID
|
||||
|
||||
// 开始训练
|
||||
const trainRes = await openai.createFineTune(
|
||||
{
|
||||
training_file: uploadFileId,
|
||||
model: trainingType,
|
||||
suffix: model.name
|
||||
},
|
||||
openaiProxy
|
||||
);
|
||||
|
||||
trainId = trainRes.data.id; // 记录训练 ID
|
||||
|
||||
// 创建训练记录
|
||||
await Training.create({
|
||||
serviceName: 'openai',
|
||||
tuneId: trainId,
|
||||
status: TrainingStatusEnum.pending,
|
||||
modelId
|
||||
});
|
||||
|
||||
// 修改模型状态
|
||||
await Model.findByIdAndUpdate(modelId, {
|
||||
$inc: {
|
||||
trainingTimes: +1
|
||||
},
|
||||
updateTime: new Date(),
|
||||
status: ModelStatusEnum.training
|
||||
});
|
||||
|
||||
jsonRes(res, {
|
||||
data: 'start training'
|
||||
});
|
||||
} catch (err: any) {
|
||||
/* 清除上传的文件,关闭训练记录 */
|
||||
// @ts-ignore
|
||||
if (openai) {
|
||||
// @ts-ignore
|
||||
uploadFileId && openai.deleteFile(uploadFileId, openaiProxy);
|
||||
// @ts-ignore
|
||||
trainId && openai.cancelFineTune(trainId, openaiProxy);
|
||||
}
|
||||
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
49
src/pages/api/model/update.ts
Normal file
49
src/pages/api/model/update.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { Model } from '@/service/models/model';
|
||||
import type { ModelUpdateParams } from '@/types/model';
|
||||
|
||||
/* 获取我的模型 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { name, service, security, systemPrompt } = req.body as ModelUpdateParams;
|
||||
const { modelId } = req.query as { modelId: string };
|
||||
const { authorization } = req.headers;
|
||||
|
||||
if (!authorization) {
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
if (!name || !service || !security || !modelId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
|
||||
// 凭证校验
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 更新模型
|
||||
await Model.updateOne(
|
||||
{
|
||||
_id: modelId,
|
||||
userId
|
||||
},
|
||||
{
|
||||
name,
|
||||
service,
|
||||
systemPrompt,
|
||||
security
|
||||
}
|
||||
);
|
||||
|
||||
jsonRes(res);
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user