diff --git a/client/src/constants/plugin.ts b/client/src/constants/plugin.ts index 187a7b7a2..3ab989044 100644 --- a/client/src/constants/plugin.ts +++ b/client/src/constants/plugin.ts @@ -6,3 +6,5 @@ export const TrainingTypeMap = { [TrainingModeEnum.qa]: 'qa', [TrainingModeEnum.index]: 'index' }; + +export const PgTrainingTableName = 'modeldata'; diff --git a/client/src/pages/api/openapi/kb/delDataById.ts b/client/src/pages/api/openapi/kb/delDataById.ts index 7cad054ab..a0e7ac9ef 100644 --- a/client/src/pages/api/openapi/kb/delDataById.ts +++ b/client/src/pages/api/openapi/kb/delDataById.ts @@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response'; import { authUser } from '@/service/utils/auth'; import { PgClient } from '@/service/pg'; import { withNextCors } from '@/service/utils/tools'; +import { PgTrainingTableName } from '@/constants/plugin'; export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -17,7 +18,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex // 凭证校验 const { userId } = await authUser({ req }); - await PgClient.delete('modelData', { + await PgClient.delete(PgTrainingTableName, { where: [['user_id', userId], 'AND', ['id', dataId]] }); diff --git a/client/src/pages/api/openapi/kb/pushData.ts b/client/src/pages/api/openapi/kb/pushData.ts index 53336f9c4..a668637d0 100644 --- a/client/src/pages/api/openapi/kb/pushData.ts +++ b/client/src/pages/api/openapi/kb/pushData.ts @@ -4,7 +4,7 @@ import { connectToDatabase, TrainingData } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { authKb } from '@/service/utils/auth'; import { withNextCors } from '@/service/utils/tools'; -import { TrainingModeEnum } from '@/constants/plugin'; +import { PgTrainingTableName, TrainingModeEnum } from '@/constants/plugin'; import { startQueue } from '@/service/utils/tools'; import { PgClient } from '@/service/pg'; import { modelToolMap } from '@/utils/plugin'; @@ -129,7 +129,7 @@ export async function pushDataToKb({ try { const { rows } = await PgClient.query(` SELECT COUNT(*) > 0 AS exists - FROM modelData + FROM ${PgTrainingTableName} WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}' `); const exists = rows[0]?.exists || false; diff --git a/client/src/pages/api/openapi/kb/searchTest.ts b/client/src/pages/api/openapi/kb/searchTest.ts index fb4bcc9b8..fc43960d7 100644 --- a/client/src/pages/api/openapi/kb/searchTest.ts +++ b/client/src/pages/api/openapi/kb/searchTest.ts @@ -5,6 +5,7 @@ import { PgClient } from '@/service/pg'; import { withNextCors } from '@/service/utils/tools'; import { getVector } from '../plugin/vector'; import type { KbTestItemType } from '@/types/plugin'; +import { PgTrainingTableName } from '@/constants/plugin'; export type Props = { model: string; @@ -39,7 +40,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex SET LOCAL ivfflat.probes = ${global.systemEnv.pgIvfflatProbe || 10}; select id,q,a,source,(vector <#> '[${ vectors[0] - }]') * -1 AS score from modelData where kb_id='${kbId}' AND user_id='${userId}' order by vector <#> '[${ + }]') * -1 AS score from ${PgTrainingTableName} where kb_id='${kbId}' AND user_id='${userId}' order by vector <#> '[${ vectors[0] }]' limit 12; COMMIT;` diff --git a/client/src/pages/api/openapi/kb/updateData.ts b/client/src/pages/api/openapi/kb/updateData.ts index acea5308f..1c6043330 100644 --- a/client/src/pages/api/openapi/kb/updateData.ts +++ b/client/src/pages/api/openapi/kb/updateData.ts @@ -5,6 +5,7 @@ import { PgClient } from '@/service/pg'; import { withNextCors } from '@/service/utils/tools'; import { KB, connectToDatabase } from '@/service/mongo'; import { getVector } from '../plugin/vector'; +import { PgTrainingTableName } from '@/constants/plugin'; export type Props = { dataId: string; @@ -46,7 +47,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex })(); // 更新 pg 内容.仅修改a,不需要更新向量。 - await PgClient.update('modelData', { + await PgClient.update(PgTrainingTableName, { where: [['id', dataId], 'AND', ['user_id', userId]], values: [ { key: 'source', value: '手动修改' }, diff --git a/client/src/pages/api/plugins/kb/data/exportModelData.ts b/client/src/pages/api/plugins/kb/data/exportModelData.ts index 3bf6cd5b9..9b3bc90a4 100644 --- a/client/src/pages/api/plugins/kb/data/exportModelData.ts +++ b/client/src/pages/api/plugins/kb/data/exportModelData.ts @@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response'; import { connectToDatabase, User } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { PgClient } from '@/service/pg'; +import { PgTrainingTableName } from '@/constants/plugin'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -38,16 +39,19 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< } // 统计数据 - const count = await PgClient.count('modelData', { + const count = await PgClient.count(PgTrainingTableName, { where: [['kb_id', kbId], 'AND', ['user_id', userId]] }); // 从 pg 中获取所有数据 - const pgData = await PgClient.select<{ q: string; a: string; source: string }>('modelData', { - where: [['kb_id', kbId], 'AND', ['user_id', userId]], - fields: ['q', 'a', 'source'], - order: [{ field: 'id', mode: 'DESC' }], - limit: count - }); + const pgData = await PgClient.select<{ q: string; a: string; source: string }>( + PgTrainingTableName, + { + where: [['kb_id', kbId], 'AND', ['user_id', userId]], + fields: ['q', 'a', 'source'], + order: [{ field: 'id', mode: 'DESC' }], + limit: count + } + ); const data: [string, string, string][] = pgData.rows.map((item) => [ item.q.replace(/\n/g, '\\n'), diff --git a/client/src/pages/api/plugins/kb/data/getDataById.ts b/client/src/pages/api/plugins/kb/data/getDataById.ts index f7ee9db9e..5d6ea143d 100644 --- a/client/src/pages/api/plugins/kb/data/getDataById.ts +++ b/client/src/pages/api/plugins/kb/data/getDataById.ts @@ -4,6 +4,7 @@ import { connectToDatabase } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { PgClient } from '@/service/pg'; import type { KbDataItemType } from '@/types/plugin'; +import { PgTrainingTableName } from '@/constants/plugin'; export type Response = { id: string; @@ -28,7 +29,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< const where: any = [['user_id', userId], 'AND', ['id', dataId]]; - const searchRes = await PgClient.select('modelData', { + const searchRes = await PgClient.select(PgTrainingTableName, { fields: ['kb_id', 'id', 'q', 'a', 'source'], where, limit: 1 diff --git a/client/src/pages/api/plugins/kb/data/getDataList.ts b/client/src/pages/api/plugins/kb/data/getDataList.ts index acbe902a7..daf9b529d 100644 --- a/client/src/pages/api/plugins/kb/data/getDataList.ts +++ b/client/src/pages/api/plugins/kb/data/getDataList.ts @@ -4,6 +4,7 @@ import { connectToDatabase } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { PgClient } from '@/service/pg'; import type { KbDataItemType } from '@/types/plugin'; +import { PgTrainingTableName } from '@/constants/plugin'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -41,14 +42,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< ]; const [searchRes, total] = await Promise.all([ - PgClient.select('modelData', { + PgClient.select(PgTrainingTableName, { fields: ['id', 'q', 'a', 'source'], where, order: [{ field: 'id', mode: 'DESC' }], limit: pageSize, offset: pageSize * (pageNum - 1) }), - PgClient.count('modelData', { + PgClient.count(PgTrainingTableName, { fields: ['id'], where }) diff --git a/client/src/pages/api/plugins/kb/delete.ts b/client/src/pages/api/plugins/kb/delete.ts index a956cb0e0..defbef9d8 100644 --- a/client/src/pages/api/plugins/kb/delete.ts +++ b/client/src/pages/api/plugins/kb/delete.ts @@ -4,6 +4,7 @@ import { connectToDatabase, KB, App, TrainingData } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { PgClient } from '@/service/pg'; import { Types } from 'mongoose'; +import { PgTrainingTableName } from '@/constants/plugin'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -21,7 +22,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); // delete all pg data - await PgClient.delete('modelData', { + await PgClient.delete(PgTrainingTableName, { where: [['user_id', userId], 'AND', ['kb_id', id]] }); diff --git a/client/src/service/moduleDispatch/kb/search.ts b/client/src/service/moduleDispatch/kb/search.ts index 54e88f3be..8f63ba36c 100644 --- a/client/src/service/moduleDispatch/kb/search.ts +++ b/client/src/service/moduleDispatch/kb/search.ts @@ -5,6 +5,7 @@ import { getVector } from '@/pages/api/openapi/plugin/vector'; import { countModelPrice } from '@/service/events/pushBill'; import type { SelectedKbType } from '@/types/plugin'; import type { QuoteItemType } from '@/types/chat'; +import { PgTrainingTableName } from '@/constants/plugin'; type KBSearchProps = { kbList: SelectedKbType; @@ -48,7 +49,7 @@ export async function dispatchKBSearch(props: Record): Promise `'${item.kbId}'`) .join(',')}) AND vector <#> '[${vectors[0]}]' < -${similarity} order by vector <#> '[${ vectors[0] diff --git a/client/src/service/mongo.ts b/client/src/service/mongo.ts index 37995def3..e2e9d75e5 100644 --- a/client/src/service/mongo.ts +++ b/client/src/service/mongo.ts @@ -6,6 +6,7 @@ import { User } from './models/user'; import { PRICE_SCALE } from '@/constants/common'; import { connectPg, PgClient } from './pg'; import { createHashPassword } from '@/utils/tools'; +import { PgTrainingTableName } from '@/constants/plugin'; /** * connect MongoDB and init data @@ -92,7 +93,7 @@ async function initPg() { await connectPg(); await PgClient.query(` CREATE EXTENSION IF NOT EXISTS vector; - CREATE TABLE IF NOT EXISTS modelData ( + CREATE TABLE IF NOT EXISTS ${PgTrainingTableName} ( id BIGSERIAL PRIMARY KEY, vector VECTOR(1536) NOT NULL, user_id VARCHAR(50) NOT NULL, @@ -101,9 +102,9 @@ async function initPg() { q TEXT NOT NULL, a TEXT NOT NULL ); - CREATE INDEX IF NOT EXISTS modelData_userId_index ON modelData USING HASH (user_id); - CREATE INDEX IF NOT EXISTS modelData_kbId_index ON modelData USING HASH (kb_id); - CREATE INDEX IF NOT EXISTS idx_model_data_md5_q_a_user_id_kb_id ON modelData (md5(q), md5(a), user_id, kb_id); + CREATE INDEX IF NOT EXISTS modelData_userId_index ON ${PgTrainingTableName} USING HASH (user_id); + CREATE INDEX IF NOT EXISTS modelData_kbId_index ON ${PgTrainingTableName} USING HASH (kb_id); + CREATE INDEX IF NOT EXISTS idx_model_data_md5_q_a_user_id_kb_id ON ${PgTrainingTableName} (md5(q), md5(a), user_id, kb_id); `); console.log('init pg successful'); } catch (error) { diff --git a/client/src/service/pg.ts b/client/src/service/pg.ts index 4123506fa..65452976a 100644 --- a/client/src/service/pg.ts +++ b/client/src/service/pg.ts @@ -1,5 +1,6 @@ import { Pool } from 'pg'; import type { QueryResultRow } from 'pg'; +import { PgTrainingTableName } from '@/constants/plugin'; export const connectPg = async () => { if (global.pgClient) { @@ -173,7 +174,7 @@ export const insertKbItem = ({ source?: string; }[]; }) => { - return PgClient.insert('modelData', { + return PgClient.insert(PgTrainingTableName, { values: data.map((item) => [ { key: 'user_id', value: userId }, { key: 'kb_id', value: kbId },