feat: default model (#3662)

* move model config

* feat: default model
This commit is contained in:
Archer
2025-01-24 18:44:43 +08:00
committed by archer
parent 2015bbe9a9
commit 51fac7431f
167 changed files with 2999 additions and 2899 deletions

View File

@@ -11,8 +11,7 @@ import {
ReRankModelItemType
} from '@fastgpt/global/core/ai/model.d';
import { debounce } from 'lodash';
type FolderBaseType = `${ModelTypeEnum}`;
import { ModelProviderType } from '@fastgpt/global/core/ai/provider';
/*
TODO: 分优先级读取:
@@ -20,9 +19,9 @@ type FolderBaseType = `${ModelTypeEnum}`;
2. 没有外部挂载目录,则读取本地的。然后试图拉取云端的进行覆盖。
*/
export const loadSystemModels = async (init = false) => {
const getModelNameList = (base: FolderBaseType) => {
const getProviderList = () => {
const currentFileUrl = new URL(import.meta.url);
const modelsPath = path.join(path.dirname(currentFileUrl.pathname), base);
const modelsPath = path.join(path.dirname(currentFileUrl.pathname), 'provider');
return fs.readdirSync(modelsPath) as string[];
};
@@ -35,18 +34,33 @@ export const loadSystemModels = async (init = false) => {
if (model.type === ModelTypeEnum.llm) {
global.llmModelMap.set(model.model, model);
global.llmModelMap.set(model.name, model);
if (model.isDefault) {
global.systemDefaultModel.llm = model;
}
} else if (model.type === ModelTypeEnum.embedding) {
global.embeddingModelMap.set(model.model, model);
global.embeddingModelMap.set(model.name, model);
if (model.isDefault) {
global.systemDefaultModel.embedding = model;
}
} else if (model.type === ModelTypeEnum.tts) {
global.ttsModelMap.set(model.model, model);
global.ttsModelMap.set(model.name, model);
if (model.isDefault) {
global.systemDefaultModel.tts = model;
}
} else if (model.type === ModelTypeEnum.stt) {
global.sttModelMap.set(model.model, model);
global.sttModelMap.set(model.name, model);
if (model.isDefault) {
global.systemDefaultModel.stt = model;
}
} else if (model.type === ModelTypeEnum.rerank) {
global.reRankModelMap.set(model.model, model);
global.reRankModelMap.set(model.name, model);
if (model.isDefault) {
global.systemDefaultModel.rerank = model;
}
}
}
};
@@ -60,40 +74,34 @@ export const loadSystemModels = async (init = false) => {
global.ttsModelMap = new Map<string, TTSModelType>();
global.sttModelMap = new Map<string, STTModelType>();
global.reRankModelMap = new Map<string, ReRankModelItemType>();
// @ts-ignore
global.systemDefaultModel = {};
try {
const dbModels = await MongoSystemModel.find({}).lean();
const baseList: FolderBaseType[] = [
ModelTypeEnum.llm,
ModelTypeEnum.embedding,
ModelTypeEnum.tts,
ModelTypeEnum.stt,
ModelTypeEnum.rerank
];
const providerList = getProviderList();
// System model
await Promise.all(
baseList.map(async (base) => {
const modelList = getModelNameList(base);
const nameList = modelList.map((name) => `${base}/${name}`);
providerList.map(async (name) => {
const fileContent = (await import(`./provider/${name}`))?.default as {
provider: ModelProviderType;
list: SystemModelItemType[];
};
await Promise.all(
nameList.map(async (name) => {
const fileContent = (await import(`./${name}`))?.default as SystemModelItemType;
fileContent.list.forEach((fileModel) => {
const dbModel = dbModels.find((item) => item.model === fileModel.model);
const dbModel = dbModels.find((item) => item.model === fileContent.model);
const modelData: any = {
...fileModel,
...dbModel?.metadata,
provider: fileContent.provider,
type: dbModel?.metadata?.type || fileModel.type,
isCustom: false
};
const model: any = {
...fileContent,
...dbModel?.metadata,
type: dbModel?.metadata?.type || base,
isCustom: false
};
pushModel(model);
})
);
pushModel(modelData);
});
})
);
@@ -107,6 +115,23 @@ export const loadSystemModels = async (init = false) => {
});
});
// Default model check
if (!global.systemDefaultModel.llm) {
global.systemDefaultModel.llm = Array.from(global.llmModelMap.values())[0];
}
if (!global.systemDefaultModel.embedding) {
global.systemDefaultModel.embedding = Array.from(global.embeddingModelMap.values())[0];
}
if (!global.systemDefaultModel.tts) {
global.systemDefaultModel.tts = Array.from(global.ttsModelMap.values())[0];
}
if (!global.systemDefaultModel.stt) {
global.systemDefaultModel.stt = Array.from(global.sttModelMap.values())[0];
}
if (!global.systemDefaultModel.rerank) {
global.systemDefaultModel.rerank = Array.from(global.reRankModelMap.values())[0];
}
console.log('Load models success', JSON.stringify(global.systemActiveModelList, null, 2));
} catch (error) {
console.error('Load models error', error);