Compare commits

..

15 Commits
v2.2 ... v2.4

Author SHA1 Message Date
archer
e08e8aa00b feat: 修改模型数据可修改问题 2023-04-04 13:15:34 +08:00
archer
becee69d6a perf: 发送区域样式 2023-04-03 17:28:35 +08:00
archer
042b0c535a perf: 发送按键 2023-04-03 17:14:46 +08:00
archer
f97c29b41e feat: lafgpt请求;fix: 修复发送按键 2023-04-03 16:35:48 +08:00
archer
4d6616cbfa fix: ts 2023-04-03 11:03:51 +08:00
archer
cf37992b5c feat: 封装向量生成和账单 2023-04-03 10:59:32 +08:00
archer
6c4026ccef perf: 文件结构 2023-04-03 10:20:17 +08:00
archer
caf31faf31 perf: 生成qa prompt 2023-04-03 01:39:00 +08:00
archer
a0832af14b perf: 数据集刷新导致页面抖动 2023-04-03 00:51:53 +08:00
archer
677e61416d perf: 版本文案 2023-04-03 00:48:56 +08:00
archer
56ba6fa5f7 feat: 拆分数据自定义prompt 2023-04-03 00:37:40 +08:00
archer
16a31de1c7 feat: 数据集导出 2023-04-03 00:18:21 +08:00
archer
05b2e9e99c feat: 拆分测试环境 2023-04-02 23:38:28 +08:00
archer
ae4243b522 perf: 知识库数据结构 2023-04-01 22:31:56 +08:00
archer
5759cbeae0 perf: 知识库录入 2023-03-31 18:23:07 +08:00
51 changed files with 1420 additions and 791 deletions

View File

