feat: lafClaude

This commit is contained in:
archer
2023-05-04 10:53:55 +08:00
parent 3c8f38799c
commit 0d6897e180
22 changed files with 327 additions and 231 deletions

View File

@@ -3,22 +3,20 @@ import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/
import { ModelSchema } from '@/types/mongoSchema';
import { openaiCreateEmbedding } from '../utils/chat/openai';
import { ChatRoleEnum } from '@/constants/chat';
import { sliceTextByToken } from '@/utils/chat';
import { modelToolMap } from '@/utils/chat';
import { ChatItemSimpleType } from '@/types/chat';
/**
* use openai embedding search kb
*/
export const searchKb = async ({
userApiKey,
systemApiKey,
userOpenAiKey,
prompts,
similarity = 0.2,
model,
userId
}: {
userApiKey?: string;
systemApiKey: string;
userOpenAiKey?: string;
prompts: ChatItemSimpleType[];
model: ModelSchema;
userId: string;
@@ -33,8 +31,7 @@ export const searchKb = async ({
async function search(textArr: string[] = []) {
// 获取提示词的向量
const { vectors: promptVectors } = await openaiCreateEmbedding({
userApiKey,
systemApiKey,
userOpenAiKey,
userId,
textArr
});
@@ -81,11 +78,24 @@ export const searchKb = async ({
].filter((item) => item);
const systemPrompts = await search(searchArr);
// filter system prompt
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
// filter system prompts.
const filterRateMap: Record<number, number[]> = {
1: [1],
2: [0.7, 0.3]
};
const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0];
const filterSystemPrompt = filterRate
.map((rate, i) =>
modelToolMap[model.chat.chatModel].sliceText({
text: systemPrompts[i],
length: Math.floor(modelConstantsData.systemMaxToken * rate)
})
)
.join('\n');
/* 高相似度+不回复 */
if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) {
return {
code: 201,
searchPrompt: {
@@ -95,7 +105,7 @@ export const searchKb = async ({
};
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (systemPrompts.length === 0 && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
return {
code: 200,
searchPrompt: model.chat.systemPrompt
@@ -107,25 +117,7 @@ export const searchKb = async ({
};
}
/* 有匹配情况下system 添加知识库内容。 */
// filter system prompts. max 70% tokens
const filterRateMap: Record<number, number[]> = {
1: [0.7],
2: [0.5, 0.2]
};
const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0];
const filterSystemPrompt = filterRate
.map((rate, i) =>
sliceTextByToken({
model: model.chat.chatModel,
text: systemPrompts[i],
length: Math.floor(modelConstantsData.contextMaxToken * rate)
})
)
.join('\n');
/* 有匹配 */
return {
code: 200,
searchPrompt: {
@@ -133,9 +125,9 @@ export const searchKb = async ({
value: `
${model.chat.systemPrompt}
${
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : ''
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? '不回答知识库外的内容.' : ''
}
知识库内容为: ${filterSystemPrompt}'
知识库内容为: '${filterSystemPrompt}'
`
}
};