feat: chat content use tiktoken count
This commit is contained in:
@@ -2,7 +2,6 @@ import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { getOpenAIApi, authChat } from '@/service/utils/auth';
|
||||
import { httpsAgent, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
@@ -64,42 +63,23 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
|
||||
|
||||
// 格式化文本内容成 chatgpt 格式
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
|
||||
(item: ChatItemType) => ({
|
||||
role: map[item.obj],
|
||||
content: item.value
|
||||
})
|
||||
);
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
// console.log({
|
||||
// model: model.service.chatModel,
|
||||
// temperature: temperature,
|
||||
// // max_tokens: modelConstantsData.maxToken,
|
||||
// messages: formatPrompts,
|
||||
// frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
// presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
// stream: true,
|
||||
// stop: ['.!?。']
|
||||
// });
|
||||
// console.log(filterPrompts);
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature: temperature,
|
||||
// max_tokens: modelConstantsData.maxToken,
|
||||
messages: formatPrompts,
|
||||
temperature,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: true,
|
||||
@@ -121,7 +101,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
stream,
|
||||
chatResponse
|
||||
});
|
||||
const promptsContent = formatPrompts.map((item) => item.content).join('');
|
||||
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
@@ -129,7 +108,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
modelName: model.service.modelName,
|
||||
userId,
|
||||
chatId,
|
||||
text: promptsContent + responseContent
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
|
||||
@@ -2,10 +2,8 @@ import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authChat } from '@/service/utils/auth';
|
||||
import { httpsAgent, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { PassThrough } from 'stream';
|
||||
import {
|
||||
modelList,
|
||||
@@ -105,9 +103,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
value: model.systemPrompt
|
||||
});
|
||||
} else {
|
||||
// 有匹配情况下,添加知识库内容。
|
||||
// 系统提示词过滤,最多 3000 tokens
|
||||
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
|
||||
// 有匹配情况下,system 添加知识库内容。
|
||||
// 系统提示词过滤,最多 2500 tokens
|
||||
const systemPrompt = systemPromptFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts: formatRedisPrompt,
|
||||
maxTokens: 2500
|
||||
});
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
@@ -124,21 +126,13 @@ ${
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// 格式化文本内容成 chatgpt 格式
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
|
||||
(item: ChatItemType) => ({
|
||||
role: map[item.obj],
|
||||
content: item.value
|
||||
})
|
||||
);
|
||||
// console.log(formatPrompts);
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
|
||||
@@ -146,9 +140,8 @@ ${
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature: temperature,
|
||||
// max_tokens: modelConstantsData.maxToken,
|
||||
messages: formatPrompts,
|
||||
temperature,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: true
|
||||
@@ -170,14 +163,13 @@ ${
|
||||
chatResponse
|
||||
});
|
||||
|
||||
const promptsContent = formatPrompts.map((item) => item.content).join('');
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: !userApiKey,
|
||||
modelName: model.service.modelName,
|
||||
userId,
|
||||
chatId,
|
||||
text: promptsContent + responseContent
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
// jsonRes(res);
|
||||
} catch (err: any) {
|
||||
|
||||
@@ -4,7 +4,7 @@ import { connectToDatabase, DataItem, Data } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { generateQA } from '@/service/events/generateQA';
|
||||
import { generateAbstract } from '@/service/events/generateAbstract';
|
||||
import { encode } from 'gpt-token-utils';
|
||||
import { countChatTokens } from '@/utils/tools';
|
||||
|
||||
/* 拆分数据成QA */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
@@ -34,7 +34,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
|
||||
chunks.forEach((chunk) => {
|
||||
splitText += chunk;
|
||||
const tokens = encode(splitText).length;
|
||||
const tokens = countChatTokens({ messages: [{ role: 'system', content: splitText }] });
|
||||
if (tokens >= 780) {
|
||||
dataItems.push({
|
||||
userId,
|
||||
|
||||
@@ -3,14 +3,14 @@ import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model';
|
||||
import { ModelStatusEnum, modelList, ModelNameEnum, Model2ChatModelMap } from '@/constants/model';
|
||||
import { Model } from '@/service/models/model';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { name, serviceModelName } = req.body as {
|
||||
name: string;
|
||||
serviceModelName: `${ChatModelNameEnum}`;
|
||||
serviceModelName: `${ModelNameEnum}`;
|
||||
};
|
||||
const { authorization } = req.headers;
|
||||
|
||||
@@ -48,7 +48,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
status: ModelStatusEnum.running,
|
||||
service: {
|
||||
trainId: '',
|
||||
chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
|
||||
chatModel: Model2ChatModelMap[modelItem.model], // 聊天时用的模型
|
||||
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
|
||||
}
|
||||
});
|
||||
|
||||
@@ -75,21 +75,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// 格式化文本内容成 chatgpt 格式
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
|
||||
(item: ChatItemType) => ({
|
||||
role: map[item.obj],
|
||||
content: item.value
|
||||
})
|
||||
);
|
||||
// console.log(formatPrompts);
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
|
||||
@@ -99,9 +91,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature: temperature,
|
||||
// max_tokens: modelConstantsData.maxToken,
|
||||
messages: formatPrompts,
|
||||
temperature,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: isStream,
|
||||
@@ -133,14 +124,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
});
|
||||
}
|
||||
|
||||
const promptsContent = formatPrompts.map((item) => item.content).join('');
|
||||
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
modelName: model.service.modelName,
|
||||
userId,
|
||||
text: promptsContent + responseContent
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
|
||||
@@ -3,15 +3,14 @@ import { connectToDatabase, Model } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/auth';
|
||||
import { authOpenApiKey } from '@/service/utils/tools';
|
||||
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
|
||||
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import {
|
||||
ChatModelNameEnum,
|
||||
ModelNameEnum,
|
||||
modelList,
|
||||
ChatModelNameMap,
|
||||
ModelVectorSearchModeMap
|
||||
ModelVectorSearchModeMap,
|
||||
ChatModelEnum
|
||||
} from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
||||
@@ -60,9 +59,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
throw new Error('找不到模型');
|
||||
}
|
||||
|
||||
const modelConstantsData = modelList.find(
|
||||
(item) => item.model === ChatModelNameEnum.VECTOR_GPT
|
||||
);
|
||||
const modelConstantsData = modelList.find((item) => item.model === ModelNameEnum.VECTOR_GPT);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型已下架');
|
||||
}
|
||||
@@ -74,7 +71,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
// 请求一次 chatgpt 拆解需求
|
||||
const promptResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: ChatModelNameMap[ChatModelNameEnum.GPT35],
|
||||
model: ChatModelEnum.GPT35,
|
||||
temperature: 0,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
@@ -122,7 +119,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
]
|
||||
},
|
||||
{
|
||||
timeout: 120000,
|
||||
timeout: 180000,
|
||||
httpsAgent: httpsAgent(true)
|
||||
}
|
||||
);
|
||||
@@ -163,30 +160,26 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
|
||||
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
|
||||
|
||||
// textArr 筛选,最多 2500 tokens
|
||||
const systemPrompt = systemPromptFilter(formatRedisPrompt, 2500);
|
||||
// system 筛选,最多 2500 tokens
|
||||
const systemPrompt = systemPromptFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts: formatRedisPrompt,
|
||||
maxTokens: 2500
|
||||
});
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
|
||||
});
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
|
||||
// 控制上下文 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// 格式化文本内容成 chatgpt 格式
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
|
||||
(item: ChatItemType) => ({
|
||||
role: map[item.obj],
|
||||
content: item.value
|
||||
})
|
||||
);
|
||||
// console.log(formatPrompts);
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
|
||||
@@ -195,13 +188,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature,
|
||||
messages: formatPrompts,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: isStream
|
||||
},
|
||||
{
|
||||
timeout: 120000,
|
||||
timeout: 180000,
|
||||
responseType: isStream ? 'stream' : 'json',
|
||||
httpsAgent: httpsAgent(true)
|
||||
}
|
||||
@@ -228,13 +221,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
|
||||
console.log('laf gpt done. time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
const promptsContent = formatPrompts.map((item) => item.content).join('');
|
||||
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
modelName: model.service.modelName,
|
||||
userId,
|
||||
text: promptsContent + responseContent
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
|
||||
@@ -126,8 +126,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
});
|
||||
} else {
|
||||
// 有匹配或者低匹配度模式情况下,添加知识库内容。
|
||||
// 系统提示词过滤,最多 3000 tokens
|
||||
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
|
||||
// 系统提示词过滤,最多 2500 tokens
|
||||
const systemPrompt = systemPromptFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts: formatRedisPrompt,
|
||||
maxTokens: 2500
|
||||
});
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
@@ -144,21 +148,13 @@ ${
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.service.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 500
|
||||
});
|
||||
|
||||
// 格式化文本内容成 chatgpt 格式
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
|
||||
(item: ChatItemType) => ({
|
||||
role: map[item.obj],
|
||||
content: item.value
|
||||
})
|
||||
);
|
||||
// console.log(formatPrompts);
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
|
||||
@@ -166,14 +162,14 @@ ${
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature: temperature,
|
||||
messages: formatPrompts,
|
||||
temperature,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: isStream
|
||||
},
|
||||
{
|
||||
timeout: 120000,
|
||||
timeout: 180000,
|
||||
responseType: isStream ? 'stream' : 'json',
|
||||
httpsAgent: httpsAgent(true)
|
||||
}
|
||||
@@ -198,12 +194,11 @@ ${
|
||||
});
|
||||
}
|
||||
|
||||
const promptsContent = formatPrompts.map((item) => item.content).join('');
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
modelName: model.service.modelName,
|
||||
userId,
|
||||
text: promptsContent + responseContent
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
// jsonRes(res);
|
||||
} catch (err: any) {
|
||||
|
||||
Reference in New Issue
Block a user