perf: model framwork
This commit is contained in:
@@ -1,23 +1,17 @@
|
||||
import { connectToDatabase, Bill, User } from '../mongo';
|
||||
import {
|
||||
modelList,
|
||||
ChatModelEnum,
|
||||
ModelNameEnum,
|
||||
Model2ChatModelMap,
|
||||
embeddingModel
|
||||
} from '@/constants/model';
|
||||
import { modelList, ChatModelEnum, embeddingModel } from '@/constants/model';
|
||||
import { BillTypeEnum } from '@/constants/user';
|
||||
import { countChatTokens } from '@/utils/tools';
|
||||
|
||||
export const pushChatBill = async ({
|
||||
isPay,
|
||||
modelName,
|
||||
chatModel,
|
||||
userId,
|
||||
chatId,
|
||||
messages
|
||||
}: {
|
||||
isPay: boolean;
|
||||
modelName: `${ModelNameEnum}`;
|
||||
chatModel: `${ChatModelEnum}`;
|
||||
userId: string;
|
||||
chatId?: '' | string;
|
||||
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
|
||||
@@ -26,7 +20,7 @@ export const pushChatBill = async ({
|
||||
|
||||
try {
|
||||
// 计算 token 数量
|
||||
const tokens = countChatTokens({ model: Model2ChatModelMap[modelName] as any, messages });
|
||||
const tokens = countChatTokens({ model: chatModel, messages });
|
||||
const text = messages.map((item) => item.content).join('');
|
||||
|
||||
console.log(
|
||||
@@ -37,7 +31,7 @@ export const pushChatBill = async ({
|
||||
await connectToDatabase();
|
||||
|
||||
// 获取模型单价格
|
||||
const modelItem = modelList.find((item) => item.model === modelName);
|
||||
const modelItem = modelList.find((item) => item.chatModel === chatModel);
|
||||
// 计算价格
|
||||
const unitPrice = modelItem?.price || 5;
|
||||
const price = unitPrice * tokens;
|
||||
@@ -47,7 +41,7 @@ export const pushChatBill = async ({
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type: 'chat',
|
||||
modelName,
|
||||
modelName: chatModel,
|
||||
chatId: chatId ? chatId : undefined,
|
||||
textLen: text.length,
|
||||
tokenLen: tokens,
|
||||
@@ -94,7 +88,7 @@ export const pushSplitDataBill = async ({
|
||||
if (isPay) {
|
||||
try {
|
||||
// 获取模型单价格, 都是用 gpt35 拆分
|
||||
const modelItem = modelList.find((item) => item.model === ChatModelEnum.GPT35);
|
||||
const modelItem = modelList.find((item) => item.chatModel === ChatModelEnum.GPT35);
|
||||
const unitPrice = modelItem?.price || 3;
|
||||
// 计算价格
|
||||
const price = unitPrice * tokenLen;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Schema, model, models, Model } from 'mongoose';
|
||||
import { modelList } from '@/constants/model';
|
||||
import { ChatModelMap } from '@/constants/model';
|
||||
import { BillSchema as BillType } from '@/types/mongoSchema';
|
||||
import { BillTypeMap } from '@/constants/user';
|
||||
|
||||
@@ -16,7 +16,7 @@ const BillSchema = new Schema({
|
||||
},
|
||||
modelName: {
|
||||
type: String,
|
||||
enum: [...modelList.map((item) => item.model), 'text-embedding-ada-002'],
|
||||
enum: [...Object.keys(ChatModelMap), 'text-embedding-ada-002'],
|
||||
required: true
|
||||
},
|
||||
chatId: {
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import { Schema, model, models, Model as MongoModel } from 'mongoose';
|
||||
import { ModelSchema as ModelType } from '@/types/mongoSchema';
|
||||
import { ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
|
||||
import {
|
||||
ModelVectorSearchModeMap,
|
||||
ModelVectorSearchModeEnum,
|
||||
ChatModelMap,
|
||||
ChatModelEnum
|
||||
} from '@/constants/model';
|
||||
|
||||
const ModelSchema = new Schema({
|
||||
userId: {
|
||||
@@ -16,11 +21,6 @@ const ModelSchema = new Schema({
|
||||
type: String,
|
||||
default: '/icon/logo.png'
|
||||
},
|
||||
systemPrompt: {
|
||||
// 系统提示词
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
status: {
|
||||
type: String,
|
||||
required: true,
|
||||
@@ -30,17 +30,34 @@ const ModelSchema = new Schema({
|
||||
type: Date,
|
||||
default: () => new Date()
|
||||
},
|
||||
temperature: {
|
||||
type: Number,
|
||||
min: 0,
|
||||
max: 10,
|
||||
default: 4
|
||||
},
|
||||
search: {
|
||||
mode: {
|
||||
chat: {
|
||||
useKb: {
|
||||
// use knowledge base to search
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
searchMode: {
|
||||
// knowledge base search mode
|
||||
type: String,
|
||||
enum: Object.keys(ModelVectorSearchModeMap),
|
||||
default: ModelVectorSearchModeEnum.hightSimilarity
|
||||
},
|
||||
systemPrompt: {
|
||||
// 系统提示词
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
temperature: {
|
||||
type: Number,
|
||||
min: 0,
|
||||
max: 10,
|
||||
default: 0
|
||||
},
|
||||
chatModel: {
|
||||
// 聊天时使用的模型
|
||||
type: String,
|
||||
enum: Object.keys(ChatModelMap),
|
||||
default: ChatModelEnum.GPT35
|
||||
}
|
||||
},
|
||||
share: {
|
||||
@@ -63,18 +80,6 @@ const ModelSchema = new Schema({
|
||||
default: 0
|
||||
}
|
||||
},
|
||||
service: {
|
||||
chatModel: {
|
||||
// 聊天时使用的模型
|
||||
type: String,
|
||||
required: true
|
||||
},
|
||||
modelName: {
|
||||
// 底层模型的名称
|
||||
type: String,
|
||||
required: true
|
||||
}
|
||||
},
|
||||
security: {
|
||||
type: {
|
||||
domain: {
|
||||
@@ -100,8 +105,7 @@ const ModelSchema = new Schema({
|
||||
default: -1
|
||||
}
|
||||
},
|
||||
default: {},
|
||||
required: true
|
||||
default: {}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
47
src/service/tools/searchKb.ts
Normal file
47
src/service/tools/searchKb.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { openaiCreateEmbedding } from '../utils/openai';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { ModelDataStatusEnum } from '@/constants/model';
|
||||
|
||||
/**
|
||||
* use openai embedding search kb
|
||||
*/
|
||||
export const searchKb_openai = async ({
|
||||
apiKey,
|
||||
isPay,
|
||||
text,
|
||||
similarity,
|
||||
modelId,
|
||||
userId
|
||||
}: {
|
||||
apiKey: string;
|
||||
isPay: boolean;
|
||||
text: string;
|
||||
modelId: string;
|
||||
userId: string;
|
||||
similarity: number;
|
||||
}) => {
|
||||
// 获取提示词的向量
|
||||
const { vector: promptVector } = await openaiCreateEmbedding({
|
||||
isPay,
|
||||
apiKey,
|
||||
userId,
|
||||
text
|
||||
});
|
||||
|
||||
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
|
||||
fields: ['id', 'q', 'a'],
|
||||
where: [
|
||||
['status', ModelDataStatusEnum.ready],
|
||||
'AND',
|
||||
['model_id', modelId],
|
||||
'AND',
|
||||
`vector <=> '[${promptVector}]' < ${similarity}`
|
||||
],
|
||||
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
||||
limit: 20
|
||||
});
|
||||
|
||||
const systemPrompts: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
|
||||
|
||||
return { systemPrompts };
|
||||
};
|
||||
@@ -1,10 +1,33 @@
|
||||
import { Configuration, OpenAIApi } from 'openai';
|
||||
import { Chat, Model } from '../mongo';
|
||||
import type { NextApiRequest } from 'next';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import { Chat, Model, OpenApi, User } from '../mongo';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { authToken } from './tools';
|
||||
import { getOpenApiKey } from './openai';
|
||||
import type { ChatItemType } from '@/types/chat';
|
||||
import mongoose from 'mongoose';
|
||||
import { defaultModel } from '@/constants/model';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { ERROR_ENUM } from '../errorCode';
|
||||
|
||||
/* 校验 token */
|
||||
export const authToken = (token?: string): Promise<string> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (!token) {
|
||||
reject('缺少登录凭证');
|
||||
return;
|
||||
}
|
||||
const key = process.env.TOKEN_KEY as string;
|
||||
|
||||
jwt.verify(token, key, function (err, decoded: any) {
|
||||
if (err || !decoded?.userId) {
|
||||
reject('凭证无效');
|
||||
return;
|
||||
}
|
||||
resolve(decoded.userId);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
export const getOpenAIApi = (apiKey: string) => {
|
||||
const configuration = new Configuration({
|
||||
@@ -20,12 +43,14 @@ export const authModel = async ({
|
||||
modelId,
|
||||
userId,
|
||||
authUser = true,
|
||||
authOwner = true
|
||||
authOwner = true,
|
||||
reserveDetail = false
|
||||
}: {
|
||||
modelId: string;
|
||||
userId: string;
|
||||
authUser?: boolean;
|
||||
authOwner?: boolean;
|
||||
reserveDetail?: boolean; // focus reserve detail
|
||||
}) => {
|
||||
// 获取 model 数据
|
||||
const model = await Model.findById<ModelSchema>(modelId);
|
||||
@@ -33,15 +58,21 @@ export const authModel = async ({
|
||||
return Promise.reject('模型不存在');
|
||||
}
|
||||
|
||||
// 使用权限校验
|
||||
/*
|
||||
Access verification
|
||||
1. authOwner=true or authUser = true , just owner can use
|
||||
2. authUser = false and share, anyone can use
|
||||
*/
|
||||
if ((authOwner || (authUser && !model.share.isShare)) && userId !== String(model.userId)) {
|
||||
return Promise.reject('无权操作该模型');
|
||||
}
|
||||
|
||||
// detail 内容去除
|
||||
if (!model.share.isShareDetail && userId !== String(model.userId)) {
|
||||
model.systemPrompt = '';
|
||||
model.temperature = 0;
|
||||
// do not share detail info
|
||||
if (!reserveDetail && !model.share.isShareDetail && userId !== String(model.userId)) {
|
||||
model.chat = {
|
||||
...defaultModel.chat,
|
||||
chatModel: model.chat.chatModel
|
||||
};
|
||||
}
|
||||
|
||||
return { model };
|
||||
@@ -60,7 +91,7 @@ export const authChat = async ({
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
// 获取 model 数据
|
||||
const { model } = await authModel({ modelId, userId, authOwner: false });
|
||||
const { model } = await authModel({ modelId, userId, authOwner: false, reserveDetail: true });
|
||||
|
||||
// 聊天内容
|
||||
let content: ChatItemType[] = [];
|
||||
@@ -91,3 +122,41 @@ export const authChat = async ({
|
||||
model
|
||||
};
|
||||
};
|
||||
|
||||
/* 校验 open api key */
|
||||
export const authOpenApiKey = async (req: NextApiRequest) => {
|
||||
const { apikey: apiKey } = req.headers;
|
||||
|
||||
if (!apiKey) {
|
||||
return Promise.reject(ERROR_ENUM.unAuthorization);
|
||||
}
|
||||
|
||||
try {
|
||||
const openApi = await OpenApi.findOne({ apiKey });
|
||||
if (!openApi) {
|
||||
return Promise.reject(ERROR_ENUM.unAuthorization);
|
||||
}
|
||||
const userId = String(openApi.userId);
|
||||
|
||||
// 余额校验
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
return Promise.reject(ERROR_ENUM.unAuthorization);
|
||||
}
|
||||
if (formatPrice(user.balance) <= 0) {
|
||||
return Promise.reject(ERROR_ENUM.insufficientQuota);
|
||||
}
|
||||
|
||||
// 更新使用的时间
|
||||
await OpenApi.findByIdAndUpdate(openApi._id, {
|
||||
lastUsedTime: new Date()
|
||||
});
|
||||
|
||||
return {
|
||||
apiKey: process.env.OPENAIKEY as string,
|
||||
userId
|
||||
};
|
||||
} catch (error) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import * as nodemailer from 'nodemailer';
|
||||
import { UserAuthTypeEnum } from '@/constants/common';
|
||||
import dayjs from 'dayjs';
|
||||
import Dysmsapi, * as dysmsapi from '@alicloud/dysmsapi20170525';
|
||||
// @ts-ignore
|
||||
import * as OpenApi from '@alicloud/openapi-client';
|
||||
@@ -48,25 +47,6 @@ export const sendEmailCode = (email: string, code: string, type: `${UserAuthType
|
||||
});
|
||||
};
|
||||
|
||||
export const sendTrainSucceed = (email: string, modelName: string) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const options = {
|
||||
from: `"FastGPT" ${myEmail}`,
|
||||
to: email,
|
||||
subject: '模型训练完成通知',
|
||||
html: `你的模型 ${modelName} 已于 ${dayjs().format('YYYY-MM-DD HH:mm')} 训练完成!`
|
||||
};
|
||||
mailTransport.sendMail(options, function (err, msg) {
|
||||
if (err) {
|
||||
console.log('send email error->', err);
|
||||
reject('邮箱异常');
|
||||
} else {
|
||||
resolve('');
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
export const sendPhoneCode = async (phone: string, code: string) => {
|
||||
const accessKeyId = process.env.aliAccessKeyId;
|
||||
const accessKeySecret = process.env.aliAccessKeySecret;
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import type { NextApiRequest } from 'next';
|
||||
import crypto from 'crypto';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { OpenApi, User } from '../mongo';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { ERROR_ENUM } from '../errorCode';
|
||||
import { countChatTokens } from '@/utils/tools';
|
||||
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
|
||||
import { ChatModelEnum } from '@/constants/model';
|
||||
@@ -46,44 +42,6 @@ export const authToken = (token?: string): Promise<string> => {
|
||||
});
|
||||
};
|
||||
|
||||
/* 校验 open api key */
|
||||
export const authOpenApiKey = async (req: NextApiRequest) => {
|
||||
const { apikey: apiKey } = req.headers;
|
||||
|
||||
if (!apiKey) {
|
||||
return Promise.reject(ERROR_ENUM.unAuthorization);
|
||||
}
|
||||
|
||||
try {
|
||||
const openApi = await OpenApi.findOne({ apiKey });
|
||||
if (!openApi) {
|
||||
return Promise.reject(ERROR_ENUM.unAuthorization);
|
||||
}
|
||||
const userId = String(openApi.userId);
|
||||
|
||||
// 余额校验
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
return Promise.reject(ERROR_ENUM.unAuthorization);
|
||||
}
|
||||
if (formatPrice(user.balance) <= 0) {
|
||||
return Promise.reject('Insufficient account balance');
|
||||
}
|
||||
|
||||
// 更新使用的时间
|
||||
await OpenApi.findByIdAndUpdate(openApi._id, {
|
||||
lastUsedTime: new Date()
|
||||
});
|
||||
|
||||
return {
|
||||
apiKey: process.env.OPENAIKEY as string,
|
||||
userId
|
||||
};
|
||||
} catch (error) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
};
|
||||
|
||||
/* openai axios config */
|
||||
export const axiosConfig = () => ({
|
||||
httpsAgent: global.httpsAgent,
|
||||
|
||||
Reference in New Issue
Block a user