init
This commit is contained in:
323
src/core/llm/anthropic.ts
Normal file
323
src/core/llm/anthropic.ts
Normal file
@@ -0,0 +1,323 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
ImageBlockParam,
|
||||
MessageParam,
|
||||
MessageStreamEvent,
|
||||
TextBlockParam,
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
ResponseUsage,
|
||||
} from '../../types/llm/response'
|
||||
import { parseImageDataUrl } from '../../utils/image'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export class AnthropicProvider implements BaseLLMProvider {
|
||||
private client: Anthropic
|
||||
|
||||
private static readonly DEFAULT_MAX_TOKENS = 8192
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new Anthropic({ apiKey, dangerouslyAllowBrowser: true })
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Anthropic API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Anthropic({
|
||||
baseURL: model.baseUrl,
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true
|
||||
})
|
||||
}
|
||||
|
||||
const systemMessage = AnthropicProvider.validateSystemMessages(
|
||||
request.messages,
|
||||
)
|
||||
|
||||
try {
|
||||
const response = await this.client.messages.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages
|
||||
.filter((m) => m.role !== 'system')
|
||||
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
|
||||
.map((m) => AnthropicProvider.parseRequestMessage(m)),
|
||||
system: systemMessage,
|
||||
max_tokens:
|
||||
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
return AnthropicProvider.parseNonStreamingResponse(response)
|
||||
} catch (error) {
|
||||
if (error instanceof Anthropic.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Anthropic API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Anthropic API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Anthropic({
|
||||
baseURL: model.baseUrl,
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true
|
||||
})
|
||||
}
|
||||
|
||||
const systemMessage = AnthropicProvider.validateSystemMessages(
|
||||
request.messages,
|
||||
)
|
||||
|
||||
try {
|
||||
const stream = await this.client.messages.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages
|
||||
.filter((m) => m.role !== 'system')
|
||||
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
|
||||
.map((m) => AnthropicProvider.parseRequestMessage(m)),
|
||||
system: systemMessage,
|
||||
max_tokens:
|
||||
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
let messageId = ''
|
||||
let model = ''
|
||||
let usage: ResponseUsage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.type === 'message_start') {
|
||||
messageId = chunk.message.id
|
||||
model = chunk.message.model
|
||||
usage = {
|
||||
prompt_tokens: chunk.message.usage.input_tokens,
|
||||
completion_tokens: chunk.message.usage.output_tokens,
|
||||
total_tokens:
|
||||
chunk.message.usage.input_tokens +
|
||||
chunk.message.usage.output_tokens,
|
||||
}
|
||||
} else if (chunk.type === 'content_block_delta') {
|
||||
yield AnthropicProvider.parseStreamingResponseChunk(
|
||||
chunk,
|
||||
messageId,
|
||||
model,
|
||||
)
|
||||
} else if (chunk.type === 'message_delta') {
|
||||
usage = {
|
||||
prompt_tokens: usage.prompt_tokens,
|
||||
completion_tokens:
|
||||
usage.completion_tokens + chunk.usage.output_tokens,
|
||||
total_tokens: usage.total_tokens + chunk.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After the stream is complete, yield the final usage
|
||||
yield {
|
||||
id: messageId,
|
||||
choices: [],
|
||||
object: 'chat.completion.chunk',
|
||||
model: model,
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
} catch (error) {
|
||||
if (error instanceof Anthropic.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Anthropic API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(message: RequestMessage): MessageParam {
|
||||
if (message.role !== 'user' && message.role !== 'assistant') {
|
||||
throw new Error(`Anthropic does not support role: ${message.role}`)
|
||||
}
|
||||
|
||||
if (message.role === 'user' && Array.isArray(message.content)) {
|
||||
const content = message.content.map(
|
||||
(part): TextBlockParam | ImageBlockParam => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url': {
|
||||
const { mimeType, base64Data } = parseImageDataUrl(
|
||||
part.image_url.url,
|
||||
)
|
||||
AnthropicProvider.validateImageType(mimeType)
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data,
|
||||
media_type:
|
||||
mimeType as ImageBlockParam['source']['media_type'],
|
||||
type: 'base64',
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
return { role: 'user', content }
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content as string,
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: Anthropic.Message,
|
||||
): LLMResponseNonStreaming {
|
||||
if (response.content[0].type === 'tool_use') {
|
||||
throw new Error('Unsupported content type: tool_use')
|
||||
}
|
||||
return {
|
||||
id: response.id,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: response.stop_reason,
|
||||
message: {
|
||||
content: response.content[0].text,
|
||||
role: response.role,
|
||||
},
|
||||
},
|
||||
],
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
usage: {
|
||||
prompt_tokens: response.usage.input_tokens,
|
||||
completion_tokens: response.usage.output_tokens,
|
||||
total_tokens:
|
||||
response.usage.input_tokens + response.usage.output_tokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: MessageStreamEvent,
|
||||
messageId: string,
|
||||
model: string,
|
||||
): LLMResponseStreaming {
|
||||
if (chunk.type !== 'content_block_delta') {
|
||||
throw new Error('Unsupported chunk type')
|
||||
}
|
||||
if (chunk.delta.type === 'input_json_delta') {
|
||||
throw new Error('Unsupported content type: input_json_delta')
|
||||
}
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: null,
|
||||
delta: {
|
||||
content: chunk.delta.text,
|
||||
},
|
||||
},
|
||||
],
|
||||
object: 'chat.completion.chunk',
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
private static validateSystemMessages(
|
||||
messages: RequestMessage[],
|
||||
): string | undefined {
|
||||
const systemMessages = messages.filter((m) => m.role === 'system')
|
||||
if (systemMessages.length > 1) {
|
||||
throw new Error(`Anthropic does not support more than one system message`)
|
||||
}
|
||||
const systemMessage =
|
||||
systemMessages.length > 0 ? systemMessages[0].content : undefined
|
||||
if (systemMessage && typeof systemMessage !== 'string') {
|
||||
throw new Error(
|
||||
`Anthropic only supports string content for system messages`,
|
||||
)
|
||||
}
|
||||
return systemMessage
|
||||
}
|
||||
|
||||
private static isMessageEmpty(message: RequestMessage) {
|
||||
if (typeof message.content === 'string') {
|
||||
return message.content.trim() === ''
|
||||
}
|
||||
return message.content.length === 0
|
||||
}
|
||||
|
||||
private static validateImageType(mimeType: string) {
|
||||
const SUPPORTED_IMAGE_TYPES = [
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/webp',
|
||||
]
|
||||
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
|
||||
throw new Error(
|
||||
`Anthropic does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
|
||||
', ',
|
||||
)}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
23
src/core/llm/base.ts
Normal file
23
src/core/llm/base.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
export type BaseLLMProvider = {
|
||||
generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming>
|
||||
streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>>
|
||||
}
|
||||
34
src/core/llm/exception.ts
Normal file
34
src/core/llm/exception.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
export class LLMAPIKeyNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMAPIKeyNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMAPIKeyInvalidException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMAPIKeyInvalidException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMBaseUrlNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMBaseUrlNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMModelNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMModelNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMRateLimitExceededException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMRateLimitExceededException'
|
||||
}
|
||||
}
|
||||
299
src/core/llm/gemini.ts
Normal file
299
src/core/llm/gemini.ts
Normal file
@@ -0,0 +1,299 @@
|
||||
import {
|
||||
Content,
|
||||
EnhancedGenerateContentResponse,
|
||||
GenerateContentResult,
|
||||
GenerateContentStreamResult,
|
||||
GoogleGenerativeAI,
|
||||
} from '@google/generative-ai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
import { parseImageDataUrl } from '../../utils/image'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
/**
|
||||
* Note on OpenAI Compatibility API:
|
||||
* Gemini provides an OpenAI-compatible endpoint (https://ai.google.dev/gemini-api/docs/openai)
|
||||
* which allows using the OpenAI SDK with Gemini models. However, there are currently CORS issues
|
||||
* preventing its use in Obsidian. Consider switching to this endpoint in the future once these
|
||||
* issues are resolved.
|
||||
*/
|
||||
export class GeminiProvider implements BaseLLMProvider {
|
||||
private client: GoogleGenerativeAI
|
||||
private apiKey: string
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.apiKey = apiKey
|
||||
this.client = new GoogleGenerativeAI(apiKey)
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
`Gemini API key is missing. Please set it in settings menu.`,
|
||||
)
|
||||
}
|
||||
this.apiKey = model.apiKey
|
||||
this.client = new GoogleGenerativeAI(model.apiKey)
|
||||
}
|
||||
|
||||
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
||||
const systemInstruction: string | undefined =
|
||||
systemMessages.length > 0
|
||||
? systemMessages.map((m) => m.content).join('\n')
|
||||
: undefined
|
||||
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: request.model,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
topP: request.top_p,
|
||||
presencePenalty: request.presence_penalty,
|
||||
frequencyPenalty: request.frequency_penalty,
|
||||
},
|
||||
systemInstruction: systemInstruction,
|
||||
})
|
||||
|
||||
const result = await model.generateContent(
|
||||
{
|
||||
systemInstruction: systemInstruction,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
||||
return GeminiProvider.parseNonStreamingResponse(
|
||||
result,
|
||||
request.model,
|
||||
messageId,
|
||||
)
|
||||
} catch (error) {
|
||||
const isInvalidApiKey =
|
||||
error.message?.includes('API_KEY_INVALID') ||
|
||||
error.message?.includes('API key not valid')
|
||||
|
||||
if (isInvalidApiKey) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
`Gemini API key is invalid. Please update it in settings menu.`,
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
`Gemini API key is missing. Please set it in settings menu.`,
|
||||
)
|
||||
}
|
||||
this.apiKey = model.apiKey
|
||||
this.client = new GoogleGenerativeAI(model.apiKey)
|
||||
}
|
||||
|
||||
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
||||
const systemInstruction: string | undefined =
|
||||
systemMessages.length > 0
|
||||
? systemMessages.map((m) => m.content).join('\n')
|
||||
: undefined
|
||||
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: request.model,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
topP: request.top_p,
|
||||
presencePenalty: request.presence_penalty,
|
||||
frequencyPenalty: request.frequency_penalty,
|
||||
},
|
||||
systemInstruction: systemInstruction,
|
||||
})
|
||||
|
||||
const stream = await model.generateContentStream(
|
||||
{
|
||||
systemInstruction: systemInstruction,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
||||
return this.streamResponseGenerator(stream, request.model, messageId)
|
||||
} catch (error) {
|
||||
const isInvalidApiKey =
|
||||
error.message?.includes('API_KEY_INVALID') ||
|
||||
error.message?.includes('API key not valid')
|
||||
|
||||
if (isInvalidApiKey) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
`Gemini API key is invalid. Please update it in settings menu.`,
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private async *streamResponseGenerator(
|
||||
stream: GenerateContentStreamResult,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream.stream) {
|
||||
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(message: RequestMessage): Content | null {
|
||||
if (message.role === 'system') {
|
||||
return null
|
||||
}
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
return {
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: message.content.map((part) => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { text: part.text }
|
||||
case 'image_url': {
|
||||
const { mimeType, base64Data } = parseImageDataUrl(
|
||||
part.image_url.url,
|
||||
)
|
||||
GeminiProvider.validateImageType(mimeType)
|
||||
|
||||
return {
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: [
|
||||
{
|
||||
text: message.content,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: GenerateContentResult,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason:
|
||||
response.response.candidates?.[0]?.finishReason ?? null,
|
||||
message: {
|
||||
content: response.response.text(),
|
||||
role: 'assistant',
|
||||
},
|
||||
},
|
||||
],
|
||||
created: Date.now(),
|
||||
model: model,
|
||||
object: 'chat.completion',
|
||||
usage: response.response.usageMetadata
|
||||
? {
|
||||
prompt_tokens: response.response.usageMetadata.promptTokenCount,
|
||||
completion_tokens:
|
||||
response.response.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: response.response.usageMetadata.totalTokenCount,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: EnhancedGenerateContentResponse,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
|
||||
delta: {
|
||||
content: chunk.text(),
|
||||
},
|
||||
},
|
||||
],
|
||||
created: Date.now(),
|
||||
model: model,
|
||||
object: 'chat.completion.chunk',
|
||||
usage: chunk.usageMetadata
|
||||
? {
|
||||
prompt_tokens: chunk.usageMetadata.promptTokenCount,
|
||||
completion_tokens: chunk.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: chunk.usageMetadata.totalTokenCount,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
private static validateImageType(mimeType: string) {
|
||||
const SUPPORTED_IMAGE_TYPES = [
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/webp',
|
||||
'image/heic',
|
||||
'image/heif',
|
||||
]
|
||||
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
|
||||
throw new Error(
|
||||
`Gemini does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
|
||||
', ',
|
||||
)}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
200
src/core/llm/groq.ts
Normal file
200
src/core/llm/groq.ts
Normal file
@@ -0,0 +1,200 @@
|
||||
import Groq from 'groq-sdk'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionMessageParam,
|
||||
} from 'groq-sdk/resources/chat/completions'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export class GroqProvider implements BaseLLMProvider {
|
||||
private client: Groq
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new Groq({
|
||||
apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Groq API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Groq({
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await this.client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
GroqProvider.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
return GroqProvider.parseNonStreamingResponse(response)
|
||||
} catch (error) {
|
||||
if (error instanceof Groq.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Groq API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Groq API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Groq({
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
const stream = await this.client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
GroqProvider.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream) {
|
||||
yield GroqProvider.parseStreamingResponseChunk(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
} catch (error) {
|
||||
if (error instanceof Groq.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Groq API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(
|
||||
message: RequestMessage,
|
||||
): ChatCompletionMessageParam {
|
||||
switch (message.role) {
|
||||
case 'user': {
|
||||
const content = Array.isArray(message.content)
|
||||
? message.content.map((part): ChatCompletionContentPart => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url':
|
||||
return { type: 'image_url', image_url: part.image_url }
|
||||
}
|
||||
})
|
||||
: message.content
|
||||
return { role: 'user', content }
|
||||
}
|
||||
case 'assistant': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('Assistant message should be a string')
|
||||
}
|
||||
return { role: 'assistant', content: message.content }
|
||||
}
|
||||
case 'system': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('System message should be a string')
|
||||
}
|
||||
return { role: 'system', content: message.content }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
}
|
||||
}
|
||||
}
|
||||
252
src/core/llm/infio.ts
Normal file
252
src/core/llm/infio.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import OpenAI from 'openai'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
} from 'openai/resources/chat/completions'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export interface RangeFilter {
|
||||
gte?: number;
|
||||
lte?: number;
|
||||
}
|
||||
|
||||
export interface ChunkFilter {
|
||||
field: string;
|
||||
match_all?: string[];
|
||||
range?: RangeFilter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for making requests to the Infio API
|
||||
*/
|
||||
export interface InfioRequest {
|
||||
/** Required: The content of the user message to attach to the topic and then generate an assistant message in response to */
|
||||
messages: RequestMessage[];
|
||||
// /** Required: The ID of the topic to attach the message to */
|
||||
// topic_id: string;
|
||||
/** Optional: URLs to include */
|
||||
links?: string[];
|
||||
/** Optional: Files to include */
|
||||
files?: string[];
|
||||
/** Optional: Whether to highlight results in chunk_html. Default is true */
|
||||
highlight_results?: boolean;
|
||||
/** Optional: Delimiters for highlighting citations. Default is [".", "!", "?", "\n", "\t", ","] */
|
||||
highlight_delimiters?: string[];
|
||||
/** Optional: Search type - "semantic", "fulltext", or "hybrid". Default is "hybrid" */
|
||||
search_type?: string;
|
||||
/** Optional: Filters for chunk filtering */
|
||||
filters?: ChunkFilter;
|
||||
/** Optional: Whether to use web search API. Default is false */
|
||||
use_web_search?: boolean;
|
||||
/** Optional: LLM model to use */
|
||||
llm_model?: string;
|
||||
/** Optional: Force source */
|
||||
force_source?: string;
|
||||
/** Optional: Whether completion should come before chunks in stream. Default is false */
|
||||
completion_first?: boolean;
|
||||
/** Optional: Whether to stream the response. Default is true */
|
||||
stream_response?: boolean;
|
||||
/** Optional: Sampling temperature between 0 and 2. Default is 0.5 */
|
||||
temperature?: number;
|
||||
/** Optional: Frequency penalty between -2.0 and 2.0. Default is 0.7 */
|
||||
frequency_penalty?: number;
|
||||
/** Optional: Presence penalty between -2.0 and 2.0. Default is 0.7 */
|
||||
presence_penalty?: number;
|
||||
/** Optional: Maximum tokens to generate */
|
||||
max_tokens?: number;
|
||||
/** Optional: Stop tokens (up to 4 sequences) */
|
||||
stop_tokens?: string[];
|
||||
}
|
||||
|
||||
export class InfioProvider implements BaseLLMProvider {
|
||||
// private adapter: OpenAIMessageAdapter
|
||||
// private client: OpenAI
|
||||
private apiKey: string
|
||||
private baseUrl: string
|
||||
|
||||
constructor(apiKey: string) {
|
||||
// this.client = new OpenAI({ apiKey, dangerouslyAllowBrowser: true })
|
||||
// this.adapter = new OpenAIMessageAdapter()
|
||||
this.apiKey = apiKey
|
||||
this.baseUrl = 'https://api.infio.com/api/raw_message'
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
try {
|
||||
const req: InfioRequest = {
|
||||
messages: request.messages,
|
||||
stream_response: false,
|
||||
temperature: request.temperature,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
max_tokens: request.max_tokens,
|
||||
}
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: this.apiKey,
|
||||
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(req)
|
||||
};
|
||||
|
||||
const response = await fetch(this.baseUrl, options);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data = await response.json() as ChatCompletion;
|
||||
return InfioProvider.parseNonStreamingResponse(data);
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
try {
|
||||
const req: InfioRequest = {
|
||||
llm_model: request.model,
|
||||
messages: request.messages,
|
||||
stream_response: true,
|
||||
temperature: request.temperature,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
max_tokens: request.max_tokens,
|
||||
}
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: this.apiKey,
|
||||
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify(req)
|
||||
};
|
||||
|
||||
const response = await fetch(this.baseUrl, options);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split('\n').filter(line => line.trim());
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const jsonData = JSON.parse(line.slice(6)) as ChatCompletionChunk;
|
||||
if (!jsonData || typeof jsonData !== 'object' || !('choices' in jsonData)) {
|
||||
throw new Error('Invalid chunk format received');
|
||||
}
|
||||
yield InfioProvider.parseStreamingResponseChunk(jsonData);
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
system_fingerprint: chunk.system_fingerprint,
|
||||
usage: chunk.usage ?? undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
142
src/core/llm/manager.ts
Normal file
142
src/core/llm/manager.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import { DEEPSEEK_BASE_URL } from '../../constants'
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { AnthropicProvider } from './anthropic'
|
||||
import { GeminiProvider } from './gemini'
|
||||
import { GroqProvider } from './groq'
|
||||
import { InfioProvider } from './infio'
|
||||
import { OllamaProvider } from './ollama'
|
||||
import { OpenAIAuthenticatedProvider } from './openai'
|
||||
import { OpenAICompatibleProvider } from './openai-compatible-provider'
|
||||
|
||||
|
||||
export type LLMManagerInterface = {
|
||||
generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming>
|
||||
streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>>
|
||||
}
|
||||
|
||||
class LLMManager implements LLMManagerInterface {
|
||||
private openaiProvider: OpenAIAuthenticatedProvider
|
||||
private deepseekProvider: OpenAICompatibleProvider
|
||||
private anthropicProvider: AnthropicProvider
|
||||
private googleProvider: GeminiProvider
|
||||
private groqProvider: GroqProvider
|
||||
private infioProvider: InfioProvider
|
||||
private ollamaProvider: OllamaProvider
|
||||
private isInfioEnabled: boolean
|
||||
|
||||
constructor(apiKeys: {
|
||||
deepseek?: string
|
||||
openai?: string
|
||||
anthropic?: string
|
||||
gemini?: string
|
||||
groq?: string
|
||||
infio?: string
|
||||
}) {
|
||||
this.deepseekProvider = new OpenAICompatibleProvider(apiKeys.deepseek ?? '', DEEPSEEK_BASE_URL)
|
||||
this.openaiProvider = new OpenAIAuthenticatedProvider(apiKeys.openai ?? '')
|
||||
this.anthropicProvider = new AnthropicProvider(apiKeys.anthropic ?? '')
|
||||
this.googleProvider = new GeminiProvider(apiKeys.gemini ?? '')
|
||||
this.groqProvider = new GroqProvider(apiKeys.groq ?? '')
|
||||
this.infioProvider = new InfioProvider(apiKeys.infio ?? '')
|
||||
this.ollamaProvider = new OllamaProvider()
|
||||
this.isInfioEnabled = !!apiKeys.infio
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (this.isInfioEnabled) {
|
||||
return await this.infioProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
}
|
||||
// use custom provider
|
||||
switch (model.provider) {
|
||||
case 'deepseek':
|
||||
return await this.deepseekProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'openai':
|
||||
return await this.openaiProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'anthropic':
|
||||
return await this.anthropicProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'google':
|
||||
return await this.googleProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'groq':
|
||||
return await this.groqProvider.generateResponse(model, request, options)
|
||||
case 'ollama':
|
||||
return await this.ollamaProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (this.isInfioEnabled) {
|
||||
return await this.infioProvider.streamResponse(model, request, options)
|
||||
}
|
||||
// use custom provider
|
||||
switch (model.provider) {
|
||||
case 'deepseek':
|
||||
return await this.deepseekProvider.streamResponse(model, request, options)
|
||||
case 'openai':
|
||||
return await this.openaiProvider.streamResponse(model, request, options)
|
||||
case 'anthropic':
|
||||
return await this.anthropicProvider.streamResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'google':
|
||||
return await this.googleProvider.streamResponse(model, request, options)
|
||||
case 'groq':
|
||||
return await this.groqProvider.streamResponse(model, request, options)
|
||||
case 'ollama':
|
||||
return await this.ollamaProvider.streamResponse(model, request, options)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default LLMManager
|
||||
104
src/core/llm/ollama.ts
Normal file
104
src/core/llm/ollama.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
/**
|
||||
* This provider is nearly identical to OpenAICompatibleProvider, but uses a custom OpenAI client
|
||||
* (NoStainlessOpenAI) to work around CORS issues specific to Ollama.
|
||||
*/
|
||||
|
||||
import OpenAI from 'openai'
|
||||
import { FinalRequestOptions } from 'openai/core'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import { LLMBaseUrlNotSetException, LLMModelNotSetException } from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class NoStainlessOpenAI extends OpenAI {
|
||||
defaultHeaders() {
|
||||
return {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
}
|
||||
|
||||
buildRequest<Req>(
|
||||
options: FinalRequestOptions<Req>,
|
||||
{ retryCount = 0 }: { retryCount?: number } = {},
|
||||
): { req: RequestInit; url: string; timeout: number } {
|
||||
const req = super.buildRequest(options, { retryCount })
|
||||
const headers = req.req.headers as Record<string, string>
|
||||
Object.keys(headers).forEach((k) => {
|
||||
if (k.startsWith('x-stainless')) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-dynamic-delete
|
||||
delete headers[k]
|
||||
}
|
||||
})
|
||||
return req
|
||||
}
|
||||
}
|
||||
|
||||
export class OllamaProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
|
||||
constructor() {
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama base URL is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
if (!model.name) {
|
||||
throw new LLMModelNotSetException(
|
||||
'Ollama model is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
const client = new NoStainlessOpenAI({
|
||||
baseURL: `${model.baseUrl}/v1`,
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return this.adapter.generateResponse(client, request, options)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama base URL is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
if (!model.name) {
|
||||
throw new LLMModelNotSetException(
|
||||
'Ollama model is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
const client = new NoStainlessOpenAI({
|
||||
baseURL: `${model.baseUrl}/v1`,
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return this.adapter.streamResponse(client, request, options)
|
||||
}
|
||||
}
|
||||
62
src/core/llm/openai-compatible-provider.ts
Normal file
62
src/core/llm/openai-compatible-provider.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import { LLMBaseUrlNotSetException } from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class OpenAICompatibleProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
private client: OpenAI
|
||||
private apiKey: string
|
||||
private baseURL: string
|
||||
|
||||
constructor(apiKey: string, baseURL: string) {
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
this.client = new OpenAI({
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
this.apiKey = apiKey
|
||||
this.baseURL = baseURL
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.baseURL || !this.apiKey) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
return this.adapter.generateResponse(this.client, request, options)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.baseURL || !this.apiKey) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
return this.adapter.streamResponse(this.client, request, options)
|
||||
}
|
||||
}
|
||||
155
src/core/llm/openai-message-adapter.ts
Normal file
155
src/core/llm/openai-message-adapter.ts
Normal file
@@ -0,0 +1,155 @@
|
||||
import OpenAI from 'openai'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionMessageParam,
|
||||
} from 'openai/resources/chat/completions'
|
||||
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
export class OpenAIMessageAdapter {
|
||||
async generateResponse(
|
||||
client: OpenAI,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
const response = await client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
OpenAIMessageAdapter.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
logit_bias: request.logit_bias,
|
||||
prediction: request.prediction,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
return OpenAIMessageAdapter.parseNonStreamingResponse(response)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
client: OpenAI,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
const stream = await client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
OpenAIMessageAdapter.parseRequestMessage(m),
|
||||
),
|
||||
max_completion_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
logit_bias: request.logit_bias,
|
||||
stream: true,
|
||||
stream_options: {
|
||||
include_usage: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream) {
|
||||
yield OpenAIMessageAdapter.parseStreamingResponseChunk(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
}
|
||||
|
||||
static parseRequestMessage(
|
||||
message: RequestMessage,
|
||||
): ChatCompletionMessageParam {
|
||||
switch (message.role) {
|
||||
case 'user': {
|
||||
const content = Array.isArray(message.content)
|
||||
? message.content.map((part): ChatCompletionContentPart => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url':
|
||||
return { type: 'image_url', image_url: part.image_url }
|
||||
}
|
||||
})
|
||||
: message.content
|
||||
return { role: 'user', content }
|
||||
}
|
||||
case 'assistant': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('Assistant message should be a string')
|
||||
}
|
||||
return { role: 'assistant', content: message.content }
|
||||
}
|
||||
case 'system': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('System message should be a string')
|
||||
}
|
||||
return { role: 'system', content: message.content }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
system_fingerprint: chunk.system_fingerprint,
|
||||
usage: chunk.usage ?? undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
91
src/core/llm/openai.ts
Normal file
91
src/core/llm/openai.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class OpenAIAuthenticatedProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
private client: OpenAI
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new OpenAI({
|
||||
apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new OpenAI({
|
||||
apiKey: model.apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
try {
|
||||
return this.adapter.generateResponse(this.client, request, options)
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new OpenAI({
|
||||
apiKey: model.apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
return this.adapter.streamResponse(this.client, request, options)
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user