feat: chat quote

This commit is contained in:
archer
2023-05-23 15:09:57 +08:00
parent ee2c259c3d
commit 944e876aaa
29 changed files with 933 additions and 660 deletions

View File

@@ -2,19 +2,24 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { authChat } from '@/service/utils/auth';
import { modelServiceToolMap } from '@/service/utils/chat';
import { ChatItemSimpleType } from '@/types/chat';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { resStreamResponse } from '@/service/utils/chat';
import { searchKb } from '@/service/plugins/searchKb';
import { appKbSearch } from '../openapi/kb/appKbSearch';
import { ChatRoleEnum } from '@/constants/chat';
import { BillTypeEnum } from '@/constants/user';
import { sensitiveCheck } from '@/service/api/text';
import { NEW_CHATID_HEADER } from '@/constants/chat';
import { saveChat } from './saveChat';
import { Types } from 'mongoose';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
res.on('close', () => {
res.end();
});
res.on('error', () => {
console.log('error: ', 'request error');
res.end();
@@ -22,9 +27,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
try {
const { chatId, prompt, modelId } = req.body as {
prompt: ChatItemSimpleType;
prompt: [ChatItemType, ChatItemType];
modelId: string;
chatId: '' | string;
chatId?: string;
};
if (!modelId || !prompt) {
@@ -44,42 +49,69 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const modelConstantsData = ChatModelMap[model.chat.chatModel];
// 读取对话内容
const prompts = [...content, prompt];
let systemPrompts: {
obj: ChatRoleEnum;
value: string;
}[] = [];
const prompts = [...content, prompt[0]];
const {
code = 200,
systemPrompts = [],
quote = []
} = await (async () => {
// 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts, rawSearch } = await appKbSearch({
model,
userId,
prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
});
// 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts } = await searchKb({
userOpenAiKey,
prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
model,
return {
code,
quote: rawSearch,
systemPrompts: searchPrompts
};
}
if (model.chat.systemPrompt) {
return {
systemPrompts: [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
]
};
}
return {};
})();
// get conversationId. create a newId if it is null
const conversationId = chatId || String(new Types.ObjectId());
!chatId && res?.setHeader(NEW_CHATID_HEADER, conversationId);
// search result is empty
if (code === 201) {
const response = systemPrompts[0]?.value;
await saveChat({
chatId,
newChatId: conversationId,
modelId,
prompts: [
prompt[0],
{
...prompt[1],
quote: [],
value: response
}
],
userId
});
// search result is empty
if (code === 201) {
return res.send(searchPrompts[0]?.value);
}
systemPrompts = searchPrompts;
} else if (model.chat.systemPrompt) {
systemPrompts = [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
];
return res.end(response);
}
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
// content check
await sensitiveCheck({
input: [...systemPrompts, prompt].map((item) => item.value).join('')
input: [...systemPrompts, prompt[0]].map((item) => item.value).join('')
});
// 计算温度
@@ -87,54 +119,65 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
2
);
// 发出请求
// 发出 chat 请求
const { streamResponse } = await modelServiceToolMap[model.chat.chatModel].chatCompletion({
apiKey: userOpenAiKey || systemAuthKey,
temperature: +temperature,
messages: prompts,
stream: true,
res,
chatId
chatId: conversationId
});
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
step = 1;
if (res.closed) return res.end();
const { totalTokens, finishMessages } = await resStreamResponse({
model: model.chat.chatModel,
res,
chatResponse: streamResponse,
prompts,
systemPrompt: showModelDetail
? prompts
.filter((item) => item.obj === ChatRoleEnum.System)
.map((item) => item.value)
.join('\n')
: ''
});
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userOpenAiKey,
chatModel: model.chat.chatModel,
userId,
chatId,
textLen: finishMessages.map((item) => item.value).join('').length,
tokens: totalTokens,
type: BillTypeEnum.chat
});
} catch (err: any) {
if (step === 1) {
// 直接结束流
res.end();
console.log('error结束');
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
try {
const { totalTokens, finishMessages, responseContent } = await resStreamResponse({
model: model.chat.chatModel,
res,
chatResponse: streamResponse,
prompts
});
// save chat
await saveChat({
chatId,
newChatId: conversationId,
modelId,
prompts: [
prompt[0],
{
...prompt[1],
quote: showModelDetail ? quote : [],
value: responseContent
}
],
userId
});
res.end();
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userOpenAiKey,
chatModel: model.chat.chatModel,
userId,
chatId: conversationId,
textLen: finishMessages.map((item) => item.value).join('').length,
tokens: totalTokens,
type: BillTypeEnum.chat
});
} catch (error) {
res.end();
console.log('error结束', error);
}
} catch (err: any) {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -73,7 +73,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
_id: '$content._id',
obj: '$content.obj',
value: '$content.value',
systemPrompt: '$content.systemPrompt'
quote: '$content.quote'
}
}
]);

View File

