feat: rerank modal select and weight (#4164)
This commit is contained in:
@@ -8,7 +8,7 @@ import {
|
||||
EmbeddingModelItemType,
|
||||
TTSModelType,
|
||||
STTModelType,
|
||||
ReRankModelItemType
|
||||
RerankModelItemType
|
||||
} from '@fastgpt/global/core/ai/model.d';
|
||||
import { debounce } from 'lodash';
|
||||
import {
|
||||
@@ -94,7 +94,7 @@ export const loadSystemModels = async (init = false) => {
|
||||
global.embeddingModelMap = new Map<string, EmbeddingModelItemType>();
|
||||
global.ttsModelMap = new Map<string, TTSModelType>();
|
||||
global.sttModelMap = new Map<string, STTModelType>();
|
||||
global.reRankModelMap = new Map<string, ReRankModelItemType>();
|
||||
global.reRankModelMap = new Map<string, RerankModelItemType>();
|
||||
// @ts-ignore
|
||||
global.systemDefaultModel = {};
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ export function getSTTModel(model?: string) {
|
||||
}
|
||||
|
||||
export const getDefaultRerankModel = () => global?.systemDefaultModel.rerank!;
|
||||
export function getReRankModel(model?: string) {
|
||||
export function getRerankModel(model?: string) {
|
||||
if (!model) return getDefaultRerankModel();
|
||||
return global.reRankModelMap.get(model) || getDefaultRerankModel();
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { addLog } from '../../../common/system/log';
|
||||
import { POST } from '../../../common/api/serverRequest';
|
||||
import { getDefaultRerankModel } from '../model';
|
||||
import { getAxiosConfig } from '../config';
|
||||
import { ReRankModelItemType } from '@fastgpt/global/core/ai/model.d';
|
||||
import { RerankModelItemType } from '@fastgpt/global/core/ai/model.d';
|
||||
|
||||
type PostReRankResponse = {
|
||||
id: string;
|
||||
@@ -19,7 +19,7 @@ export function reRankRecall({
|
||||
documents,
|
||||
headers
|
||||
}: {
|
||||
model?: ReRankModelItemType;
|
||||
model?: RerankModelItemType;
|
||||
query: string;
|
||||
documents: { id: string; text: string }[];
|
||||
headers?: Record<string, string>;
|
||||
|
||||
8
packages/service/core/ai/type.d.ts
vendored
8
packages/service/core/ai/type.d.ts
vendored
@@ -1,7 +1,7 @@
|
||||
import { ModelTypeEnum } from '@fastgpt/global/core/ai/model';
|
||||
import {
|
||||
STTModelType,
|
||||
ReRankModelItemType,
|
||||
RerankModelItemType,
|
||||
TTSModelType,
|
||||
EmbeddingModelItemType,
|
||||
LLMModelItemType
|
||||
@@ -18,7 +18,7 @@ export type SystemModelItemType =
|
||||
| EmbeddingModelItemType
|
||||
| TTSModelType
|
||||
| STTModelType
|
||||
| ReRankModelItemType;
|
||||
| RerankModelItemType;
|
||||
|
||||
export type SystemDefaultModelType = {
|
||||
[ModelTypeEnum.llm]?: LLMModelItemType;
|
||||
@@ -28,7 +28,7 @@ export type SystemDefaultModelType = {
|
||||
[ModelTypeEnum.embedding]?: EmbeddingModelItemType;
|
||||
[ModelTypeEnum.tts]?: TTSModelType;
|
||||
[ModelTypeEnum.stt]?: STTModelType;
|
||||
[ModelTypeEnum.rerank]?: ReRankModelItemType;
|
||||
[ModelTypeEnum.rerank]?: RerankModelItemType;
|
||||
};
|
||||
|
||||
declare global {
|
||||
@@ -38,7 +38,7 @@ declare global {
|
||||
var embeddingModelMap: Map<string, EmbeddingModelItemType>;
|
||||
var ttsModelMap: Map<string, TTSModelType>;
|
||||
var sttModelMap: Map<string, STTModelType>;
|
||||
var reRankModelMap: Map<string, ReRankModelItemType>;
|
||||
var reRankModelMap: Map<string, RerankModelItemType>;
|
||||
|
||||
var systemActiveModelList: SystemModelItemType[];
|
||||
var systemDefaultModel: SystemDefaultModelType;
|
||||
|
||||
@@ -27,6 +27,7 @@ import { ChatItemType } from '@fastgpt/global/core/chat/type';
|
||||
import { POST } from '../../../common/api/plusRequest';
|
||||
import { NodeInputKeyEnum } from '@fastgpt/global/core/workflow/constants';
|
||||
import { datasetSearchQueryExtension } from './utils';
|
||||
import type { RerankModelItemType } from '@fastgpt/global/core/ai/model.d';
|
||||
|
||||
export type SearchDatasetDataProps = {
|
||||
histories: ChatItemType[];
|
||||
@@ -39,7 +40,10 @@ export type SearchDatasetDataProps = {
|
||||
[NodeInputKeyEnum.datasetSimilarity]?: number; // min distance
|
||||
[NodeInputKeyEnum.datasetMaxTokens]: number; // max Token limit
|
||||
[NodeInputKeyEnum.datasetSearchMode]?: `${DatasetSearchModeEnum}`;
|
||||
|
||||
[NodeInputKeyEnum.datasetSearchUsingReRank]?: boolean;
|
||||
[NodeInputKeyEnum.datasetSearchRerankModel]?: RerankModelItemType;
|
||||
[NodeInputKeyEnum.datasetSearchRerankWeight]?: number;
|
||||
|
||||
/*
|
||||
{
|
||||
@@ -75,13 +79,16 @@ export type SearchDatasetDataResponse = {
|
||||
};
|
||||
|
||||
export const datasetDataReRank = async ({
|
||||
rerankModel,
|
||||
data,
|
||||
query
|
||||
}: {
|
||||
rerankModel?: RerankModelItemType;
|
||||
data: SearchDataResponseItemType[];
|
||||
query: string;
|
||||
}): Promise<SearchDataResponseItemType[]> => {
|
||||
const results = await reRankRecall({
|
||||
model: rerankModel,
|
||||
query,
|
||||
documents: data.map((item) => ({
|
||||
id: item.id,
|
||||
@@ -155,6 +162,8 @@ export async function searchDatasetData(
|
||||
limit: maxTokens,
|
||||
searchMode = DatasetSearchModeEnum.embedding,
|
||||
usingReRank = false,
|
||||
rerankModel,
|
||||
rerankWeight = 0.5,
|
||||
datasetIds = [],
|
||||
collectionFilterMatch
|
||||
} = props;
|
||||
@@ -711,6 +720,7 @@ export async function searchDatasetData(
|
||||
});
|
||||
try {
|
||||
return await datasetDataReRank({
|
||||
rerankModel,
|
||||
query: reRankQuery,
|
||||
data: filterSameDataResults
|
||||
});
|
||||
@@ -721,11 +731,22 @@ export async function searchDatasetData(
|
||||
})();
|
||||
|
||||
// embedding recall and fullText recall rrf concat
|
||||
const rrfConcatResults = datasetSearchResultConcat([
|
||||
const rrfSearchResult = datasetSearchResultConcat([
|
||||
{ k: 60, list: embeddingRecallResults },
|
||||
{ k: 60, list: fullTextRecallResults },
|
||||
{ k: 58, list: reRankResults }
|
||||
{ k: 60, list: fullTextRecallResults }
|
||||
]);
|
||||
const rrfConcatResults = (() => {
|
||||
if (rerankWeight === 1) return reRankResults;
|
||||
|
||||
const baseK = 30;
|
||||
const searchK = Math.round(baseK / (1 - rerankWeight)); // 搜索结果的 k 值
|
||||
const rerankK = Math.round(baseK / rerankWeight); // rerank 结果的 k 值
|
||||
|
||||
return datasetSearchResultConcat([
|
||||
{ k: searchK, list: rrfSearchResult },
|
||||
{ k: rerankK, list: reRankResults }
|
||||
]);
|
||||
})();
|
||||
|
||||
// remove same q and a data
|
||||
set = new Set<string>();
|
||||
|
||||
@@ -6,7 +6,7 @@ import { formatModelChars2Points } from '../../../../support/wallet/usage/utils'
|
||||
import type { SelectedDatasetType } from '@fastgpt/global/core/workflow/api.d';
|
||||
import type { SearchDataResponseItemType } from '@fastgpt/global/core/dataset/type';
|
||||
import type { ModuleDispatchProps } from '@fastgpt/global/core/workflow/runtime/type';
|
||||
import { getEmbeddingModel } from '../../../ai/model';
|
||||
import { getEmbeddingModel, getRerankModel } from '../../../ai/model';
|
||||
import { deepRagSearch, defaultSearchDatasetData } from '../../../dataset/search/controller';
|
||||
import { NodeInputKeyEnum, NodeOutputKeyEnum } from '@fastgpt/global/core/workflow/constants';
|
||||
import { DispatchNodeResponseKeyEnum } from '@fastgpt/global/core/workflow/runtime/constants';
|
||||
@@ -24,7 +24,11 @@ type DatasetSearchProps = ModuleDispatchProps<{
|
||||
[NodeInputKeyEnum.datasetMaxTokens]: number;
|
||||
[NodeInputKeyEnum.datasetSearchMode]: `${DatasetSearchModeEnum}`;
|
||||
[NodeInputKeyEnum.userChatInput]?: string;
|
||||
|
||||
[NodeInputKeyEnum.datasetSearchUsingReRank]: boolean;
|
||||
[NodeInputKeyEnum.datasetSearchRerankModel]?: string;
|
||||
[NodeInputKeyEnum.datasetSearchRerankWeight]?: number;
|
||||
|
||||
[NodeInputKeyEnum.collectionFilterMatch]: string;
|
||||
[NodeInputKeyEnum.authTmbId]?: boolean;
|
||||
|
||||
@@ -53,12 +57,15 @@ export async function dispatchDatasetSearch(
|
||||
datasets = [],
|
||||
similarity,
|
||||
limit = 1500,
|
||||
usingReRank,
|
||||
searchMode,
|
||||
userChatInput = '',
|
||||
authTmbId = false,
|
||||
collectionFilterMatch,
|
||||
|
||||
usingReRank,
|
||||
rerankModel,
|
||||
rerankWeight,
|
||||
|
||||
datasetSearchUsingExtensionQuery,
|
||||
datasetSearchExtensionModel,
|
||||
datasetSearchExtensionBg,
|
||||
@@ -123,6 +130,8 @@ export async function dispatchDatasetSearch(
|
||||
datasetIds,
|
||||
searchMode,
|
||||
usingReRank: usingReRank && (await checkTeamReRankPermission(teamId)),
|
||||
rerankModel: getRerankModel(rerankModel),
|
||||
rerankWeight,
|
||||
collectionFilterMatch
|
||||
};
|
||||
const {
|
||||
|
||||
Reference in New Issue
Block a user