feat: chat quote

This commit is contained in:
archer
2023-05-23 15:09:57 +08:00
parent ee2c259c3d
commit 944e876aaa
29 changed files with 933 additions and 660 deletions

View File

@@ -1,8 +1,8 @@
import { openaiCreateEmbedding } from '../utils/chat/openai';
import { getApiKey } from '../utils/auth';
import { openaiError2 } from '../errorCode';
import { PgClient } from '@/service/pg';
import { getErrText } from '@/utils/tools';
import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding';
export async function generateVector(next = false): Promise<any> {
if (process.env.queueTask !== '1') {
@@ -42,24 +42,20 @@ export async function generateVector(next = false): Promise<any> {
dataId = dataItem.id;
// 获取 openapi Key
let userOpenAiKey;
try {
const res = await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId });
userOpenAiKey = res.userOpenAiKey;
await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId });
} catch (err: any) {
await PgClient.delete('modelData', {
where: [['id', dataId]]
});
generateVector(true);
getErrText(err, '获取 OpenAi Key 失败');
return;
return generateVector(true);
}
// 生成词向量
const { vectors } = await openaiCreateEmbedding({
textArr: [dataItem.q],
userId: dataItem.userId,
userOpenAiKey
const vectors = await openaiEmbedding({
input: [dataItem.q],
userId: dataItem.userId
});
// 更新 pg 向量和状态数据

View File

@@ -47,10 +47,14 @@ const ChatSchema = new Schema({
type: String,
required: true
},
systemPrompt: {
type: String,
default: ''
quote: {
type: [{ id: String, q: String, a: String }],
default: []
}
// systemPrompt: {
// type: String,
// default: ''
// }
}
],
default: []

View File

@@ -1,175 +0,0 @@
import { PgClient } from '@/service/pg';
import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/constants/model';
import { ModelSchema } from '@/types/mongoSchema';
import { openaiCreateEmbedding } from '../utils/chat/openai';
import { ChatRoleEnum } from '@/constants/chat';
import { modelToolMap } from '@/utils/chat';
import { ChatItemSimpleType } from '@/types/chat';
/**
* use openai embedding search kb
*/
export const searchKb = async ({
userOpenAiKey,
prompts,
similarity = 0.2,
model,
userId
}: {
userOpenAiKey?: string;
prompts: ChatItemSimpleType[];
model: ModelSchema;
userId: string;
similarity?: number;
}): Promise<{
code: 200 | 201;
searchPrompts: {
obj: ChatRoleEnum;
value: string;
}[];
}> => {
async function search(textArr: string[] = []) {
const limitMap: Record<ModelVectorSearchModeEnum, number> = {
[ModelVectorSearchModeEnum.hightSimilarity]: 15,
[ModelVectorSearchModeEnum.noContext]: 15,
[ModelVectorSearchModeEnum.lowSimilarity]: 20
};
// 获取提示词的向量
const { vectors: promptVectors } = await openaiCreateEmbedding({
userOpenAiKey,
userId,
textArr
});
const searchRes = await Promise.all(
promptVectors.map((promptVector) =>
PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: limitMap[model.chat.searchMode]
}).then((res) => res.rows)
)
);
// Remove repeat record
const idSet = new Set<string>();
const filterSearch = searchRes.map((search) =>
search.filter((item) => {
if (idSet.has(item.id)) {
return false;
}
idSet.add(item.id);
return true;
})
);
return filterSearch.map((item) => item.map((item) => `${item.q}\n${item.a}`).join('\n'));
}
const modelConstantsData = ChatModelMap[model.chat.chatModel];
// search three times
const userPrompts = prompts.filter((item) => item.obj === 'Human');
const searchArr: string[] = [
userPrompts[userPrompts.length - 1].value,
userPrompts[userPrompts.length - 2]?.value
].filter((item) => item);
const systemPrompts = await search(searchArr);
// filter system prompts.
const filterRateMap: Record<number, number[]> = {
1: [1],
2: [0.7, 0.3]
};
const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0];
// 计算固定提示词的 token 数量
const fixedPrompts = [
...(model.chat.systemPrompt
? [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
]
: []),
...(model.chat.searchMode === ModelVectorSearchModeEnum.noContext
? [
{
obj: ChatRoleEnum.System,
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
}
]
: [
{
obj: ChatRoleEnum.System,
value: `玩一个问答游戏,规则为:
1.你完全忘记你已有的知识
2.你只回答关于"${model.name}"的问题
3.你只从知识库中选择内容进行回答
4.如果问题不在知识库中,你会回答:"我不知道。"
请务必遵守规则`
}
])
];
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
messages: fixedPrompts
});
const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens;
const filterSystemPrompt = filterRate
.map((rate, i) =>
modelToolMap[model.chat.chatModel].sliceText({
text: systemPrompts[i],
length: Math.floor(maxTokens * rate)
})
)
.join('\n')
.trim();
/* 高相似度+不回复 */
if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) {
return {
code: 201,
searchPrompts: [
{
obj: ChatRoleEnum.System,
value: '对不起,你的问题不在知识库中。'
}
]
};
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
return {
code: 200,
searchPrompts: model.chat.systemPrompt
? [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
]
: []
};
}
/* 有匹配 */
return {
code: 200,
searchPrompts: [
{
obj: ChatRoleEnum.System,
value: `知识库:${filterSystemPrompt}`
},
...fixedPrompts
]
};
};