@@ -6,15 +6,17 @@ import { authModel } from '@/service/utils/auth';
import { authUser } from '@/service/utils/auth';
import mongoose from 'mongoose';
type Props = {
newChatId?: string;
chatId?: string;
modelId: string;
prompts: [ChatItemType, ChatItemType];
};
/* 聊天内容存存储 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { chatId, modelId, prompts, newChatId } = req.body as {
newChatId: '' | string;
chatId: '' | string;
modelId: string;
prompts: [ChatItemType, ChatItemType];
};
const { chatId, modelId, prompts, newChatId } = req.body as Props;
if (!prompts) {
throw new Error('缺少参数');
@@ -22,44 +24,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { userId } = await authUser({ req, authToken: true });
await connectToDatabase();
const nId = await saveChat({
chatId,
modelId,
prompts,
newChatId,
userId
});
const content = prompts.map((item) => ({
_id: new mongoose.Types.ObjectId(item._id),
obj: item.obj,
value: item.value,
systemPrompt: item.systemPrompt
}));
await authModel({ modelId, userId, authOwner: false });
// 没有 chatId, 创建一个对话
if (!chatId) {
const { _id } = await Chat.create({
_id: newChatId ? new mongoose.Types.ObjectId(newChatId) : undefined,
userId,
modelId,
content,
title: content[0].value.slice(0, 20),
latestChat: content[1].value
});
return jsonRes(res, {
data: _id
});
} else {
// 已经有记录,追加入库
await Chat.findByIdAndUpdate(chatId, {
$push: {
content: {
$each: content
}
},
title: content[0].value.slice(0, 20),
latestChat: content[1].value,
updateTime: new Date()
});
}
jsonRes(res);
jsonRes(res, {
data: nId
});
} catch (err) {
jsonRes(res, {
code: 500,
@@ -67,3 +42,46 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
});
}
}
export async function saveChat({
chatId,
newChatId,
modelId,
prompts,
userId
}: Props & { userId: string }) {
await connectToDatabase();
await authModel({ modelId, userId, authOwner: false });
const content = prompts.map((item) => ({
_id: item._id ? new mongoose.Types.ObjectId(item._id) : undefined,
obj: item.obj,
value: item.value,
quote: item.quote
}));
// 没有 chatId, 创建一个对话
if (!chatId) {
const { _id } = await Chat.create({
_id: newChatId ? new mongoose.Types.ObjectId(newChatId) : undefined,
userId,
modelId,
content,
title: content[0].value.slice(0, 20),
latestChat: content[1].value
});
return _id;
} else {
// 已经有记录,追加入库
await Chat.findByIdAndUpdate(chatId, {
$push: {
content: {
$each: content
}
},
title: content[0].value.slice(0, 20),
latestChat: content[1].value,
updateTime: new Date()
});
}
}

View File

@@ -7,14 +7,13 @@ import { jsonRes } from '@/service/response';
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
import { pushChatBill, updateShareChatBill } from '@/service/events/pushBill';
import { resStreamResponse } from '@/service/utils/chat';
import { searchKb } from '@/service/plugins/searchKb';
import { ChatRoleEnum } from '@/constants/chat';
import { BillTypeEnum } from '@/constants/user';
import { sensitiveCheck } from '@/service/api/text';
import { appKbSearch } from '../../openapi/kb/appKbSearch';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1 时,表示开始了流响应
res.on('error', () => {
console.log('error: ', 'request error');
res.end();
@@ -42,34 +41,37 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const modelConstantsData = ChatModelMap[model.chat.chatModel];
let systemPrompts: {
obj: ChatRoleEnum;
value: string;
}[] = [];
const { code = 200, systemPrompts = [] } = await (async () => {
// 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts } = await appKbSearch({
model,
userId,
prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
});
// 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts } = await searchKb({
userOpenAiKey,
prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
model,
userId
});
// search result is empty
if (code === 201) {
return res.send(searchPrompts[0]?.value);
return {
code,
systemPrompts: searchPrompts
};
}
if (model.chat.systemPrompt) {
return {
systemPrompts: [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
]
};
}
return {};
})();
systemPrompts = searchPrompts;
} else if (model.chat.systemPrompt) {
systemPrompts = [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
];
// search result is empty
if (code === 201) {
return res.send(systemPrompts[0]?.value);
}
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
@@ -96,40 +98,40 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
step = 1;
if (res.closed) return res.end();
const { totalTokens, finishMessages } = await resStreamResponse({
model: model.chat.chatModel,
res,
chatResponse: streamResponse,
prompts,
systemPrompt: ''
});
/* bill */
pushChatBill({
isPay: !userOpenAiKey,
chatModel: model.chat.chatModel,
userId,
textLen: finishMessages.map((item) => item.value).join('').length,
tokens: totalTokens,
type: BillTypeEnum.chat
});
updateShareChatBill({
shareId,
tokens: totalTokens
});
} catch (err: any) {
if (step === 1) {
// 直接结束流
res.end();
console.log('error结束');
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
try {
const { totalTokens, finishMessages } = await resStreamResponse({
model: model.chat.chatModel,
res,
chatResponse: streamResponse,
prompts
});
res.end();
/* bill */
pushChatBill({
isPay: !userOpenAiKey,
chatModel: model.chat.chatModel,
userId,
textLen: finishMessages.map((item) => item.value).join('').length,
tokens: totalTokens,
type: BillTypeEnum.chat
});
updateShareChatBill({
shareId,
tokens: totalTokens
});
} catch (error) {
res.end();
console.log('error结束', error);
}
} catch (err: any) {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}