feat: chat quote

This commit is contained in:
archer
2023-05-23 15:09:57 +08:00
parent ee2c259c3d
commit 944e876aaa
29 changed files with 933 additions and 660 deletions

View File

@@ -6,15 +6,19 @@ import { ChatItemSimpleType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { searchKb } from '@/service/plugins/searchKb';
import { ChatRoleEnum } from '@/constants/chat';
import { withNextCors } from '@/service/utils/tools';
import { BillTypeEnum } from '@/constants/user';
import { sensitiveCheck } from '@/service/api/text';
import { NEW_CHATID_HEADER } from '@/constants/chat';
import { Types } from 'mongoose';
import { appKbSearch } from '../kb/appKbSearch';
/* 发送提示词 */
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
res.on('close', () => {
res.end();
});
res.on('error', () => {
console.log('error: ', 'request error');
res.end();
@@ -70,7 +74,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts } = await searchKb({
const { code, searchPrompts } = await appKbSearch({
prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
model,
@@ -109,6 +113,10 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
2
);
// get conversationId. create a newId if it is null
const conversationId = chatId || String(new Types.ObjectId());
!chatId && res?.setHeader(NEW_CHATID_HEADER, conversationId);
// 发出请求
const { streamResponse, responseMessages, responseText, totalTokens } =
await modelServiceToolMap[model.chat.chatModel].chatCompletion({
@@ -117,30 +125,41 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
messages: prompts,
stream: isStream,
res,
chatId
chatId: conversationId
});
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
let textLen = 0;
let tokens = totalTokens;
if (res.closed) return res.end();
if (isStream) {
step = 1;
const { finishMessages, totalTokens } = await resStreamResponse({
model: model.chat.chatModel,
res,
chatResponse: streamResponse,
prompts
});
textLen = finishMessages.map((item) => item.value).join('').length;
tokens = totalTokens;
} else {
textLen = responseMessages.map((item) => item.value).join('').length;
jsonRes(res, {
data: responseText
});
}
const { textLen = 0, tokens = totalTokens } = await (async () => {
if (isStream) {
try {
const { finishMessages, totalTokens } = await resStreamResponse({
model: model.chat.chatModel,
res,
chatResponse: streamResponse,
prompts
});
res.end();
return {
textLen: finishMessages.map((item) => item.value).join('').length,
tokens: totalTokens
};
} catch (error) {
res.end();
console.log('error结束', error);
}
} else {
jsonRes(res, {
data: responseText
});
return {
textLen: responseMessages.map((item) => item.value).join('').length
};
}
return {};
})();
pushChatBill({
isPay: true,
@@ -151,16 +170,10 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
type: BillTypeEnum.openapiChat
});
} catch (err: any) {
if (step === 1) {
// 直接结束流
res.end();
console.log('error结束');
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
});

View File

@@ -0,0 +1,38 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { Chat } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { QuoteItemType } from '../kb/appKbSearch';
type Props = {
chatId: string;
};
export type Response = {
quote: QuoteItemType[];
};
/* 聊天内容存存储 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { chatId } = req.query as Props;
if (!chatId) {
throw new Error('缺少参数');
}
const { userId } = await authUser({ req });
const chatItem = await Chat.findOne({ _id: chatId, userId }, { content: { $slice: -1 } });
jsonRes<Response>(res, {
data: {
quote: chatItem?.content[0]?.quote || []
}
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -0,0 +1,224 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools';
import type { ChatItemSimpleType } from '@/types/chat';
import type { ModelSchema } from '@/types/mongoSchema';
import { ModelVectorSearchModeEnum } from '@/constants/model';
import { authModel } from '@/service/utils/auth';
import { ChatModelMap } from '@/constants/model';
import { ChatRoleEnum } from '@/constants/chat';
import { openaiEmbedding } from '../plugin/openaiEmbedding';
import { ModelDataStatusEnum } from '@/constants/model';
import { modelToolMap } from '@/utils/plugin';
export type QuoteItemType = { id: string; q: string; a: string };
type Props = {
prompts: ChatItemSimpleType[];
similarity: number;
appId: string;
};
type Response = {
code: 200 | 201;
rawSearch: QuoteItemType[];
searchPrompts: {
obj: ChatRoleEnum;
value: string;
}[];
};
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { userId } = await authUser({ req });
if (!userId) {
throw new Error('userId is empty');
}
const { prompts, similarity, appId } = req.body as Props;
if (!similarity || !Array.isArray(prompts) || !appId) {
throw new Error('params is error');
}
// auth model
const { model } = await authModel({
modelId: appId,
userId
});
const result = await appKbSearch({
userId,
prompts,
similarity,
model
});
jsonRes<Response>(res, {
data: result
});
} catch (err) {
console.log(err);
jsonRes(res, {
code: 500,
error: err
});
}
});
export async function appKbSearch({
model,
userId,
prompts,
similarity
}: {
userId: string;
prompts: ChatItemSimpleType[];
similarity: number;
model: ModelSchema;
}): Promise<Response> {
const modelConstantsData = ChatModelMap[model.chat.chatModel];
// search two times.
const userPrompts = prompts.filter((item) => item.obj === 'Human');
const input: string[] = [
userPrompts[userPrompts.length - 1].value,
userPrompts[userPrompts.length - 2]?.value
].filter((item) => item);
// get vector
const promptVectors = await openaiEmbedding({
userId,
input
});
// search kb
const searchRes = await Promise.all(
promptVectors.map((promptVector) =>
PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: promptVectors.length === 1 ? 15 : 10
}).then((res) => res.rows)
)
);
// filter same search result
const idSet = new Set<string>();
const filterSearch = searchRes.map((search) =>
search.filter((item) => {
if (idSet.has(item.id)) {
return false;
}
idSet.add(item.id);
return true;
})
);
// slice search result by rate.
const sliceRateMap: Record<number, number[]> = {
1: [1],
2: [0.7, 0.3]
};
const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0];
// 计算固定提示词的 token 数量
const fixedPrompts = [
// user system prompt
...(model.chat.systemPrompt
? [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
]
: model.chat.searchMode === ModelVectorSearchModeEnum.noContext
? [
{
obj: ChatRoleEnum.System,
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
}
]
: [
{
obj: ChatRoleEnum.System,
value: `玩一个问答游戏,规则为:
1.你完全忘记你已有的知识
2.你只回答关于"${model.name}"的问题
3.你只从知识库中选择内容进行回答
4.如果问题不在知识库中,你会回答:"我不知道。"
请务必遵守规则`
}
])
];
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
messages: fixedPrompts
});
const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens;
const sliceResult = sliceRate.map((rate, i) =>
modelToolMap[model.chat.chatModel]
.tokenSlice({
maxToken: Math.round(maxTokens * rate),
messages: filterSearch[i].map((item) => ({
obj: ChatRoleEnum.System,
value: `${item.q}\n${item.a}`
}))
})
.map((item) => item.value)
);
// slice filterSearch
const sliceSearch = filterSearch.map((item, i) => item.slice(0, sliceResult[i].length)).flat();
// system prompt
const systemPrompt = sliceResult.flat().join('\n').trim();
/* 高相似度+不回复 */
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) {
return {
code: 201,
rawSearch: [],
searchPrompts: [
{
obj: ChatRoleEnum.System,
value: '对不起,你的问题不在知识库中。'
}
]
};
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
return {
code: 200,
rawSearch: [],
searchPrompts: model.chat.systemPrompt
? [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
]
: []
};
}
return {
code: 200,
rawSearch: sliceSearch,
searchPrompts: [
{
obj: ChatRoleEnum.System,
value: `知识库:${systemPrompt}`
},
...fixedPrompts
]
};
}

View File

@@ -0,0 +1,77 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools';
import { getApiKey } from '@/service/utils/auth';
import { getOpenAIApi } from '@/service/utils/chat/openai';
import { embeddingModel } from '@/constants/model';
import { axiosConfig } from '@/service/utils/tools';
import { pushGenerateVectorBill } from '@/service/events/pushBill';
type Props = {
input: string[];
};
type Response = number[][];
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { userId } = await authUser({ req });
let { input } = req.query as Props;
if (!Array.isArray(input)) {
throw new Error('缺少参数');
}
jsonRes<Response>(res, {
data: await openaiEmbedding({ userId, input, mustPay: true })
});
} catch (err) {
console.log(err);
jsonRes(res, {
code: 500,
error: err
});
}
});
export async function openaiEmbedding({
userId,
input,
mustPay = false
}: { userId: string; mustPay?: boolean } & Props) {
const { userOpenAiKey, systemAuthKey } = await getApiKey({
model: 'gpt-3.5-turbo',
userId,
mustPay
});
// 获取 chatAPI
const chatAPI = getOpenAIApi();
// 把输入的内容转成向量
const result = await chatAPI
.createEmbedding(
{
model: embeddingModel,
input
},
{
timeout: 60000,
...axiosConfig(userOpenAiKey || systemAuthKey)
}
)
.then((res) => ({
tokenLen: res.data.usage.total_tokens || 0,
vectors: res.data.data.map((item) => item.embedding)
}));
pushGenerateVectorBill({
isPay: !userOpenAiKey,
userId,
text: input.join(''),
tokenLen: result.tokenLen
});
return result.vectors;
}

