feat: lafClaude

This commit is contained in:
archer
2023-05-04 10:53:55 +08:00
parent 3c8f38799c
commit 0d6897e180
22 changed files with 327 additions and 231 deletions

View File

@@ -45,12 +45,12 @@ export async function generateQA(next = false): Promise<any> {
const textList: string[] = dataItem.textList.slice(-5);
// 获取 openapi Key
let userApiKey = '',
systemApiKey = '';
let userOpenAiKey = '',
systemAuthKey = '';
try {
const key = await getApiKey({ model: OpenAiChatEnum.GPT35, userId: dataItem.userId });
userApiKey = key.userApiKey;
systemApiKey = key.systemApiKey;
userOpenAiKey = key.userOpenAiKey;
systemAuthKey = key.systemAuthKey;
} catch (error: any) {
if (error?.code === 501) {
// 余额不够了, 清空该记录
@@ -73,18 +73,18 @@ export async function generateQA(next = false): Promise<any> {
textList.map((text) =>
modelServiceToolMap[OpenAiChatEnum.GPT35]
.chatCompletion({
apiKey: userApiKey || systemApiKey,
apiKey: userOpenAiKey || systemAuthKey,
temperature: 0.8,
messages: [
{
obj: ChatRoleEnum.System,
value: `你是出题人
${dataItem.prompt || '下面是"一段长文本"'}
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
A1:
Q2:
A2:
...`
${dataItem.prompt || '下面是"一段长文本"'}
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
A1:
Q2:
A2:
...`
},
{
obj: 'Human',
@@ -98,7 +98,7 @@ export async function generateQA(next = false): Promise<any> {
console.log(`split result length: `, result.length);
// 计费
pushSplitDataBill({
isPay: !userApiKey && result.length > 0,
isPay: !userOpenAiKey && result.length > 0,
userId: dataItem.userId,
type: 'QA',
textLen: responseMessages.map((item) => item.value).join('').length,

View File

@@ -2,7 +2,6 @@ import { openaiCreateEmbedding } from '../utils/chat/openai';
import { getApiKey } from '../utils/auth';
import { openaiError2 } from '../errorCode';
import { PgClient } from '@/service/pg';
import { embeddingModel } from '@/constants/model';
export async function generateVector(next = false): Promise<any> {
if (process.env.queueTask !== '1') {
@@ -42,11 +41,10 @@ export async function generateVector(next = false): Promise<any> {
dataId = dataItem.id;
// 获取 openapi Key
let userApiKey, systemApiKey;
let userOpenAiKey;
try {
const res = await getApiKey({ model: embeddingModel, userId: dataItem.userId });
userApiKey = res.userApiKey;
systemApiKey = res.systemApiKey;
const res = await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId });
userOpenAiKey = res.userOpenAiKey;
} catch (error: any) {
if (error?.code === 501) {
await PgClient.delete('modelData', {
@@ -63,8 +61,7 @@ export async function generateVector(next = false): Promise<any> {
const { vectors } = await openaiCreateEmbedding({
textArr: [dataItem.q],
userId: dataItem.userId,
userApiKey,
systemApiKey
userOpenAiKey
});
// 更新 pg 向量和状态数据

View File

@@ -3,22 +3,20 @@ import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/
import { ModelSchema } from '@/types/mongoSchema';
import { openaiCreateEmbedding } from '../utils/chat/openai';
import { ChatRoleEnum } from '@/constants/chat';
import { sliceTextByToken } from '@/utils/chat';
import { modelToolMap } from '@/utils/chat';
import { ChatItemSimpleType } from '@/types/chat';
/**
* use openai embedding search kb
*/
export const searchKb = async ({
userApiKey,
systemApiKey,
userOpenAiKey,
prompts,
similarity = 0.2,
model,
userId
}: {
userApiKey?: string;
systemApiKey: string;
userOpenAiKey?: string;
prompts: ChatItemSimpleType[];
model: ModelSchema;
userId: string;
@@ -33,8 +31,7 @@ export const searchKb = async ({
async function search(textArr: string[] = []) {
// 获取提示词的向量
const { vectors: promptVectors } = await openaiCreateEmbedding({
userApiKey,
systemApiKey,
userOpenAiKey,
userId,
textArr
});
@@ -81,11 +78,24 @@ export const searchKb = async ({
].filter((item) => item);
const systemPrompts = await search(searchArr);
// filter system prompt
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
// filter system prompts.
const filterRateMap: Record<number, number[]> = {
1: [1],
2: [0.7, 0.3]
};
const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0];
const filterSystemPrompt = filterRate
.map((rate, i) =>
modelToolMap[model.chat.chatModel].sliceText({
text: systemPrompts[i],
length: Math.floor(modelConstantsData.systemMaxToken * rate)
})
)
.join('\n');
/* 高相似度+不回复 */
if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) {
return {
code: 201,
searchPrompt: {
@@ -95,7 +105,7 @@ export const searchKb = async ({
};
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (systemPrompts.length === 0 && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
return {
code: 200,
searchPrompt: model.chat.systemPrompt
@@ -107,25 +117,7 @@ export const searchKb = async ({
};
}
/* 有匹配情况下system 添加知识库内容。 */
// filter system prompts. max 70% tokens
const filterRateMap: Record<number, number[]> = {
1: [0.7],
2: [0.5, 0.2]
};
const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0];
const filterSystemPrompt = filterRate
.map((rate, i) =>
sliceTextByToken({
model: model.chat.chatModel,
text: systemPrompts[i],
length: Math.floor(modelConstantsData.contextMaxToken * rate)
})
)
.join('\n');
/* 有匹配 */
return {
code: 200,
searchPrompt: {
@@ -133,9 +125,9 @@ export const searchKb = async ({
value: `
${model.chat.systemPrompt}
${
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : ''
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? '不回答知识库外的内容.' : ''
}
知识库内容为: ${filterSystemPrompt}'
知识库内容为: '${filterSystemPrompt}'
`
}
};

View File

@@ -4,15 +4,10 @@ import { Chat, Model, OpenApi, User } from '../mongo';
import type { ModelSchema } from '@/types/mongoSchema';
import type { ChatItemSimpleType } from '@/types/chat';
import mongoose from 'mongoose';
import { defaultModel } from '@/constants/model';
import { ClaudeEnum, defaultModel } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
import {
ChatModelType,
OpenAiChatEnum,
embeddingModel,
EmbeddingModelType
} from '@/constants/model';
import { ChatModelType, OpenAiChatEnum } from '@/constants/model';
/* 校验 token */
export const authToken = (token?: string): Promise<string> => {
@@ -34,13 +29,7 @@ export const authToken = (token?: string): Promise<string> => {
};
/* 获取 api 请求的 key */
export const getApiKey = async ({
model,
userId
}: {
model: ChatModelType | EmbeddingModelType;
userId: string;
}) => {
export const getApiKey = async ({ model, userId }: { model: ChatModelType; userId: string }) => {
const user = await User.findById(userId);
if (!user) {
return Promise.reject({
@@ -51,29 +40,29 @@ export const getApiKey = async ({
const keyMap = {
[OpenAiChatEnum.GPT35]: {
userApiKey: user.openaiKey || '',
systemApiKey: process.env.OPENAIKEY as string
userOpenAiKey: user.openaiKey || '',
systemAuthKey: process.env.OPENAIKEY as string
},
[OpenAiChatEnum.GPT4]: {
userApiKey: user.openaiKey || '',
systemApiKey: process.env.OPENAIKEY as string
userOpenAiKey: user.openaiKey || '',
systemAuthKey: process.env.OPENAIKEY as string
},
[OpenAiChatEnum.GPT432k]: {
userApiKey: user.openaiKey || '',
systemApiKey: process.env.OPENAIKEY as string
userOpenAiKey: user.openaiKey || '',
systemAuthKey: process.env.OPENAIKEY as string
},
[embeddingModel]: {
userApiKey: user.openaiKey || '',
systemApiKey: process.env.OPENAIKEY as string
[ClaudeEnum.Claude]: {
userOpenAiKey: '',
systemAuthKey: process.env.LAFKEY as string
}
};
// 有自己的key
if (keyMap[model].userApiKey) {
if (keyMap[model].userOpenAiKey) {
return {
user,
userApiKey: keyMap[model].userApiKey,
systemApiKey: ''
userOpenAiKey: keyMap[model].userOpenAiKey,
systemAuthKey: ''
};
}
@@ -87,8 +76,8 @@ export const getApiKey = async ({
return {
user,
userApiKey: '',
systemApiKey: keyMap[model].systemApiKey
userOpenAiKey: '',
systemAuthKey: keyMap[model].systemAuthKey
};
};
@@ -176,11 +165,11 @@ export const authChat = async ({
]);
}
// 获取 user 的 apiKey
const { userApiKey, systemApiKey } = await getApiKey({ model: model.chat.chatModel, userId });
const { userOpenAiKey, systemAuthKey } = await getApiKey({ model: model.chat.chatModel, userId });
return {
userApiKey,
systemApiKey,
userOpenAiKey,
systemAuthKey,
content,
userId,
model,

View File

@@ -0,0 +1,103 @@
import { modelToolMap } from '@/utils/chat';
import { ChatCompletionType, StreamResponseType } from './index';
import { ChatRoleEnum } from '@/constants/chat';
import axios from 'axios';
import mongoose from 'mongoose';
import { NEW_CHATID_HEADER } from '@/constants/chat';
import { ClaudeEnum } from '@/constants/model';
/* 模型对话 */
export const lafClaudChat = async ({
apiKey,
messages,
stream,
chatId,
res
}: ChatCompletionType) => {
const conversationId = chatId || String(new mongoose.Types.ObjectId());
// create a new chat
!chatId &&
messages.filter((item) => item.obj === 'Human').length === 1 &&
res?.setHeader(NEW_CHATID_HEADER, conversationId);
// get system prompt
const systemPrompt = messages
.filter((item) => item.obj === 'System')
.map((item) => item.value)
.join('\n');
const systemPromptText = systemPrompt ? `这是我的知识:'${systemPrompt}'\n` : '';
const prompt = systemPromptText + messages[messages.length - 1].value;
const lafResponse = await axios.post(
'https://hnvacz.laf.run/claude-gpt',
{
prompt,
stream,
conversationId
},
{
headers: {
Authorization: apiKey
},
timeout: stream ? 40000 : 240000,
responseType: stream ? 'stream' : 'json'
}
);
let responseText = '';
let totalTokens = 0;
if (!stream) {
responseText = lafResponse.data?.text || '';
}
return {
streamResponse: lafResponse,
responseMessages: messages.concat({ obj: ChatRoleEnum.AI, value: responseText }),
responseText,
totalTokens
};
};
/* openai stream response */
export const lafClaudStreamResponse = async ({
stream,
chatResponse,
prompts
}: StreamResponseType) => {
try {
let responseContent = '';
try {
const decoder = new TextDecoder();
for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
const content = decoder.decode(chunk);
responseContent += content;
content && stream.push(content.replace(/\n/g, '<br/>'));
}
} catch (error) {
console.log('pipe error', error);
}
// count tokens
const finishMessages = prompts.concat({
obj: ChatRoleEnum.AI,
value: responseContent
});
const totalTokens = modelToolMap[ClaudeEnum.Claude].countTokens({
messages: finishMessages
});
return {
responseContent,
totalTokens,
finishMessages
};
} catch (error) {
return Promise.reject(error);
}
};

View File

@@ -1,19 +1,19 @@
import { ChatItemSimpleType } from '@/types/chat';
import { modelToolMap } from '@/utils/chat';
import type { ChatModelType } from '@/constants/model';
import { ChatRoleEnum, SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
import { OpenAiChatEnum } from '@/constants/model';
import { ChatRoleEnum, SYSTEM_PROMPT_HEADER } from '@/constants/chat';
import { OpenAiChatEnum, ClaudeEnum } from '@/constants/model';
import { chatResponse, openAiStreamResponse } from './openai';
import { lafClaudChat, lafClaudStreamResponse } from './claude';
import type { NextApiResponse } from 'next';
import type { PassThrough } from 'stream';
import delay from 'delay';
export type ChatCompletionType = {
apiKey: string;
temperature: number;
messages: ChatItemSimpleType[];
stream: boolean;
params?: any;
[key: string]: any;
};
export type ChatCompletionResponseType = {
streamResponse: any;
@@ -25,6 +25,9 @@ export type StreamResponseType = {
stream: PassThrough;
chatResponse: any;
prompts: ChatItemSimpleType[];
res: NextApiResponse;
systemPrompt?: string;
[key: string]: any;
};
export type StreamResponseReturnType = {
responseContent: string;
@@ -65,6 +68,10 @@ export const modelServiceToolMap: Record<
model: OpenAiChatEnum.GPT432k,
...data
})
},
[ClaudeEnum.Claude]: {
chatCompletion: lafClaudChat,
streamResponse: lafClaudStreamResponse
}
};
@@ -143,14 +150,13 @@ export const resStreamResponse = async ({
prompts
}: StreamResponseType & {
model: ChatModelType;
res: NextApiResponse;
systemPrompt?: string;
}) => {
// 创建响应流
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
systemPrompt && res.setHeader(SYSTEM_PROMPT_HEADER, encodeURIComponent(systemPrompt));
stream.pipe(res);
const { responseContent, totalTokens, finishMessages } = await modelServiceToolMap[
@@ -158,16 +164,11 @@ export const resStreamResponse = async ({
].streamResponse({
chatResponse,
stream,
prompts
prompts,
res,
systemPrompt
});
await delay(100);
// push system prompt
!stream.destroyed &&
systemPrompt &&
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
// close stream
!stream.destroyed && stream.push(null);
stream.destroy();

View File

@@ -19,18 +19,18 @@ export const getOpenAIApi = (apiKey: string) => {
/* 获取向量 */
export const openaiCreateEmbedding = async ({
userApiKey,
systemApiKey,
userOpenAiKey,
userId,
textArr
}: {
userApiKey?: string;
systemApiKey: string;
userOpenAiKey?: string;
userId: string;
textArr: string[];
}) => {
const systemAuthKey = process.env.OPENAIKEY as string;
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemApiKey);
const chatAPI = getOpenAIApi(userOpenAiKey || systemAuthKey);
// 把输入的内容转成向量
const res = await chatAPI
@@ -50,7 +50,7 @@ export const openaiCreateEmbedding = async ({
}));
pushGenerateVectorBill({
isPay: !userApiKey,
isPay: !userOpenAiKey,
userId,
text: textArr.join(''),
tokenLen: res.tokenLen