perf: model framwork

This commit is contained in:
archer
2023-04-29 15:55:47 +08:00
parent cd9acab938
commit 78762498eb
30 changed files with 649 additions and 757 deletions

View File

@@ -0,0 +1,202 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const {
prompts,
modelId,
isStream = true
} = req.body as {
prompts: ChatItemType[];
modelId: string;
isStream: boolean;
};
if (!prompts || !modelId) {
throw new Error('缺少参数');
}
if (!Array.isArray(prompts)) {
throw new Error('prompts is not array');
}
if (prompts.length > 30 || prompts.length === 0) {
throw new Error('prompts length range 1-30');
}
await connectToDatabase();
let startTime = Date.now();
/* 凭证校验 */
const { apiKey, userId } = await authOpenApiKey(req);
const { model } = await authModel({
userId,
modelId
});
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 使用了知识库搜索
if (model.chat.useKb) {
const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22;
const { systemPrompts } = await searchKb_openai({
apiKey,
isPay: true,
text: prompts[prompts.length - 1].value,
similarity,
modelId,
userId
});
// filter system prompt
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
return jsonRes(res, {
code: 500,
message: '对不起,你的问题不在知识库中。',
data: '对不起,你的问题不在知识库中。'
});
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
} else {
// 有匹配情况下system 添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const filterSystemPrompt = systemPromptFilter({
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.chat.systemPrompt}
${
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
? `不回答知识库外的内容.`
: ''
}
知识库内容为: ${filterSystemPrompt}'
`
});
}
} else {
// 没有用知识库搜索,仅用系统提示词
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
}
}
// 控制总 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 计算温度
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// console.log(filterPrompts);
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream,
stop: ['.!?。']
},
{
timeout: 180000,
responseType: isStream ? 'stream' : 'json',
...axiosConfig()
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
let responseContent = '';
if (isStream) {
step = 1;
const streamResponse = await gpt35StreamResponse({
res,
stream,
chatResponse
});
responseContent = streamResponse.responseContent;
} else {
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
jsonRes(res, {
data: responseContent
});
}
// 只有使用平台的 key 才计费
pushChatBill({
isPay: true,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@@ -1,7 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, authOpenApiKey } from '@/service/utils/tools';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
@@ -60,37 +60,38 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('无权使用该模型');
}
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 如果有系统提示词,自动插入
if (model.systemPrompt) {
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
value: model.chat.systemPrompt
});
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
@@ -126,7 +127,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 只有使用平台的 key 才计费
pushChatBill({
isPay: true,
modelName: model.service.modelName,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});

View File

@@ -1,20 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/auth';
import { authOpenApiKey } from '@/service/utils/tools';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import {
ModelNameEnum,
modelList,
ModelVectorSearchModeMap,
ChatModelEnum
} from '@/constants/model';
import { modelList, ModelVectorSearchModeMap, ChatModelEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import { PgClient } from '@/service/pg';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -59,10 +53,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('找不到模型');
}
const modelConstantsData = modelList.find((item) => item.model === ModelNameEnum.VECTOR_GPT);
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型已下架');
throw new Error('model is undefined');
}
console.log('laf gpt start');
// 获取 chatAPI
@@ -132,62 +127,48 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompt.value += ` ${promptResolve}`;
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
// 获取提示词的向量
const { vector: promptVector } = await openaiCreateEmbedding({
isPay: true,
apiKey,
userId,
text: prompt.value
});
// 读取对话内容
const prompts = [prompt];
// 相似度搜索
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
where: [
['model_id', model._id],
'AND',
['user_id', userId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
limit: 30
// 获取向量匹配到的提示词
const { systemPrompts } = await searchKb_openai({
isPay: true,
apiKey,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
text: prompt.value,
modelId,
userId
});
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// system 筛选,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
const filterSystemPrompt = systemPromptFilter({
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
value: `${model.chat.systemPrompt} 知识库是最新的,下面是知识库内容:${filterSystemPrompt}`
});
// 控制上下文 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
@@ -223,7 +204,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
pushChatBill({
isPay: true,
modelName: model.service.modelName,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});

View File

@@ -1,24 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import {
axiosConfig,
systemPromptFilter,
authOpenApiKey,
openaiChatFilter
} from '@/service/utils/tools';
import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import {
modelList,
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum,
ModelDataStatusEnum
} from '@/constants/model';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs';
import { PgClient } from '@/service/pg';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -72,96 +62,86 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('无权使用该模型');
}
const modelConstantsData = modelList.find((item) => item.model === model?.service?.modelName);
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型初始化异常');
}
// 获取提示词的向量
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({
// 获取向量匹配到的提示词
const { systemPrompts } = await searchKb_openai({
isPay: true,
apiKey,
userId,
text: prompts[prompts.length - 1].value // 取最后一个
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
text: prompts[prompts.length - 1].value,
modelId,
userId
});
// 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
['model_id', model._id],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
});
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// system 合并
if (prompts[0].obj === 'SYSTEM') {
formatRedisPrompt.unshift(prompts.shift()?.value || '');
systemPrompts.unshift(prompts.shift()?.value || '');
}
/* 高相似度+退出,无法匹配时直接退出 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
return res.send('对不起,你的问题不在知识库中。');
return jsonRes(res, {
code: 500,
message: '对不起,你的问题不在知识库中。',
data: '对不起,你的问题不在知识库中。'
});
}
/* 高相似度+无上下文,不添加额外知识 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.noContext
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
value: model.chat.systemPrompt
});
} else {
// 有匹配或者低匹配度模式情况下,添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.systemPrompt}
${model.chat.systemPrompt}
${
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
? `你只能从知识库选择内容回答.不在知识库内容拒绝回复`
: ''
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : ''
}
知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}'
知识库内容为: ${systemPrompt}'
`
});
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
@@ -196,7 +176,7 @@ ${
pushChatBill({
isPay: true,
modelName: model.service.modelName,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});