From 98c458dcf8c5c54496ce9ac3b564e0d2f0e0a6a4 Mon Sep 17 00:00:00 2001
From: archer <545436317@qq.com>
Date: Sun, 26 Mar 2023 13:56:00 +0800
Subject: [PATCH] =?UTF-8?q?fix:=20=E8=AE=AD=E7=BB=83=E5=90=8E=E6=A8=A1?=
=?UTF-8?q?=E5=9E=8B=E6=B2=A1=E9=80=89=E4=B8=AD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/api/response/chat.d.ts | 3 ++-
src/pages/api/chat/gpt3.ts | 11 +++--------
src/pages/api/chat/init.ts | 1 +
src/pages/api/model/getTrainings.ts | 16 ++++------------
src/pages/api/model/putTrainStatus.ts | 9 ++++++---
src/pages/api/model/train.ts | 3 ++-
src/pages/chat/index.tsx | 8 +++++---
src/pages/model/detail.tsx | 14 +++++++++++---
src/service/models/model.ts | 2 +-
9 files changed, 35 insertions(+), 32 deletions(-)
diff --git a/src/api/response/chat.d.ts b/src/api/response/chat.d.ts
index fe802c5ee..3b41658c0 100644
--- a/src/api/response/chat.d.ts
+++ b/src/api/response/chat.d.ts
@@ -8,7 +8,8 @@ export type InitChatResponse = {
avatar: string;
intro: string;
secret: ModelSchema.secret;
- chatModel: ModelSchema.service.ChatModel; // 模型名
+ chatModel: ModelSchema.service.chatModel; // 对话模型名
+ modelName: ModelSchema.service.modelName; // 底层模型
history: ChatItemType[];
isExpiredTime: boolean;
};
diff --git a/src/pages/api/chat/gpt3.ts b/src/pages/api/chat/gpt3.ts
index 286024368..03786f4a8 100644
--- a/src/pages/api/chat/gpt3.ts
+++ b/src/pages/api/chat/gpt3.ts
@@ -51,11 +51,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompts.length > maxContext ? prompts.slice(prompts.length - maxContext) : prompts;
// 格式化文本内容
- const map = {
- Human: 'Human',
- AI: 'AI',
- SYSTEM: 'SYSTEM'
- };
const formatPrompts: string[] = filterPrompts.map((item: ChatItemType) => item.value);
// 如果有系统提示词,自动插入
if (model.systemPrompt) {
@@ -85,7 +80,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
max_tokens: modelConstantsData.maxToken,
presence_penalty: 0, // 越大,越容易出现新内容
frequency_penalty: 0, // 越大,重复内容越少
- stop: ['。!?.!.', ``]
+ stop: [``, '。!?.!.']
},
{
timeout: 40000,
@@ -113,10 +108,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].text || '';
+ console.log('content:', content);
if (!content || (responseContent === '' && content === '\n')) return;
responseContent += content;
- // console.log('content:', content);
!stream.destroyed && stream.push(content.replace(/\n/g, '
'));
} catch (error) {
error;
@@ -143,7 +138,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 只有使用平台的 key 才计费
!userApiKey &&
pushChatBill({
- modelName: model.service.modelName,
+ modelName: model.service.chatModel,
userId,
chatId,
text: promptText + responseContent
diff --git a/src/pages/api/chat/init.ts b/src/pages/api/chat/init.ts
index 7378a111d..cdacccd83 100644
--- a/src/pages/api/chat/init.ts
+++ b/src/pages/api/chat/init.ts
@@ -52,6 +52,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
avatar: model.avatar,
intro: model.intro,
secret: model.security,
+ modelName: model.service.modelName,
chatModel: model.service.chatModel,
history: chat.content
}
diff --git a/src/pages/api/model/getTrainings.ts b/src/pages/api/model/getTrainings.ts
index bc7c54c98..f50c0b747 100644
--- a/src/pages/api/model/getTrainings.ts
+++ b/src/pages/api/model/getTrainings.ts
@@ -1,15 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
-import { connectToDatabase, Model, Training } from '@/service/mongo';
-import { getOpenAIApi } from '@/service/utils/chat';
-import formidable from 'formidable';
-import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
-import { join } from 'path';
-import fs from 'fs';
-import type { ModelSchema } from '@/types/mongoSchema';
-import type { OpenAIApi } from 'openai';
-import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model';
-import { httpsAgent } from '@/service/utils/tools';
+import { connectToDatabase, Training } from '@/service/mongo';
+import { authToken } from '@/service/utils/tools';
// 关闭next默认的bodyParser处理方式
export const config = {
@@ -18,7 +10,7 @@ export const config = {
}
};
-/* 上传文件,开始微调 */
+/* 获取模型训练记录 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { authorization } = req.headers;
@@ -30,7 +22,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
if (!modelId) {
throw new Error('参数错误');
}
- const userId = await authToken(authorization);
+ await authToken(authorization);
await connectToDatabase();
diff --git a/src/pages/api/model/putTrainStatus.ts b/src/pages/api/model/putTrainStatus.ts
index 3ce4ba2c8..dedfebfff 100644
--- a/src/pages/api/model/putTrainStatus.ts
+++ b/src/pages/api/model/putTrainStatus.ts
@@ -52,7 +52,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 删除训练文件
openai.deleteFile(data.training_files[0].id, { httpsAgent });
- // 更新模型
+ // 更新模型状态和模型内容
await Model.findByIdAndUpdate(modelId, {
status: ModelStatusEnum.running,
updateTime: new Date(),
@@ -72,6 +72,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
});
}
+ /* 取消微调 */
if (data.status === OpenAiTuneStatusEnum.cancelled) {
// 删除训练文件
openai.deleteFile(data.training_files[0].id, { httpsAgent });
@@ -87,11 +88,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
});
return jsonRes(res, {
- data: '模型微调取消'
+ data: '模型微调已取消'
});
}
- throw new Error('模型还在训练中');
+ jsonRes(res, {
+ data: '模型还在训练中'
+ });
} catch (err: any) {
jsonRes(res, {
code: 500,
diff --git a/src/pages/api/model/train.ts b/src/pages/api/model/train.ts
index 95bc033ec..8bebb65c9 100644
--- a/src/pages/api/model/train.ts
+++ b/src/pages/api/model/train.ts
@@ -30,6 +30,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('无权操作');
}
const { modelId } = req.query;
+
if (!modelId) {
throw new Error('参数错误');
}
@@ -67,7 +68,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
});
const file = files.file;
- // 上传文件
+ // 上传文件到 openai
// @ts-ignore
const uploadRes = await openai.createFile(
// @ts-ignore
diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx
index e5d8098d7..1037e14aa 100644
--- a/src/pages/chat/index.tsx
+++ b/src/pages/chat/index.tsx
@@ -62,6 +62,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
intro: '',
secret: {},
chatModel: '',
+ modelName: '',
history: [],
isExpiredTime: false
}); // 聊天框整体数据
@@ -156,7 +157,8 @@ const Chat = ({ chatId }: { chatId: string }) => {
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
};
- if (!urlMap[chatData.chatModel]) return Promise.reject('找不到模型');
+
+ if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型');
const prompt = {
obj: prompts.obj,
@@ -164,7 +166,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
};
// 流请求,获取数据
const res = await streamFetch({
- url: urlMap[chatData.chatModel],
+ url: urlMap[chatData.modelName],
data: {
prompt,
chatId
@@ -217,7 +219,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
})
}));
},
- [chatData.chatModel, chatId, toast]
+ [chatData.modelName, chatId, toast]
);
/**
diff --git a/src/pages/model/detail.tsx b/src/pages/model/detail.tsx
index 7c1976f08..50c0e44e7 100644
--- a/src/pages/model/detail.tsx
+++ b/src/pages/model/detail.tsx
@@ -108,9 +108,9 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
// 重新获取模型
loadModel();
- } catch (err) {
+ } catch (err: any) {
toast({
- title: typeof err === 'string' ? err : '文件格式错误',
+ title: err?.message || '上传文件失败',
status: 'error'
});
console.log('error->', err);
@@ -126,7 +126,12 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
setLoading(true);
try {
- await putModelTrainingStatus(model._id);
+ const res = await putModelTrainingStatus(model._id);
+ typeof res === 'string' &&
+ toast({
+ title: res,
+ status: 'info'
+ });
loadModel();
} catch (error: any) {
console.log('error->', error);
@@ -284,6 +289,9 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
{/* 提示 */}
+
+ 暂时需要使用自己的openai key
+
可以使用