feat: chat quote
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import { openaiCreateEmbedding } from '../utils/chat/openai';
|
||||
import { getApiKey } from '../utils/auth';
|
||||
import { openaiError2 } from '../errorCode';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { getErrText } from '@/utils/tools';
|
||||
import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding';
|
||||
|
||||
export async function generateVector(next = false): Promise<any> {
|
||||
if (process.env.queueTask !== '1') {
|
||||
@@ -42,24 +42,20 @@ export async function generateVector(next = false): Promise<any> {
|
||||
dataId = dataItem.id;
|
||||
|
||||
// 获取 openapi Key
|
||||
let userOpenAiKey;
|
||||
try {
|
||||
const res = await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId });
|
||||
userOpenAiKey = res.userOpenAiKey;
|
||||
await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId });
|
||||
} catch (err: any) {
|
||||
await PgClient.delete('modelData', {
|
||||
where: [['id', dataId]]
|
||||
});
|
||||
generateVector(true);
|
||||
getErrText(err, '获取 OpenAi Key 失败');
|
||||
return;
|
||||
return generateVector(true);
|
||||
}
|
||||
|
||||
// 生成词向量
|
||||
const { vectors } = await openaiCreateEmbedding({
|
||||
textArr: [dataItem.q],
|
||||
userId: dataItem.userId,
|
||||
userOpenAiKey
|
||||
const vectors = await openaiEmbedding({
|
||||
input: [dataItem.q],
|
||||
userId: dataItem.userId
|
||||
});
|
||||
|
||||
// 更新 pg 向量和状态数据
|
||||
|
||||
@@ -47,10 +47,14 @@ const ChatSchema = new Schema({
|
||||
type: String,
|
||||
required: true
|
||||
},
|
||||
systemPrompt: {
|
||||
type: String,
|
||||
default: ''
|
||||
quote: {
|
||||
type: [{ id: String, q: String, a: String }],
|
||||
default: []
|
||||
}
|
||||
// systemPrompt: {
|
||||
// type: String,
|
||||
// default: ''
|
||||
// }
|
||||
}
|
||||
],
|
||||
default: []
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/constants/model';
|
||||
import { ModelSchema } from '@/types/mongoSchema';
|
||||
import { openaiCreateEmbedding } from '../utils/chat/openai';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
import { modelToolMap } from '@/utils/chat';
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
|
||||
/**
|
||||
* use openai embedding search kb
|
||||
*/
|
||||
export const searchKb = async ({
|
||||
userOpenAiKey,
|
||||
prompts,
|
||||
similarity = 0.2,
|
||||
model,
|
||||
userId
|
||||
}: {
|
||||
userOpenAiKey?: string;
|
||||
prompts: ChatItemSimpleType[];
|
||||
model: ModelSchema;
|
||||
userId: string;
|
||||
similarity?: number;
|
||||
}): Promise<{
|
||||
code: 200 | 201;
|
||||
searchPrompts: {
|
||||
obj: ChatRoleEnum;
|
||||
value: string;
|
||||
}[];
|
||||
}> => {
|
||||
async function search(textArr: string[] = []) {
|
||||
const limitMap: Record<ModelVectorSearchModeEnum, number> = {
|
||||
[ModelVectorSearchModeEnum.hightSimilarity]: 15,
|
||||
[ModelVectorSearchModeEnum.noContext]: 15,
|
||||
[ModelVectorSearchModeEnum.lowSimilarity]: 20
|
||||
};
|
||||
// 获取提示词的向量
|
||||
const { vectors: promptVectors } = await openaiCreateEmbedding({
|
||||
userOpenAiKey,
|
||||
userId,
|
||||
textArr
|
||||
});
|
||||
|
||||
const searchRes = await Promise.all(
|
||||
promptVectors.map((promptVector) =>
|
||||
PgClient.select<{ id: string; q: string; a: string }>('modelData', {
|
||||
fields: ['id', 'q', 'a'],
|
||||
where: [
|
||||
['status', ModelDataStatusEnum.ready],
|
||||
'AND',
|
||||
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
|
||||
'AND',
|
||||
`vector <=> '[${promptVector}]' < ${similarity}`
|
||||
],
|
||||
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
||||
limit: limitMap[model.chat.searchMode]
|
||||
}).then((res) => res.rows)
|
||||
)
|
||||
);
|
||||
|
||||
// Remove repeat record
|
||||
const idSet = new Set<string>();
|
||||
const filterSearch = searchRes.map((search) =>
|
||||
search.filter((item) => {
|
||||
if (idSet.has(item.id)) {
|
||||
return false;
|
||||
}
|
||||
idSet.add(item.id);
|
||||
return true;
|
||||
})
|
||||
);
|
||||
|
||||
return filterSearch.map((item) => item.map((item) => `${item.q}\n${item.a}`).join('\n'));
|
||||
}
|
||||
const modelConstantsData = ChatModelMap[model.chat.chatModel];
|
||||
|
||||
// search three times
|
||||
const userPrompts = prompts.filter((item) => item.obj === 'Human');
|
||||
|
||||
const searchArr: string[] = [
|
||||
userPrompts[userPrompts.length - 1].value,
|
||||
userPrompts[userPrompts.length - 2]?.value
|
||||
].filter((item) => item);
|
||||
const systemPrompts = await search(searchArr);
|
||||
|
||||
// filter system prompts.
|
||||
const filterRateMap: Record<number, number[]> = {
|
||||
1: [1],
|
||||
2: [0.7, 0.3]
|
||||
};
|
||||
const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0];
|
||||
|
||||
// 计算固定提示词的 token 数量
|
||||
const fixedPrompts = [
|
||||
...(model.chat.systemPrompt
|
||||
? [
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: model.chat.systemPrompt
|
||||
}
|
||||
]
|
||||
: []),
|
||||
...(model.chat.searchMode === ModelVectorSearchModeEnum.noContext
|
||||
? [
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
|
||||
}
|
||||
]
|
||||
: [
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `玩一个问答游戏,规则为:
|
||||
1.你完全忘记你已有的知识
|
||||
2.你只回答关于"${model.name}"的问题
|
||||
3.你只从知识库中选择内容进行回答
|
||||
4.如果问题不在知识库中,你会回答:"我不知道。"
|
||||
请务必遵守规则`
|
||||
}
|
||||
])
|
||||
];
|
||||
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
|
||||
messages: fixedPrompts
|
||||
});
|
||||
const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens;
|
||||
|
||||
const filterSystemPrompt = filterRate
|
||||
.map((rate, i) =>
|
||||
modelToolMap[model.chat.chatModel].sliceText({
|
||||
text: systemPrompts[i],
|
||||
length: Math.floor(maxTokens * rate)
|
||||
})
|
||||
)
|
||||
.join('\n')
|
||||
.trim();
|
||||
|
||||
/* 高相似度+不回复 */
|
||||
if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) {
|
||||
return {
|
||||
code: 201,
|
||||
searchPrompts: [
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: '对不起,你的问题不在知识库中。'
|
||||
}
|
||||
]
|
||||
};
|
||||
}
|
||||
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
|
||||
if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
|
||||
return {
|
||||
code: 200,
|
||||
searchPrompts: model.chat.systemPrompt
|
||||
? [
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: model.chat.systemPrompt
|
||||
}
|
||||
]
|
||||
: []
|
||||
};
|
||||
}
|
||||
|
||||
/* 有匹配 */
|
||||
return {
|
||||
code: 200,
|
||||
searchPrompts: [
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `知识库:${filterSystemPrompt}`
|
||||
},
|
||||
...fixedPrompts
|
||||
]
|
||||
};
|
||||
};
|
||||
@@ -38,12 +38,14 @@ export const authUser = async ({
|
||||
req,
|
||||
authToken = false,
|
||||
authOpenApi = false,
|
||||
authRoot = false
|
||||
authRoot = false,
|
||||
authBalance = false
|
||||
}: {
|
||||
req: NextApiRequest;
|
||||
authToken?: boolean;
|
||||
authOpenApi?: boolean;
|
||||
authRoot?: boolean;
|
||||
authBalance?: boolean;
|
||||
}) => {
|
||||
const parseOpenApiKey = async (apiKey?: string) => {
|
||||
if (!apiKey) {
|
||||
@@ -99,6 +101,17 @@ export const authUser = async ({
|
||||
return Promise.reject(ERROR_ENUM.unAuthorization);
|
||||
}
|
||||
|
||||
if (authBalance) {
|
||||
const user = await User.findById(uid);
|
||||
if (!user) {
|
||||
return Promise.reject(ERROR_ENUM.unAuthorization);
|
||||
}
|
||||
|
||||
if (!user.openaiKey && formatPrice(user.balance) <= 0) {
|
||||
return Promise.reject(ERROR_ENUM.insufficientQuota);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
userId: uid
|
||||
};
|
||||
@@ -226,7 +239,7 @@ export const authChat = async ({
|
||||
req
|
||||
}: {
|
||||
modelId: string;
|
||||
chatId: '' | string;
|
||||
chatId?: string;
|
||||
req: NextApiRequest;
|
||||
}) => {
|
||||
const { userId } = await authUser({ req, authToken: true });
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
import { ChatCompletionType, StreamResponseType } from './index';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
import axios from 'axios';
|
||||
import mongoose from 'mongoose';
|
||||
import { NEW_CHATID_HEADER } from '@/constants/chat';
|
||||
|
||||
/* 模型对话 */
|
||||
export const claudChat = async ({ apiKey, messages, stream, chatId, res }: ChatCompletionType) => {
|
||||
const conversationId = chatId || String(new mongoose.Types.ObjectId());
|
||||
// create a new chat
|
||||
!chatId &&
|
||||
messages.filter((item) => item.obj === 'Human').length === 1 &&
|
||||
res?.setHeader(NEW_CHATID_HEADER, conversationId);
|
||||
|
||||
export const claudChat = async ({ apiKey, messages, stream, chatId }: ChatCompletionType) => {
|
||||
// get system prompt
|
||||
const systemPrompt = messages
|
||||
.filter((item) => item.obj === 'System')
|
||||
@@ -26,7 +18,7 @@ export const claudChat = async ({ apiKey, messages, stream, chatId, res }: ChatC
|
||||
{
|
||||
prompt,
|
||||
stream,
|
||||
conversationId
|
||||
conversationId: chatId
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
@@ -55,8 +47,7 @@ export const claudStreamResponse = async ({ res, chatResponse, prompts }: Stream
|
||||
try {
|
||||
const decoder = new TextDecoder();
|
||||
for await (const chunk of chatResponse.data as any) {
|
||||
if (!res.writable) {
|
||||
// 流被中断了,直接忽略后面的内容
|
||||
if (res.closed) {
|
||||
break;
|
||||
}
|
||||
const content = decoder.decode(chunk);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { modelToolMap } from '@/utils/chat';
|
||||
import { modelToolMap } from '@/utils/plugin';
|
||||
import type { ChatModelType } from '@/constants/model';
|
||||
import { ChatRoleEnum, SYSTEM_PROMPT_HEADER } from '@/constants/chat';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
import { OpenAiChatEnum, ClaudeEnum } from '@/constants/model';
|
||||
import { chatResponse, openAiStreamResponse } from './openai';
|
||||
import { claudChat, claudStreamResponse } from './claude';
|
||||
@@ -11,6 +11,7 @@ export type ChatCompletionType = {
|
||||
apiKey: string;
|
||||
temperature: number;
|
||||
messages: ChatItemSimpleType[];
|
||||
chatId?: string;
|
||||
[key: string]: any;
|
||||
};
|
||||
export type ChatCompletionResponseType = {
|
||||
@@ -23,7 +24,6 @@ export type StreamResponseType = {
|
||||
chatResponse: any;
|
||||
prompts: ChatItemSimpleType[];
|
||||
res: NextApiResponse;
|
||||
systemPrompt?: string;
|
||||
[key: string]: any;
|
||||
};
|
||||
export type StreamResponseReturnType = {
|
||||
@@ -129,7 +129,6 @@ export const resStreamResponse = async ({
|
||||
model,
|
||||
res,
|
||||
chatResponse,
|
||||
systemPrompt,
|
||||
prompts
|
||||
}: StreamResponseType & {
|
||||
model: ChatModelType;
|
||||
@@ -139,18 +138,14 @@ export const resStreamResponse = async ({
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
systemPrompt && res.setHeader(SYSTEM_PROMPT_HEADER, encodeURIComponent(systemPrompt));
|
||||
|
||||
const { responseContent, totalTokens, finishMessages } = await modelServiceToolMap[
|
||||
model
|
||||
].streamResponse({
|
||||
chatResponse,
|
||||
prompts,
|
||||
res,
|
||||
systemPrompt
|
||||
res
|
||||
});
|
||||
|
||||
res.end();
|
||||
|
||||
return { responseContent, totalTokens, finishMessages };
|
||||
};
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import { Configuration, OpenAIApi } from 'openai';
|
||||
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
|
||||
import { axiosConfig } from '../tools';
|
||||
import { ChatModelMap, embeddingModel, OpenAiChatEnum } from '@/constants/model';
|
||||
import { pushGenerateVectorBill } from '../../events/pushBill';
|
||||
import { adaptChatItem_openAI } from '@/utils/chat/openai';
|
||||
import { modelToolMap } from '@/utils/chat';
|
||||
import { ChatModelMap, OpenAiChatEnum } from '@/constants/model';
|
||||
import { adaptChatItem_openAI } from '@/utils/plugin/openai';
|
||||
import { modelToolMap } from '@/utils/plugin';
|
||||
import { ChatCompletionType, ChatContextFilter, StreamResponseType } from './index';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
import { getSystemOpenAiKey } from '../auth';
|
||||
|
||||
export const getOpenAIApi = () =>
|
||||
new OpenAIApi(
|
||||
@@ -16,51 +14,6 @@ export const getOpenAIApi = () =>
|
||||
})
|
||||
);
|
||||
|
||||
/* 获取向量 */
|
||||
export const openaiCreateEmbedding = async ({
|
||||
userOpenAiKey,
|
||||
userId,
|
||||
textArr
|
||||
}: {
|
||||
userOpenAiKey?: string;
|
||||
userId: string;
|
||||
textArr: string[];
|
||||
}) => {
|
||||
const systemAuthKey = getSystemOpenAiKey();
|
||||
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi();
|
||||
|
||||
// 把输入的内容转成向量
|
||||
const res = await chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: embeddingModel,
|
||||
input: textArr
|
||||
},
|
||||
{
|
||||
timeout: 60000,
|
||||
...axiosConfig(userOpenAiKey || systemAuthKey)
|
||||
}
|
||||
)
|
||||
.then((res) => ({
|
||||
tokenLen: res.data.usage.total_tokens || 0,
|
||||
vectors: res.data.data.map((item) => item.embedding)
|
||||
}));
|
||||
|
||||
pushGenerateVectorBill({
|
||||
isPay: !userOpenAiKey,
|
||||
userId,
|
||||
text: textArr.join(''),
|
||||
tokenLen: res.tokenLen
|
||||
});
|
||||
|
||||
return {
|
||||
vectors: res.vectors,
|
||||
chatAPI
|
||||
};
|
||||
};
|
||||
|
||||
/* 模型对话 */
|
||||
export const chatResponse = async ({
|
||||
model,
|
||||
@@ -127,7 +80,7 @@ export const openAiStreamResponse = async ({
|
||||
const content: string = json?.choices?.[0].delta.content || '';
|
||||
responseContent += content;
|
||||
|
||||
res.writable && content && res.write(content);
|
||||
!res.closed && content && res.write(content);
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
@@ -137,8 +90,7 @@ export const openAiStreamResponse = async ({
|
||||
const decoder = new TextDecoder();
|
||||
const parser = createParser(onParse);
|
||||
for await (const chunk of chatResponse.data as any) {
|
||||
if (!res.writable) {
|
||||
// 流被中断了,直接忽略后面的内容
|
||||
if (res.closed) {
|
||||
break;
|
||||
}
|
||||
parser.feed(decoder.decode(chunk, { stream: true }));
|
||||
|
||||
Reference in New Issue
Block a user