View File

@@ -0,0 +1,119 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import type { NextApiRequest, NextApiResponse } from 'next';
import { type Tiktoken } from '@dqbd/tiktoken';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import Graphemer from 'graphemer';
import type { ChatItemSimpleType } from '@/types/chat';
import { ChatCompletionRequestMessage } from 'openai';
import { getOpenAiEncMap } from '@/utils/plugin/openai';
import { adaptChatItem_openAI } from '@/utils/plugin/openai';
type ModelType = 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k';
type Props = {
messages: ChatItemSimpleType[];
model: ModelType;
maxLen: number;
};
type Response = ChatItemSimpleType[];
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
await authUser({ req });
const { messages, model, maxLen } = req.body as Props;
if (!Array.isArray(messages) || !model || !maxLen) {
throw new Error('params is error');
}
return jsonRes<Response>(res, {
data: gpt_chatItemTokenSlice({
messages,
model,
maxToken: maxLen
})
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}
export function gpt_chatItemTokenSlice({
messages,
model,
maxToken
}: {
messages: ChatItemSimpleType[];
model: ModelType;
maxToken: number;
}) {
const textDecoder = new TextDecoder();
const graphemer = new Graphemer();
function getChatGPTEncodingText(messages: ChatCompletionRequestMessage[], model: ModelType) {
const isGpt3 = model === 'gpt-3.5-turbo';
const msgSep = isGpt3 ? '\n' : '';
const roleSep = isGpt3 ? '\n' : '<|im_sep|>';
return [
messages
.map(({ name = '', role, content }) => {
return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`;
})
.join(msgSep),
`<|im_start|>assistant${roleSep}`
].join(msgSep);
}
function text2TokensLen(encoder: Tiktoken, inputText: string) {
const encoding = encoder.encode(inputText, 'all');
const segments: { text: string; tokens: { id: number; idx: number }[] }[] = [];
let byteAcc: number[] = [];
let tokenAcc: { id: number; idx: number }[] = [];
let inputGraphemes = graphemer.splitGraphemes(inputText);
for (let idx = 0; idx < encoding.length; idx++) {
const token = encoding[idx]!;
byteAcc.push(...encoder.decode_single_token_bytes(token));
tokenAcc.push({ id: token, idx });
const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
const graphemes = graphemer.splitGraphemes(segmentText);
if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
segments.push({ text: segmentText, tokens: tokenAcc });
byteAcc = [];
tokenAcc = [];
inputGraphemes = inputGraphemes.slice(graphemes.length);
}
}
return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0;
}
const OpenAiEncMap = getOpenAiEncMap();
const enc = OpenAiEncMap[model];
let result: ChatItemSimpleType[] = [];
for (let i = 0; i < messages.length; i++) {
const msgs = [...result, messages[i]];
const tokens = text2TokensLen(
enc,
getChatGPTEncodingText(adaptChatItem_openAI({ messages }), model)
);
if (tokens < maxToken) {
result = msgs;
} else {
break;
}
}
return result;
}