feat: 替换redis搜索

This commit is contained in:
archer
2023-04-19 12:00:28 +08:00
parent 867d69659f
commit 1e5714da1b
12 changed files with 147 additions and 228 deletions

View File

@@ -7,12 +7,15 @@ import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } fr
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { ChatModelNameEnum, modelList, ChatModelNameMap } from '@/constants/model';
import {
ChatModelNameEnum,
modelList,
ChatModelNameMap,
ModelVectorSearchModeMap
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import { PgClient } from '@/service/pg';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -46,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
await connectToDatabase();
const redis = await connectRedis();
let startTime = Date.now();
/* 凭证校验 */
@@ -144,39 +146,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容
const prompts = [prompt];
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataPrefix}:hash`,
`@modelId:{${String(model._id)}}=>[KNN 20 @vector $blob AS score]`,
'RETURN',
'1',
'text',
'SORTBY',
'score',
'PARAMS',
'2',
'blob',
vectorToBuffer(promptVector),
'DIALECT',
'2'
]);
// 相似度搜索
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
where: [
['model_id', model._id],
'AND',
['user_id', userId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
limit: 30
});
// 格式化响应值,获取 qa
const formatRedisPrompt: string[] = [];
for (let i = 2; i < 42; i += 2) {
const text = redisData[i]?.[1];
if (text) {
formatRedisPrompt.push(text);
}
}
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
prompts.unshift({
obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库内容是最新的知识库内容为: "${systemPrompt}"`
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
});
// 控制在 tokens 数量,防止超出