perf: model framwork
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { getOpenAIApi, authChat } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import { modelList } from '@/constants/model';
|
||||
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
@@ -46,7 +47,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
authorization
|
||||
});
|
||||
|
||||
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
|
||||
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型加载异常');
|
||||
}
|
||||
@@ -54,31 +55,84 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
// 读取对话内容
|
||||
const prompts = [...content, prompt];
|
||||
|
||||
// 如果有系统提示词,自动插入
|
||||
if (model.systemPrompt) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.systemPrompt
|
||||
// 使用了知识库搜索
|
||||
if (model.chat.useKb) {
|
||||
const { systemPrompts } = await searchKb_openai({
|
||||
apiKey: userApiKey || systemKey,
|
||||
isPay: !userApiKey,
|
||||
text: prompt.value,
|
||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
|
||||
modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
// filter system prompt
|
||||
if (
|
||||
systemPrompts.length === 0 &&
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
) {
|
||||
return res.send('对不起,你的问题不在知识库中。');
|
||||
}
|
||||
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
|
||||
if (
|
||||
systemPrompts.length === 0 &&
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
|
||||
) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
} else {
|
||||
// 有匹配情况下,system 添加知识库内容。
|
||||
// 系统提示词过滤,最多 2500 tokens
|
||||
const filterSystemPrompt = systemPromptFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts: systemPrompts,
|
||||
maxTokens: 2500
|
||||
});
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: `
|
||||
${model.chat.systemPrompt}
|
||||
${
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
? `不回答知识库外的内容.`
|
||||
: ''
|
||||
}
|
||||
知识库内容为: ${filterSystemPrompt}'
|
||||
`
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// 没有用知识库搜索,仅用系统提示词
|
||||
if (model.chat.systemPrompt) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
// 控制总 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
// console.log(filterPrompts);
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature,
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
@@ -105,7 +159,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: !userApiKey,
|
||||
modelName: model.service.modelName,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
chatId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
@@ -59,8 +59,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
name: model.name,
|
||||
avatar: model.avatar,
|
||||
intro: model.share.intro,
|
||||
modelName: model.service.modelName,
|
||||
chatModel: model.service.chatModel,
|
||||
chatModel: model.chat.chatModel,
|
||||
history
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authChat } from '@/service/utils/auth';
|
||||
import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import {
|
||||
modelList,
|
||||
ModelVectorSearchModeMap,
|
||||
ModelVectorSearchModeEnum,
|
||||
ModelDataStatusEnum
|
||||
} from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import dayjs from 'dayjs';
|
||||
import { PgClient } from '@/service/pg';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
let step = 0; // step=1时,表示开始了流响应
|
||||
const stream = new PassThrough();
|
||||
stream.on('error', () => {
|
||||
console.log('error: ', 'stream error');
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('close', () => {
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('error', () => {
|
||||
console.log('error: ', 'request error');
|
||||
stream.destroy();
|
||||
});
|
||||
|
||||
try {
|
||||
const { modelId, chatId, prompt } = req.body as {
|
||||
modelId: string;
|
||||
chatId: '' | string;
|
||||
prompt: ChatItemType;
|
||||
};
|
||||
|
||||
const { authorization } = req.headers;
|
||||
if (!modelId || !prompt) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
let startTime = Date.now();
|
||||
|
||||
const { model, content, userApiKey, systemKey, userId } = await authChat({
|
||||
modelId,
|
||||
chatId,
|
||||
authorization
|
||||
});
|
||||
|
||||
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型加载异常');
|
||||
}
|
||||
|
||||
// 读取对话内容
|
||||
const prompts = [...content, prompt];
|
||||
|
||||
// 获取提示词的向量
|
||||
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({
|
||||
isPay: !userApiKey,
|
||||
apiKey: userApiKey || systemKey,
|
||||
userId,
|
||||
text: prompt.value
|
||||
});
|
||||
|
||||
// 相似度搜素
|
||||
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
|
||||
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
|
||||
fields: ['id', 'q', 'a'],
|
||||
where: [
|
||||
['status', ModelDataStatusEnum.ready],
|
||||
'AND',
|
||||
['model_id', model._id],
|
||||
'AND',
|
||||
`vector <=> '[${promptVector}]' < ${similarity}`
|
||||
],
|
||||
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
||||
limit: 20
|
||||
});
|
||||
|
||||
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
|
||||
|
||||
/* 高相似度+退出,无法匹配时直接退出 */
|
||||
if (
|
||||
formatRedisPrompt.length === 0 &&
|
||||
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
) {
|
||||
return res.send('对不起,你的问题不在知识库中。');
|
||||
}
|
||||
/* 高相似度+无上下文,不添加额外知识 */
|
||||
if (
|
||||
formatRedisPrompt.length === 0 &&
|
||||
model.search.mode === ModelVectorSearchModeEnum.noContext
|
||||
) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.systemPrompt
|
||||
});
|
||||
} else {
|
||||
// 有匹配情况下,system 添加知识库内容。
|
||||
// 系统提示词过滤,最多 2500 tokens
|
||||
const systemPrompt = systemPromptFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts: formatRedisPrompt,
|
||||
maxTokens: 2500
|
||||
});
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: `
|
||||
${model.systemPrompt}
|
||||
${
|
||||
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
? `你只能从知识库选择内容回答.不在知识库内容拒绝回复`
|
||||
: ''
|
||||
}
|
||||
知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}'
|
||||
`
|
||||
});
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: true,
|
||||
stop: ['.!?。']
|
||||
},
|
||||
{
|
||||
timeout: 40000,
|
||||
responseType: 'stream',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
|
||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
step = 1;
|
||||
|
||||
const { responseContent } = await gpt35StreamResponse({
|
||||
res,
|
||||
stream,
|
||||
chatResponse
|
||||
});
|
||||
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: !userApiKey,
|
||||
modelName: model.service.modelName,
|
||||
userId,
|
||||
chatId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
// jsonRes(res);
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
// 直接结束流
|
||||
console.log('error,结束');
|
||||
stream.destroy();
|
||||
} else {
|
||||
res.status(500);
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,14 +3,13 @@ import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { ModelStatusEnum, modelList, ModelNameEnum, Model2ChatModelMap } from '@/constants/model';
|
||||
import { ModelStatusEnum } from '@/constants/model';
|
||||
import { Model } from '@/service/models/model';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { name, serviceModelName } = req.body as {
|
||||
const { name } = req.body as {
|
||||
name: string;
|
||||
serviceModelName: `${ModelNameEnum}`;
|
||||
};
|
||||
const { authorization } = req.headers;
|
||||
|
||||
@@ -18,45 +17,32 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
if (!name || !serviceModelName) {
|
||||
if (!name) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
// 凭证校验
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
const modelItem = modelList.find((item) => item.model === serviceModelName);
|
||||
|
||||
if (!modelItem) {
|
||||
throw new Error('模型不存在');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 上限校验
|
||||
const authCount = await Model.countDocuments({
|
||||
userId
|
||||
});
|
||||
if (authCount >= 20) {
|
||||
throw new Error('上限 20 个模型');
|
||||
if (authCount >= 30) {
|
||||
throw new Error('上限 30 个模型');
|
||||
}
|
||||
|
||||
// 创建模型
|
||||
const response = await Model.create({
|
||||
name,
|
||||
userId,
|
||||
status: ModelStatusEnum.running,
|
||||
service: {
|
||||
chatModel: Model2ChatModelMap[modelItem.model], // 聊天时用的模型
|
||||
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
|
||||
}
|
||||
status: ModelStatusEnum.running
|
||||
});
|
||||
|
||||
// 根据 id 获取模型信息
|
||||
const model = await Model.findById(response._id);
|
||||
|
||||
jsonRes(res, {
|
||||
data: model
|
||||
data: response._id
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
|
||||
@@ -9,8 +9,7 @@ import { authModel } from '@/service/utils/auth';
|
||||
/* 获取我的模型 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { name, avatar, search, share, service, security, systemPrompt, temperature } =
|
||||
req.body as ModelUpdateParams;
|
||||
const { name, avatar, chat, share, security } = req.body as ModelUpdateParams;
|
||||
const { modelId } = req.query as { modelId: string };
|
||||
const { authorization } = req.headers;
|
||||
|
||||
@@ -18,7 +17,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
if (!name || !service || !security || !modelId) {
|
||||
if (!name || !chat || !security || !modelId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
|
||||
@@ -41,12 +40,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
{
|
||||
name,
|
||||
avatar,
|
||||
systemPrompt,
|
||||
temperature,
|
||||
chat,
|
||||
'share.isShare': share.isShare,
|
||||
'share.isShareDetail': share.isShareDetail,
|
||||
'share.intro': share.intro,
|
||||
search,
|
||||
security
|
||||
}
|
||||
);
|
||||
|
||||
202
src/pages/api/openapi/chat/chat.ts
Normal file
202
src/pages/api/openapi/chat/chat.ts
Normal file
@@ -0,0 +1,202 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
let step = 0; // step=1时,表示开始了流响应
|
||||
const stream = new PassThrough();
|
||||
stream.on('error', () => {
|
||||
console.log('error: ', 'stream error');
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('close', () => {
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('error', () => {
|
||||
console.log('error: ', 'request error');
|
||||
stream.destroy();
|
||||
});
|
||||
|
||||
try {
|
||||
const {
|
||||
prompts,
|
||||
modelId,
|
||||
isStream = true
|
||||
} = req.body as {
|
||||
prompts: ChatItemType[];
|
||||
modelId: string;
|
||||
isStream: boolean;
|
||||
};
|
||||
|
||||
if (!prompts || !modelId) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
if (!Array.isArray(prompts)) {
|
||||
throw new Error('prompts is not array');
|
||||
}
|
||||
if (prompts.length > 30 || prompts.length === 0) {
|
||||
throw new Error('prompts length range 1-30');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
let startTime = Date.now();
|
||||
|
||||
/* 凭证校验 */
|
||||
const { apiKey, userId } = await authOpenApiKey(req);
|
||||
|
||||
const { model } = await authModel({
|
||||
userId,
|
||||
modelId
|
||||
});
|
||||
|
||||
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型加载异常');
|
||||
}
|
||||
|
||||
// 使用了知识库搜索
|
||||
if (model.chat.useKb) {
|
||||
const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22;
|
||||
|
||||
const { systemPrompts } = await searchKb_openai({
|
||||
apiKey,
|
||||
isPay: true,
|
||||
text: prompts[prompts.length - 1].value,
|
||||
similarity,
|
||||
modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
// filter system prompt
|
||||
if (
|
||||
systemPrompts.length === 0 &&
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
) {
|
||||
return jsonRes(res, {
|
||||
code: 500,
|
||||
message: '对不起,你的问题不在知识库中。',
|
||||
data: '对不起,你的问题不在知识库中。'
|
||||
});
|
||||
}
|
||||
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
|
||||
if (
|
||||
systemPrompts.length === 0 &&
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
|
||||
) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
} else {
|
||||
// 有匹配情况下,system 添加知识库内容。
|
||||
// 系统提示词过滤,最多 2500 tokens
|
||||
const filterSystemPrompt = systemPromptFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts: systemPrompts,
|
||||
maxTokens: 2500
|
||||
});
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: `
|
||||
${model.chat.systemPrompt}
|
||||
${
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
? `不回答知识库外的内容.`
|
||||
: ''
|
||||
}
|
||||
知识库内容为: ${filterSystemPrompt}'
|
||||
`
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// 没有用知识库搜索,仅用系统提示词
|
||||
if (model.chat.systemPrompt) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 控制总 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// 计算温度
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
// console.log(filterPrompts);
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: isStream,
|
||||
stop: ['.!?。']
|
||||
},
|
||||
{
|
||||
timeout: 180000,
|
||||
responseType: isStream ? 'stream' : 'json',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
|
||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
let responseContent = '';
|
||||
|
||||
if (isStream) {
|
||||
step = 1;
|
||||
const streamResponse = await gpt35StreamResponse({
|
||||
res,
|
||||
stream,
|
||||
chatResponse
|
||||
});
|
||||
responseContent = streamResponse.responseContent;
|
||||
} else {
|
||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
||||
jsonRes(res, {
|
||||
data: responseContent
|
||||
});
|
||||
}
|
||||
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
// 直接结束流
|
||||
console.log('error,结束');
|
||||
stream.destroy();
|
||||
} else {
|
||||
res.status(500);
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase, Model } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter, authOpenApiKey } from '@/service/utils/tools';
|
||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
@@ -60,37 +60,38 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
throw new Error('无权使用该模型');
|
||||
}
|
||||
|
||||
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
|
||||
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型加载异常');
|
||||
}
|
||||
|
||||
// 如果有系统提示词,自动插入
|
||||
if (model.systemPrompt) {
|
||||
if (model.chat.systemPrompt) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.systemPrompt
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature,
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
@@ -126,7 +127,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
modelName: model.service.modelName,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase, Model } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/auth';
|
||||
import { authOpenApiKey } from '@/service/utils/tools';
|
||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import {
|
||||
ModelNameEnum,
|
||||
modelList,
|
||||
ModelVectorSearchModeMap,
|
||||
ChatModelEnum
|
||||
} from '@/constants/model';
|
||||
import { modelList, ModelVectorSearchModeMap, ChatModelEnum } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
@@ -59,10 +53,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
throw new Error('找不到模型');
|
||||
}
|
||||
|
||||
const modelConstantsData = modelList.find((item) => item.model === ModelNameEnum.VECTOR_GPT);
|
||||
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型已下架');
|
||||
throw new Error('model is undefined');
|
||||
}
|
||||
|
||||
console.log('laf gpt start');
|
||||
|
||||
// 获取 chatAPI
|
||||
@@ -132,62 +127,48 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
prompt.value += ` ${promptResolve}`;
|
||||
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
// 获取提示词的向量
|
||||
const { vector: promptVector } = await openaiCreateEmbedding({
|
||||
isPay: true,
|
||||
apiKey,
|
||||
userId,
|
||||
text: prompt.value
|
||||
});
|
||||
|
||||
// 读取对话内容
|
||||
const prompts = [prompt];
|
||||
|
||||
// 相似度搜索
|
||||
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
|
||||
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
|
||||
fields: ['id', 'q', 'a'],
|
||||
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
||||
where: [
|
||||
['model_id', model._id],
|
||||
'AND',
|
||||
['user_id', userId],
|
||||
'AND',
|
||||
`vector <=> '[${promptVector}]' < ${similarity}`
|
||||
],
|
||||
limit: 30
|
||||
// 获取向量匹配到的提示词
|
||||
const { systemPrompts } = await searchKb_openai({
|
||||
isPay: true,
|
||||
apiKey,
|
||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
|
||||
text: prompt.value,
|
||||
modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
|
||||
|
||||
// system 筛选,最多 2500 tokens
|
||||
const systemPrompt = systemPromptFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts: formatRedisPrompt,
|
||||
const filterSystemPrompt = systemPromptFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts: systemPrompts,
|
||||
maxTokens: 2500
|
||||
});
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
|
||||
value: `${model.chat.systemPrompt} 知识库是最新的,下面是知识库内容:${filterSystemPrompt}`
|
||||
});
|
||||
|
||||
// 控制上下文 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature,
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
@@ -223,7 +204,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
modelName: model.service.modelName,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
|
||||
@@ -1,24 +1,14 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase, Model } from '@/service/mongo';
|
||||
import {
|
||||
axiosConfig,
|
||||
systemPromptFilter,
|
||||
authOpenApiKey,
|
||||
openaiChatFilter
|
||||
} from '@/service/utils/tools';
|
||||
import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import {
|
||||
modelList,
|
||||
ModelVectorSearchModeMap,
|
||||
ModelVectorSearchModeEnum,
|
||||
ModelDataStatusEnum
|
||||
} from '@/constants/model';
|
||||
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import dayjs from 'dayjs';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
@@ -72,96 +62,86 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
throw new Error('无权使用该模型');
|
||||
}
|
||||
|
||||
const modelConstantsData = modelList.find((item) => item.model === model?.service?.modelName);
|
||||
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型初始化异常');
|
||||
}
|
||||
|
||||
// 获取提示词的向量
|
||||
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({
|
||||
// 获取向量匹配到的提示词
|
||||
const { systemPrompts } = await searchKb_openai({
|
||||
isPay: true,
|
||||
apiKey,
|
||||
userId,
|
||||
text: prompts[prompts.length - 1].value // 取最后一个
|
||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
|
||||
text: prompts[prompts.length - 1].value,
|
||||
modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
// 相似度搜素
|
||||
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
|
||||
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
|
||||
fields: ['id', 'q', 'a'],
|
||||
where: [
|
||||
['status', ModelDataStatusEnum.ready],
|
||||
'AND',
|
||||
['model_id', model._id],
|
||||
'AND',
|
||||
`vector <=> '[${promptVector}]' < ${similarity}`
|
||||
],
|
||||
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
||||
limit: 20
|
||||
});
|
||||
|
||||
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
|
||||
|
||||
// system 合并
|
||||
if (prompts[0].obj === 'SYSTEM') {
|
||||
formatRedisPrompt.unshift(prompts.shift()?.value || '');
|
||||
systemPrompts.unshift(prompts.shift()?.value || '');
|
||||
}
|
||||
|
||||
/* 高相似度+退出,无法匹配时直接退出 */
|
||||
if (
|
||||
formatRedisPrompt.length === 0 &&
|
||||
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
systemPrompts.length === 0 &&
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
) {
|
||||
return res.send('对不起,你的问题不在知识库中。');
|
||||
return jsonRes(res, {
|
||||
code: 500,
|
||||
message: '对不起,你的问题不在知识库中。',
|
||||
data: '对不起,你的问题不在知识库中。'
|
||||
});
|
||||
}
|
||||
/* 高相似度+无上下文,不添加额外知识 */
|
||||
if (
|
||||
formatRedisPrompt.length === 0 &&
|
||||
model.search.mode === ModelVectorSearchModeEnum.noContext
|
||||
systemPrompts.length === 0 &&
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
|
||||
) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.systemPrompt
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
} else {
|
||||
// 有匹配或者低匹配度模式情况下,添加知识库内容。
|
||||
// 系统提示词过滤,最多 2500 tokens
|
||||
const systemPrompt = systemPromptFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts: formatRedisPrompt,
|
||||
model: model.chat.chatModel,
|
||||
prompts: systemPrompts,
|
||||
maxTokens: 2500
|
||||
});
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: `
|
||||
${model.systemPrompt}
|
||||
${model.chat.systemPrompt}
|
||||
${
|
||||
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
|
||||
? `你只能从知识库选择内容回答.不在知识库内容拒绝回复`
|
||||
: ''
|
||||
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : ''
|
||||
}
|
||||
知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}'
|
||||
知识库内容为: ${systemPrompt}'
|
||||
`
|
||||
});
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature,
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
@@ -196,7 +176,7 @@ ${
|
||||
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
modelName: model.service.modelName,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user