diff --git a/src/api/response/chat.d.ts b/src/api/response/chat.d.ts index fe802c5ee..3b41658c0 100644 --- a/src/api/response/chat.d.ts +++ b/src/api/response/chat.d.ts @@ -8,7 +8,8 @@ export type InitChatResponse = { avatar: string; intro: string; secret: ModelSchema.secret; - chatModel: ModelSchema.service.ChatModel; // 模型名 + chatModel: ModelSchema.service.chatModel; // 对话模型名 + modelName: ModelSchema.service.modelName; // 底层模型 history: ChatItemType[]; isExpiredTime: boolean; }; diff --git a/src/pages/api/chat/gpt3.ts b/src/pages/api/chat/gpt3.ts index 286024368..03786f4a8 100644 --- a/src/pages/api/chat/gpt3.ts +++ b/src/pages/api/chat/gpt3.ts @@ -51,11 +51,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) prompts.length > maxContext ? prompts.slice(prompts.length - maxContext) : prompts; // 格式化文本内容 - const map = { - Human: 'Human', - AI: 'AI', - SYSTEM: 'SYSTEM' - }; const formatPrompts: string[] = filterPrompts.map((item: ChatItemType) => item.value); // 如果有系统提示词,自动插入 if (model.systemPrompt) { @@ -85,7 +80,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) max_tokens: modelConstantsData.maxToken, presence_penalty: 0, // 越大,越容易出现新内容 frequency_penalty: 0, // 越大,重复内容越少 - stop: ['。!?.!.', ``] + stop: [``, '。!?.!.'] }, { timeout: 40000, @@ -113,10 +108,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) try { const json = JSON.parse(data); const content: string = json?.choices?.[0].text || ''; + console.log('content:', content); if (!content || (responseContent === '' && content === '\n')) return; responseContent += content; - // console.log('content:', content); !stream.destroyed && stream.push(content.replace(/\n/g, '
')); } catch (error) { error; @@ -143,7 +138,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 只有使用平台的 key 才计费 !userApiKey && pushChatBill({ - modelName: model.service.modelName, + modelName: model.service.chatModel, userId, chatId, text: promptText + responseContent diff --git a/src/pages/api/chat/init.ts b/src/pages/api/chat/init.ts index 7378a111d..cdacccd83 100644 --- a/src/pages/api/chat/init.ts +++ b/src/pages/api/chat/init.ts @@ -52,6 +52,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) avatar: model.avatar, intro: model.intro, secret: model.security, + modelName: model.service.modelName, chatModel: model.service.chatModel, history: chat.content } diff --git a/src/pages/api/model/getTrainings.ts b/src/pages/api/model/getTrainings.ts index bc7c54c98..f50c0b747 100644 --- a/src/pages/api/model/getTrainings.ts +++ b/src/pages/api/model/getTrainings.ts @@ -1,15 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, Model, Training } from '@/service/mongo'; -import { getOpenAIApi } from '@/service/utils/chat'; -import formidable from 'formidable'; -import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; -import { join } from 'path'; -import fs from 'fs'; -import type { ModelSchema } from '@/types/mongoSchema'; -import type { OpenAIApi } from 'openai'; -import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model'; -import { httpsAgent } from '@/service/utils/tools'; +import { connectToDatabase, Training } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; // 关闭next默认的bodyParser处理方式 export const config = { @@ -18,7 +10,7 @@ export const config = { } }; -/* 上传文件,开始微调 */ +/* 获取模型训练记录 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { const { authorization } = req.headers; @@ -30,7 +22,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) if (!modelId) { throw new Error('参数错误'); } - const userId = await authToken(authorization); + await authToken(authorization); await connectToDatabase(); diff --git a/src/pages/api/model/putTrainStatus.ts b/src/pages/api/model/putTrainStatus.ts index 3ce4ba2c8..dedfebfff 100644 --- a/src/pages/api/model/putTrainStatus.ts +++ b/src/pages/api/model/putTrainStatus.ts @@ -52,7 +52,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 删除训练文件 openai.deleteFile(data.training_files[0].id, { httpsAgent }); - // 更新模型 + // 更新模型状态和模型内容 await Model.findByIdAndUpdate(modelId, { status: ModelStatusEnum.running, updateTime: new Date(), @@ -72,6 +72,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) }); } + /* 取消微调 */ if (data.status === OpenAiTuneStatusEnum.cancelled) { // 删除训练文件 openai.deleteFile(data.training_files[0].id, { httpsAgent }); @@ -87,11 +88,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) }); return jsonRes(res, { - data: '模型微调取消' + data: '模型微调已取消' }); } - throw new Error('模型还在训练中'); + jsonRes(res, { + data: '模型还在训练中' + }); } catch (err: any) { jsonRes(res, { code: 500, diff --git a/src/pages/api/model/train.ts b/src/pages/api/model/train.ts index 95bc033ec..8bebb65c9 100644 --- a/src/pages/api/model/train.ts +++ b/src/pages/api/model/train.ts @@ -30,6 +30,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) throw new Error('无权操作'); } const { modelId } = req.query; + if (!modelId) { throw new Error('参数错误'); } @@ -67,7 +68,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) }); const file = files.file; - // 上传文件 + // 上传文件到 openai // @ts-ignore const uploadRes = await openai.createFile( // @ts-ignore diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index e5d8098d7..1037e14aa 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -62,6 +62,7 @@ const Chat = ({ chatId }: { chatId: string }) => { intro: '', secret: {}, chatModel: '', + modelName: '', history: [], isExpiredTime: false }); // 聊天框整体数据 @@ -156,7 +157,8 @@ const Chat = ({ chatId }: { chatId: string }) => { [ChatModelNameEnum.GPT35]: '/api/chat/chatGpt', [ChatModelNameEnum.GPT3]: '/api/chat/gpt3' }; - if (!urlMap[chatData.chatModel]) return Promise.reject('找不到模型'); + + if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型'); const prompt = { obj: prompts.obj, @@ -164,7 +166,7 @@ const Chat = ({ chatId }: { chatId: string }) => { }; // 流请求,获取数据 const res = await streamFetch({ - url: urlMap[chatData.chatModel], + url: urlMap[chatData.modelName], data: { prompt, chatId @@ -217,7 +219,7 @@ const Chat = ({ chatId }: { chatId: string }) => { }) })); }, - [chatData.chatModel, chatId, toast] + [chatData.modelName, chatId, toast] ); /** diff --git a/src/pages/model/detail.tsx b/src/pages/model/detail.tsx index 7c1976f08..50c0e44e7 100644 --- a/src/pages/model/detail.tsx +++ b/src/pages/model/detail.tsx @@ -108,9 +108,9 @@ const ModelDetail = ({ modelId }: { modelId: string }) => { // 重新获取模型 loadModel(); - } catch (err) { + } catch (err: any) { toast({ - title: typeof err === 'string' ? err : '文件格式错误', + title: err?.message || '上传文件失败', status: 'error' }); console.log('error->', err); @@ -126,7 +126,12 @@ const ModelDetail = ({ modelId }: { modelId: string }) => { setLoading(true); try { - await putModelTrainingStatus(model._id); + const res = await putModelTrainingStatus(model._id); + typeof res === 'string' && + toast({ + title: res, + status: 'info' + }); loadModel(); } catch (error: any) { console.log('error->', error); @@ -284,6 +289,9 @@ const ModelDetail = ({ modelId }: { modelId: string }) => { {/* 提示 */} + + 暂时需要使用自己的openai key + 可以使用