View File

@@ -38,12 +38,14 @@ export const authUser = async ({
req,
authToken = false,
authOpenApi = false,
authRoot = false
authRoot = false,
authBalance = false
}: {
req: NextApiRequest;
authToken?: boolean;
authOpenApi?: boolean;
authRoot?: boolean;
authBalance?: boolean;
}) => {
const parseOpenApiKey = async (apiKey?: string) => {
if (!apiKey) {
@@ -99,6 +101,17 @@ export const authUser = async ({
return Promise.reject(ERROR_ENUM.unAuthorization);
}
if (authBalance) {
const user = await User.findById(uid);
if (!user) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
if (!user.openaiKey && formatPrice(user.balance) <= 0) {
return Promise.reject(ERROR_ENUM.insufficientQuota);
}
}
return {
userId: uid
};
@@ -226,7 +239,7 @@ export const authChat = async ({
req
}: {
modelId: string;
chatId: '' | string;
chatId?: string;
req: NextApiRequest;
}) => {
const { userId } = await authUser({ req, authToken: true });

View File

@@ -1,17 +1,9 @@
import { ChatCompletionType, StreamResponseType } from './index';
import { ChatRoleEnum } from '@/constants/chat';
import axios from 'axios';
import mongoose from 'mongoose';
import { NEW_CHATID_HEADER } from '@/constants/chat';
/* 模型对话 */
export const claudChat = async ({ apiKey, messages, stream, chatId, res }: ChatCompletionType) => {
const conversationId = chatId || String(new mongoose.Types.ObjectId());
// create a new chat
!chatId &&
messages.filter((item) => item.obj === 'Human').length === 1 &&
res?.setHeader(NEW_CHATID_HEADER, conversationId);
export const claudChat = async ({ apiKey, messages, stream, chatId }: ChatCompletionType) => {
// get system prompt
const systemPrompt = messages
.filter((item) => item.obj === 'System')
@@ -26,7 +18,7 @@ export const claudChat = async ({ apiKey, messages, stream, chatId, res }: ChatC
{
prompt,
stream,
conversationId
conversationId: chatId
},
{
headers: {
@@ -55,8 +47,7 @@ export const claudStreamResponse = async ({ res, chatResponse, prompts }: Stream
try {
const decoder = new TextDecoder();
for await (const chunk of chatResponse.data as any) {
if (!res.writable) {
// 流被中断了,直接忽略后面的内容
if (res.closed) {
break;
}
const content = decoder.decode(chunk);

View File

@@ -1,7 +1,7 @@
import { ChatItemSimpleType } from '@/types/chat';
import { modelToolMap } from '@/utils/chat';
import { modelToolMap } from '@/utils/plugin';
import type { ChatModelType } from '@/constants/model';
import { ChatRoleEnum, SYSTEM_PROMPT_HEADER } from '@/constants/chat';
import { ChatRoleEnum } from '@/constants/chat';
import { OpenAiChatEnum, ClaudeEnum } from '@/constants/model';
import { chatResponse, openAiStreamResponse } from './openai';
import { claudChat, claudStreamResponse } from './claude';
@@ -11,6 +11,7 @@ export type ChatCompletionType = {
apiKey: string;
temperature: number;
messages: ChatItemSimpleType[];
chatId?: string;
[key: string]: any;
};
export type ChatCompletionResponseType = {
@@ -23,7 +24,6 @@ export type StreamResponseType = {
chatResponse: any;
prompts: ChatItemSimpleType[];
res: NextApiResponse;
systemPrompt?: string;
[key: string]: any;
};
export type StreamResponseReturnType = {
@@ -129,7 +129,6 @@ export const resStreamResponse = async ({
model,
res,
chatResponse,
systemPrompt,
prompts
}: StreamResponseType & {
model: ChatModelType;
@@ -139,18 +138,14 @@ export const resStreamResponse = async ({
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
systemPrompt && res.setHeader(SYSTEM_PROMPT_HEADER, encodeURIComponent(systemPrompt));
const { responseContent, totalTokens, finishMessages } = await modelServiceToolMap[
model
].streamResponse({
chatResponse,
prompts,
res,
systemPrompt
res
});
res.end();
return { responseContent, totalTokens, finishMessages };
};

View File

@@ -1,13 +1,11 @@
import { Configuration, OpenAIApi } from 'openai';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { axiosConfig } from '../tools';
import { ChatModelMap, embeddingModel, OpenAiChatEnum } from '@/constants/model';
import { pushGenerateVectorBill } from '../../events/pushBill';
import { adaptChatItem_openAI } from '@/utils/chat/openai';
import { modelToolMap } from '@/utils/chat';
import { ChatModelMap, OpenAiChatEnum } from '@/constants/model';
import { adaptChatItem_openAI } from '@/utils/plugin/openai';
import { modelToolMap } from '@/utils/plugin';
import { ChatCompletionType, ChatContextFilter, StreamResponseType } from './index';
import { ChatRoleEnum } from '@/constants/chat';
import { getSystemOpenAiKey } from '../auth';
export const getOpenAIApi = () =>
new OpenAIApi(
@@ -16,51 +14,6 @@ export const getOpenAIApi = () =>
})
);
/* 获取向量 */
export const openaiCreateEmbedding = async ({
userOpenAiKey,
userId,
textArr
}: {
userOpenAiKey?: string;
userId: string;
textArr: string[];
}) => {
const systemAuthKey = getSystemOpenAiKey();
// 获取 chatAPI
const chatAPI = getOpenAIApi();
// 把输入的内容转成向量
const res = await chatAPI
.createEmbedding(
{
model: embeddingModel,
input: textArr
},
{
timeout: 60000,
...axiosConfig(userOpenAiKey || systemAuthKey)
}
)
.then((res) => ({
tokenLen: res.data.usage.total_tokens || 0,
vectors: res.data.data.map((item) => item.embedding)
}));
pushGenerateVectorBill({
isPay: !userOpenAiKey,
userId,
text: textArr.join(''),
tokenLen: res.tokenLen
});
return {
vectors: res.vectors,
chatAPI
};
};
/* 模型对话 */
export const chatResponse = async ({
model,
@@ -127,7 +80,7 @@ export const openAiStreamResponse = async ({
const content: string = json?.choices?.[0].delta.content || '';
responseContent += content;
res.writable && content && res.write(content);
!res.closed && content && res.write(content);
} catch (error) {
error;
}
@@ -137,8 +90,7 @@ export const openAiStreamResponse = async ({
const decoder = new TextDecoder();
const parser = createParser(onParse);
for await (const chunk of chatResponse.data as any) {
if (!res.writable) {
// 流被中断了,直接忽略后面的内容
if (res.closed) {
break;
}
parser.feed(decoder.decode(chunk, { stream: true }));