feat: chat content use tiktoken count
This commit is contained in:
@@ -4,7 +4,7 @@ import { httpsAgent } from '@/service/utils/tools';
|
||||
import { getOpenApiKey } from '../utils/openai';
|
||||
import type { ChatCompletionRequestMessage } from 'openai';
|
||||
import { DataItemSchema } from '@/types/mongoSchema';
|
||||
import { ChatModelNameEnum } from '@/constants/model';
|
||||
import { ChatModelEnum } from '@/constants/model';
|
||||
import { pushSplitDataBill } from '@/service/events/pushBill';
|
||||
|
||||
export async function generateAbstract(next = false): Promise<any> {
|
||||
@@ -68,7 +68,7 @@ export async function generateAbstract(next = false): Promise<any> {
|
||||
// 请求 chatgpt 获取摘要
|
||||
const abstractResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: ChatModelNameEnum.GPT35,
|
||||
model: ChatModelEnum.GPT35,
|
||||
temperature: 0.8,
|
||||
n: 1,
|
||||
messages: [
|
||||
|
||||
@@ -3,7 +3,7 @@ import { getOpenAIApi } from '@/service/utils/auth';
|
||||
import { httpsAgent } from '@/service/utils/tools';
|
||||
import { getOpenApiKey } from '../utils/openai';
|
||||
import type { ChatCompletionRequestMessage } from 'openai';
|
||||
import { ChatModelNameEnum } from '@/constants/model';
|
||||
import { ChatModelEnum } from '@/constants/model';
|
||||
import { pushSplitDataBill } from '@/service/events/pushBill';
|
||||
import { generateVector } from './generateVector';
|
||||
import { openaiError2 } from '../errorCode';
|
||||
@@ -84,7 +84,7 @@ A2:
|
||||
chatAPI
|
||||
.createChatCompletion(
|
||||
{
|
||||
model: ChatModelNameEnum.GPT35,
|
||||
model: ChatModelEnum.GPT35,
|
||||
temperature: 0.8,
|
||||
n: 1,
|
||||
messages: [
|
||||
|
||||
@@ -1,27 +1,34 @@
|
||||
import { connectToDatabase, Bill, User } from '../mongo';
|
||||
import { modelList, ChatModelNameEnum } from '@/constants/model';
|
||||
import { encode } from 'gpt-token-utils';
|
||||
import {
|
||||
modelList,
|
||||
ChatModelEnum,
|
||||
ModelNameEnum,
|
||||
Model2ChatModelMap,
|
||||
embeddingModel
|
||||
} from '@/constants/model';
|
||||
import { BillTypeEnum } from '@/constants/user';
|
||||
import type { DataType } from '@/types/data';
|
||||
import { countChatTokens } from '@/utils/tools';
|
||||
|
||||
export const pushChatBill = async ({
|
||||
isPay,
|
||||
modelName,
|
||||
userId,
|
||||
chatId,
|
||||
text
|
||||
messages
|
||||
}: {
|
||||
isPay: boolean;
|
||||
modelName: string;
|
||||
modelName: `${ModelNameEnum}`;
|
||||
userId: string;
|
||||
chatId?: '' | string;
|
||||
text: string;
|
||||
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
|
||||
}) => {
|
||||
let billId;
|
||||
let billId = '';
|
||||
|
||||
try {
|
||||
// 计算 token 数量
|
||||
const tokens = Math.floor(encode(text).length * 0.75);
|
||||
const tokens = countChatTokens({ model: Model2ChatModelMap[modelName] as any, messages });
|
||||
const text = messages.map((item) => item.content).join('');
|
||||
|
||||
console.log(
|
||||
`chat generate success. text len: ${text.length}. token len: ${tokens}. pay:${isPay}`
|
||||
@@ -88,7 +95,7 @@ export const pushSplitDataBill = async ({
|
||||
if (isPay) {
|
||||
try {
|
||||
// 获取模型单价格, 都是用 gpt35 拆分
|
||||
const modelItem = modelList.find((item) => item.model === ChatModelNameEnum.GPT35);
|
||||
const modelItem = modelList.find((item) => item.model === ChatModelEnum.GPT35);
|
||||
const unitPrice = modelItem?.price || 3;
|
||||
// 计算价格
|
||||
const price = unitPrice * tokenLen;
|
||||
@@ -97,7 +104,7 @@ export const pushSplitDataBill = async ({
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type,
|
||||
modelName: ChatModelNameEnum.GPT35,
|
||||
modelName: ChatModelEnum.GPT35,
|
||||
textLen: text.length,
|
||||
tokenLen,
|
||||
price
|
||||
@@ -149,7 +156,7 @@ export const pushGenerateVectorBill = async ({
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type: BillTypeEnum.vector,
|
||||
modelName: ChatModelNameEnum.VECTOR,
|
||||
modelName: embeddingModel,
|
||||
textLen: text.length,
|
||||
tokenLen,
|
||||
price
|
||||
|
||||
@@ -5,7 +5,7 @@ import { getOpenAIApi } from '@/service/utils/auth';
|
||||
import { httpsAgent } from './tools';
|
||||
import { User } from '../models/user';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { ChatModelNameEnum } from '@/constants/model';
|
||||
import { embeddingModel } from '@/constants/model';
|
||||
import { pushGenerateVectorBill } from '../events/pushBill';
|
||||
|
||||
/* 获取用户 api 的 openai 信息 */
|
||||
@@ -80,7 +80,7 @@ export const openaiCreateEmbedding = async ({
|
||||
const res = await chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: ChatModelNameEnum.VECTOR,
|
||||
model: embeddingModel,
|
||||
input: text
|
||||
},
|
||||
{
|
||||
@@ -134,11 +134,11 @@ export const gpt35StreamResponse = ({
|
||||
try {
|
||||
const json = JSON.parse(data);
|
||||
const content: string = json?.choices?.[0].delta.content || '';
|
||||
// console.log('content:', content);
|
||||
if (!content || (responseContent === '' && content === '\n')) return;
|
||||
|
||||
responseContent += content;
|
||||
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
|
||||
|
||||
if (!stream.destroyed && content) {
|
||||
stream.push(content.replace(/\n/g, '<br/>'));
|
||||
}
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ import type { NextApiRequest } from 'next';
|
||||
import crypto from 'crypto';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { encode } from 'gpt-token-utils';
|
||||
import { OpenApi, User } from '../mongo';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { ERROR_ENUM } from '../errorCode';
|
||||
import { countChatTokens } from '@/utils/tools';
|
||||
import { ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||
import { ChatModelEnum } from '@/constants/model';
|
||||
|
||||
/* 密码加密 */
|
||||
export const hashPassword = (psw: string) => {
|
||||
@@ -86,8 +88,16 @@ export const authOpenApiKey = async (req: NextApiRequest) => {
|
||||
export const httpsAgent = (fast: boolean) =>
|
||||
fast ? global.httpsAgentFast : global.httpsAgentNormal;
|
||||
|
||||
/* tokens 截断 */
|
||||
export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) => {
|
||||
/* 聊天内容 tokens 截断 */
|
||||
export const openaiChatFilter = ({
|
||||
model,
|
||||
prompts,
|
||||
maxTokens
|
||||
}: {
|
||||
model: `${ChatModelEnum}`;
|
||||
prompts: ChatItemType[];
|
||||
maxTokens: number;
|
||||
}) => {
|
||||
const formatPrompts = prompts.map((item) => ({
|
||||
obj: item.obj,
|
||||
value: item.value
|
||||
@@ -97,41 +107,64 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
|
||||
.trim()
|
||||
}));
|
||||
|
||||
let res: ChatItemType[] = [];
|
||||
|
||||
let chats: ChatItemType[] = [];
|
||||
let systemPrompt: ChatItemType | null = null;
|
||||
|
||||
// System 词保留
|
||||
if (formatPrompts[0]?.obj === 'SYSTEM') {
|
||||
systemPrompt = formatPrompts.shift() as ChatItemType;
|
||||
maxTokens -= encode(formatPrompts[0].value).length;
|
||||
}
|
||||
|
||||
// 从后往前截取
|
||||
// 格式化文本内容成 chatgpt 格式
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
|
||||
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
|
||||
|
||||
// 从后往前截取对话内容
|
||||
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
||||
const tokens = encode(formatPrompts[i].value).length;
|
||||
res.unshift(formatPrompts[i]);
|
||||
chats.unshift(formatPrompts[i]);
|
||||
|
||||
messages = (systemPrompt ? [systemPrompt, ...chats] : chats).map((item) => ({
|
||||
role: map[item.obj],
|
||||
content: item.value
|
||||
}));
|
||||
|
||||
const tokens = countChatTokens({
|
||||
model,
|
||||
messages
|
||||
});
|
||||
|
||||
/* 整体 tokens 超出范围 */
|
||||
if (tokens >= maxTokens) {
|
||||
break;
|
||||
}
|
||||
|
||||
maxTokens -= tokens;
|
||||
}
|
||||
|
||||
return systemPrompt ? [systemPrompt, ...res] : res;
|
||||
return messages;
|
||||
};
|
||||
|
||||
/* system 内容截断 */
|
||||
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
|
||||
export const systemPromptFilter = ({
|
||||
model,
|
||||
prompts,
|
||||
maxTokens
|
||||
}: {
|
||||
model: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
|
||||
prompts: string[];
|
||||
maxTokens: number;
|
||||
}) => {
|
||||
let splitText = '';
|
||||
|
||||
// 从前往前截取
|
||||
for (let i = 0; i < prompts.length; i++) {
|
||||
const prompt = prompts[i];
|
||||
const prompt = prompts[i].replace(/\n+/g, '\n');
|
||||
|
||||
splitText += `${prompt}\n`;
|
||||
const tokens = encode(splitText).length;
|
||||
const tokens = countChatTokens({ model, messages: [{ role: 'system', content: splitText }] });
|
||||
if (tokens >= maxTokens) {
|
||||
break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user