diff --git a/.env.template b/.env.template index 33c0a8203..0d1b43381 100644 --- a/.env.template +++ b/.env.template @@ -3,4 +3,6 @@ AXIOS_PROXY_PORT=33210 MONGODB_URI= MY_MAIL= MAILE_CODE= -TOKEN_KEY= \ No newline at end of file +TOKEN_KEY= +OPENAIKEY= +REDIS_URL= \ No newline at end of file diff --git a/src/api/model.ts b/src/api/model.ts index f20ce46ee..0d5cdc1e1 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -44,8 +44,8 @@ export const postModelDataInput = (data: { data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[]; }) => POST(`/model/data/pushModelDataInput`, data); -export const postModelDataSelect = (modelId: string, dataIds: string[]) => - POST(`/model/data/pushModelDataSelectData`, { modelId, dataIds }); +export const postModelDataFileText = (modelId: string, text: string) => + POST(`/model/data/splitData`, { modelId, text }); export const putModelDataById = (data: { dataId: string; text: string }) => PUT('/model/data/putModelData', data); diff --git a/src/components/Layout/index.tsx b/src/components/Layout/index.tsx index ee448021f..358ee25db 100644 --- a/src/components/Layout/index.tsx +++ b/src/components/Layout/index.tsx @@ -26,12 +26,12 @@ const navbarList = [ link: '/model/list', activeLink: ['/model/list', '/model/detail'] }, - { - label: '数据', - icon: 'icon-datafull', - link: '/data/list', - activeLink: ['/data/list', '/data/detail'] - }, + // { + // label: '数据', + // icon: 'icon-datafull', + // link: '/data/list', + // activeLink: ['/data/list', '/data/detail'] + // }, { label: '账号', icon: 'icon-yonghu-yuan', diff --git a/src/constants/model.ts b/src/constants/model.ts index cb7e37239..784908bcd 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -2,9 +2,16 @@ import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchem export enum ChatModelNameEnum { GPT35 = 'gpt-3.5-turbo', + VECTOR_GPT = 'VECTOR_GPT', GPT3 = 'text-davinci-003' } +export const ChatModelNameMap = { + [ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo', + [ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo', + [ChatModelNameEnum.GPT3]: 'text-davinci-003' +}; + export type ModelConstantsData = { serviceCompany: `${ServiceName}`; name: string; @@ -28,6 +35,17 @@ export const modelList: ModelConstantsData[] = [ trainedMaxToken: 2000, maxTemperature: 2, price: 3 + }, + { + serviceCompany: 'openai', + name: '知识库', + model: ChatModelNameEnum.VECTOR_GPT, + trainName: 'vector', + maxToken: 4000, + contextMaxToken: 7500, + trainedMaxToken: 2000, + maxTemperature: 1, + price: 3 } // { // serviceCompany: 'openai', diff --git a/src/constants/redis.ts b/src/constants/redis.ts index 6e8cf1b36..9b0edc618 100644 --- a/src/constants/redis.ts +++ b/src/constants/redis.ts @@ -1,2 +1 @@ -export const ModelDataIndex = 'model:data'; -export const VecModelDataIndex = 'vec:model:data'; +export const VecModelDataIndex = 'model:data'; diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index 8a877bbaf..c42f95411 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -46,7 +46,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const model: ModelSchema = chat.modelId; const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); if (!modelConstantsData) { - throw new Error('模型异常,请用 chatgpt 模型'); + throw new Error('模型加载异常'); } // 读取对话内容 diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts new file mode 100644 index 000000000..9f0a14d0f --- /dev/null +++ b/src/pages/api/chat/vectorGpt.ts @@ -0,0 +1,241 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; +import { connectToDatabase, ModelData } from '@/service/mongo'; +import { getOpenAIApi, authChat } from '@/service/utils/chat'; +import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; +import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; +import { ChatItemType } from '@/types/chat'; +import { jsonRes } from '@/service/response'; +import type { ModelSchema } from '@/types/mongoSchema'; +import { PassThrough } from 'stream'; +import { modelList } from '@/constants/model'; +import { pushChatBill } from '@/service/events/pushBill'; +import { connectRedis } from '@/service/redis'; +import { VecModelDataIndex } from '@/constants/redis'; +import { vectorToBuffer } from '@/utils/tools'; + +let vectorData = [ + -0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097, + 0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969 +].concat(new Array(1524).fill(0)); + +/* 发送提示词 */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + let step = 0; // step=1时,表示开始了流响应 + const stream = new PassThrough(); + stream.on('error', () => { + console.log('error: ', 'stream error'); + stream.destroy(); + }); + res.on('close', () => { + stream.destroy(); + }); + res.on('error', () => { + console.log('error: ', 'request error'); + stream.destroy(); + }); + + try { + const { chatId, prompt } = req.body as { + prompt: ChatItemType; + chatId: string; + }; + + const { authorization } = req.headers; + if (!chatId || !prompt) { + throw new Error('缺少参数'); + } + + await connectToDatabase(); + const redis = await connectRedis(); + + const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); + + const model: ModelSchema = chat.modelId; + const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); + if (!modelConstantsData) { + throw new Error('模型加载异常'); + } + + // 读取对话内容 + const prompts = [...chat.content, prompt]; + + // 获取 chatAPI + const chatAPI = getOpenAIApi(userApiKey || systemKey); + + // 把输入的内容转成向量 + const promptVector = await chatAPI + .createEmbedding( + { + model: 'text-embedding-ada-002', + input: prompt.value + }, + { + timeout: 120000, + httpsAgent + } + ) + .then((res) => res?.data?.data?.[0]?.embedding || []); + + const binary = vectorToBuffer(promptVector); + + // 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据 + const redisData: any[] = await redis.sendCommand([ + 'FT.SEARCH', + `idx:${VecModelDataIndex}`, + `@modelId:{${String(chat.modelId._id)}} @vector:[VECTOR_RANGE 0.2 $blob]`, + // `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`, + 'RETURN', + '1', + 'dataId', + // 'SORTBY', + // 'score', + 'PARAMS', + '2', + 'blob', + binary, + 'DIALECT', + '2' + ]); + + // 格式化响应值,获取去重后的id + let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] + .map((i) => { + if (!redisData[i] || !redisData[i][1]) return ''; + return redisData[i][1]; + }) + .filter((item) => item); + formatIds = Array.from(new Set(formatIds)); + + if (formatIds.length === 0) { + throw new Error('对不起,我没有找到你的问题'); + } + + // 从 mongo 中取出原文作为提示词 + const textArr = ( + await Promise.all( + [2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => { + if (!redisData[i] || !redisData[i][1]) return ''; + return ModelData.findById(redisData[i][1]) + .select('text') + .then((res) => res?.text || ''); + }) + ) + ).filter((item) => item); + + // textArr 筛选,最多 3000 tokens + const systemPrompt = systemPromptFilter(textArr, 2800); + + prompts.unshift({ + obj: 'SYSTEM', + value: `请根据下面的知识回答问题: ${systemPrompt}` + }); + + // 控制在 tokens 数量,防止超出 + const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); + + // 格式化文本内容成 chatgpt 格式 + const map = { + Human: ChatCompletionRequestMessageRoleEnum.User, + AI: ChatCompletionRequestMessageRoleEnum.Assistant, + SYSTEM: ChatCompletionRequestMessageRoleEnum.System + }; + const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map( + (item: ChatItemType) => ({ + role: map[item.obj], + content: item.value + }) + ); + // console.log(formatPrompts); + // 计算温度 + const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); + + let startTime = Date.now(); + // 发出请求 + const chatResponse = await chatAPI.createChatCompletion( + { + model: model.service.chatModel, + temperature: temperature, + // max_tokens: modelConstantsData.maxToken, + messages: formatPrompts, + frequency_penalty: 0.5, // 越大,重复内容越少 + presence_penalty: -0.5, // 越大,越容易出现新内容 + stream: true + }, + { + timeout: 40000, + responseType: 'stream', + httpsAgent + } + ); + + console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); + + // 创建响应流 + res.setHeader('Content-Type', 'text/event-stream;charset-utf-8'); + res.setHeader('Access-Control-Allow-Origin', '*'); + res.setHeader('X-Accel-Buffering', 'no'); + res.setHeader('Cache-Control', 'no-cache, no-transform'); + step = 1; + + let responseContent = ''; + stream.pipe(res); + + const onParse = async (event: ParsedEvent | ReconnectInterval) => { + if (event.type !== 'event') return; + const data = event.data; + if (data === '[DONE]') return; + try { + const json = JSON.parse(data); + const content: string = json?.choices?.[0].delta.content || ''; + if (!content || (responseContent === '' && content === '\n')) return; + + responseContent += content; + // console.log('content:', content) + !stream.destroyed && stream.push(content.replace(/\n/g, '
')); + } catch (error) { + error; + } + }; + + const decoder = new TextDecoder(); + try { + for await (const chunk of chatResponse.data as any) { + if (stream.destroyed) { + // 流被中断了,直接忽略后面的内容 + break; + } + const parser = createParser(onParse); + parser.feed(decoder.decode(chunk)); + } + } catch (error) { + console.log('pipe error', error); + } + // close stream + !stream.destroyed && stream.push(null); + stream.destroy(); + + const promptsContent = formatPrompts.map((item) => item.content).join(''); + // 只有使用平台的 key 才计费 + pushChatBill({ + isPay: !userApiKey, + modelName: model.service.modelName, + userId, + chatId, + text: promptsContent + responseContent + }); + // jsonRes(res); + } catch (err: any) { + if (step === 1) { + // 直接结束流 + console.log('error,结束'); + stream.destroy(); + } else { + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); + } + } +} diff --git a/src/pages/api/data/splitData.ts b/src/pages/api/data/splitData.ts index 16bded265..13143ae6e 100644 --- a/src/pages/api/data/splitData.ts +++ b/src/pages/api/data/splitData.ts @@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) if (!DataRecord) { throw new Error('找不到数据集'); } - const replaceText = text.replace(/[\r\n\\n]+/g, ' '); + const replaceText = text.replace(/[\\n]+/g, ' '); // 文本拆分成 chunk let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || []; @@ -35,7 +35,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) chunks.forEach((chunk) => { splitText += chunk; const tokens = encode(splitText).length; - if (tokens >= 980) { + if (tokens >= 780) { dataItems.push({ userId, dataId, diff --git a/src/pages/api/model/create.ts b/src/pages/api/model/create.ts index 259c46bea..b7def6ae5 100644 --- a/src/pages/api/model/create.ts +++ b/src/pages/api/model/create.ts @@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; -import { ModelStatusEnum, modelList, ChatModelNameEnum } from '@/constants/model'; +import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model'; import { Model } from '@/service/models/model'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -33,15 +33,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); - // 重名校验 - const authRepeatName = await Model.findOne({ - name, - userId - }); - if (authRepeatName) { - throw new Error('模型名重复'); - } - // 上限校验 const authCount = await Model.countDocuments({ userId @@ -57,9 +48,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< status: ModelStatusEnum.running, service: { company: modelItem.serviceCompany, - trainId: modelItem.trainName, - chatModel: modelItem.model, - modelName: modelItem.model + trainId: '', + chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型 + modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作 } }); diff --git a/src/pages/api/model/data/pushModelDataInput.ts b/src/pages/api/model/data/pushModelDataInput.ts index 6f3362eed..9b0b5614e 100644 --- a/src/pages/api/model/data/pushModelDataInput.ts +++ b/src/pages/api/model/data/pushModelDataInput.ts @@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response'; import { connectToDatabase, ModelData, Model } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { ModelDataSchema } from '@/types/mongoSchema'; +import { generateVector } from '@/service/events/generateVector'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -44,6 +45,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< })) ); + generateVector(true); + jsonRes(res, { data: model }); diff --git a/src/pages/api/model/data/splitData.ts b/src/pages/api/model/data/splitData.ts new file mode 100644 index 000000000..379d952e1 --- /dev/null +++ b/src/pages/api/model/data/splitData.ts @@ -0,0 +1,67 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { connectToDatabase, SplitData, Model } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; +import { generateQA } from '@/service/events/generateQA'; +import { encode } from 'gpt-token-utils'; + +/* 拆分数据成QA */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + const { text, modelId } = req.body as { text: string; modelId: string }; + if (!text || !modelId) { + throw new Error('参数错误'); + } + await connectToDatabase(); + + const { authorization } = req.headers; + + const userId = await authToken(authorization); + + // 验证是否是该用户的 model + const model = await Model.findOne({ + _id: modelId, + userId + }); + + if (!model) { + throw new Error('无权操作该模型'); + } + + const replaceText = text.replace(/(\\n|\n)+/g, ' '); + + // 文本拆分成 chunk + let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || []; + + const textList: string[] = []; + let splitText = ''; + + chunks.forEach((chunk) => { + splitText += chunk; + const tokens = encode(splitText).length; + if (tokens >= 980) { + textList.push(splitText); + splitText = ''; + } + }); + + // 批量插入数据 + await SplitData.create({ + userId, + modelId, + rawText: text, + textList + }); + + // generateQA(); + + jsonRes(res, { + data: { chunks, replaceText } + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/api/model/del.ts b/src/pages/api/model/del.ts index 18dc5b8fa..976ca96f6 100644 --- a/src/pages/api/model/del.ts +++ b/src/pages/api/model/del.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { Chat, Model, Training, connectToDatabase } from '@/service/mongo'; +import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo'; import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; import { TrainingStatusEnum } from '@/constants/model'; import { getOpenAIApi } from '@/service/utils/chat'; @@ -26,16 +26,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); - // 删除模型 - await Model.deleteOne({ - _id: modelId, - userId - }); - + let requestQueue: any[] = []; // 删除对应的聊天 - await Chat.deleteMany({ - modelId - }); + requestQueue.push( + Chat.deleteMany({ + modelId + }) + ); + + // 删除数据集 + requestQueue.push( + ModelData.deleteMany({ + modelId + }) + ); // 查看是否正在训练 const training: TrainingItemType | null = await Training.findOne({ @@ -56,9 +60,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< } // 删除对应训练记录 - await Training.deleteMany({ - modelId - }); + requestQueue.push( + Training.deleteMany({ + modelId + }) + ); + + // 删除模型 + requestQueue.push( + Model.deleteOne({ + _id: modelId, + userId + }) + ); + await requestQueue; jsonRes(res); } catch (err) { diff --git a/src/pages/api/model/update.ts b/src/pages/api/model/update.ts index 1de24e856..af9d013c2 100644 --- a/src/pages/api/model/update.ts +++ b/src/pages/api/model/update.ts @@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< systemPrompt, intro, temperature, - service, + // service, security } ); diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 7f95376b3..6eff04351 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -119,6 +119,7 @@ const Chat = ({ chatId }: { chatId: string }) => { async (prompts: ChatSiteItemType) => { const urlMap: Record = { [ChatModelNameEnum.GPT35]: '/api/chat/chatGpt', + [ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt', [ChatModelNameEnum.GPT3]: '/api/chat/gpt3' }; diff --git a/src/pages/data/list.tsx b/src/pages/data/list.tsx index e1dd0d07a..378d21634 100644 --- a/src/pages/data/list.tsx +++ b/src/pages/data/list.tsx @@ -184,7 +184,7 @@ const DataList = () => { > 导入 - + {/* 导出 @@ -200,7 +200,7 @@ const DataList = () => { )} - + */}