feat: use last quote

This commit is contained in:
archer
2023-05-30 21:18:08 +08:00
parent 59ddf09b94
commit 0cde9a10a8
7 changed files with 86 additions and 81 deletions

View File

@@ -50,6 +50,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容
const prompts = [...content, prompt[0]];
const {
code = 200,
systemPrompts = [],
@@ -61,7 +62,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({
model,
userId,
prompts,
fixedQuote: content[content.length - 1]?.quote || [],
prompt: prompt[0],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
});
@@ -114,7 +116,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
return res.end(response);
}
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
prompts.unshift(...systemPrompts);
// content check
await sensitiveCheck({

View File

@@ -47,7 +47,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { code, searchPrompts } = await appKbSearch({
model,
userId,
prompts,
fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
});
@@ -74,7 +75,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
return res.send(systemPrompts[0]?.value);
}
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
prompts.unshift(...systemPrompts);
// content check
await sensitiveCheck({

View File

@@ -75,10 +75,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts } = await appKbSearch({
prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
model,
userId
userId,
fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
});
// search result is empty
@@ -101,7 +102,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
];
}
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
prompts.unshift(...systemPrompts);
// content check
await sensitiveCheck({

View File

@@ -49,10 +49,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
});
const result = await appKbSearch({
model,
userId,
prompts,
similarity,
model
fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity
});
jsonRes<Response>(res, {
@@ -70,67 +71,53 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
export async function appKbSearch({
model,
userId,
prompts,
fixedQuote,
prompt,
similarity
}: {
userId: string;
prompts: ChatItemSimpleType[];
similarity: number;
model: ModelSchema;
userId: string;
fixedQuote: QuoteItemType[];
prompt: ChatItemSimpleType;
similarity: number;
}): Promise<Response> {
const modelConstantsData = ChatModelMap[model.chat.chatModel];
// search two times.
const userPrompts = prompts.filter((item) => item.obj === 'Human');
const input: string[] = [
userPrompts[userPrompts.length - 1].value,
userPrompts[userPrompts.length - 2]?.value
].filter((item) => item);
// get vector
const promptVectors = await openaiEmbedding({
const promptVector = await openaiEmbedding({
userId,
input,
input: [prompt.value],
type: 'chat'
});
// search kb
const searchRes = await Promise.all(
promptVectors.map((promptVector) =>
PgClient.select<QuoteItemType>('modelData', {
fields: ['id', 'q', 'a'],
where: [
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: promptVectors.length === 1 ? 15 : 10
}).then((res) => res.rows)
)
);
const { rows: searchRes } = await PgClient.select<QuoteItemType>('modelData', {
fields: ['id', 'q', 'a'],
where: [
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
'AND',
`vector <=> '[${promptVector[0]}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector[0]}]'` }],
limit: 8
});
// filter same search result
const idSet = new Set<string>();
const filterSearch = searchRes.map((search) =>
search.filter((item) => {
if (idSet.has(item.id)) {
return false;
}
idSet.add(item.id);
return true;
})
);
const filterSearch = [
...searchRes.slice(0, 3),
...fixedQuote.slice(0, 2),
...searchRes.slice(3),
...fixedQuote.slice(2, 5)
].filter((item) => {
if (idSet.has(item.id)) {
return false;
}
idSet.add(item.id);
return true;
});
// slice search result by rate.
const sliceRateMap: Record<number, number[]> = {
1: [1],
2: [0.7, 0.3]
};
const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0];
// 计算固定提示词的 token 数量
const guidePrompt = model.chat.systemPrompt // user system prompt
? {
obj: ChatRoleEnum.System,
@@ -154,24 +141,21 @@ export async function appKbSearch({
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
messages: [guidePrompt]
});
const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens;
const sliceResult = sliceRate.map((rate, i) =>
modelToolMap[model.chat.chatModel]
.tokenSlice({
maxToken: Math.round(maxTokens * rate),
messages: filterSearch[i].map((item) => ({
obj: ChatRoleEnum.System,
value: `${item.q}\n${item.a}`
}))
})
.map((item) => item.value)
);
const sliceResult = modelToolMap[model.chat.chatModel]
.tokenSlice({
maxToken: modelConstantsData.systemMaxToken - fixedSystemTokens,
messages: filterSearch.map((item) => ({
obj: ChatRoleEnum.System,
value: `${item.q}\n${item.a}`
}))
})
.map((item) => item.value);
// slice filterSearch
const sliceSearch = filterSearch.map((item, i) => item.slice(0, sliceResult[i].length)).flat();
const rawSearch = filterSearch.slice(0, sliceResult.length);
// system prompt
const systemPrompt = sliceResult.flat().join('\n').trim();
const systemPrompt = sliceResult.join('\n').trim();
/* 高相似度+不回复 */
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
@@ -206,7 +190,7 @@ export async function appKbSearch({
return {
code: 200,
rawSearch: sliceSearch,
rawSearch,
guidePrompt: guidePrompt.value || '',
searchPrompts: [
{