training queue

This commit is contained in:
archer
2023-05-26 23:08:25 +08:00
parent 69f32a0861
commit dc1c1d1355
32 changed files with 528 additions and 493 deletions

View File

@@ -53,10 +53,11 @@ function responseError(err: any) {
}
/* 创建请求实例 */
const instance = axios.create({
export const instance = axios.create({
timeout: 60000, // 超时时间
baseURL: `http://localhost:${process.env.PORT || 3000}/api`,
headers: {
'content-type': 'application/json'
rootkey: process.env.ROOT_KEY
}
});
@@ -75,7 +76,6 @@ function request(url: string, data: any, config: ConfigType, method: Method): an
return instance
.request({
baseURL: `http://localhost:${process.env.PORT || 3000}/api`,
url,
method,
data: method === 'GET' ? null : data,
@@ -93,18 +93,30 @@ function request(url: string, data: any, config: ConfigType, method: Method): an
* @param {Object} config
* @returns
*/
export function GET<T>(url: string, params = {}, config: ConfigType = {}): Promise<T> {
export function GET<T = { data: any }>(
url: string,
params = {},
config: ConfigType = {}
): Promise<T> {
return request(url, params, config, 'GET');
}
export function POST<T>(url: string, data = {}, config: ConfigType = {}): Promise<T> {
export function POST<T = { data: any }>(
url: string,
data = {},
config: ConfigType = {}
): Promise<T> {
return request(url, data, config, 'POST');
}
export function PUT<T>(url: string, data = {}, config: ConfigType = {}): Promise<T> {
export function PUT<T = { data: any }>(
url: string,
data = {},
config: ConfigType = {}
): Promise<T> {
return request(url, data, config, 'PUT');
}
export function DELETE<T>(url: string, config: ConfigType = {}): Promise<T> {
export function DELETE<T = { data: any }>(url: string, config: ConfigType = {}): Promise<T> {
return request(url, {}, config, 'DELETE');
}

View File

@@ -1,75 +1,55 @@
import { SplitData } from '@/service/mongo';
import { TrainingData } from '@/service/mongo';
import { getApiKey } from '../utils/auth';
import { OpenAiChatEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { openaiError2 } from '../errorCode';
import { insertKbItem } from '@/service/pg';
import { SplitDataSchema } from '@/types/mongoSchema';
import { modelServiceToolMap } from '../utils/chat';
import { ChatRoleEnum } from '@/constants/chat';
import { getErrText } from '@/utils/tools';
import { BillTypeEnum } from '@/constants/user';
import { pushDataToKb } from '@/pages/api/openapi/kb/pushData';
import { ERROR_ENUM } from '../errorCode';
export async function generateQA(next = false): Promise<any> {
if (process.env.queueTask !== '1') {
try {
fetch(process.env.parentUrl || '');
} catch (error) {
console.log('parentUrl fetch error', error);
}
return;
}
if (global.generatingQA === true && !next) return;
global.generatingQA = true;
let dataId = null;
// 每次最多选 1 组
const listLen = 1;
export async function generateQA(trainingId: string): Promise<any> {
try {
// 找出一个需要生成的 dataItem
const data = await SplitData.aggregate([
{ $match: { textList: { $exists: true, $ne: [] } } },
{ $sample: { size: 1 } }
]);
// 找出一个需要生成的 dataItem (4分钟锁)
const data = await TrainingData.findOneAndUpdate(
{
_id: trainingId,
lockTime: { $lte: Date.now() - 4 * 60 * 1000 }
},
{
lockTime: new Date()
}
);
const dataItem: SplitDataSchema = data[0];
if (!dataItem) {
console.log('没有需要生成 QA 的数据');
global.generatingQA = false;
return;
}
dataId = dataItem._id;
// 获取 5 个源文本
const textList: string[] = dataItem.textList.slice(-5);
// 获取 openapi Key
let userOpenAiKey = '',
systemAuthKey = '';
try {
const key = await getApiKey({ model: OpenAiChatEnum.GPT35, userId: dataItem.userId });
userOpenAiKey = key.userOpenAiKey;
systemAuthKey = key.systemAuthKey;
} catch (err: any) {
// 余额不够了, 清空该记录
await SplitData.findByIdAndUpdate(dataItem._id, {
textList: [],
errorText: getErrText(err, '获取 OpenAi Key 失败')
if (!data || data.qaList.length === 0) {
await TrainingData.findOneAndDelete({
_id: trainingId,
qaList: [],
vectorList: []
});
generateQA(true);
return;
}
console.log(`正在生成一组QA, 包含 ${textList.length} 组文本。ID: ${dataItem._id}`);
const qaList: string[] = data.qaList.slice(-listLen);
// 余额校验并获取 openapi Key
const { userOpenAiKey, systemAuthKey } = await getApiKey({
model: OpenAiChatEnum.GPT35,
userId: data.userId,
type: 'training'
});
console.log(`正在生成一组QA, 包含 ${qaList.length} 组文本。ID: ${data._id}`);
const startTime = Date.now();
// 请求 chatgpt 获取回答
const response = await Promise.allSettled(
textList.map((text) =>
const response = await Promise.all(
qaList.map((text) =>
modelServiceToolMap[OpenAiChatEnum.GPT35]
.chatCompletion({
apiKey: userOpenAiKey || systemAuthKey,
@@ -78,7 +58,7 @@ export async function generateQA(next = false): Promise<any> {
{
obj: ChatRoleEnum.System,
value: `你是出题人
${dataItem.prompt || '下面是"一段长文本"'}
${data.prompt || '下面是"一段长文本"'}
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
A1:
Q2:
@@ -98,7 +78,7 @@ A2:
// 计费
pushSplitDataBill({
isPay: !userOpenAiKey && result.length > 0,
userId: dataItem.userId,
userId: data.userId,
type: BillTypeEnum.QA,
textLen: responseMessages.map((item) => item.value).join('').length,
totalTokens
@@ -116,57 +96,59 @@ A2:
)
);
// 获取成功的回答
const successResponse: {
rawContent: string;
result: {
q: string;
a: string;
}[];
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
const responseList = response.map((item) => item.result).flat();
const resultList = successResponse.map((item) => item.result).flat();
// 创建 向量生成 队列
pushDataToKb({
kbId: data.kbId,
data: responseList,
userId: data.userId
});
await Promise.allSettled([
// 删掉后5个数据
SplitData.findByIdAndUpdate(dataItem._id, {
textList: dataItem.textList.slice(0, -5)
}),
// 生成的内容插入 pg
insertKbItem({
userId: dataItem.userId,
kbId: dataItem.kbId,
data: resultList
})
]);
console.log('生成QA成功time:', `${(Date.now() - startTime) / 1000}s`);
generateQA(true);
generateVector();
} catch (error: any) {
// log
if (error?.response) {
console.log('openai error: 生成QA错误');
console.log(error.response?.status, error.response?.statusText, error.response?.data);
// 删除 QA 队列。如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个
if (data.vectorList.length <= listLen) {
await TrainingData.findByIdAndDelete(data._id);
} else {
console.log('生成QA错误:', error);
await TrainingData.findByIdAndUpdate(data._id, {
qaList: data.qaList.slice(0, -listLen),
lockTime: new Date('2000/1/1')
});
}
// 没有余额或者凭证错误时,拒绝任务
if (dataId && openaiError2[error?.response?.data?.error?.type]) {
console.log(openaiError2[error?.response?.data?.error?.type], '删除QA任务');
console.log('生成QA成功time:', `${(Date.now() - startTime) / 1000}s`);
await SplitData.findByIdAndUpdate(dataId, {
textList: [],
errorText: 'api 余额不足'
});
generateQA(trainingId);
} catch (err: any) {
// log
if (err?.response) {
console.log('openai error: 生成QA错误');
console.log(err.response?.status, err.response?.statusText, err.response?.data);
} else {
console.log('生成QA错误:', err);
}
generateQA(true);
// openai 账号异常或者账号余额不足,删除任务
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) {
console.log('余额不足,删除向量生成任务');
await TrainingData.findByIdAndDelete(trainingId);
return;
}
// unlock
await TrainingData.findByIdAndUpdate(trainingId, {
lockTime: new Date('2000/1/1')
});
// 频率限制
if (err?.response?.statusText === 'Too Many Requests') {
console.log('生成向量次数限制30s后尝试');
return setTimeout(() => {
generateQA(trainingId);
}, 30000);
}
setTimeout(() => {
generateQA(true);
generateQA(trainingId);
}, 1000);
}
}

View File

@@ -1,107 +1,137 @@
import { getApiKey } from '../utils/auth';
import { openaiError2 } from '../errorCode';
import { PgClient } from '@/service/pg';
import { getErrText } from '@/utils/tools';
import { insertKbItem, PgClient } from '@/service/pg';
import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding';
import { TrainingData } from '../models/trainingData';
import { ERROR_ENUM } from '../errorCode';
export async function generateVector(next = false): Promise<any> {
if (process.env.queueTask !== '1') {
try {
fetch(process.env.parentUrl || '');
} catch (error) {
console.log('parentUrl fetch error', error);
}
return;
}
if (global.generatingVector && !next) return;
global.generatingVector = true;
let dataId = null;
// 每次最多选 5 组
const listLen = 5;
/* 索引生成队列。每导入一次,就是一个单独的线程 */
export async function generateVector(trainingId: string): Promise<any> {
try {
// 找出一个 status = waiting 的数据
const searchRes = await PgClient.select('modelData', {
fields: ['id', 'q', 'user_id'],
where: [['status', 'waiting']],
limit: 1
});
if (searchRes.rowCount === 0) {
console.log('没有需要生成 【向量】 的数据');
global.generatingVector = false;
return;
}
const dataItem: { id: string; q: string; userId: string } = {
id: searchRes.rows[0].id,
q: searchRes.rows[0].q,
userId: searchRes.rows[0].user_id
};
dataId = dataItem.id;
// 获取 openapi Key
try {
await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId });
} catch (err: any) {
await PgClient.delete('modelData', {
where: [['id', dataId]]
});
getErrText(err, '获取 OpenAi Key 失败');
return generateVector(true);
}
// 生成词向量
const vectors = await openaiEmbedding({
input: [dataItem.q],
userId: dataItem.userId
});
// 更新 pg 向量和状态数据
await PgClient.update('modelData', {
values: [
{ key: 'vector', value: `[${vectors[0]}]` },
{ key: 'status', value: `ready` }
],
where: [['id', dataId]]
});
console.log(`生成向量成功: ${dataItem.id}`);
generateVector(true);
} catch (error: any) {
// log
if (error?.response) {
console.log('openai error: 生成向量错误');
console.log(error.response?.status, error.response?.statusText, error.response?.data);
} else {
console.log('生成向量错误:', error);
}
// 没有余额或者凭证错误时,拒绝任务
if (dataId && openaiError2[error?.response?.data?.error?.type]) {
console.log('删除向量生成任务记录');
try {
await PgClient.delete('modelData', {
where: [['id', dataId]]
});
} catch (error) {
error;
// 找出一个需要生成的 dataItem (2分钟锁)
const data = await TrainingData.findOneAndUpdate(
{
_id: trainingId,
lockTime: { $lte: Date.now() - 2 * 60 * 1000 }
},
{
lockTime: new Date()
}
generateVector(true);
);
if (!data) {
await TrainingData.findOneAndDelete({
_id: trainingId,
qaList: [],
vectorList: []
});
return;
}
if (error?.response?.statusText === 'Too Many Requests') {
console.log('生成向量次数限制1分钟后尝试');
// 限制次数1分钟后再试
setTimeout(() => {
generateVector(true);
}, 60000);
const userId = String(data.userId);
const kbId = String(data.kbId);
const dataItems: { q: string; a: string }[] = data.vectorList.slice(-listLen).map((item) => ({
q: item.q,
a: item.a
}));
// 过滤重复的 qa 内容
const searchRes = await Promise.allSettled(
dataItems.map(async ({ q, a = '' }) => {
if (!q) {
return Promise.reject('q为空');
}
q = q.replace(/\\n/g, '\n');
a = a.replace(/\\n/g, '\n');
// Exactly the same data, not push
try {
const count = await PgClient.count('modelData', {
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
});
if (count > 0) {
return Promise.reject('已经存在');
}
} catch (error) {
error;
}
return Promise.resolve({
q,
a
});
})
);
const filterData = searchRes
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
if (filterData.length > 0) {
// 生成词向量
const vectors = await openaiEmbedding({
input: filterData.map((item) => item.q),
userId,
type: 'training'
});
// 生成结果插入到 pg
await insertKbItem({
userId,
kbId,
data: vectors.map((vector, i) => ({
q: filterData[i].q,
a: filterData[i].a,
vector
}))
});
}
// 删除 mongo 训练队列. 如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个
if (data.vectorList.length <= listLen) {
await TrainingData.findByIdAndDelete(trainingId);
console.log(`全部向量生成完毕: ${trainingId}`);
} else {
await TrainingData.findByIdAndUpdate(trainingId, {
vectorList: data.vectorList.slice(0, -listLen),
lockTime: new Date('2000/1/1')
});
console.log(`生成向量成功: ${trainingId}`);
generateVector(trainingId);
}
} catch (err: any) {
// log
if (err?.response) {
console.log('openai error: 生成向量错误');
console.log(err.response?.status, err.response?.statusText, err.response?.data);
} else {
console.log('生成向量错误:', err);
}
// openai 账号异常或者账号余额不足,删除任务
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) {
console.log('余额不足,删除向量生成任务');
await TrainingData.findByIdAndDelete(trainingId);
return;
}
// unlock
await TrainingData.findByIdAndUpdate(trainingId, {
lockTime: new Date('2000/1/1')
});
// 频率限制
if (err?.response?.statusText === 'Too Many Requests') {
console.log('生成向量次数限制30s后尝试');
return setTimeout(() => {
generateVector(trainingId);
}, 30000);
}
setTimeout(() => {
generateVector(true);
generateVector(trainingId);
}, 1000);
}
}

View File

@@ -2,7 +2,7 @@ import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSchema as ModelType } from '@/types/mongoSchema';
import {
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum,
appVectorSearchModeEnum,
ChatModelMap,
OpenAiChatEnum
} from '@/constants/model';
@@ -40,7 +40,7 @@ const ModelSchema = new Schema({
// knowledge base search mode
type: String,
enum: Object.keys(ModelVectorSearchModeMap),
default: ModelVectorSearchModeEnum.hightSimilarity
default: appVectorSearchModeEnum.hightSimilarity
},
systemPrompt: {
// 系统提示词

View File

@@ -1,32 +0,0 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { SplitDataSchema as SplitDataType } from '@/types/mongoSchema';
const SplitDataSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
prompt: {
// 拆分时的提示词
type: String,
required: true
},
kbId: {
type: Schema.Types.ObjectId,
ref: 'kb',
required: true
},
textList: {
type: [String],
default: []
},
errorText: {
type: String,
default: ''
}
});
export const SplitData: MongoModel<SplitDataType> =
models['splitData'] || model('splitData', SplitDataSchema);

View File

@@ -0,0 +1,38 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { TrainingDataSchema as TrainingDateType } from '@/types/mongoSchema';
// pgList and vectorList, Only one of them will work
const TrainingDataSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
kbId: {
type: Schema.Types.ObjectId,
ref: 'kb',
required: true
},
lockTime: {
type: Date,
default: () => new Date('2000/1/1')
},
vectorList: {
type: [{ q: String, a: String }],
default: []
},
prompt: {
// 拆分时的提示词
type: String,
default: ''
},
qaList: {
type: [String],
default: []
}
});
export const TrainingData: MongoModel<TrainingDateType> =
models['trainingData'] || model('trainingData', TrainingDataSchema);

View File

@@ -2,6 +2,7 @@ import mongoose from 'mongoose';
import { generateQA } from './events/generateQA';
import { generateVector } from './events/generateVector';
import tunnel from 'tunnel';
import { TrainingData } from './mongo';
/**
* 连接 MongoDB 数据库
@@ -27,9 +28,6 @@ export async function connectToDatabase(): Promise<void> {
global.mongodb = null;
}
generateQA();
generateVector();
// 创建代理对象
if (process.env.AXIOS_PROXY_HOST && process.env.AXIOS_PROXY_PORT) {
global.httpsAgent = tunnel.httpsOverHttp({
@@ -39,6 +37,34 @@ export async function connectToDatabase(): Promise<void> {
}
});
}
startTrain();
// 5 分钟后解锁不正常的数据,并触发开始训练
setTimeout(async () => {
await TrainingData.updateMany(
{
lockTime: { $lte: Date.now() - 5 * 60 * 1000 }
},
{
lockTime: new Date('2000/1/1')
}
);
startTrain();
}, 5 * 60 * 1000);
}
async function startTrain() {
const qa = await TrainingData.find({
qaList: { $exists: true, $ne: [] }
});
qa.map((item) => generateQA(String(item._id)));
const vector = await TrainingData.find({
vectorList: { $exists: true, $ne: [] }
});
vector.map((item) => generateVector(String(item._id)));
}
export * from './models/authCode';
@@ -47,7 +73,7 @@ export * from './models/model';
export * from './models/user';
export * from './models/bill';
export * from './models/pay';
export * from './models/splitData';
export * from './models/trainingData';
export * from './models/openapi';
export * from './models/promotionRecord';
export * from './models/collection';

View File

@@ -1,5 +1,6 @@
import { Pool } from 'pg';
import type { QueryResultRow } from 'pg';
import { ModelDataStatusEnum } from '@/constants/model';
export const connectPg = async () => {
if (global.pgClient) {
@@ -168,6 +169,7 @@ export const insertKbItem = ({
userId: string;
kbId: string;
data: {
vector: number[];
q: string;
a: string;
}[];
@@ -178,7 +180,8 @@ export const insertKbItem = ({
{ key: 'kb_id', value: kbId },
{ key: 'q', value: item.q },
{ key: 'a', value: item.a },
{ key: 'status', value: 'waiting' }
{ key: 'vector', value: `[${item.vector}]` },
{ key: 'status', value: ModelDataStatusEnum.ready }
])
});
};

View File

@@ -5,12 +5,14 @@ import { Chat, Model, OpenApi, User, ShareChat, KB } from '../mongo';
import type { ModelSchema } from '@/types/mongoSchema';
import type { ChatItemSimpleType } from '@/types/chat';
import mongoose from 'mongoose';
import { ClaudeEnum, defaultModel } from '@/constants/model';
import { ClaudeEnum, defaultModel, embeddingModel, EmbeddingModelType } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
import { ChatModelType, OpenAiChatEnum } from '@/constants/model';
import { hashPassword } from '@/service/utils/tools';
export type ApiKeyType = 'training' | 'chat';
export const parseCookie = (cookie?: string): Promise<string> => {
return new Promise((resolve, reject) => {
// 获取 cookie
@@ -118,9 +120,15 @@ export const authUser = async ({
};
/* random get openai api key */
export const getSystemOpenAiKey = () => {
export const getSystemOpenAiKey = (type: ApiKeyType) => {
const keys = (() => {
if (type === 'training') {
return process.env.OPENAI_TRAINING_KEY?.split(',') || [];
}
return process.env.OPENAIKEY?.split(',') || [];
})();
// 纯字符串类型
const keys = process.env.OPENAIKEY?.split(',') || [];
const i = Math.floor(Math.random() * keys.length);
return keys[i] || (process.env.OPENAIKEY as string);
};
@@ -129,11 +137,13 @@ export const getSystemOpenAiKey = () => {
export const getApiKey = async ({
model,
userId,
mustPay = false
mustPay = false,
type = 'chat'
}: {
model: ChatModelType;
userId: string;
mustPay?: boolean;
type?: ApiKeyType;
}) => {
const user = await User.findById(userId);
if (!user) {
@@ -143,7 +153,7 @@ export const getApiKey = async ({
const keyMap = {
[OpenAiChatEnum.GPT35]: {
userOpenAiKey: user.openaiKey || '',
systemAuthKey: getSystemOpenAiKey() as string
systemAuthKey: getSystemOpenAiKey(type) as string
},
[OpenAiChatEnum.GPT4]: {
userOpenAiKey: user.openaiKey || '',