@@ -107,5 +107,6 @@ echo "Restart clash"
```bash
# 索引
# FT.CREATE idx:model:data ON JSON PREFIX 1 model:data: SCHEMA $.modelId AS modelId TAG $.dataId AS dataId TAG $.vector AS vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32
FT.CREATE idx:model:data:hash ON HASH PREFIX 1 model:data: SCHEMA modelId TAG dataId TAG vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32
# FT.CREATE idx:model:data:hash ON HASH PREFIX 1 model:data: SCHEMA modelId TAG dataId TAG vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32
FT.CREATE idx:model:data ON HASH PREFIX 1 model:data: SCHEMA modelId TAG userId TAG q TEXT text TEXT vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 320 KiB

BIN
public/imgs/wxcode300.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View File

@@ -38,18 +38,26 @@ type GetModelDataListProps = RequestPaging & {
export const getModelDataList = (props: GetModelDataListProps) =>
GET(`/model/data/getModelData?${Obj2Query(props)}`);
export const getExportDataList = (modelId: string) =>
GET<string>(`/model/data/exportModelData?modelId=${modelId}`);
export const getModelSplitDataList = (modelId: string) =>
GET<ModelSplitDataSchema[]>(`/model/data/getSplitData?modelId=${modelId}`);
export const postModelDataInput = (data: {
modelId: string;
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}) => POST(`/model/data/pushModelDataInput`, data);
}) => POST<number>(`/model/data/pushModelDataInput`, data);
export const postModelDataFileText = (modelId: string, text: string) =>
POST(`/model/data/splitData`, { modelId, text });
export const postModelDataFileText = (data: { modelId: string; text: string; prompt: string }) =>
POST(`/model/data/splitData`, data);
export const putModelDataById = (data: { dataId: string; text: string }) =>
export const postModelDataJsonData = (
modelId: string,
jsonData: { prompt: string; completion: string; vector?: number[] }[]
) => POST(`/model/data/pushModelDataJson`, { modelId, data: jsonData });
export const putModelDataById = (data: { dataId: string; text: string; q?: string }) =>
PUT('/model/data/putModelData', data);
export const delOneModelData = (dataId: string) =>
DELETE(`/model/data/delModelDataById?dataId=${dataId}`);

View File

@@ -23,7 +23,7 @@ const WxConcat = ({ onClose }: { onClose: () => void }) => {
<ModalBody textAlign={'center'}>
<Image
style={{ margin: 'auto' }}
src={'/imgs/wxcode.jpg'}
src={'/imgs/wxcode300.jpg'}
width={200}
height={200}
alt=""

View File

@@ -10,6 +10,11 @@ export const introPage = `
[Git 仓库](https://github.com/c121914yu/FastGPT)
### 交流群/问题反馈
wx: YNyiqi
![](/imgs/wxcode300.jpg)
### 快速开始
1. 使用邮箱注册账号。
2. 进入账号页面,添加关联账号,目前只有 openai 的账号可以添加,直接去 openai 官网,把 API Key 粘贴过来。
@@ -31,30 +36,25 @@ export const introPage = `
4. 使用该模型对话。
注意使用知识库模型对话时tokens 消耗会加快。
### 其他问题
还有其他问题,可以加我 wx: YNyiqi拉个交流群大家一起聊聊。
`;
export const chatProblem = `
## 常见问题
**内容长度**
单次最长 4000 tokens, 上下文最长 8000 tokens, 上下文超长时会被截断。
**模型问题**
一般情况下,请直接选择 chatGPT 模型,价格低效果好
**删除和复制**
点击对话头像,可以选择复制或删除该条内容
**代理出错**
服务器代理不稳定,可以过一会儿再尝试。
**API key 问题**
请把 openai 的 API key 粘贴到账号里再创建对话。如果是使用分享的对话,不需要填写 API key。
`;
export const versionIntro = `
## Fast GPT V2.2
## Fast GPT V2.3
* 数据集导出功能,可用于知识库分享。
* 优化文件拆分功能,可自定义提示词。
* 定制知识库:创建模型时可以选择【知识库】模型, 可以手动导入知识点或者直接导入一个文件自动学习。
* 删除和复制功能:点击对话头像,可以选择复制或删除该条内容。
`;
export const shareHint = `

View File

@@ -1,15 +1,18 @@
import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema';
import type { RedisModelDataItemType } from '@/types/redis';
export enum ChatModelNameEnum {
GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT',
GPT3 = 'text-davinci-003'
GPT3 = 'text-davinci-003',
VECTOR = 'text-embedding-ada-002'
}
export const ChatModelNameMap = {
[ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo',
[ChatModelNameEnum.GPT3]: 'text-davinci-003'
[ChatModelNameEnum.GPT3]: 'text-davinci-003',
[ChatModelNameEnum.VECTOR]: 'text-embedding-ada-002'
};
export type ModelConstantsData = {
@@ -93,9 +96,9 @@ export const formatModelStatus = {
}
};
export const ModelDataStatusMap = {
0: '训练完成',
1: '训练中'
export const ModelDataStatusMap: Record<RedisModelDataItemType['status'], string> = {
ready: '训练完成',
waiting: '训练中'
};
export const defaultModel: ModelSchema = {

View File

@@ -1 +1,6 @@
export const VecModelDataIndex = 'model:data';
export const VecModelDataPrefix = 'model:data';
export const VecModelDataIdx = `idx:${VecModelDataPrefix}:hash`;
export enum ModelDataStatusEnum {
ready = 'ready',
waiting = 'waiting'
}

View File

@@ -3,6 +3,7 @@ export enum BillTypeEnum {
splitData = 'splitData',
QA = 'QA',
abstract = 'abstract',
vector = 'vector',
return = 'return'
}
export enum PageTypeEnum {
@@ -16,5 +17,6 @@ export const BillTypeMap: Record<`${BillTypeEnum}`, string> = {
[BillTypeEnum.splitData]: 'QA拆分',
[BillTypeEnum.QA]: 'QA拆分',
[BillTypeEnum.abstract]: '摘要总结',
[BillTypeEnum.vector]: '索引生成',
[BillTypeEnum.return]: '退款'
};

View File

@@ -17,13 +17,10 @@ export const usePagination = <T = any,>({
const { toast } = useToast();
const [pageNum, setPageNum] = useState(1);
const [total, setTotal] = useState(0);
const [data, setData] = useState<T[]>([]);
const maxPage = useMemo(() => Math.ceil(total / pageSize), [pageSize, total]);
const {
mutate,
data = [],
isLoading
} = useMutation({
const { mutate, isLoading } = useMutation({
mutationFn: async (num: number = pageNum) => {
try {
const res: PagingData<T> = await api({
@@ -33,7 +30,7 @@ export const usePagination = <T = any,>({
});
setPageNum(num);
setTotal(res.total);
return res.data;
setData(res.data);
} catch (error: any) {
toast({
title: error?.message || '获取数据异常',
@@ -43,7 +40,6 @@ export const usePagination = <T = any,>({
}
}
});
useQuery(['init'], () => {
mutate(1);
return null;

View File

@@ -1,4 +1,5 @@
import type { AppProps, NextWebVitalsMetric } from 'next/app';
import { useEffect } from 'react';
import type { AppProps } from 'next/app';
import Script from 'next/script';
import Head from 'next/head';
import { ChakraProvider, ColorModeScript } from '@chakra-ui/react';
@@ -9,6 +10,7 @@ import NProgress from 'nprogress'; //nprogress module
import Router from 'next/router';
import 'nprogress/nprogress.css';
import '../styles/reset.scss';
import { useToast } from '@/hooks/useToast';
//Binding events.
Router.events.on('routeChangeStart', () => NProgress.start());
@@ -27,6 +29,17 @@ const queryClient = new QueryClient({
});
export default function App({ Component, pageProps }: AppProps) {
const { toast } = useToast();
// 校验是否支持 click 事件
useEffect(() => {
if (typeof document.createElement('div').click !== 'function') {
toast({
title: '你的浏览器版本过低',
status: 'warning'
});
}
}, [toast]);
return (
<>
<Head>

View File

@@ -2,7 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat';
import { httpsAgent } from '@/service/utils/tools';
import { httpsAgent, openaiChatFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
@@ -10,7 +10,6 @@ import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiChatFilter } from '@/service/utils/tools';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {

View File

@@ -0,0 +1,277 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding } from '@/service/utils/openai';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const { chatId, prompt } = req.body as {
prompt: ChatItemType;
chatId: string;
};
const { authorization } = req.headers;
if (!chatId || !prompt) {
throw new Error('缺少参数');
}
await connectToDatabase();
const redis = await connectRedis();
let startTime = Date.now();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 请求一次 chatgpt 拆解需求
const promptResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: 0,
// max_tokens: modelConstantsData.maxToken,
messages: [
{
role: 'system',
content: `服务端逻辑生成器。根据用户输入的需求,拆解成代码实现的步骤,并按格式返回: 1.\n2.\n3.\n ......
下面是一些例子:
实现一个手机号注册账号的方法,包含两个函数
* 发送手机验证码函数:
1. 从 query 中获取 phone
2. 校验手机号格式是否正确,不正确返回{error: "手机号格式错误"}
3. 给 phone 发送一个短信验证码验证码长度为6位字符串内容为你正在注册laf, 验证码为code
4. 数据库添加数据,表为"codes",内容为 {phone, code}
* 注册函数
1. 从 body 中获取 phone 和 code
2. 校验手机号格式是否正确,不正确返回{error: "手机号格式错误"}
2. 获取数据库数据,表为"codes",查找是否有符合 phone, code 等于body参数的记录没有的话返回 {error:"验证码不正确"}
4. 添加数据库数据,表为"users" ,内容为{phone, code, createTime}
5. 删除数据库数据,删除 code 记录
---------------
更新博客记录。传入blogIdblogTexttags还需要记录更新的时间
1. 从 body 中获取 blogIdblogText 和 tags
2. 校验 blogId 是否为空,为空则返回 {error: "博客ID不能为空"}
3. 校验 blogText 是否为空,为空则返回 {error: "博客内容不能为空"}
4. 校验 tags 是否为数组,不是则返回 {error: "标签必须为数组"}
5. 获取当前时间,记录为 updateTime
6. 更新数据库数据,表为"blogs",更新符合 blogId 的记录的内容为{blogText, tags, updateTime}
7. 返回结果 {message: "更新博客记录成功"}`
},
{
role: 'user',
content: prompt.value
}
]
},
{
timeout: 40000,
httpsAgent
}
);
const promptResolve = promptResponse.data.choices?.[0]?.message?.content || '';
if (!promptResolve) {
throw new Error('gpt 异常');
}
prompt.value += `\n${promptResolve}`;
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
// 获取提示词的向量
const { vector: promptVector } = await openaiCreateEmbedding({
isPay: !userApiKey,
apiKey: userApiKey || systemKey,
userId,
text: prompt.value
});
// 读取对话内容
const prompts = [...chat.content, prompt];
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataPrefix}:hash`,
`@modelId:{${String(
chat.modelId._id
)}} @vector:[VECTOR_RANGE 0.25 $blob]=>{$YIELD_DISTANCE_AS: score}`,
// `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`,
'RETURN',
'1',
'text',
'SORTBY',
'score',
'PARAMS',
'2',
'blob',
vectorToBuffer(promptVector),
'LIMIT',
'0',
'20',
'DIALECT',
'2'
]);
// 格式化响应值,获取 qa
const formatRedisPrompt = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
.map((i) => {
if (!redisData[i]) return '';
const text = (redisData[i][1] as string) || '';
if (!text) return '';
return text;
})
.filter((item) => item);
if (formatRedisPrompt.length === 0) {
throw new Error('对不起,我没有找到你的问题');
}
// textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3400);
prompts.unshift({
obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库内容是最新的,知识库内容为: "${systemPrompt}"`
});
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
console.log(formatPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
// max_tokens: modelConstantsData.maxToken,
messages: formatPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true
},
{
timeout: 40000,
responseType: 'stream',
httpsAgent
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
// 创建响应流
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
step = 1;
let responseContent = '';
stream.pipe(res);
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return;
const data = event.data;
if (data === '[DONE]') return;
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
if (!content || (responseContent === '' && content === '\n')) return;
responseContent += content;
// console.log('content:', content)
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
} catch (error) {
error;
}
};
const decoder = new TextDecoder();
try {
for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
const parser = createParser(onParse);
parser.feed(decoder.decode(chunk));
}
} catch (error) {
console.log('pipe error', error);
}
// close stream
!stream.destroyed && stream.push(null);
stream.destroy();
const promptsContent = formatPrompts.map((item) => item.content).join('');
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
userId,
chatId,
text: promptsContent + responseContent
});
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
@@ -11,8 +11,9 @@ import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding } from '@/service/utils/openai';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -56,34 +57,25 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容
const prompts = [...chat.content, prompt];
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 获取提示词的向量
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({
isPay: !userApiKey,
apiKey: userApiKey || systemKey,
userId,
text: prompt.value
});
// 把输入的内容转成向量
const promptVector = await chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: prompt.value
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || []);
// 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataIndex}:hash`,
`idx:${VecModelDataPrefix}:hash`,
`@modelId:{${String(
chat.modelId._id
)}} @vector:[VECTOR_RANGE 0.15 $blob]=>{$YIELD_DISTANCE_AS: score}`,
)}} @vector:[VECTOR_RANGE 0.25 $blob]=>{$YIELD_DISTANCE_AS: score}`,
// `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`,
'RETURN',
'1',
'dataId',
'text',
'SORTBY',
'score',
'PARAMS',
@@ -97,42 +89,28 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
'2'
]);
// 格式化响应值,获取去重后的id
let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
// 格式化响应值,获取 qa
const formatRedisPrompt = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
.map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return redisData[i][1];
if (!redisData[i]) return '';
const text = (redisData[i][1] as string) || '';
if (!text) return '';
return text;
})
.filter((item) => item);
formatIds = Array.from(new Set(formatIds));
if (formatIds.length === 0) {
if (formatRedisPrompt.length === 0) {
throw new Error('对不起,我没有找到你的问题');
}
// 从 mongo 中取出原文作为提示词
const textArr = (
await Promise.all(
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return ModelData.findById(redisData[i][1])
.select('text q')
.then((res) => {
if (!res) return '';
const questions = res.q.map((item) => item.text).join(' ');
const answer = res.text;
return `${questions} ${answer}`;
});
})
)
).filter((item) => item);
// textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(textArr, 2800);
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3400);
prompts.unshift({
obj: 'SYSTEM',
value: `根据下面的知识回答问题: ${systemPrompt}`
value: `${model.systemPrompt} 知识库内容是最新的,知识库内容为: "${systemPrompt}"`
});
// 控制在 tokens 数量,防止超出

View File

@@ -1,9 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -23,25 +21,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 凭证校验
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
const data = await ModelData.findById(dataId);
await ModelData.deleteOne({
_id: dataId,
userId
});
// 删除 redis 数据
data?.q.forEach(async (item) => {
try {
await redis.json.del(`${VecModelDataIndex}:${item.id}`);
} catch (error) {
console.log(error);
}
});
// 校验是否为该用户的数据
const dataItemUserId = await redis.hGet(dataId, 'userId');
if (dataItemUserId !== userId) {
throw new Error('无权操作');
}
// 删除
await redis.del(dataId);
jsonRes(res);
} catch (err) {
console.log(err);

View File

@@ -0,0 +1,69 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataIdx } from '@/constants/redis';
import { BufferToVector } from '@/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
let { modelId } = req.query as {
modelId: string;
};
const { authorization } = req.headers;
if (!authorization) {
throw new Error('无权操作');
}
if (!modelId) {
throw new Error('缺少参数');
}
// 凭证校验
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
// 从 redis 中获取数据
const searchRes = await redis.ft.search(
VecModelDataIdx,
`@modelId:{${modelId}} @userId:{${userId}}`,
{
RETURN: ['q', 'text', 'rawVector'],
LIMIT: {
from: 0,
size: 10000
}
}
);
const data = searchRes.documents
.filter((item) => {
if (!item?.value?.rawVector) return false;
try {
JSON.parse(item.value.rawVector as string);
return true;
} catch (error) {
return false;
}
})
.map((item: any) => ({
prompt: item.value.q,
completion: item.value.text,
vector: JSON.parse(item.value.rawVector)
}));
jsonRes(res, {
data: JSON.stringify(data)
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -1,7 +1,10 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataIdx } from '@/constants/redis';
import { SearchOptions } from 'redis';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -32,24 +35,34 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
const data = await ModelData.find({
modelId,
userId
})
.sort({ _id: -1 }) // 按照创建时间倒序排列
.skip((pageNum - 1) * pageSize)
.limit(pageSize);
// 从 redis 中获取数据
const searchRes = await redis.ft.search(
VecModelDataIdx,
`@modelId:{${modelId}} @userId:{${userId}}`,
{
RETURN: ['q', 'text', 'status'],
LIMIT: {
from: (pageNum - 1) * pageSize,
size: pageSize
},
SORTBY: {
BY: 'modelId',
DIRECTION: 'DESC'
}
}
);
jsonRes(res, {
data: {
pageNum,
pageSize,
data,
total: await ModelData.countDocuments({
modelId,
userId
})
data: searchRes.documents.map((item) => ({
id: item.id,
...item.value
})),
total: searchRes.total
}
});
} catch (err) {

View File

@@ -1,9 +1,11 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData, Model } from '@/service/mongo';
import { connectToDatabase, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelDataSchema } from '@/types/mongoSchema';
import { generateVector } from '@/service/events/generateVector';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -25,6 +27,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
// 验证是否是该用户的 model
const model = await Model.findOne({
@@ -36,19 +39,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作该模型');
}
// push data
await ModelData.insertMany(
data.map((item) => ({
...item,
modelId,
userId
}))
const insertRes = await Promise.allSettled(
data.map((item) => {
return redis.sendCommand([
'HMSET',
`${VecModelDataPrefix}:${item.q.id}`,
'userId',
userId,
'modelId',
modelId,
'q',
item.q.text,
'text',
item.text,
'status',
ModelDataStatusEnum.waiting
]);
})
);
generateVector(true);
jsonRes(res, {
data: model
data: insertRes.filter((item) => item.status === 'rejected').length
});
} catch (err) {
jsonRes(res, {

View File

@@ -0,0 +1,80 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateVector } from '@/service/events/generateVector';
import { vectorToBuffer, formatVector } from '@/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { modelId, data } = req.body as {
modelId: string;
data: { prompt: string; completion: string; vector?: number[] }[];
};
const { authorization } = req.headers;
if (!authorization) {
throw new Error('无权操作');
}
if (!modelId || !Array.isArray(data)) {
throw new Error('缺少参数');
}
// 凭证校验
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
// 插入 redis
const insertRedisRes = await Promise.allSettled(
data.map((item) => {
const vector = item.vector;
return redis.sendCommand([
'HMSET',
`${VecModelDataPrefix}:${nanoid()}`,
'userId',
userId,
'modelId',
String(modelId),
...(vector
? ['vector', vectorToBuffer(formatVector(vector)), 'rawVector', JSON.stringify(vector)]
: []),
'q',
item.prompt,
'text',
item.completion,
'status',
vector ? ModelDataStatusEnum.ready : ModelDataStatusEnum.waiting
]);
})
);
generateVector(true);
jsonRes(res, {
data: insertRedisRes.filter((item) => item.status === 'rejected').length
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -1,57 +0,0 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, DataItem, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string };
if (!dataIds) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
const dataItems = (
await Promise.all(
dataIds.map((dataId) =>
DataItem.find<{ _id: string; result: { q: string }[]; text: string }>(
{
userId,
dataId
},
'result text'
)
)
)
).flat();
// push data
await ModelData.insertMany(
dataItems.map((item) => ({
modelId: modelId,
userId,
text: item.text,
q: item.result.map((item) => ({
id: nanoid(),
text: item.q
}))
}))
);
jsonRes(res, {
data: dataItems
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -1,14 +1,13 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { ModelDataStatusEnum } from '@/constants/redis';
import { generateVector } from '@/service/events/generateVector';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
let { dataId, text } = req.body as {
dataId: string;
text: string;
};
const { dataId, text, q } = req.body as { dataId: string; text: string; q?: string };
const { authorization } = req.headers;
if (!authorization) {
@@ -22,17 +21,26 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 凭证校验
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
await ModelData.updateOne(
{
_id: dataId,
userId
},
{
text
}
);
// 校验是否为该用户的数据
const dataItemUserId = await redis.hGet(dataId, 'userId');
if (dataItemUserId !== userId) {
throw new Error('无权操作');
}
// 更新
await redis.sendCommand([
'HMSET',
dataId,
...(q ? ['q', q, 'status', ModelDataStatusEnum.waiting] : []),
'text',
text
]);
if (q) {
generateVector();
}
jsonRes(res);
} catch (err) {

View File

@@ -8,8 +8,8 @@ import { encode } from 'gpt-token-utils';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { text, modelId } = req.body as { text: string; modelId: string };
if (!text || !modelId) {
const { text, modelId, prompt } = req.body as { text: string; modelId: string; prompt: string };
if (!text || !modelId || !prompt) {
throw new Error('参数错误');
}
await connectToDatabase();
@@ -31,17 +31,25 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const replaceText = text.replace(/(\\n|\n)+/g, ' ');
// 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
const chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
const textList: string[] = [];
let splitText = '';
/* 取 3k ~ 4K tokens 内容 */
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
if (tokens >= 980) {
const tokens = encode(splitText + chunk).length;
if (tokens >= 4000) {
// 超过 4000不要这块内容
textList.push(splitText);
splitText = chunk;
} else if (tokens >= 3000) {
// 超过 3000取内容
textList.push(splitText + chunk);
splitText = '';
} else {
//没超过 3000继续添加
splitText += chunk;
}
});
@@ -54,7 +62,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
userId,
modelId,
rawText: text,
textList
textList,
prompt
});
generateQA();

View File

@@ -1,13 +1,13 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo';
import { authToken, getUserApiOpenai } from '@/service/utils/tools';
import { Chat, Model, Training, connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { getUserApiOpenai } from '@/service/utils/openai';
import { TrainingStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat';
import { TrainingItemType } from '@/types/training';
import { httpsAgent } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis';
import { VecModelDataIdx } from '@/constants/redis';
/* 获取我的模型 */
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
@@ -26,39 +26,38 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 凭证校验
const userId = await authToken(authorization);
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
await connectToDatabase();
const redis = await connectRedis();
const modelDataList = await ModelData.find({
// 获取 redis 中模型关联的所有数据
const searchRes = await redis.ft.search(
VecModelDataIdx,
`@modelId:{${modelId}} @userId:{${userId}}`,
{
LIMIT: {
from: 0,
size: 10000
}
}
);
// 删除 redis 内容
await Promise.all(searchRes.documents.map((item) => redis.del(item.id)));
// 删除对应的聊天
await Chat.deleteMany({
modelId
});
// 删除 redis
modelDataList?.forEach((modelData) =>
modelData.q.forEach(async (item) => {
try {
await redis.json.del(`${VecModelDataIndex}:${item.id}`);
} catch (error) {
console.log(error);
}
})
);
let requestQueue: any[] = [];
// 删除对应的聊天
requestQueue.push(
Chat.deleteMany({
modelId
})
);
// 删除数据集
requestQueue.push(
ModelData.deleteMany({
modelId
})
);
// 查看是否正在训练
const training: TrainingItemType | null = await Training.findOne({
modelId,
@@ -78,21 +77,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}
// 删除对应训练记录
requestQueue.push(
Training.deleteMany({
modelId
})
);
await Training.deleteMany({
modelId
});
// 删除模型
requestQueue.push(
Model.deleteOne({
_id: modelId,
userId
})
);
await Promise.all(requestQueue);
await Model.deleteOne({
_id: modelId,
userId
});
jsonRes(res);
} catch (err) {

View File

@@ -1,8 +1,8 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Model, Training } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat';
import { authToken, getUserApiOpenai } from '@/service/utils/tools';
import { authToken } from '@/service/utils/tools';
import { getUserApiOpenai } from '@/service/utils/openai';
import type { ModelSchema } from '@/types/mongoSchema';
import { TrainingItemType } from '@/types/training';
import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model';

View File

@@ -3,7 +3,8 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Model, Training } from '@/service/mongo';
import formidable from 'formidable';
import { authToken, getUserApiOpenai } from '@/service/utils/tools';
import { authToken } from '@/service/utils/tools';
import { getUserApiOpenai } from '@/service/utils/openai';
import { join } from 'path';
import fs from 'fs';
import type { ModelSchema } from '@/types/mongoSchema';

View File

@@ -1,68 +0,0 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Bill } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import type { BillSchema } from '@/types/mongoSchema';
import { VecModelDataIndex } from '@/constants/redis';
import { connectRedis } from '@/service/redis';
import { vectorToBuffer } from '@/utils/tools';
let vectorData = [
-0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097,
0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969
].concat(new Array(1524).fill(0));
let vectorData2 = [
0.025028639, 0.010407282, 0.026523087, 0.0107438695, -0.006967359, 0.010043768, -0.012043097,
0.008724345, 0.028919589, 0.0117738275, 0.0050690062, 0.02961969
].concat(new Array(1524).fill(0));
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
if (process.env.NODE_ENV !== 'development') {
throw new Error('不是开发环境');
}
await connectToDatabase();
const redis = await connectRedis();
await redis.sendCommand([
'HMSET',
'model:data:333',
'vector',
vectorToBuffer(vectorData2),
'modelId',
'1133',
'dataId',
'safadfa'
]);
// search
const response = await redis.sendCommand([
'FT.SEARCH',
'idx:model:data:hash',
'@modelId:{1133} @vector:[VECTOR_RANGE 0.15 $blob]=>{$YIELD_DISTANCE_AS: score}',
'RETURN',
'2',
'modelId',
'dataId',
'PARAMS',
'2',
'blob',
vectorToBuffer(vectorData2),
'SORTBY',
'score',
'DIALECT',
'2'
]);
jsonRes(res, {
data: response
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -4,8 +4,7 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, Training, Model } from '@/service/mongo';
import type { TrainingItemType } from '@/types/training';
import { TrainingStatusEnum, ModelStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat';
import { getUserApiOpenai } from '@/service/utils/tools';
import { getUserApiOpenai } from '@/service/utils/openai';
import { OpenAiTuneStatusEnum } from '@/service/constants/training';
import { sendTrainSucceed } from '@/service/utils/sendEmail';
import { httpsAgent } from '@/service/utils/tools';

View File

@@ -30,7 +30,6 @@ const Empty = ({ intro }: { intro: string }) => {
<Markdown source={versionIntro} />
</Card>
<Card p={4}>
<Header></Header>
<Markdown source={chatProblem} />
</Card>
</Box>

View File

@@ -120,6 +120,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
const urlMap: Record<string, string> = {
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt',
// [ChatModelNameEnum.VECTOR_GPT]: '/api/chat/lafGpt',
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
};
@@ -191,14 +192,22 @@ const Chat = ({ chatId }: { chatId: string }) => {
* 发送一个内容
*/
const sendPrompt = useCallback(async () => {
if (isChatting) {
toast({
title: '正在聊天中...请等待结束',
status: 'warning'
});
return;
}
const storeInput = inputVal;
// 去除空行
const val = inputVal
.trim()
.split('\n')
.filter((val) => val)
.join('\n');
if (!chatData?.modelId || !val || !ChatBox.current || isChatting) {
const val = inputVal.trim().replace(/\n\s*/g, '\n');
if (!chatData?.modelId || !val) {
toast({
title: '内容为空',
status: 'warning'
});
return;
}
@@ -452,9 +461,8 @@ const Chat = ({ chatId }: { chatId: string }) => {
</Box>
{/* 发送区 */}
<Box m={media('20px auto', '0 auto')} w={'100%'} maxW={media('min(750px, 100%)', 'auto')}>
<Flex
alignItems={'flex-end'}
py={5}
<Box
py={'18px'}
position={'relative'}
boxShadow={`0 0 15px rgba(0,0,0,0.1)`}
border={media('1px solid', '0')}
@@ -465,9 +473,8 @@ const Chat = ({ chatId }: { chatId: string }) => {
{/* 输入框 */}
<Textarea
ref={TextareaDom}
flex={1}
w={0}
py={0}
pr={['45px', '55px']}
border={'none'}
_focusVisible={{
border: 'none'
@@ -481,6 +488,8 @@ const Chat = ({ chatId }: { chatId: string }) => {
maxHeight={'150px'}
maxLength={-1}
overflowY={'auto'}
whiteSpace={'pre-wrap'}
wordBreak={'break-all'}
color={useColorModeValue('blackAlpha.700', 'white')}
onChange={(e) => {
const textarea = e.target;
@@ -500,27 +509,34 @@ const Chat = ({ chatId }: { chatId: string }) => {
}}
/>
{/* 发送和等待按键 */}
<Box px={4} onClick={sendPrompt}>
<Flex
alignItems={'center'}
justifyContent={'center'}
h={'30px'}
w={'30px'}
position={'absolute'}
right={['12px', '20px']}
bottom={'15px'}
onClick={sendPrompt}
>
{isChatting ? (
<Image
style={{ transform: 'translateY(4px)' }}
src={'/icon/chatting.svg'}
width={30}
height={30}
fill
alt={''}
/>
) : (
<Box cursor={'pointer'}>
<Icon
name={'chatSend'}
width={'20px'}
height={'20px'}
fill={useColorModeValue('#718096', 'white')}
></Icon>
</Box>
<Icon
name={'chatSend'}
width={['18px', '20px']}
height={['18px', '20px']}
cursor={'pointer'}
fill={useColorModeValue('#718096', 'white')}
></Icon>
)}
</Box>
</Flex>
</Flex>
</Box>
</Box>
</Flex>
</Flex>

View File

@@ -71,7 +71,6 @@ const Login = () => {
order={1}
flex={`0 0 ${isPc ? '400px' : '100%'}`}
height={'100%'}
maxH={'450px'}
border="1px"
borderColor="gray.200"
py={5}

View File

@@ -1,7 +1,6 @@
import React, { useState, useCallback } from 'react';
import {
Box,
IconButton,
Flex,
Button,
Modal,
@@ -9,66 +8,61 @@ import {
ModalContent,
ModalHeader,
ModalCloseButton,
Input,
Textarea
} from '@chakra-ui/react';
import { useForm, useFieldArray } from 'react-hook-form';
import { postModelDataInput } from '@/api/model';
import { useForm } from 'react-hook-form';
import { postModelDataInput, putModelDataById } from '@/api/model';
import { useToast } from '@/hooks/useToast';
import { DeleteIcon } from '@chakra-ui/icons';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
type FormData = { text: string; q: { val: string }[] };
export type FormData = { dataId?: string; text: string; q: string };
const InputDataModal = ({
onClose,
onSuccess,
modelId
modelId,
defaultValues = {
text: '',
q: ''
}
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
defaultValues?: FormData;
}) => {
const [importing, setImporting] = useState(false);
const { toast } = useToast();
const { register, handleSubmit, control } = useForm<FormData>({
defaultValues: {
text: '',
q: [{ val: '' }]
}
});
const {
fields: inputQ,
append: appendQ,
remove: removeQ
} = useFieldArray({
control,
name: 'q'
const { register, handleSubmit } = useForm<FormData>({
defaultValues
});
/**
* 确认导入新数据
*/
const sureImportData = useCallback(
async (e: FormData) => {
setImporting(true);
try {
await postModelDataInput({
const res = await postModelDataInput({
modelId: modelId,
data: [
{
text: e.text,
q: e.q.map((item) => ({
q: {
id: nanoid(),
text: item.val
}))
text: e.q
}
}
]
});
toast({
title: '导入数据成功,需要一段时间训练',
status: 'success'
title: res === 0 ? '导入数据成功,需要一段时间训练' : '数据导入异常',
status: res === 0 ? 'success' : 'warning'
});
onClose();
onSuccess();
@@ -80,56 +74,83 @@ const InputDataModal = ({
[modelId, onClose, onSuccess, toast]
);
const updateData = useCallback(
async (e: FormData) => {
if (!e.dataId) return;
if (e.text === defaultValues.text && e.q === defaultValues.q) return;
await putModelDataById({
dataId: e.dataId,
text: e.text,
q: e.q === defaultValues.q ? '' : e.q
});
toast({
title: '修改回答成功',
status: 'success'
});
onClose();
onSuccess();
},
[defaultValues.q, onClose, onSuccess, toast]
);
return (
<Modal isOpen={true} onClose={onClose}>
<Modal isOpen={true} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} maxH={'80vh'} position={'relative'}>
<ModalContent
m={0}
display={'flex'}
flexDirection={'column'}
h={'90vh'}
maxW={'90vw'}
position={'relative'}
>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<Box px={6} pb={2} overflowY={'auto'}>
<Box mb={2}>:</Box>
<Textarea
mb={4}
placeholder="知识点"
rows={3}
maxH={'200px'}
{...register(`text`, {
required: '知识点'
})}
/>
{inputQ.map((item, index) => (
<Box key={item.id} mb={5}>
<Box mb={2}>{index + 1}:</Box>
<Flex>
<Input
placeholder="问法"
{...register(`q.${index}.val`, {
required: '问法不能为空'
})}
></Input>
{inputQ.length > 1 && (
<IconButton
icon={<DeleteIcon />}
aria-label={'delete'}
colorScheme={'gray'}
variant={'unstyled'}
onClick={() => removeQ(index)}
/>
)}
</Flex>
</Box>
))}
<Box
display={['block', 'flex']}
flex={'1 0 0'}
h={['100%', 0]}
overflowY={'auto'}
px={6}
pb={2}
>
<Box flex={2} mr={[0, 4]} mb={[4, 0]} h={['230px', '100%']}>
<Box h={'30px'}></Box>
<Textarea
placeholder="相关问题,可以回车输入多个问法, 最多500字"
maxLength={500}
resize={'none'}
h={'calc(100% - 30px)'}
{...register(`q`, {
required: '相关问题,可以回车输入多个问法'
})}
/>
</Box>
<Box flex={3} h={['330px', '100%']}>
<Box h={'30px'}></Box>
<Textarea
placeholder="知识点,最多1000字"
maxLength={1000}
resize={'none'}
h={'calc(100% - 30px)'}
{...register(`text`, {
required: '知识点'
})}
/>
</Box>
</Box>
<Flex px={6} pt={2} pb={4}>
<Button alignSelf={'flex-start'} variant={'outline'} onClick={() => appendQ({ val: '' })}>
</Button>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={importing} onClick={handleSubmit(sureImportData)}>
<Button
isLoading={importing}
onClick={handleSubmit(defaultValues.dataId ? updateData : sureImportData)}
>
</Button>
</Flex>

View File

@@ -1,4 +1,4 @@
import React, { useCallback } from 'react';
import React, { useCallback, useState } from 'react';
import {
Box,
TableContainer,
@@ -12,33 +12,33 @@ import {
Flex,
Button,
useDisclosure,
Textarea,
Menu,
MenuButton,
MenuList,
MenuItem
} from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema';
import { ModelDataSchema } from '@/types/mongoSchema';
import type { RedisModelDataItemType } from '@/types/redis';
import { ModelDataStatusMap } from '@/constants/model';
import { usePagination } from '@/hooks/usePagination';
import {
getModelDataList,
delOneModelData,
putModelDataById,
getModelSplitDataList
getModelSplitDataList,
getExportDataList
} from '@/api/model';
import { DeleteIcon, RepeatIcon } from '@chakra-ui/icons';
import { DeleteIcon, RepeatIcon, EditIcon } from '@chakra-ui/icons';
import { useToast } from '@/hooks/useToast';
import { useLoading } from '@/hooks/useLoading';
import dynamic from 'next/dynamic';
import { useQuery } from '@tanstack/react-query';
import { useMutation, useQuery } from '@tanstack/react-query';
import type { FormData as InputDataType } from './InputDataModal';
const InputModel = dynamic(() => import('./InputDataModal'));
const SelectModel = dynamic(() => import('./SelectFileModal'));
const SelectFileModel = dynamic(() => import('./SelectFileModal'));
const SelectJsonModel = dynamic(() => import('./SelectJsonModal'));
const ModelDataCard = ({ model }: { model: ModelSchema }) => {
const { toast } = useToast();
const { Loading } = useLoading();
const {
@@ -48,40 +48,28 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
total,
getData,
pageNum
} = usePagination<ModelDataSchema>({
} = usePagination<RedisModelDataItemType>({
api: getModelDataList,
pageSize: 10,
pageSize: 8,
params: {
modelId: model._id
}
});
const updateAnswer = useCallback(
async (dataId: string, text: string) => {
await putModelDataById({
dataId,
text
});
toast({
title: '修改回答成功',
status: 'success'
});
},
[toast]
);
const [editInputData, setEditInputData] = useState<InputDataType>();
const {
isOpen: isOpenInputModal,
onOpen: onOpenInputModal,
onClose: onCloseInputModal
isOpen: isOpenSelectFileModal,
onOpen: onOpenSelectFileModal,
onClose: onCloseSelectFileModal
} = useDisclosure();
const {
isOpen: isOpenSelectModal,
onOpen: onOpenSelectModal,
onClose: onCloseSelectModal
isOpen: isOpenSelectJsonModal,
onOpen: onOpenSelectJsonModal,
onClose: onCloseSelectJsonModal
} = useDisclosure();
const { data, refetch } = useQuery(['getModelSplitDataList'], () =>
const { data: splitDataList, refetch } = useQuery(['getModelSplitDataList'], () =>
getModelSplitDataList(model._id)
);
@@ -93,10 +81,29 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
[getData, refetch]
);
// 获取所有的数据,并导出 json
const { mutate: onclickExport, isLoading: isLoadingExport } = useMutation({
mutationFn: () => getExportDataList(model._id),
onSuccess(res) {
// 导出为文件
const blob = new Blob([res], { type: 'application/json;charset=utf-8' });
// 创建下载链接
const downloadLink = document.createElement('a');
downloadLink.href = window.URL.createObjectURL(blob);
downloadLink.download = `data.json`;
// 添加链接到页面并触发下载
document.body.appendChild(downloadLink);
downloadLink.click();
document.body.removeChild(downloadLink);
}
});
return (
<>
<Flex>
<Box fontWeight={'bold'} fontSize={'lg'} flex={1}>
<Box fontWeight={'bold'} fontSize={'lg'} flex={1} mr={2}>
: {total}{' '}
<Box as={'span'} fontSize={'sm'}>
@@ -107,64 +114,84 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
aria-label={'refresh'}
variant={'outline'}
mr={4}
size={'sm'}
onClick={() => refetchData(pageNum)}
/>
<Button
variant={'outline'}
mr={2}
size={'sm'}
isLoading={isLoadingExport}
title={'v2.3之前版本的数据无法导出'}
onClick={() => onclickExport()}
>
</Button>
<Menu>
<MenuButton as={Button}></MenuButton>
<MenuButton as={Button} size={'sm'}>
</MenuButton>
<MenuList>
<MenuItem onClick={onOpenInputModal}></MenuItem>
<MenuItem onClick={onOpenSelectModal}></MenuItem>
<MenuItem
onClick={() =>
setEditInputData({
text: '',
q: ''
})
}
>
</MenuItem>
<MenuItem onClick={onOpenSelectFileModal}></MenuItem>
<MenuItem onClick={onOpenSelectJsonModal}>JSON导入</MenuItem>
</MenuList>
</Menu>
</Flex>
{data && data.length > 0 && <Box fontSize={'xs'}>{data.length}...</Box>}
{splitDataList && splitDataList.length > 0 && (
<Box fontSize={'xs'}>
{splitDataList.map((item) => item.textList).flat().length}...
</Box>
)}
<Box mt={4}>
<TableContainer h={'600px'} overflowY={'auto'}>
<TableContainer minH={'500px'}>
<Table variant={'simple'}>
<Thead>
<Tr>
<Th>Question</Th>
<Th>Text</Th>
<Th>Status</Th>
<Th></Th>
<Th></Th>
</Tr>
</Thead>
<Tbody>
{modelDataList.map((item) => (
<Tr key={item._id}>
<Td w={'350px'}>
{item.q.map((item, i) => (
<Box
key={item.id}
fontSize={'xs'}
w={'100%'}
whiteSpace={'pre-wrap'}
_notLast={{ mb: 1 }}
>
Q{i + 1}:{' '}
<Box as={'span'} userSelect={'all'}>
{item.text}
</Box>
</Box>
))}
<Tr key={item.id}>
<Td>
<Box fontSize={'xs'} w={'100%'} whiteSpace={'pre-wrap'}>
{item.q}
</Box>
</Td>
<Td minW={'200px'}>
<Textarea
w={'100%'}
h={'100%'}
defaultValue={item.text}
fontSize={'xs'}
resize={'both'}
onBlur={(e) => {
const oldVal = modelDataList.find((data) => item._id === data._id)?.text;
if (oldVal !== e.target.value) {
updateAnswer(item._id, e.target.value);
}
}}
></Textarea>
<Box w={'100%'} fontSize={'xs'} whiteSpace={'pre-wrap'}>
{item.text}
</Box>
</Td>
<Td w={'100px'}>{ModelDataStatusMap[item.status]}</Td>
<Td>{ModelDataStatusMap[item.status]}</Td>
<Td>
<IconButton
mr={5}
icon={<EditIcon />}
variant={'outline'}
aria-label={'delete'}
size={'sm'}
onClick={() =>
setEditInputData({
dataId: item.id,
q: item.q,
text: item.text
})
}
/>
<IconButton
icon={<DeleteIcon />}
variant={'outline'}
@@ -172,7 +199,7 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
aria-label={'delete'}
size={'sm'}
onClick={async () => {
await delOneModelData(item._id);
await delOneModelData(item.id);
refetchData(pageNum);
}}
/>
@@ -188,11 +215,27 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
</Box>
<Loading loading={isLoading} fixed={false} />
{isOpenInputModal && (
<InputModel modelId={model._id} onClose={onCloseInputModal} onSuccess={refetchData} />
{editInputData !== undefined && (
<InputModel
modelId={model._id}
defaultValues={editInputData}
onClose={() => setEditInputData(undefined)}
onSuccess={refetchData}
/>
)}
{isOpenSelectModal && (
<SelectModel modelId={model._id} onClose={onCloseSelectModal} onSuccess={refetchData} />
{isOpenSelectFileModal && (
<SelectFileModel
modelId={model._id}
onClose={onCloseSelectFileModal}
onSuccess={refetchData}
/>
)}
{isOpenSelectJsonModal && (
<SelectJsonModel
modelId={model._id}
onClose={onCloseSelectJsonModal}
onSuccess={refetchData}
/>
)}
</>
);

View File

@@ -108,7 +108,7 @@ const ModelEditForm = ({
<Slider
aria-label="slider-ex-1"
min={1}
min={0}
max={10}
step={1}
value={getValues('temperature')}
@@ -138,24 +138,17 @@ const ModelEditForm = ({
</Flex>
</FormControl>
<Box mt={4}>
{canTrain ? (
<Box fontWeight={'bold'}>
prompt
使 tokens
</Box>
) : (
<>
<Box mb={1}></Box>
<Textarea
rows={6}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'
}
/>
</>
)}
<Box mb={1}></Box>
<Textarea
rows={6}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
canTrain
? '训练的模型会根据知识库内容,生成一部分系统提示词,因此在对话时需要消耗更多的 tokens。你仍可以增加一些提示词让其效果更精确。'
: '模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'
}
/>
</Box>
</Card>
{/* <Card p={4}>

View File

@@ -8,7 +8,8 @@ import {
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody
ModalBody,
Input
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useSelectFile } from '@/hooks/useSelectFile';
@@ -34,6 +35,7 @@ const SelectFileModal = ({
}) => {
const [selecting, setSelecting] = useState(false);
const { toast } = useToast();
const [prompt, setPrompt] = useState('');
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [fileText, setFileText] = useState('');
const { openConfirm, ConfirmChild } = useConfirm({
@@ -83,7 +85,11 @@ const SelectFileModal = ({
const { mutate, isLoading } = useMutation({
mutationFn: async () => {
if (!fileText) return;
await postModelDataFileText(modelId, fileText);
await postModelDataFileText({
modelId,
text: fileText,
prompt: `下面是${prompt || '一段长文本'}`
});
toast({
title: '导入数据成功,需要一段拆解和训练',
status: 'success'
@@ -100,40 +106,54 @@ const SelectFileModal = ({
});
return (
<Modal isOpen={true} onClose={onClose}>
<Modal isOpen={true} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} position={'relative'}>
<ModalContent maxW={'min(900px, 90vw)'} m={0} position={'relative'} h={'90vh'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex
flexDirection={'column'}
p={2}
h={'100%'}
alignItems={'center'}
justifyContent={'center'}
fontSize={'sm'}
>
<Button isLoading={selecting} onClick={onOpen}>
</Button>
<Box mt={2}> {fileExtension} . </Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
h={'300px'}
w={'100%'}
overflow={'auto'}
p={2}
backgroundColor={'blackAlpha.50'}
whiteSpace={'pre'}
fontSize={'xs'}
>
{fileText}
<ModalBody
display={'flex'}
flexDirection={'column'}
p={4}
h={'100%'}
alignItems={'center'}
justifyContent={'center'}
fontSize={'sm'}
>
<Button isLoading={selecting} onClick={onOpen}>
</Button>
<Box mt={2} maxW={['100%', '70%']}>
{fileExtension} QA
tokens0.04/1k tokens
</Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Flex w={'100%'} alignItems={'center'} my={4}>
<Box flex={'0 0 auto'} mr={2}>
</Box>
<Input
placeholder="提示词,例如: Laf的介绍/关于gpt4的论文/一段长文本"
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
size={'sm'}
/>
</Flex>
<Box
flex={'1 0 0'}
h={0}
w={'100%'}
overflowY={'auto'}
p={2}
backgroundColor={'blackAlpha.50'}
whiteSpace={'pre-wrap'}
fontSize={'xs'}
>
{fileText}
</Box>
</ModalBody>
<Flex px={6} pt={2} pb={4}>

View File

@@ -0,0 +1,145 @@
import React, { useState, useCallback } from 'react';
import {
Box,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useSelectFile } from '@/hooks/useSelectFile';
import { useConfirm } from '@/hooks/useConfirm';
import { readTxtContent } from '@/utils/tools';
import { useMutation } from '@tanstack/react-query';
import { postModelDataJsonData } from '@/api/model';
import Markdown from '@/components/Markdown';
const SelectJsonModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [selecting, setSelecting] = useState(false);
const { toast } = useToast();
const { File, onOpen } = useSelectFile({ fileType: '.json', multiple: true });
const [fileData, setFileData] = useState<
{ prompt: string; completion: string; vector?: number[] }[]
>([]);
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认导入该数据集?'
});
const onSelectFile = useCallback(
async (e: File[]) => {
setSelecting(true);
try {
const jsonData = (
await Promise.all(e.map((item) => readTxtContent(item).then((text) => JSON.parse(text))))
).flat();
// check 文件类型
for (let i = 0; i < jsonData.length; i++) {
if (!jsonData[i]?.prompt || !jsonData[i]?.completion) {
throw new Error('缺少 prompt 或 completion');
}
}
setFileData(jsonData);
} catch (error: any) {
console.log(error);
toast({
title: error?.message || 'JSON文件格式有误',
status: 'error'
});
}
setSelecting(false);
},
[setSelecting, toast]
);
const { mutate, isLoading } = useMutation({
mutationFn: async () => {
if (!fileData) return;
const res = await postModelDataJsonData(modelId, fileData);
console.log(res);
toast({
title: '导入数据成功,需要一段时间训练',
status: 'success'
});
onClose();
onSuccess();
},
onError() {
toast({
title: '导入文件失败',
status: 'error'
});
}
});
return (
<Modal isOpen={true} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent maxW={'90vw'} position={'relative'} m={0} h={'90vh'}>
<ModalHeader>JSON数据集</ModalHeader>
<ModalCloseButton />
<ModalBody h={'100%'} display={['block', 'flex']} fontSize={'sm'} overflowY={'auto'}>
<Box flex={'2 0 0'} w={['100%', 0]} mr={[0, 4]} mb={[4, 0]}>
<Markdown
source={`接受一个对象数组,每个对象必须包含 prompt 和 completion 格式可以包含vector。prompt 代表问题completion 代表回答的内容可以多个问题对应一个回答vector 为 prompt 的向量,如果没有讲有系统生成。例如:
~~~json
[
{
"prompt":"sealos是什么?\\n介绍下sealos\\nsealos有什么用",
"completion":"sealos是xxxxxx"
},
{
"prompt":"laf是什么?",
"completion":"laf是xxxxxx",
"vector":[-0.42,-0.4314314,0.43143]
}
]
~~~`}
/>
<Flex alignItems={'center'}>
<Button isLoading={selecting} onClick={onOpen}>
JSON
</Button>
<Box ml={4}> {fileData.length} </Box>
</Flex>
</Box>
<Box flex={'2 0 0'} h={'100%'} overflow={'auto'} p={2} backgroundColor={'blackAlpha.50'}>
{JSON.stringify(fileData)}
</Box>
</ModalBody>
<Flex px={6} pt={2} pb={4}>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button
isLoading={isLoading}
isDisabled={fileData.length === 0}
onClick={openConfirm(mutate)}
>
</Button>
</Flex>
</ModalContent>
<ConfirmChild />
<File onSelect={onSelectFile} />
</Modal>
);
};
export default SelectJsonModal;

View File

@@ -42,12 +42,12 @@ const ModelTable = ({
dataIndex: 'status',
render: (item: ModelSchema) => (
<Tag
colorScheme={formatModelStatus[item.status].colorTheme}
colorScheme={formatModelStatus[item.status]?.colorTheme}
variant="solid"
px={3}
size={'md'}
>
{formatModelStatus[item.status].text}
{formatModelStatus[item.status]?.text}
</Tag>
)
},

View File

@@ -1,14 +1,13 @@
import { DataItem } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
import { httpsAgent } from '@/service/utils/tools';
import { getOpenApiKey } from '../utils/openai';
import type { ChatCompletionRequestMessage } from 'openai';
import { DataItemSchema } from '@/types/mongoSchema';
import { ChatModelNameEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
export async function generateAbstract(next = false): Promise<any> {
if (process.env.NODE_ENV === 'development') return;
if (global.generatingAbstract && !next) return;
global.generatingAbstract = true;
@@ -85,36 +84,6 @@ export async function generateAbstract(next = false): Promise<any> {
const rawContent: string = abstractResponse?.data.choices[0].message?.content || '';
// 从 content 中提取摘要内容
const splitContents = splitText(rawContent);
// console.log(rawContent);
// 生成词向量
// const vectorResponse = await Promise.allSettled(
// splitContents.map((item) =>
// chatAPI.createEmbedding(
// {
// model: 'text-embedding-ada-002',
// input: item.abstract
// },
// {
// timeout: 120000,
// httpsAgent
// }
// )
// )
// );
// 筛选成功的向量请求
// const vectorSuccessResponse = vectorResponse
// .map((item: any, i) => {
// if (item.status !== 'fulfilled') {
// // 没有词向量的【摘要】不要
// console.log('获取词向量错误: ', item);
// return '';
// }
// return {
// abstract: splitContents[i].abstract,
// abstractVector: item?.value?.data?.data?.[0]?.embedding
// };
// })
// .filter((item) => item);
// 插入数据库,并修改状态
await DataItem.findByIdAndUpdate(dataItem._id, {

View File

@@ -1,10 +1,13 @@
import { SplitData, ModelData } from '@/service/mongo';
import { SplitData } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
import { httpsAgent } from '@/service/utils/tools';
import { getOpenApiKey } from '../utils/openai';
import type { ChatCompletionRequestMessage } from 'openai';
import { ChatModelNameEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { connectRedis } from '../redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
@@ -12,12 +15,8 @@ export async function generateQA(next = false): Promise<any> {
if (global.generatingQA && !next) return;
global.generatingQA = true;
const systemPrompt: ChatCompletionRequestMessage = {
role: 'system',
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,并按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
};
try {
const redis = await connectRedis();
// 找出一个需要生成的 dataItem
const dataItem = await SplitData.findOne({
textList: { $exists: true, $ne: [] }
@@ -29,8 +28,10 @@ export async function generateQA(next = false): Promise<any> {
return;
}
// 源文本
const text = dataItem.textList[dataItem.textList.length - 1];
if (!text) {
await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }); // 弹出无效文本
throw new Error('无文本');
}
@@ -47,6 +48,7 @@ export async function generateQA(next = false): Promise<any> {
textList: [],
errorText: error.message
});
throw new Error('账号余额不足');
}
throw new Error('获取 openai key 失败');
@@ -58,12 +60,19 @@ export async function generateQA(next = false): Promise<any> {
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(userApiKey || systemKey);
const systemPrompt: ChatCompletionRequestMessage = {
role: 'system',
content: `${
dataItem.prompt || '下面是一段长文本'
},请从中提取出5至30个问题和答案,并按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
};
// 请求 chatgpt 获取回答
const response = await chatAPI
.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
temperature: 0.2,
temperature: 0.8,
n: 1,
messages: [
systemPrompt,
@@ -74,31 +83,34 @@ export async function generateQA(next = false): Promise<any> {
]
},
{
timeout: 120000,
timeout: 180000,
httpsAgent
}
)
.then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})); // 从 content 中提取 QA
rawContent: res?.data.choices[0].message?.content || '', // chatgpt原本的回复
result: splitText(res?.data.choices[0].message?.content || '') // 格式化后的QA对
}));
await Promise.allSettled([
SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }),
ModelData.insertMany(
response.result.map((item) => ({
modelId: dataItem.modelId,
userId: dataItem.userId,
text: item.a,
q: [
{
id: nanoid(),
text: item.q
}
],
status: 1
}))
)
SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }), // 弹出已经拆分的文本
...response.result.map((item) => {
// 插入 redis
return redis.sendCommand([
'HMSET',
`${VecModelDataPrefix}:${nanoid()}`,
'userId',
String(dataItem.userId),
'modelId',
String(dataItem.modelId),
'q',
item.q,
'text',
item.a,
'status',
'waiting'
]);
})
]);
console.log(
@@ -132,7 +144,7 @@ export async function generateQA(next = false): Promise<any> {
* 检查文本是否按格式返回
*/
function splitText(text: string) {
const regex = /Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)(.*)(\s*)/g; // 匹配Q和A的正则表达式
const regex = /Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)([\s\S]*?)(?=Q|$)/g; // 匹配Q和A的正则表达式
const matches = text.matchAll(regex); // 获取所有匹配到的结果
const result = []; // 存储最终的结果
@@ -140,7 +152,11 @@ function splitText(text: string) {
const q = match[2];
const a = match[5];
if (q && a) {
result.push({ q, a }); // 如果Q和A都存在就将其添加到结果中
// 如果Q和A都存在就将其添加到结果中
result.push({
q,
a: a.trim().replace(/\n\s*/g, '\n')
});
}
}

View File

@@ -1,9 +1,10 @@
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent } from '@/service/utils/tools';
import { ModelData } from '../models/modelData';
import { connectRedis } from '../redis';
import { VecModelDataIndex } from '@/constants/redis';
import { VecModelDataIdx } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { ModelDataStatusEnum } from '@/constants/redis';
import { openaiCreateEmbedding, getOpenApiKey } from '../utils/openai';
export async function generateVector(next = false): Promise<any> {
if (global.generatingVector && !next) return;
@@ -12,74 +13,65 @@ export async function generateVector(next = false): Promise<any> {
try {
const redis = await connectRedis();
// 找出一个需要生成的 dataItem
const dataItem = await ModelData.findOne({
status: { $ne: 0 }
});
// 找出一个 status = waiting 的数据
const searchRes = await redis.ft.search(
VecModelDataIdx,
`@status:{${ModelDataStatusEnum.waiting}}`,
{
RETURN: ['q', 'userId'],
LIMIT: {
from: 0,
size: 1
}
}
);
if (!dataItem) {
if (searchRes.total === 0) {
console.log('没有需要生成 【向量】 的数据');
global.generatingVector = false;
return;
}
const dataItem: { id: string; q: string; userId: string } = {
id: searchRes.documents[0].id,
q: String(searchRes.documents[0]?.value?.q || ''),
userId: String(searchRes.documents[0]?.value?.userId || '')
};
// 获取 openapi Key
const openAiKey = process.env.OPENAIKEY as string;
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(openAiKey);
const dataId = String(dataItem._id);
const { userApiKey, systemKey } = await getOpenApiKey(dataItem.userId);
// 生成词向量
const response = await Promise.allSettled(
dataItem.q.map((item, i) =>
chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: item.text
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || [])
.then((vector) =>
redis.sendCommand([
'HMSET',
`${VecModelDataIndex}:${item.id}`,
'vector',
vectorToBuffer(vector),
'modelId',
String(dataItem.modelId),
'dataId',
String(dataId)
])
)
)
);
if (response.filter((item) => item.status === 'fulfilled').length === 0) {
throw new Error(JSON.stringify(response));
}
// 修改该数据状态
await ModelData.findByIdAndUpdate(dataItem._id, {
status: 0
const { vector } = await openaiCreateEmbedding({
text: dataItem.q,
userId: dataItem.userId,
isPay: !userApiKey,
apiKey: userApiKey || systemKey
});
console.log(`生成向量成功: ${dataItem._id}`);
// 更新 redis 向量和状态数据
await redis.sendCommand([
'HMSET',
dataItem.id,
'vector',
vectorToBuffer(vector),
'rawVector',
JSON.stringify(vector),
'status',
ModelDataStatusEnum.ready
]);
console.log(`生成向量成功: ${dataItem.id}`);
setTimeout(() => {
generateVector(true);
}, 3000);
}, 2000);
} catch (error: any) {
console.log(error);
console.log('error: 生成向量错误', error?.response?.data);
console.log('error: 生成向量错误', error?.response?.statusText);
!error?.response && console.log(error);
if (error?.response?.statusText === 'Too Many Requests') {
console.log('次数限制1分钟后尝试');
console.log('生成向量次数限制1分钟后尝试');
// 限制次数1分钟后再试
setTimeout(() => {
generateVector(true);

View File

@@ -2,6 +2,7 @@ import { connectToDatabase, Bill, User } from '../mongo';
import { modelList, ChatModelNameEnum } from '@/constants/model';
import { encode } from 'gpt-token-utils';
import { formatPrice } from '@/utils/user';
import { BillTypeEnum } from '@/constants/user';
import type { DataType } from '@/types/data';
export const pushChatBill = async ({
@@ -23,8 +24,7 @@ export const pushChatBill = async ({
// 计算 token 数量
const tokens = encode(text);
console.log('text len: ', text.length);
console.log('token len:', tokens.length);
console.log(`chat generate success. text len: ${text.length}. token len: ${tokens.length}`);
if (isPay) {
await connectToDatabase();
@@ -34,7 +34,7 @@ export const pushChatBill = async ({
// 计算价格
const unitPrice = modelItem?.price || 5;
const price = unitPrice * tokens.length;
console.log(`chat bill, unit price: ${unitPrice}, price: ${formatPrice(price)}`);
console.log(`unit price: ${unitPrice}, price: ${formatPrice(price)}`);
try {
// 插入 Bill 记录
@@ -82,8 +82,9 @@ export const pushSplitDataBill = async ({
// 计算 token 数量
const tokens = encode(text);
console.log('text len: ', text.length);
console.log('token len:', tokens.length);
console.log(
`splitData generate success. text len: ${text.length}. token len: ${tokens.length}`
);
if (isPay) {
try {
@@ -93,7 +94,7 @@ export const pushSplitDataBill = async ({
// 计算价格
const price = unitPrice * tokens.length;
console.log(`splitData bill, price: ${formatPrice(price)}`);
console.log(`price: ${formatPrice(price)}`);
// 插入 Bill 记录
const res = await Bill.create({
@@ -119,3 +120,55 @@ export const pushSplitDataBill = async ({
console.log(error);
}
};
export const pushGenerateVectorBill = async ({
isPay,
userId,
text
}: {
isPay: boolean;
userId: string;
text: string;
}) => {
await connectToDatabase();
let billId;
try {
// 计算 token 数量
const tokens = encode(text);
console.log(`vector generate success. text len: ${text.length}. token len: ${tokens.length}`);
if (isPay) {
try {
const unitPrice = 1;
// 计算价格
const price = unitPrice * tokens.length;
console.log(`price: ${formatPrice(price)}`);
// 插入 Bill 记录
const res = await Bill.create({
userId,
type: BillTypeEnum.vector,
modelName: ChatModelNameEnum.VECTOR,
textLen: text.length,
tokenLen: tokens.length,
price
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
}
} catch (error) {
console.log(error);
}
};

View File

@@ -16,7 +16,7 @@ const BillSchema = new Schema({
},
modelName: {
type: String,
enum: modelList.map((item) => item.model),
enum: [...modelList.map((item) => item.model), 'text-embedding-ada-002'],
required: true
},
chatId: {

View File

@@ -1,37 +0,0 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelDataSchema as ModelDataType } from '@/types/mongoSchema';
const ModelDataSchema = new Schema({
modelId: {
type: Schema.Types.ObjectId,
ref: 'model',
required: true
},
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
text: {
type: String,
required: true
},
q: {
type: [
{
id: String, // 对应redis的key
text: String
}
],
default: []
},
status: {
type: Number,
enum: [0, 1], // 1 训练ing
default: 1
}
});
export const ModelData: MongoModel<ModelDataType> =
models['modelData'] || model('modelData', ModelDataSchema);

View File

@@ -8,6 +8,11 @@ const SplitDataSchema = new Schema({
ref: 'user',
required: true
},
prompt: {
// 拆分时的提示词
type: String,
required: true
},
modelId: {
type: Schema.Types.ObjectId,
ref: 'model',

View File

@@ -17,7 +17,7 @@ export async function connectToDatabase(): Promise<void> {
mongoose.set('strictQuery', true);
global.mongodb = await mongoose.connect(process.env.MONGODB_URI as string, {
bufferCommands: true,
dbName: 'doc_gpt',
dbName: process.env.NODE_ENV === 'development' ? 'doc_gpt_test' : 'doc_gpt',
maxPoolSize: 5,
minPoolSize: 1,
maxConnecting: 5
@@ -35,7 +35,6 @@ export async function connectToDatabase(): Promise<void> {
export * from './models/authCode';
export * from './models/chat';
export * from './models/model';
export * from './models/modelData';
export * from './models/user';
export * from './models/training';
export * from './models/bill';

View File

@@ -29,8 +29,8 @@ export const connectRedis = async () => {
await global.redisClient.connect();
// 0 - 测试库,1 - 正式
await global.redisClient.select(0);
// 1 - 测试库,0 - 正式
await global.redisClient.SELECT(0);
return global.redisClient;
} catch (error) {

View File

@@ -1,7 +1,8 @@
import { Configuration, OpenAIApi } from 'openai';
import { Chat } from '../mongo';
import type { ChatPopulate } from '@/types/mongoSchema';
import { authToken, getOpenApiKey } from './tools';
import { authToken } from './tools';
import { getOpenApiKey } from './openai';
export const getOpenAIApi = (apiKey: string) => {
const configuration = new Configuration({

132
src/service/utils/openai.ts Normal file
View File

@@ -0,0 +1,132 @@
import axios from 'axios';
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent } from './tools';
import { User } from '../models/user';
import { formatPrice } from '@/utils/user';
import { ChatModelNameEnum } from '@/constants/model';
import { pushGenerateVectorBill } from '../events/pushBill';
/* 判断 apikey 是否还有余额 */
export const checkKeyGrant = async (apiKey: string) => {
const grant = await axios.get('https://api.openai.com/dashboard/billing/credit_grants', {
headers: {
Authorization: `Bearer ${apiKey}`
},
httpsAgent
});
if (grant.data?.total_available <= 0.2) {
return false;
}
return true;
};
/* 获取用户 api 的 openai 信息 */
export const getUserApiOpenai = async (userId: string) => {
const user = await User.findById(userId);
const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value;
if (!userApiKey) {
return Promise.reject('缺少ApiKey, 无法请求');
}
// 余额校验
const hasGrant = await checkKeyGrant(userApiKey);
if (!hasGrant) {
return Promise.reject({
code: 501,
message: 'API 余额不足'
});
}
return {
user,
openai: getOpenAIApi(userApiKey),
apiKey: userApiKey
};
};
/* 获取 open api key如果用户没有自己的key就用平台的用平台记得加账单 */
export const getOpenApiKey = async (userId: string, checkGrant = false) => {
const user = await User.findById(userId);
if (!user) {
return Promise.reject('找不到用户');
}
const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value;
// 有自己的key
if (userApiKey) {
// api 余额校验
if (checkGrant) {
const hasGrant = await checkKeyGrant(userApiKey);
if (!hasGrant) {
return Promise.reject({
code: 501,
message: 'API 余额不足'
});
}
}
return {
user,
userApiKey,
systemKey: ''
};
}
// 平台账号余额校验
if (formatPrice(user.balance) <= 0) {
return Promise.reject({
code: 501,
message: '账号余额不足'
});
}
return {
user,
userApiKey: '',
systemKey: process.env.OPENAIKEY as string
};
};
/* 获取向量 */
export const openaiCreateEmbedding = async ({
isPay,
userId,
apiKey,
text
}: {
isPay: boolean;
userId: string;
apiKey: string;
text: string;
}) => {
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 把输入的内容转成向量
const vector = await chatAPI
.createEmbedding(
{
model: ChatModelNameEnum.VECTOR,
input: text
},
{
timeout: 60000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || []);
pushGenerateVectorBill({
isPay,
userId,
text
});
return {
vector,
chatAPI
};
};

View File

@@ -1,12 +1,8 @@
import crypto from 'crypto';
import jwt from 'jsonwebtoken';
import { User } from '../models/user';
import tunnel from 'tunnel';
import { formatPrice } from '@/utils/user';
import { ChatItemType } from '@/types/chat';
import { encode } from 'gpt-token-utils';
import { getOpenAIApi } from '@/service/utils/chat';
import axios from 'axios';
/* 密码加密 */
export const hashPassword = (psw: string) => {
@@ -56,90 +52,6 @@ export const httpsAgent =
})
: undefined;
/* 判断 apikey 是否还有余额 */
export const checkKeyGrant = async (apiKey: string) => {
const grant = await axios.get('https://api.openai.com/dashboard/billing/credit_grants', {
headers: {
Authorization: `Bearer ${apiKey}`
},
httpsAgent
});
if (grant.data?.total_available <= 0.2) {
return false;
}
return true;
};
/* 获取用户 api 的 openai 信息 */
export const getUserApiOpenai = async (userId: string) => {
const user = await User.findById(userId);
const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value;
if (!userApiKey) {
return Promise.reject('缺少ApiKey, 无法请求');
}
// 余额校验
const hasGrant = await checkKeyGrant(userApiKey);
if (!hasGrant) {
return Promise.reject({
code: 501,
message: 'API 余额不足'
});
}
return {
user,
openai: getOpenAIApi(userApiKey),
apiKey: userApiKey
};
};
/* 获取 open api key如果用户没有自己的key就用平台的用平台记得加账单 */
export const getOpenApiKey = async (userId: string, checkGrant = false) => {
const user = await User.findById(userId);
if (!user) {
return Promise.reject('找不到用户');
}
const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value;
// 有自己的key
if (userApiKey) {
// api 余额校验
if (checkGrant) {
const hasGrant = await checkKeyGrant(userApiKey);
if (!hasGrant) {
return Promise.reject({
code: 501,
message: 'API 余额不足'
});
}
}
return {
user,
userApiKey,
systemKey: ''
};
}
// 平台账号余额校验
if (formatPrice(user.balance) <= 0) {
return Promise.reject({
code: 501,
message: '账号余额不足'
});
}
return {
user,
userApiKey: '',
systemKey: process.env.OPENAIKEY as string
};
};
/* tokens 截断 */
export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) => {
let res: ChatItemType[] = [];

View File

@@ -60,7 +60,7 @@ export interface ModelDataSchema {
q: {
id: string;
text: string;
}[];
};
status: ModelDataType;
}
@@ -69,6 +69,7 @@ export interface ModelSplitDataSchema {
userId: string;
modelId: string;
rawText: string;
prompt: string;
errorText: string;
textList: string[];
}

View File

@@ -1,6 +1,7 @@
import { ModelDataStatusEnum } from '@/constants/redis';
export interface RedisModelDataItemType {
id: string;
vector: number[];
dataId: string;
modelId: string;
q: string;
text: string;
status: `${ModelDataStatusEnum}`;
}

View File

@@ -123,7 +123,26 @@ export const readDocContent = (file: File) =>
});
export const vectorToBuffer = (vector: number[]) => {
let npVector = new Float32Array(vector);
const npVector = new Float32Array(vector);
return Buffer.from(npVector.buffer);
const buffer = Buffer.from(npVector.buffer);
return buffer;
};
export const BufferToVector = (bufferStr: string) => {
let buffer = Buffer.from(`bufferStr`, 'binary'); // 将字符串转换成 Buffer 对象
const npVector = new Float32Array(
buffer,
buffer.byteOffset,
buffer.byteLength / Float32Array.BYTES_PER_ELEMENT
);
return Array.from(npVector);
};
export function formatVector(vector: number[]) {
let formattedVector = vector.slice(0, 1536); // 截取前1536个元素
if (vector.length > 1536) {
formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0
}
return formattedVector;
}