diff --git a/public/imgs/modelAvatar.png b/public/imgs/modelAvatar.png new file mode 100644 index 000000000..93670150a Binary files /dev/null and b/public/imgs/modelAvatar.png differ diff --git a/src/api/model.ts b/src/api/model.ts index fc4e575e4..4a092c67d 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -1,7 +1,7 @@ import { GET, POST, DELETE, PUT } from './request'; import type { ModelSchema, ModelDataSchema } from '@/types/mongoSchema'; import { ModelUpdateParams } from '@/types/model'; -import { RequestPaging } from '../types/index'; +import { PagingData, RequestPaging } from '../types/index'; import { Obj2Query } from '@/utils/tools'; /** @@ -93,3 +93,10 @@ export const putModelDataById = (data: { dataId: string; a: string; q?: string } */ export const delOneModelData = (dataId: string) => DELETE(`/model/data/delModelDataById?dataId=${dataId}`); + +/* 共享市场 */ +/** + * 获取共享市场模型 + */ +export const getShareModelList = (data: { searchText?: string } & RequestPaging) => + POST(`/model/share/getModels`, data); diff --git a/src/api/response/chat.d.ts b/src/api/response/chat.d.ts index baa154c55..21b5725bd 100644 --- a/src/api/response/chat.d.ts +++ b/src/api/response/chat.d.ts @@ -6,6 +6,7 @@ export type InitChatResponse = { modelId: string; name: string; avatar: string; + intro: string; chatModel: ModelSchema.service.chatModel; // 对话模型名 modelName: ModelSchema.service.modelName; // 底层模型 history: ChatItemType[]; diff --git a/src/components/Icon/icons/collectionLight.svg b/src/components/Icon/icons/collectionLight.svg new file mode 100644 index 000000000..72fb923ff --- /dev/null +++ b/src/components/Icon/icons/collectionLight.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/components/Icon/icons/collectionSolid.svg b/src/components/Icon/icons/collectionSolid.svg new file mode 100644 index 000000000..140c32999 --- /dev/null +++ b/src/components/Icon/icons/collectionSolid.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/components/Icon/icons/shareMarket.svg b/src/components/Icon/icons/shareMarket.svg new file mode 100644 index 000000000..c75ddb304 --- /dev/null +++ b/src/components/Icon/icons/shareMarket.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/components/Icon/index.tsx b/src/components/Icon/index.tsx index 9f9895db2..2aff7ec96 100644 --- a/src/components/Icon/index.tsx +++ b/src/components/Icon/index.tsx @@ -18,7 +18,10 @@ const map = { withdraw: require('./icons/withdraw.svg').default, dbModel: require('./icons/dbModel.svg').default, history: require('./icons/history.svg').default, - stop: require('./icons/stop.svg').default + stop: require('./icons/stop.svg').default, + shareMarket: require('./icons/shareMarket.svg').default, + collectionLight: require('./icons/collectionLight.svg').default, + collectionSolid: require('./icons/collectionSolid.svg').default }; export type IconName = keyof typeof map; diff --git a/src/components/Layout/index.tsx b/src/components/Layout/index.tsx index aea56cfd4..a9796ae93 100644 --- a/src/components/Layout/index.tsx +++ b/src/components/Layout/index.tsx @@ -26,6 +26,12 @@ const navbarList = [ link: '/model/list', activeLink: ['/model/list', '/model/detail'] }, + { + label: '共享', + icon: 'shareMarket', + link: '/model/share', + activeLink: ['/model/share'] + }, { label: '账号', icon: 'user', diff --git a/src/constants/model.ts b/src/constants/model.ts index 300630fc4..d71f0fe46 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -113,10 +113,10 @@ export const ModelVectorSearchModeMap: Record< }; export const defaultModel: ModelSchema = { - _id: '', - userId: '', + _id: 'modelId', + userId: 'userId', name: 'modelName', - avatar: '', + avatar: '/icon/logo.png', status: ModelStatusEnum.pending, updateTime: Date.now(), systemPrompt: '', @@ -124,6 +124,12 @@ export const defaultModel: ModelSchema = { search: { mode: ModelVectorSearchModeEnum.hightSimilarity }, + share: { + isShare: false, + isShareDetail: false, + intro: '', + collection: 0 + }, service: { chatModel: ModelNameEnum.GPT35, modelName: ModelNameEnum.GPT35 diff --git a/src/constants/theme.ts b/src/constants/theme.ts index 3c3c1cb76..e5c6410e2 100644 --- a/src/constants/theme.ts +++ b/src/constants/theme.ts @@ -90,6 +90,9 @@ export const theme = extendTheme({ fonts: { body: '-apple-system,BlinkMacSystemFont,"Segoe UI",Helvetica,Arial,sans-serif,"Apple Color Emoji","Segoe UI Emoji","Segoe UI Symbol"' }, + fontWeights: { + bold: 500 + }, breakpoints: { sm: '900px', md: '1200px', diff --git a/src/hooks/usePagination.tsx b/src/hooks/usePagination.tsx index ed1aff33c..cf1551bbf 100644 --- a/src/hooks/usePagination.tsx +++ b/src/hooks/usePagination.tsx @@ -91,7 +91,7 @@ export const usePagination = ({ useEffect(() => { mutate(1); - }, [mutate]); + }, []); return { pageNum, diff --git a/src/pages/api/chat/init.ts b/src/pages/api/chat/init.ts index 9472f1851..29740c2dd 100644 --- a/src/pages/api/chat/init.ts +++ b/src/pages/api/chat/init.ts @@ -22,7 +22,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) await connectToDatabase(); // 获取 model 数据 - const { model } = await authModel(modelId, userId); + const { model } = await authModel({ modelId, userId, authUser: false, authOwner: false }); // 历史记录 let history: ChatItemType[] = []; @@ -53,6 +53,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) modelId: modelId, name: model.name, avatar: model.avatar, + intro: model.share.intro, modelName: model.service.modelName, chatModel: model.service.chatModel, history diff --git a/src/pages/api/chat/saveChat.ts b/src/pages/api/chat/saveChat.ts index 1dbc11a24..406aec0cd 100644 --- a/src/pages/api/chat/saveChat.ts +++ b/src/pages/api/chat/saveChat.ts @@ -27,9 +27,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) value: item.value })); + await authModel({ modelId, userId, authOwner: false }); + // 没有 chatId, 创建一个对话 if (!chatId) { - await authModel(modelId, userId); const { _id } = await Chat.create({ userId, modelId, diff --git a/src/pages/api/model/data/getModelData.ts b/src/pages/api/model/data/getModelData.ts index 9de535b1d..bc3a48f47 100644 --- a/src/pages/api/model/data/getModelData.ts +++ b/src/pages/api/model/data/getModelData.ts @@ -4,6 +4,7 @@ import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { PgClient } from '@/service/pg'; import type { PgModelDataItemType } from '@/types/pg'; +import { authModel } from '@/service/utils/auth'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -36,9 +37,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); + const { model } = await authModel({ + userId, + modelId, + authOwner: false + }); + const where: any = [ - ['user_id', userId], - 'AND', + ...(model.share.isShareDetail ? [] : [['user_id', userId], 'AND']), ['model_id', modelId], ...(searchText ? ['AND', `(q LIKE '%${searchText}%' OR a LIKE '%${searchText}%')`] : []) ]; diff --git a/src/pages/api/model/data/pushModelDataCsv.ts b/src/pages/api/model/data/pushModelDataCsv.ts index fc3373a04..4b65ef720 100644 --- a/src/pages/api/model/data/pushModelDataCsv.ts +++ b/src/pages/api/model/data/pushModelDataCsv.ts @@ -1,10 +1,11 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, Model } from '@/service/mongo'; +import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { generateVector } from '@/service/events/generateVector'; import { ModelDataStatusEnum } from '@/constants/model'; import { PgClient } from '@/service/pg'; +import { authModel } from '@/service/utils/auth'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -28,15 +29,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); // 验证是否是该用户的 model - const model = await Model.findOne({ - _id: modelId, - userId + await authModel({ + userId, + modelId }); - if (!model) { - throw new Error('无权操作该模型'); - } - // 去重 const searchRes = await Promise.allSettled( data.map(async ([q, a]) => { diff --git a/src/pages/api/model/data/pushModelDataInput.ts b/src/pages/api/model/data/pushModelDataInput.ts index 17e82a4d5..477296d12 100644 --- a/src/pages/api/model/data/pushModelDataInput.ts +++ b/src/pages/api/model/data/pushModelDataInput.ts @@ -1,10 +1,11 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, Model } from '@/service/mongo'; +import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { ModelDataSchema } from '@/types/mongoSchema'; import { generateVector } from '@/service/events/generateVector'; import { PgClient } from '@/service/pg'; +import { authModel } from '@/service/utils/auth'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -28,15 +29,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); // 验证是否是该用户的 model - const model = await Model.findOne({ - _id: modelId, - userId + await authModel({ + userId, + modelId }); - if (!model) { - throw new Error('无权操作该模型'); - } - // 插入记录 await PgClient.insert('modelData', { values: data.map((item) => [ diff --git a/src/pages/api/model/del.ts b/src/pages/api/model/del.ts index e2b3dea6d..df29c107c 100644 --- a/src/pages/api/model/del.ts +++ b/src/pages/api/model/del.ts @@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response'; import { Chat, Model, connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { PgClient } from '@/service/pg'; +import { authModel } from '@/service/utils/auth'; /* 获取我的模型 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -21,18 +22,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< // 凭证校验 const userId = await authToken(authorization); + await connectToDatabase(); + // 验证是否是该用户的 model - const model = await Model.findOne({ - _id: modelId, + await authModel({ + modelId, userId }); - if (!model) { - throw new Error('无权操作该模型'); - } - - await connectToDatabase(); - // 删除 pg 中所有该模型的数据 await PgClient.delete('modelData', { where: [['user_id', userId], 'AND', ['model_id', modelId]] diff --git a/src/pages/api/model/detail.tsx b/src/pages/api/model/detail.tsx index 8e3e31d6c..142d3e7e3 100644 --- a/src/pages/api/model/detail.tsx +++ b/src/pages/api/model/detail.tsx @@ -2,8 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; -import { Model } from '@/service/models/model'; -import type { ModelSchema } from '@/types/mongoSchema'; +import { authModel } from '@/service/utils/auth'; /* 获取我的模型 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -14,7 +13,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< throw new Error('无权操作'); } - const { modelId } = req.query; + const { modelId } = req.query as { modelId: string }; if (!modelId) { throw new Error('参数错误'); @@ -25,16 +24,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); - // 根据 userId 获取模型信息 - const model = await Model.findOne({ + const { model } = await authModel({ + modelId, userId, - _id: modelId + authOwner: false }); - if (!model) { - throw new Error('模型不存在'); - } - jsonRes(res, { data: model }); diff --git a/src/pages/api/model/list.ts b/src/pages/api/model/list.ts index 006722400..3329e37a0 100644 --- a/src/pages/api/model/list.ts +++ b/src/pages/api/model/list.ts @@ -4,7 +4,7 @@ import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { Model } from '@/service/models/model'; -/* 获取我的模型 */ +/* 获取模型列表 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { const { authorization } = req.headers; diff --git a/src/pages/api/model/share/getModels.ts b/src/pages/api/model/share/getModels.ts new file mode 100644 index 000000000..c5a4233fd --- /dev/null +++ b/src/pages/api/model/share/getModels.ts @@ -0,0 +1,54 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { connectToDatabase } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; +import { Model } from '@/service/models/model'; +import type { PagingData } from '@/types'; +import type { ShareModelItem } from '@/types/model'; + +/* 获取模型列表 */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + // 凭证校验 + await authToken(req.headers.authorization); + + const { + searchText = '', + pageNum = 1, + pageSize = 20 + } = req.body as { searchText: string; pageNum: number; pageSize: number }; + + await connectToDatabase(); + + const regex = new RegExp(searchText, 'i'); + + const where = { + $and: [ + { 'share.isShare': true }, + { $or: [{ name: { $regex: regex } }, { 'share.intro': { $regex: regex } }] } + ] + }; + + // 根据分享的模型 + const models = await Model.find(where, '_id avatar name userId share') + .sort({ + 'share.collection': -1 + }) + .limit(pageSize) + .skip((pageNum - 1) * pageSize); + + jsonRes>(res, { + data: { + pageNum, + pageSize, + data: models, + total: await Model.countDocuments(where) + } + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/api/model/update.ts b/src/pages/api/model/update.ts index 4b021d379..aeb8b90c4 100644 --- a/src/pages/api/model/update.ts +++ b/src/pages/api/model/update.ts @@ -4,11 +4,12 @@ import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { Model } from '@/service/models/model'; import type { ModelUpdateParams } from '@/types/model'; +import { authModel } from '@/service/utils/auth'; /* 获取我的模型 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { name, search, service, security, systemPrompt, temperature } = + const { name, search, share, service, security, systemPrompt, temperature } = req.body as ModelUpdateParams; const { modelId } = req.query as { modelId: string }; const { authorization } = req.headers; @@ -26,6 +27,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); + await authModel({ + modelId, + userId + }); + // 更新模型 await Model.updateOne( { @@ -36,6 +42,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< name, systemPrompt, temperature, + 'share.isShare': share.isShare, + 'share.isShareDetail': share.isShareDetail, + 'share.intro': share.intro, search, security } diff --git a/src/pages/chat/components/Empty.tsx b/src/pages/chat/components/Empty.tsx index d03d34300..295e73a20 100644 --- a/src/pages/chat/components/Empty.tsx +++ b/src/pages/chat/components/Empty.tsx @@ -3,12 +3,7 @@ import { Card, Box } from '@chakra-ui/react'; import { useMarkdown } from '@/hooks/useMarkdown'; import Markdown from '@/components/Markdown'; -const Empty = ({ intro }: { intro: string }) => { - const Header = ({ children }: { children: string }) => ( - - {children} - - ); +const Empty = ({ modelName, intro }: { modelName: string; intro: string }) => { const { data: chatProblem } = useMarkdown({ url: '/chatProblem.md' }); const { data: versionIntro } = useMarkdown({ url: '/versionIntro.md' }); @@ -24,7 +19,9 @@ const Empty = ({ intro }: { intro: string }) => { > {!!intro && ( -
模型介绍
+ + {modelName} 介绍 + {intro}
)} diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 12b4eece9..387574606 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -63,7 +63,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { chatId, modelId, name: '', - avatar: '', + avatar: '/icon/logo.png', + intro: '', chatModel: '', modelName: '', history: [] @@ -155,7 +156,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { isClosable: true, duration: 5000 }); - router.replace('/model/list'); + router.back(); } setLoading(false); return null; @@ -469,7 +470,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { /icon/logo.png { ))} - {chatData.history.length === 0 && } + {chatData.history.length === 0 && ( + + )} {/* 发送区 */} diff --git a/src/pages/model/detail/components/ModelDataCard.tsx b/src/pages/model/detail/components/ModelDataCard.tsx index d4679a435..c4ca40137 100644 --- a/src/pages/model/detail/components/ModelDataCard.tsx +++ b/src/pages/model/detail/components/ModelDataCard.tsx @@ -39,7 +39,7 @@ import InputModal, { FormData as InputDataType } from './InputDataModal'; const SelectFileModal = dynamic(() => import('./SelectFileModal')); const SelectCsvModal = dynamic(() => import('./SelectCsvModal')); -const ModelDataCard = ({ modelId }: { modelId: string }) => { +const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean }) => { const { Loading, setIsLoading } = useLoading(); const lastSearch = useRef(''); const [searchText, setSearchText] = useState(''); @@ -133,50 +133,53 @@ const ModelDataCard = ({ modelId }: { modelId: string }) => { 模型数据: {total}组 - - (测试版本) - - } - aria-label={'refresh'} - variant={'outline'} - mr={4} - size={'sm'} - onClick={() => refetchData(pageNum)} - /> - - - - 导入 - - - - setEditInputData({ - a: '', - q: '' - }) - } + {isOwner && ( + <> + } + aria-label={'refresh'} + variant={'outline'} + mr={4} + size={'sm'} + onClick={() => refetchData(pageNum)} + /> + + 导出 + + + + 导入 + + + + setEditInputData({ + a: '', + q: '' + }) + } + > + 手动输入 + + 文本/文件拆分 + csv 问答对导入 + + + + )} - {splitDataLen > 0 && {splitDataLen}条数据正在拆分,请耐心等待...} + {isOwner && splitDataLen > 0 && ( + {splitDataLen}条数据正在拆分,请耐心等待... + )} { {'匹配的知识点'} 补充知识 状态 - 操作 + {isOwner && 操作} @@ -220,33 +223,35 @@ const ModelDataCard = ({ modelId }: { modelId: string }) => { {item.a || '-'} {ModelDataStatusMap[item.status]} - - } - variant={'outline'} - aria-label={'delete'} - size={'sm'} - onClick={() => - setEditInputData({ - dataId: item.id, - q: item.q, - a: item.a - }) - } - /> - } - variant={'outline'} - colorScheme={'gray'} - aria-label={'delete'} - size={'sm'} - onClick={async () => { - await delOneModelData(item.id); - refetchData(pageNum); - }} - /> - + {isOwner && ( + + } + variant={'outline'} + aria-label={'delete'} + size={'sm'} + onClick={() => + setEditInputData({ + dataId: item.id, + q: item.q, + a: item.a + }) + } + /> + } + variant={'outline'} + colorScheme={'gray'} + aria-label={'delete'} + size={'sm'} + onClick={async () => { + await delOneModelData(item.id); + refetchData(pageNum); + }} + /> + + )} ))} diff --git a/src/pages/model/detail/components/ModelEditForm.tsx b/src/pages/model/detail/components/ModelEditForm.tsx index 22e14b6d8..7cdb77c2a 100644 --- a/src/pages/model/detail/components/ModelEditForm.tsx +++ b/src/pages/model/detail/components/ModelEditForm.tsx @@ -13,7 +13,9 @@ import { SliderMark, Tooltip, Button, - Select + Select, + Grid, + Switch } from '@chakra-ui/react'; import { QuestionOutlineIcon } from '@chakra-ui/icons'; import type { ModelSchema } from '@/types/mongoSchema'; @@ -25,10 +27,12 @@ import { useConfirm } from '@/hooks/useConfirm'; const ModelEditForm = ({ formHooks, canTrain, + isOwner, handleDelModel }: { formHooks: UseFormReturn; canTrain: boolean; + isOwner: boolean; handleDelModel: () => void; }) => { const { openConfirm, ConfirmChild } = useConfirm({ @@ -40,27 +44,26 @@ const ModelEditForm = ({ return ( <> - - 基本信息 - + 基本信息 名称: - - - modelId: - - {getValues('_id')} - + + + modelId: + + {getValues('_id')} + 模型类型: @@ -79,17 +82,25 @@ const ModelEditForm = ({ 元/1K tokens(包括上下文和回答) - - 删除模型和数据集 - + + + 收藏人数: + + {getValues('share.collection')}人 + {isOwner && ( + + 删除模型和知识库 + + + )} 模型效果 @@ -110,6 +121,7 @@ const ModelEditForm = ({ max={10} step={1} value={getValues('temperature')} + isDisabled={!isOwner} onChange={(e) => { setValue('temperature', e); setRefresh(!refresh); @@ -139,7 +151,10 @@ const ModelEditForm = ({ 搜索模式 - {Object.entries(ModelVectorSearchModeMap).map(([key, { text }]) => (