This commit is contained in:
duanfuxiang
2025-01-05 11:51:39 +08:00
commit 0c7ee142cb
215 changed files with 20611 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
import { generateRandomString } from "./utils";
const UNIQUE_CURSOR = `${generateRandomString(16)}`;
const HEADER_REGEX = `^#+\\s.*${UNIQUE_CURSOR}.*$`;
const UNORDERED_LIST_REGEX = `^\\s*(-|\\*)\\s.*${UNIQUE_CURSOR}.*$`;
const TASK_LIST_REGEX = `^\\s*(-|[0-9]+\\.) +\\[.\\]\\s.*${UNIQUE_CURSOR}.*$`;
const BLOCK_QUOTES_REGEX = `^\\s*>.*${UNIQUE_CURSOR}.*$`;
const NUMBERED_LIST_REGEX = `^\\s*\\d+\\.\\s.*${UNIQUE_CURSOR}.*$`
const MATH_BLOCK_REGEX = /\$\$[\s\S]*?\$\$/g;
const INLINE_MATH_BLOCK_REGEX = /\$[\s\S]*?\$/g;
const CODE_BLOCK_REGEX = /```[\s\S]*?```/g;
const INLINE_CODE_BLOCK_REGEX = /`.*`/g;
enum Context {
Text = "Text",
Heading = "Heading",
BlockQuotes = "BlockQuotes",
UnorderedList = "UnorderedList",
NumberedList = "NumberedList",
CodeBlock = "CodeBlock",
MathBlock = "MathBlock",
TaskList = "TaskList",
}
// eslint-disable-next-line @typescript-eslint/no-namespace
namespace Context {
export function values(): Array<Context> {
return Object.values(Context).filter(
(value) => typeof value === "string"
) as Array<Context>;
}
export function getContext(prefix: string, suffix: string): Context {
if (new RegExp(HEADER_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.Heading;
}
if (new RegExp(BLOCK_QUOTES_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.BlockQuotes;
}
if (new RegExp(TASK_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.TaskList;
}
if (
isCursorInRegexBlock(prefix, suffix, MATH_BLOCK_REGEX) ||
isCursorInRegexBlock(prefix, suffix, INLINE_MATH_BLOCK_REGEX)
) {
return Context.MathBlock;
}
if (isCursorInRegexBlock(prefix, suffix, CODE_BLOCK_REGEX) || isCursorInRegexBlock(prefix, suffix, INLINE_CODE_BLOCK_REGEX)) {
return Context.CodeBlock;
}
if (new RegExp(NUMBERED_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.NumberedList;
}
if (new RegExp(UNORDERED_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.UnorderedList;
}
return Context.Text;
}
export function get(value: string) {
for (const context of Context.values()) {
if (value === context) {
return context;
}
}
return undefined;
}
}
function isCursorInRegexBlock(
prefix: string,
suffix: string,
regex: RegExp
): boolean {
const text = prefix + UNIQUE_CURSOR + suffix;
const codeBlocks = extractBlocks(text, regex);
for (const block of codeBlocks) {
if (block.includes(UNIQUE_CURSOR)) {
return true;
}
}
return false;
}
function extractBlocks(text: string, regex: RegExp) {
const codeBlocks = text.match(regex);
return codeBlocks ? codeBlocks.map((block) => block.trim()) : [];
}
export default Context;

View File

@@ -0,0 +1,273 @@
import * as Handlebars from "handlebars";
import { err, ok, Result } from "neverthrow";
import { FewShotExample } from "../../settings/versions";
import { CustomLLMModel } from "../../types/llm/model";
import { RequestMessage } from '../../types/llm/request';
import { InfioSettings } from "../../types/settings";
import LLMManager from '../llm/manager';
import Context from "./context-detection";
import RemoveCodeIndicators from "./post-processors/remove-code-indicators";
import RemoveMathIndicators from "./post-processors/remove-math-indicators";
import RemoveOverlap from "./post-processors/remove-overlap";
import RemoveWhitespace from "./post-processors/remove-whitespace";
import DataViewRemover from "./pre-processors/data-view-remover";
import LengthLimiter from "./pre-processors/length-limiter";
import {
AutocompleteService,
ChatMessage,
PostProcessor,
PreProcessor,
UserMessageFormatter,
UserMessageFormattingInputs
} from "./types";
class LLMClient {
private llm: LLMManager;
private model: CustomLLMModel;
constructor(llm: LLMManager, model: CustomLLMModel) {
this.llm = llm;
this.model = model;
}
async queryChatModel(messages: RequestMessage[]): Promise<Result<string, Error>> {
const data = await this.llm.generateResponse(this.model, {
model: this.model.name,
messages: messages,
stream: false,
})
return ok(data.choices[0].message.content);
}
}
class AutoComplete implements AutocompleteService {
private readonly client: LLMClient;
private readonly systemMessage: string;
private readonly userMessageFormatter: UserMessageFormatter;
private readonly removePreAnswerGenerationRegex: string;
private readonly preProcessors: PreProcessor[];
private readonly postProcessors: PostProcessor[];
private readonly fewShotExamples: FewShotExample[];
private debugMode: boolean;
private constructor(
client: LLMClient,
systemMessage: string,
userMessageFormatter: UserMessageFormatter,
removePreAnswerGenerationRegex: string,
preProcessors: PreProcessor[],
postProcessors: PostProcessor[],
fewShotExamples: FewShotExample[],
debugMode: boolean,
) {
this.client = client;
this.systemMessage = systemMessage;
this.userMessageFormatter = userMessageFormatter;
this.removePreAnswerGenerationRegex = removePreAnswerGenerationRegex;
this.preProcessors = preProcessors;
this.postProcessors = postProcessors;
this.fewShotExamples = fewShotExamples;
this.debugMode = debugMode;
}
public static fromSettings(settings: InfioSettings): AutocompleteService {
const formatter = Handlebars.compile<UserMessageFormattingInputs>(
settings.userMessageTemplate,
{ noEscape: true, strict: true }
);
const preProcessors: PreProcessor[] = [];
if (settings.dontIncludeDataviews) {
preProcessors.push(new DataViewRemover());
}
preProcessors.push(
new LengthLimiter(
settings.maxPrefixCharLimit,
settings.maxSuffixCharLimit
)
);
const postProcessors: PostProcessor[] = [];
if (settings.removeDuplicateMathBlockIndicator) {
postProcessors.push(new RemoveMathIndicators());
}
if (settings.removeDuplicateCodeBlockIndicator) {
postProcessors.push(new RemoveCodeIndicators());
}
postProcessors.push(new RemoveOverlap());
postProcessors.push(new RemoveWhitespace());
const llm_manager = new LLMManager({
deepseek: settings.deepseekApiKey,
openai: settings.openAIApiKey,
anthropic: settings.anthropicApiKey,
gemini: settings.geminiApiKey,
groq: settings.groqApiKey,
infio: settings.infioApiKey,
})
const model: CustomLLMModel = settings.activeModels.find(
(option) => option.name === settings.chatModelId,
)
const llm = new LLMClient(llm_manager, model);
return new AutoComplete(
llm,
settings.systemMessage,
formatter,
settings.chainOfThoughRemovalRegex,
preProcessors,
postProcessors,
settings.fewShotExamples,
settings.debugMode,
);
}
async fetchPredictions(
prefix: string,
suffix: string
): Promise<Result<string, Error>> {
const context: Context = Context.getContext(prefix, suffix);
for (const preProcessor of this.preProcessors) {
if (preProcessor.removesCursor(prefix, suffix)) {
return ok("");
}
({ prefix, suffix } = preProcessor.process(
prefix,
suffix,
context
));
}
const examples = this.fewShotExamples.filter(
(example) => example.context === context
);
const fewShotExamplesChatMessages =
fewShotExamplesToChatMessages(examples);
const messages: RequestMessage[] = [
{
content: this.getSystemMessageFor(context),
role: "system"
},
...fewShotExamplesChatMessages,
{
role: "user",
content: this.userMessageFormatter({
suffix,
prefix,
}),
},
];
if (this.debugMode) {
console.log("Copilot messages send:\n", messages);
}
let result = await this.client.queryChatModel(messages);
if (this.debugMode && result.isOk()) {
console.log("Copilot response:\n", result.value);
}
result = this.extractAnswerFromChainOfThoughts(result);
for (const postProcessor of this.postProcessors) {
result = result.map((r) => postProcessor.process(prefix, suffix, r, context));
}
result = this.checkAgainstGuardRails(result);
return result;
}
private getSystemMessageFor(context: Context): string {
if (context === Context.Text) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a paragraph. Your answer must complete this paragraph or sentence in a way that fits the surrounding text without overlapping with it. It must be in the same language as the paragraph.";
}
if (context === Context.Heading) {
return this.systemMessage + "\n\n" + "The <mask/> is located in the Markdown heading. Your answer must complete this title in a way that fits the content of this paragraph and be in the same language as the paragraph.";
}
if (context === Context.BlockQuotes) {
return this.systemMessage + "\n\n" + "The <mask/> is located within a quote. Your answer must complete this quote in a way that fits the context of the paragraph.";
}
if (context === Context.UnorderedList) {
return this.systemMessage + "\n\n" + "The <mask/> is located in an unordered list. Your answer must include one or more list items that fit with the surrounding list without overlapping with it.";
}
if (context === Context.NumberedList) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a numbered list. Your answer must include one or more list items that fit the sequence and context of the surrounding list without overlapping with it.";
}
if (context === Context.CodeBlock) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a code block. Your answer must complete this code block in the same programming language and support the surrounding code and text outside of the code block.";
}
if (context === Context.MathBlock) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a math block. Your answer must only contain LaTeX code that captures the math discussed in the surrounding text. No text or explaination only LaTex math code.";
}
if (context === Context.TaskList) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a task list. Your answer must include one or more (sub)tasks that are logical given the other tasks and the surrounding text.";
}
return this.systemMessage;
}
private extractAnswerFromChainOfThoughts(
result: Result<string, Error>
): Result<string, Error> {
if (result.isErr()) {
return result;
}
const chainOfThoughts = result.value;
const regex = new RegExp(this.removePreAnswerGenerationRegex, "gm");
const match = regex.exec(chainOfThoughts);
if (match === null) {
return err(new Error("No match found"));
}
return ok(chainOfThoughts.replace(regex, ""));
}
private checkAgainstGuardRails(
result: Result<string, Error>
): Result<string, Error> {
if (result.isErr()) {
return result;
}
if (result.value.length === 0) {
return err(new Error("Empty result"));
}
if (result.value.contains("<mask/>")) {
return err(new Error("Mask in result"));
}
return result;
}
}
function fewShotExamplesToChatMessages(
examples: FewShotExample[]
): ChatMessage[] {
return examples
.map((example): ChatMessage[] => {
return [
{
role: "user",
content: example.input,
},
{
role: "assistant",
content: example.answer,
},
];
})
.flat();
}
export default AutoComplete;

View File

@@ -0,0 +1,22 @@
import Context from "../context-detection";
import { PostProcessor } from "../types";
class RemoveCodeIndicators implements PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string {
if (context === Context.CodeBlock) {
completion = completion.replace(/```[a-zA-z]+[ \t]*\n?/g, "");
completion = completion.replace(/\n?```[ \t]*\n?/g, "");
completion = completion.replace(/`/g, "");
}
return completion;
}
}
export default RemoveCodeIndicators;

View File

@@ -0,0 +1,20 @@
import Context from "../context-detection";
import { PostProcessor } from "../types";
class RemoveMathIndicators implements PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string {
if (context === Context.MathBlock) {
completion = completion.replace(/\n?\$\$\n?/g, "");
completion = completion.replace(/\$/g, "");
}
return completion;
}
}
export default RemoveMathIndicators;

View File

@@ -0,0 +1,96 @@
import Context from "../context-detection";
import { PostProcessor } from "../types";
class RemoveOverlap implements PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string {
completion = removeWordOverlapPrefix(prefix, completion);
completion = removeWordOverlapSuffix(completion, suffix);
completion = removeWhiteSpaceOverlapPrefix(suffix, completion);
completion = removeWhiteSpaceOverlapSuffix(completion, suffix);
return completion;
}
}
function removeWhiteSpaceOverlapPrefix(prefix: string, completion: string): string {
let prefixIdx = prefix.length - 1;
while (completion.length > 0 && completion[0] === prefix[prefixIdx]) {
completion = completion.slice(1);
prefixIdx--;
}
return completion;
}
function removeWhiteSpaceOverlapSuffix(completion: string, suffix: string): string {
let suffixIdx = 0;
while (completion.length > 0 && completion[completion.length - 1] === suffix[suffixIdx]) {
completion = completion.slice(0, -1);
suffixIdx++;
}
return completion;
}
function removeWordOverlapPrefix(prefix: string, completion: string): string {
const rightTrimmed = completion.trimStart();
const startIdxOfEachWord = startLocationOfEachWord(prefix);
while (startIdxOfEachWord.length > 0) {
const idx = startIdxOfEachWord.pop();
const leftSubstring = prefix.slice(idx);
if (rightTrimmed.startsWith(leftSubstring)) {
return rightTrimmed.replace(leftSubstring, "");
}
}
return completion;
}
function removeWordOverlapSuffix(completion: string, suffix: string): string {
const suffixTrimmed = removeLeadingWhiteSpace(suffix);
const startIdxOfEachWord = startLocationOfEachWord(completion);
while (startIdxOfEachWord.length > 0) {
const idx = startIdxOfEachWord.pop();
const suffixSubstring = completion.slice(idx);
if (suffixTrimmed.startsWith(suffixSubstring)) {
return completion.replace(suffixSubstring, "");
}
}
return completion;
}
function removeLeadingWhiteSpace(completion: string): string {
return completion.replace(/^[ \t\f\r\v]+/, "");
}
function startLocationOfEachWord(text: string): number[] {
const locations: number[] = [];
if (text.length > 0 && !isWhiteSpaceChar(text[0])) {
locations.push(0);
}
for (let i = 1; i < text.length; i++) {
if (isWhiteSpaceChar(text[i - 1]) && !isWhiteSpaceChar(text[i])) {
locations.push(i);
}
}
return locations;
}
function isWhiteSpaceChar(char: string | undefined): boolean {
return char !== undefined && char.match(/\s/) !== null;
}
export default RemoveOverlap;

View File

@@ -0,0 +1,24 @@
import Context from "../context-detection";
import { PostProcessor } from "../types";
class RemoveWhitespace implements PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string {
if (context === Context.Text || context === Context.Heading || context === Context.MathBlock || context === Context.TaskList || context === Context.NumberedList || context === Context.UnorderedList) {
if (prefix.endsWith(" ") || suffix.endsWith("\n")) {
completion = completion.trimStart();
}
if (suffix.startsWith(" ")) {
completion = completion.trimEnd();
}
}
return completion;
}
}
export default RemoveWhitespace;

View File

@@ -0,0 +1,34 @@
import { generateRandomString } from "../utils";
import Context from "../context-detection";
import { PrefixAndSuffix, PreProcessor } from "../types";
const DATA_VIEW_REGEX = /```dataview(js){0,1}(.|\n)*?```/gm;
const UNIQUE_CURSOR = `${generateRandomString(16)}`;
class DataViewRemover implements PreProcessor {
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix {
let text = prefix + UNIQUE_CURSOR + suffix;
text = text.replace(DATA_VIEW_REGEX, "");
const [prefixNew, suffixNew] = text.split(UNIQUE_CURSOR);
return { prefix: prefixNew, suffix: suffixNew };
}
removesCursor(prefix: string, suffix: string): boolean {
const text = prefix + UNIQUE_CURSOR + suffix;
const dataviewAreasWithCursor = text
.match(DATA_VIEW_REGEX)
?.filter((dataviewArea) => dataviewArea.includes(UNIQUE_CURSOR));
if (
dataviewAreasWithCursor !== undefined &&
dataviewAreasWithCursor.length > 0
) {
return true;
}
return false;
}
}
export default DataViewRemover;

View File

@@ -0,0 +1,24 @@
import Context from "../context-detection";
import { PrefixAndSuffix, PreProcessor } from "../types";
class LengthLimiter implements PreProcessor {
private readonly maxPrefixChars: number;
private readonly maxSuffixChars: number;
constructor(maxPrefixChars: number, maxSuffixChars: number) {
this.maxPrefixChars = maxPrefixChars;
this.maxSuffixChars = maxSuffixChars;
}
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix {
prefix = prefix.slice(-this.maxPrefixChars);
suffix = suffix.slice(0, this.maxSuffixChars);
return { prefix, suffix };
}
removesCursor(prefix: string, suffix: string): boolean {
return false;
}
}
export default LengthLimiter;

View File

@@ -0,0 +1,35 @@
import { TFile } from "obsidian";
import { InfioSettings } from "../../../types/settings";
import State from "./state";
class DisabledFileSpecificState extends State {
getStatusBarText(): string {
return "Disabled for this file";
}
handleSettingChanged(settings: InfioSettings) {
if (!this.context.settings.autocompleteEnabled) {
this.context.transitionToDisabledManualState();
}
if (!this.context.isCurrentFilePathIgnored() || !this.context.currentFileContainsIgnoredTag()) {
this.context.transitionToIdleState();
}
}
handleFileChange(file: TFile): void {
if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
return;
}
if (this.context.settings.autocompleteEnabled) {
this.context.transitionToIdleState();
} else {
this.context.transitionToDisabledManualState();
}
}
}
export default DisabledFileSpecificState;

View File

@@ -0,0 +1,25 @@
import { InfioSettings } from "../../../types/settings";
import { checkForErrors } from "../../../utils/auto-complete";
import State from "./state";
class DisabledInvalidSettingsState extends State {
getStatusBarText(): string {
return "Disabled invalid settings";
}
handleSettingChanged(settings: InfioSettings) {
const settingErrors = checkForErrors(settings);
if (settingErrors.size > 0) {
return
}
if (this.context.settings.autocompleteEnabled) {
this.context.transitionToIdleState();
} else {
this.context.transitionToDisabledManualState();
}
}
}
export default DisabledInvalidSettingsState;

View File

@@ -0,0 +1,21 @@
import { TFile } from "obsidian";
import { InfioSettings } from "../../../types/settings";
import State from "./state";
class DisabledManualState extends State {
getStatusBarText(): string {
return "Disabled";
}
handleSettingChanged(settings: InfioSettings): void {
if (this.context.settings.autocompleteEnabled) {
this.context.transitionToIdleState();
}
}
handleFileChange(file: TFile): void { }
}
export default DisabledManualState;

View File

@@ -0,0 +1,46 @@
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import State from "./state";
class IdleState extends State {
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
if (
!documentChanges.isDocInFocus()
|| !documentChanges.hasDocChanged()
|| documentChanges.hasUserDeleted()
|| documentChanges.hasMultipleCursors()
|| documentChanges.hasSelection()
|| documentChanges.hasUserUndone()
|| documentChanges.hasUserRedone()
) {
return;
}
const cachedSuggestion = this.context.getCachedSuggestionFor(documentChanges.getPrefix(), documentChanges.getSuffix());
const isThereCachedSuggestion = cachedSuggestion !== undefined && cachedSuggestion.trim().length > 0;
if (this.context.settings.cacheSuggestions && isThereCachedSuggestion) {
this.context.transitionToSuggestingState(cachedSuggestion, documentChanges.getPrefix(), documentChanges.getSuffix());
return;
}
if (this.context.containsTriggerCharacters(documentChanges)) {
this.context.transitionToQueuedState(documentChanges.getPrefix(), documentChanges.getSuffix());
}
}
handlePredictCommand(prefix: string, suffix: string): void {
this.context.transitionToPredictingState(prefix, suffix);
}
getStatusBarText(): string {
return "Idle";
}
}
export default IdleState;

View File

@@ -0,0 +1,38 @@
import { TFile } from "obsidian";
import { InfioSettings } from "../../../types/settings";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import { EventHandler } from "./types";
class InitState implements EventHandler {
async handleDocumentChange(documentChanges: DocumentChanges): Promise<void> { }
handleSettingChanged(settings: InfioSettings): void { }
handleAcceptKeyPressed(): boolean {
return false;
}
handlePartialAcceptKeyPressed(): boolean {
return false;
}
handleCancelKeyPressed(): boolean {
return false;
}
handlePredictCommand(): void { }
handleAcceptCommand(): void { }
getStatusBarText(): string {
return "Initializing...";
}
handleFileChange(file: TFile): void {
}
}
export default InitState;

View File

@@ -0,0 +1,94 @@
import { Notice } from "obsidian";
import Context from "../context-detection";
import EventListener from "../../../event-listener";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import State from "./state";
class PredictingState extends State {
private predictionPromise: Promise<void> | null = null;
private isStillNeeded = true;
private readonly prefix: string;
private readonly suffix: string;
constructor(context: EventListener, prefix: string, suffix: string) {
super(context);
this.prefix = prefix;
this.suffix = suffix;
}
static createAndStartPredicting(
context: EventListener,
prefix: string,
suffix: string
): PredictingState {
const predictingState = new PredictingState(context, prefix, suffix);
predictingState.startPredicting();
context.setContext(Context.getContext(prefix, suffix));
return predictingState;
}
handleCancelKeyPressed(): boolean {
this.cancelPrediction();
return true;
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
if (
documentChanges.hasCursorMoved() ||
documentChanges.hasUserTyped() ||
documentChanges.hasUserDeleted() ||
documentChanges.isTextAdded()
) {
this.cancelPrediction();
}
}
private cancelPrediction(): void {
this.isStillNeeded = false;
this.context.transitionToIdleState();
}
startPredicting(): void {
this.predictionPromise = this.predict();
}
private async predict(): Promise<void> {
const result =
await this.context.autocomplete?.fetchPredictions(
this.prefix,
this.suffix
);
if (!this.isStillNeeded) {
return;
}
if (result.isErr()) {
new Notice(
`Copilot: Something went wrong cannot make a prediction. Full error is available in the dev console. Please check your settings. `
);
console.error(result.error);
this.context.transitionToIdleState();
}
const prediction = result.unwrapOr("");
if (prediction === "") {
this.context.transitionToIdleState();
return;
}
this.context.transitionToSuggestingState(prediction, this.prefix, this.suffix);
}
getStatusBarText(): string {
return `Predicting for ${this.context.context}`;
}
}
export default PredictingState;

View File

@@ -0,0 +1,84 @@
import Context from "../context-detection";
import EventListener from "../../../event-listener";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import State from "./state";
class QueuedState extends State {
private timer: ReturnType<typeof setTimeout> | null = null;
private readonly prefix: string;
private readonly suffix: string;
private constructor(
context: EventListener,
prefix: string,
suffix: string
) {
super(context);
this.prefix = prefix;
this.suffix = suffix;
}
static createAndStartTimer(
context: EventListener,
prefix: string,
suffix: string
): QueuedState {
const state = new QueuedState(context, prefix, suffix);
state.startTimer();
context.setContext(Context.getContext(prefix, suffix));
return state;
}
handleCancelKeyPressed(): boolean {
this.cancelTimer();
this.context.transitionToIdleState();
return true;
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
if (
documentChanges.isDocInFocus() &&
documentChanges.isTextAdded() &&
this.context.containsTriggerCharacters(documentChanges)
) {
this.cancelTimer();
this.context.transitionToQueuedState(documentChanges.getPrefix(), documentChanges.getSuffix());
return
}
if (
(documentChanges.hasCursorMoved() ||
documentChanges.hasUserTyped() ||
documentChanges.hasUserDeleted() ||
documentChanges.isTextAdded() ||
!documentChanges.isDocInFocus())
) {
this.cancelTimer();
this.context.transitionToIdleState();
}
}
startTimer(): void {
this.cancelTimer();
this.timer = setTimeout(() => {
this.context.transitionToPredictingState(this.prefix, this.suffix);
}, this.context.settings.delay);
}
private cancelTimer(): void {
if (this.timer !== null) {
clearTimeout(this.timer);
this.timer = null;
}
}
getStatusBarText(): string {
return `Queued (${this.context.settings.delay} ms)`;
}
}
export default QueuedState;

View File

@@ -0,0 +1,67 @@
import { Notice, TFile } from "obsidian";
import EventListener from "../../../event-listener";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
// import { Settings } from "../settings/versions";
import { InfioSettings } from "../../../types/settings";
import { checkForErrors } from "../../../utils/auto-complete";
import { EventHandler } from "./types";
abstract class State implements EventHandler {
protected readonly context: EventListener;
constructor(context: EventListener) {
this.context = context;
}
handleSettingChanged(settings: InfioSettings): void {
const settingErrors = checkForErrors(settings);
if (!settings.autocompleteEnabled) {
new Notice("Copilot is now disabled.");
this.context.transitionToDisabledManualState()
} else if (settingErrors.size > 0) {
new Notice(
`Copilot: There are ${settingErrors.size} errors in your settings. The plugin will be disabled until they are fixed.`
);
this.context.transitionToDisabledInvalidSettingsState();
} else if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
this.context.transitionToDisabledFileSpecificState();
}
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
}
handleAcceptKeyPressed(): boolean {
return false;
}
handlePartialAcceptKeyPressed(): boolean {
return false;
}
handleCancelKeyPressed(): boolean {
return false;
}
handlePredictCommand(prefix: string, suffix: string): void {
}
handleAcceptCommand(): void {
}
abstract getStatusBarText(): string;
handleFileChange(file: TFile): void {
if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
this.context.transitionToDisabledFileSpecificState();
} else if (this.context.isDisabled()) {
this.context.transitionToIdleState();
}
}
}
export default State;

View File

@@ -0,0 +1,177 @@
import { Settings } from "../../../settings/versions";
import { extractNextWordAndRemaining } from "../utils";
import EventListener from "../../../event-listener";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import State from "./state";
class SuggestingState extends State {
private readonly suggestion: string;
private readonly prefix: string;
private readonly suffix: string;
constructor(context: EventListener, suggestion: string, prefix: string, suffix: string) {
super(context);
this.suggestion = suggestion;
this.prefix = prefix;
this.suffix = suffix;
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
if (
documentChanges.hasCursorMoved()
|| documentChanges.hasUserUndone()
|| documentChanges.hasUserDeleted()
|| documentChanges.hasUserRedone()
|| !documentChanges.isDocInFocus()
|| documentChanges.hasSelection()
|| documentChanges.hasMultipleCursors()
) {
this.clearPrediction();
return;
}
if (
documentChanges.noUserEvents()
|| !documentChanges.hasDocChanged()
) {
return;
}
if (this.hasUserAddedPartOfSuggestion(documentChanges)) {
this.acceptPartialAddedText(documentChanges);
return
}
const currentPrefix = documentChanges.getPrefix();
const currentSuffix = documentChanges.getSuffix();
const suggestion = this.context.getCachedSuggestionFor(currentPrefix, currentSuffix);
const isThereCachedSuggestion = suggestion !== undefined;
const isCachedSuggestionDifferent = suggestion !== this.suggestion;
if (!isCachedSuggestionDifferent) {
return;
}
if (isThereCachedSuggestion) {
this.context.transitionToSuggestingState(suggestion, currentPrefix, currentSuffix);
return;
}
this.clearPrediction();
}
hasUserAddedPartOfSuggestion(documentChanges: DocumentChanges): boolean {
const addedPrefixText = documentChanges.getAddedPrefixText();
const addedSuffixText = documentChanges.getAddedSuffixText();
return addedPrefixText !== undefined
&& addedSuffixText !== undefined
&& this.suggestion.toLowerCase().startsWith(addedPrefixText.toLowerCase())
&& this.suggestion.toLowerCase().endsWith(addedSuffixText.toLowerCase());
}
acceptPartialAddedText(documentChanges: DocumentChanges): void {
const addedPrefixText = documentChanges.getAddedPrefixText();
const addedSuffixText = documentChanges.getAddedSuffixText();
if (addedSuffixText === undefined || addedPrefixText === undefined) {
return;
}
const startIdx = addedPrefixText.length;
const endIdx = this.suggestion.length - addedSuffixText.length
const remainingSuggestion = this.suggestion.substring(startIdx, endIdx);
if (remainingSuggestion.trim() === "") {
this.clearPrediction();
} else {
this.context.transitionToSuggestingState(remainingSuggestion, documentChanges.getPrefix(), documentChanges.getSuffix());
}
}
private clearPrediction(): void {
this.context.transitionToIdleState();
}
handleAcceptKeyPressed(): boolean {
this.accept();
return true;
}
private accept() {
this.addPartialSuggestionCaches(this.suggestion);
this.context.insertCurrentSuggestion(this.suggestion);
this.context.transitionToIdleState();
}
handlePartialAcceptKeyPressed(): boolean {
this.acceptNextWord();
return true;
}
private acceptNextWord() {
const [nextWord, remaining] = extractNextWordAndRemaining(this.suggestion);
if (nextWord !== undefined && remaining !== undefined) {
const updatedPrefix = this.prefix + nextWord;
this.addPartialSuggestionCaches(nextWord, remaining);
this.context.insertCurrentSuggestion(nextWord);
this.context.transitionToSuggestingState(remaining, updatedPrefix, this.suffix, false);
} else {
this.accept();
}
}
private addPartialSuggestionCaches(acceptSuggestion: string, remainingSuggestion = "") {
// store the sub-suggestions in the cache
// so that we can have partial suggestions if the user edits a part
for (let i = 0; i < acceptSuggestion.length; i++) {
const prefix = this.prefix + acceptSuggestion.substring(0, i);
const suggestion = acceptSuggestion.substring(i) + remainingSuggestion;
this.context.addSuggestionToCache(prefix, this.suffix, suggestion);
}
}
private getNextWordAndRemaining(): [string | undefined, string | undefined] {
const words = this.suggestion.split(" ");
if (words.length === 0) {
return ["", ""];
}
if (words.length === 1) {
return [words[0] + " ", ""];
}
return [words[0] + " ", words.slice(1).join(" ")];
}
handleCancelKeyPressed(): boolean {
this.context.clearSuggestionsCache();
this.clearPrediction();
return true;
}
handleAcceptCommand() {
this.accept();
}
getStatusBarText(): string {
return `Suggesting for ${this.context.context}`;
}
handleSettingChanged(settings: Settings): void {
if (!settings.cacheSuggestions) {
this.clearPrediction();
}
}
}
export default SuggestingState;

View File

@@ -0,0 +1,25 @@
import { TFile } from "obsidian";
import { InfioSettings } from "../../../types/settings";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
export interface EventHandler {
handleSettingChanged(settings: InfioSettings): void;
handleDocumentChange(documentChanges: DocumentChanges): Promise<void>;
handleAcceptKeyPressed(): boolean;
handlePartialAcceptKeyPressed(): boolean;
handleCancelKeyPressed(): boolean;
handlePredictCommand(prefix: string, suffix: string): void;
handleAcceptCommand(): void;
getStatusBarText(): string;
handleFileChange(file: TFile): void;
}

View File

@@ -0,0 +1,57 @@
import { Result } from "neverthrow";
import Context from "./context-detection";
export interface AutocompleteService {
fetchPredictions(
prefix: string,
suffix: string
): Promise<Result<string, Error>>;
}
export interface PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string;
}
export interface PreProcessor {
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix;
removesCursor(prefix: string, suffix: string): boolean;
}
export interface PrefixAndSuffix {
prefix: string;
suffix: string;
}
export interface ChatMessage {
content: string;
role: "user" | "assistant" | "system";
}
export interface UserMessageFormattingInputs {
prefix: string;
suffix: string;
}
export type UserMessageFormatter = (
inputs: UserMessageFormattingInputs
) => string;
export interface ApiClient {
queryChatModel(messages: ChatMessage[]): Promise<Result<string, Error>>;
checkIfConfiguredCorrectly?(): Promise<string[]>;
}
export interface ModelOptions {
temperature: number;
top_p: number;
frequency_penalty: number;
presence_penalty: number;
max_tokens: number;
}

View File

@@ -0,0 +1,68 @@
import * as mm from "micromatch";
export function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
export function enumKeys<O extends object, K extends keyof O = keyof O>(
obj: O
): K[] {
return Object.keys(obj).filter((k) => Number.isNaN(+k)) as K[];
}
export function generateRandomString(n: number): string {
let result = '';
const characters = '0123456789abcdef';
for (let i = 0; i < n; i++) {
const randomIndex = Math.floor(Math.random() * characters.length);
result += characters[randomIndex];
}
return result;
}
export function isMatchBetweenPathAndPatterns(
path: string,
patterns: string[],
): boolean {
patterns = patterns
.map(p => p.trim())
.filter((p) => p.length > 0);
if (patterns.length === 0) {
return false;
}
const exclusionPatterns = patterns.filter((p) => p.startsWith('!')).map(p => p.slice(1));
const inclusionPatterns = patterns.filter((p) => !p.startsWith('!'));
return mm.some(path, inclusionPatterns) && !mm.some(path, exclusionPatterns);
}
export function extractNextWordAndRemaining(suggestion: string): [string | undefined, string | undefined] {
const leadingWhitespacesMatch = suggestion.match(/^(\s*)/);
const leadingWhitespaces = leadingWhitespacesMatch ? leadingWhitespacesMatch[0] : '';
const trimmedSuggestion = suggestion.slice(leadingWhitespaces.length);
let nextWord: string | undefined;
let remaining: string | undefined = undefined;
const whitespaceAfterNextWordMatch = trimmedSuggestion.match(/\s+/);
if (!whitespaceAfterNextWordMatch) {
nextWord = trimmedSuggestion || undefined;
} else {
const whitespaceAfterNextWordStartingIndex = whitespaceAfterNextWordMatch.index!;
const whitespaceAfterNextWord = whitespaceAfterNextWordMatch[0];
const whitespaceLength = whitespaceAfterNextWord.length;
const startOfWhitespaceAfterNextWordIndex = whitespaceAfterNextWordStartingIndex + whitespaceLength;
nextWord = trimmedSuggestion.substring(0, whitespaceAfterNextWordStartingIndex);
if (startOfWhitespaceAfterNextWordIndex < trimmedSuggestion.length) {
remaining = trimmedSuggestion.slice(startOfWhitespaceAfterNextWordIndex);
nextWord += whitespaceAfterNextWord;
}
}
return [nextWord ? leadingWhitespaces + nextWord : undefined, remaining];
}

View File

@@ -0,0 +1,44 @@
import { MarkdownPostProcessorContext, Plugin } from "obsidian";
import React from "react";
import { createRoot } from "react-dom/client";
import { InlineEdit as InlineEditComponent } from "../../components/inline-edit/InlineEdit";
import { InfioSettings } from "../../types/settings";
export class InlineEdit {
plugin: Plugin;
settings: InfioSettings;
constructor(plugin: Plugin, settings: InfioSettings) {
this.plugin = plugin;
this.settings = settings;
}
/**
* Markdown 处理器入口函数
*/
Processor(source: string, el: HTMLElement, ctx: MarkdownPostProcessorContext) {
const sec = ctx.getSectionInfo(el);
if (!sec) return;
const container = createDiv();
const root = createRoot(container);
root.render(
React.createElement(InlineEditComponent, {
source: source,
secStartLine: sec.lineStart,
secEndLine: sec.lineEnd,
plugin: this.plugin,
settings: this.settings
})
);
// 移除父元素的代码块样式
const parent = el.parentElement;
if (parent) {
parent.addClass("infio-ai-block");
}
el.replaceWith(container);
}
}

323
src/core/llm/anthropic.ts Normal file
View File

@@ -0,0 +1,323 @@
import Anthropic from '@anthropic-ai/sdk'
import {
ImageBlockParam,
MessageParam,
MessageStreamEvent,
TextBlockParam,
} from '@anthropic-ai/sdk/resources/messages'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
ResponseUsage,
} from '../../types/llm/response'
import { parseImageDataUrl } from '../../utils/image'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
export class AnthropicProvider implements BaseLLMProvider {
private client: Anthropic
private static readonly DEFAULT_MAX_TOKENS = 8192
constructor(apiKey: string) {
this.client = new Anthropic({ apiKey, dangerouslyAllowBrowser: true })
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Anthropic API key is missing. Please set it in settings menu.',
)
}
this.client = new Anthropic({
baseURL: model.baseUrl,
apiKey: model.apiKey,
dangerouslyAllowBrowser: true
})
}
const systemMessage = AnthropicProvider.validateSystemMessages(
request.messages,
)
try {
const response = await this.client.messages.create(
{
model: request.model,
messages: request.messages
.filter((m) => m.role !== 'system')
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
.map((m) => AnthropicProvider.parseRequestMessage(m)),
system: systemMessage,
max_tokens:
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
temperature: request.temperature,
top_p: request.top_p,
},
{
signal: options?.signal,
},
)
return AnthropicProvider.parseNonStreamingResponse(response)
} catch (error) {
if (error instanceof Anthropic.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'Anthropic API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Anthropic API key is missing. Please set it in settings menu.',
)
}
this.client = new Anthropic({
baseURL: model.baseUrl,
apiKey: model.apiKey,
dangerouslyAllowBrowser: true
})
}
const systemMessage = AnthropicProvider.validateSystemMessages(
request.messages,
)
try {
const stream = await this.client.messages.create(
{
model: request.model,
messages: request.messages
.filter((m) => m.role !== 'system')
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
.map((m) => AnthropicProvider.parseRequestMessage(m)),
system: systemMessage,
max_tokens:
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
temperature: request.temperature,
top_p: request.top_p,
stream: true,
},
{
signal: options?.signal,
},
)
// eslint-disable-next-line no-inner-declarations
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
let messageId = ''
let model = ''
let usage: ResponseUsage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
}
for await (const chunk of stream) {
if (chunk.type === 'message_start') {
messageId = chunk.message.id
model = chunk.message.model
usage = {
prompt_tokens: chunk.message.usage.input_tokens,
completion_tokens: chunk.message.usage.output_tokens,
total_tokens:
chunk.message.usage.input_tokens +
chunk.message.usage.output_tokens,
}
} else if (chunk.type === 'content_block_delta') {
yield AnthropicProvider.parseStreamingResponseChunk(
chunk,
messageId,
model,
)
} else if (chunk.type === 'message_delta') {
usage = {
prompt_tokens: usage.prompt_tokens,
completion_tokens:
usage.completion_tokens + chunk.usage.output_tokens,
total_tokens: usage.total_tokens + chunk.usage.output_tokens,
}
}
}
// After the stream is complete, yield the final usage
yield {
id: messageId,
choices: [],
object: 'chat.completion.chunk',
model: model,
usage: usage,
}
}
return streamResponse()
} catch (error) {
if (error instanceof Anthropic.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'Anthropic API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
static parseRequestMessage(message: RequestMessage): MessageParam {
if (message.role !== 'user' && message.role !== 'assistant') {
throw new Error(`Anthropic does not support role: ${message.role}`)
}
if (message.role === 'user' && Array.isArray(message.content)) {
const content = message.content.map(
(part): TextBlockParam | ImageBlockParam => {
switch (part.type) {
case 'text':
return { type: 'text', text: part.text }
case 'image_url': {
const { mimeType, base64Data } = parseImageDataUrl(
part.image_url.url,
)
AnthropicProvider.validateImageType(mimeType)
return {
type: 'image',
source: {
data: base64Data,
media_type:
mimeType as ImageBlockParam['source']['media_type'],
type: 'base64',
},
}
}
}
},
)
return { role: 'user', content }
}
return {
role: message.role,
content: message.content as string,
}
}
static parseNonStreamingResponse(
response: Anthropic.Message,
): LLMResponseNonStreaming {
if (response.content[0].type === 'tool_use') {
throw new Error('Unsupported content type: tool_use')
}
return {
id: response.id,
choices: [
{
finish_reason: response.stop_reason,
message: {
content: response.content[0].text,
role: response.role,
},
},
],
model: response.model,
object: 'chat.completion',
usage: {
prompt_tokens: response.usage.input_tokens,
completion_tokens: response.usage.output_tokens,
total_tokens:
response.usage.input_tokens + response.usage.output_tokens,
},
}
}
static parseStreamingResponseChunk(
chunk: MessageStreamEvent,
messageId: string,
model: string,
): LLMResponseStreaming {
if (chunk.type !== 'content_block_delta') {
throw new Error('Unsupported chunk type')
}
if (chunk.delta.type === 'input_json_delta') {
throw new Error('Unsupported content type: input_json_delta')
}
return {
id: messageId,
choices: [
{
finish_reason: null,
delta: {
content: chunk.delta.text,
},
},
],
object: 'chat.completion.chunk',
model: model,
}
}
private static validateSystemMessages(
messages: RequestMessage[],
): string | undefined {
const systemMessages = messages.filter((m) => m.role === 'system')
if (systemMessages.length > 1) {
throw new Error(`Anthropic does not support more than one system message`)
}
const systemMessage =
systemMessages.length > 0 ? systemMessages[0].content : undefined
if (systemMessage && typeof systemMessage !== 'string') {
throw new Error(
`Anthropic only supports string content for system messages`,
)
}
return systemMessage
}
private static isMessageEmpty(message: RequestMessage) {
if (typeof message.content === 'string') {
return message.content.trim() === ''
}
return message.content.length === 0
}
private static validateImageType(mimeType: string) {
const SUPPORTED_IMAGE_TYPES = [
'image/jpeg',
'image/png',
'image/gif',
'image/webp',
]
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
throw new Error(
`Anthropic does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
', ',
)}`,
)
}
}
}

23
src/core/llm/base.ts Normal file
View File

@@ -0,0 +1,23 @@
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
export type BaseLLMProvider = {
generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming>
streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>>
}

34
src/core/llm/exception.ts Normal file
View File

@@ -0,0 +1,34 @@
export class LLMAPIKeyNotSetException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMAPIKeyNotSetException'
}
}
export class LLMAPIKeyInvalidException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMAPIKeyInvalidException'
}
}
export class LLMBaseUrlNotSetException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMBaseUrlNotSetException'
}
}
export class LLMModelNotSetException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMModelNotSetException'
}
}
export class LLMRateLimitExceededException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMRateLimitExceededException'
}
}

299
src/core/llm/gemini.ts Normal file
View File

@@ -0,0 +1,299 @@
import {
Content,
EnhancedGenerateContentResponse,
GenerateContentResult,
GenerateContentStreamResult,
GoogleGenerativeAI,
} from '@google/generative-ai'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { parseImageDataUrl } from '../../utils/image'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
/**
* Note on OpenAI Compatibility API:
* Gemini provides an OpenAI-compatible endpoint (https://ai.google.dev/gemini-api/docs/openai)
* which allows using the OpenAI SDK with Gemini models. However, there are currently CORS issues
* preventing its use in Obsidian. Consider switching to this endpoint in the future once these
* issues are resolved.
*/
export class GeminiProvider implements BaseLLMProvider {
private client: GoogleGenerativeAI
private apiKey: string
constructor(apiKey: string) {
this.apiKey = apiKey
this.client = new GoogleGenerativeAI(apiKey)
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
this.apiKey = model.apiKey
this.client = new GoogleGenerativeAI(model.apiKey)
}
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
systemMessages.length > 0
? systemMessages.map((m) => m.content).join('\n')
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const result = await model.generateContent(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return GeminiProvider.parseNonStreamingResponse(
result,
request.model,
messageId,
)
} catch (error) {
const isInvalidApiKey =
error.message?.includes('API_KEY_INVALID') ||
error.message?.includes('API key not valid')
if (isInvalidApiKey) {
throw new LLMAPIKeyInvalidException(
`Gemini API key is invalid. Please update it in settings menu.`,
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
this.apiKey = model.apiKey
this.client = new GoogleGenerativeAI(model.apiKey)
}
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
systemMessages.length > 0
? systemMessages.map((m) => m.content).join('\n')
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const stream = await model.generateContentStream(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return this.streamResponseGenerator(stream, request.model, messageId)
} catch (error) {
const isInvalidApiKey =
error.message?.includes('API_KEY_INVALID') ||
error.message?.includes('API key not valid')
if (isInvalidApiKey) {
throw new LLMAPIKeyInvalidException(
`Gemini API key is invalid. Please update it in settings menu.`,
)
}
throw error
}
}
private async *streamResponseGenerator(
stream: GenerateContentStreamResult,
model: string,
messageId: string,
): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream.stream) {
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
}
}
static parseRequestMessage(message: RequestMessage): Content | null {
if (message.role === 'system') {
return null
}
if (Array.isArray(message.content)) {
return {
role: message.role === 'user' ? 'user' : 'model',
parts: message.content.map((part) => {
switch (part.type) {
case 'text':
return { text: part.text }
case 'image_url': {
const { mimeType, base64Data } = parseImageDataUrl(
part.image_url.url,
)
GeminiProvider.validateImageType(mimeType)
return {
inlineData: {
data: base64Data,
mimeType,
},
}
}
}
}),
}
}
return {
role: message.role === 'user' ? 'user' : 'model',
parts: [
{
text: message.content,
},
],
}
}
static parseNonStreamingResponse(
response: GenerateContentResult,
model: string,
messageId: string,
): LLMResponseNonStreaming {
return {
id: messageId,
choices: [
{
finish_reason:
response.response.candidates?.[0]?.finishReason ?? null,
message: {
content: response.response.text(),
role: 'assistant',
},
},
],
created: Date.now(),
model: model,
object: 'chat.completion',
usage: response.response.usageMetadata
? {
prompt_tokens: response.response.usageMetadata.promptTokenCount,
completion_tokens:
response.response.usageMetadata.candidatesTokenCount,
total_tokens: response.response.usageMetadata.totalTokenCount,
}
: undefined,
}
}
static parseStreamingResponseChunk(
chunk: EnhancedGenerateContentResponse,
model: string,
messageId: string,
): LLMResponseStreaming {
return {
id: messageId,
choices: [
{
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
delta: {
content: chunk.text(),
},
},
],
created: Date.now(),
model: model,
object: 'chat.completion.chunk',
usage: chunk.usageMetadata
? {
prompt_tokens: chunk.usageMetadata.promptTokenCount,
completion_tokens: chunk.usageMetadata.candidatesTokenCount,
total_tokens: chunk.usageMetadata.totalTokenCount,
}
: undefined,
}
}
private static validateImageType(mimeType: string) {
const SUPPORTED_IMAGE_TYPES = [
'image/png',
'image/jpeg',
'image/webp',
'image/heic',
'image/heif',
]
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
throw new Error(
`Gemini does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
', ',
)}`,
)
}
}
}

200
src/core/llm/groq.ts Normal file
View File

@@ -0,0 +1,200 @@
import Groq from 'groq-sdk'
import {
ChatCompletion,
ChatCompletionChunk,
ChatCompletionContentPart,
ChatCompletionMessageParam,
} from 'groq-sdk/resources/chat/completions'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
export class GroqProvider implements BaseLLMProvider {
private client: Groq
constructor(apiKey: string) {
this.client = new Groq({
apiKey,
dangerouslyAllowBrowser: true,
})
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Groq API key is missing. Please set it in settings menu.',
)
}
this.client = new Groq({
apiKey: model.apiKey,
dangerouslyAllowBrowser: true,
})
}
try {
const response = await this.client.chat.completions.create(
{
model: request.model,
messages: request.messages.map((m) =>
GroqProvider.parseRequestMessage(m),
),
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
},
{
signal: options?.signal,
},
)
return GroqProvider.parseNonStreamingResponse(response)
} catch (error) {
if (error instanceof Groq.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'Groq API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Groq API key is missing. Please set it in settings menu.',
)
}
this.client = new Groq({
apiKey: model.apiKey,
dangerouslyAllowBrowser: true,
})
}
try {
const stream = await this.client.chat.completions.create(
{
model: request.model,
messages: request.messages.map((m) =>
GroqProvider.parseRequestMessage(m),
),
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
stream: true,
},
{
signal: options?.signal,
},
)
// eslint-disable-next-line no-inner-declarations
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream) {
yield GroqProvider.parseStreamingResponseChunk(chunk)
}
}
return streamResponse()
} catch (error) {
if (error instanceof Groq.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'Groq API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
static parseRequestMessage(
message: RequestMessage,
): ChatCompletionMessageParam {
switch (message.role) {
case 'user': {
const content = Array.isArray(message.content)
? message.content.map((part): ChatCompletionContentPart => {
switch (part.type) {
case 'text':
return { type: 'text', text: part.text }
case 'image_url':
return { type: 'image_url', image_url: part.image_url }
}
})
: message.content
return { role: 'user', content }
}
case 'assistant': {
if (Array.isArray(message.content)) {
throw new Error('Assistant message should be a string')
}
return { role: 'assistant', content: message.content }
}
case 'system': {
if (Array.isArray(message.content)) {
throw new Error('System message should be a string')
}
return { role: 'system', content: message.content }
}
}
}
static parseNonStreamingResponse(
response: ChatCompletion,
): LLMResponseNonStreaming {
return {
id: response.id,
choices: response.choices.map((choice) => ({
finish_reason: choice.finish_reason,
message: {
content: choice.message.content,
role: choice.message.role,
},
})),
created: response.created,
model: response.model,
object: 'chat.completion',
usage: response.usage,
}
}
static parseStreamingResponseChunk(
chunk: ChatCompletionChunk,
): LLMResponseStreaming {
return {
id: chunk.id,
choices: chunk.choices.map((choice) => ({
finish_reason: choice.finish_reason ?? null,
delta: {
content: choice.delta.content ?? null,
role: choice.delta.role,
},
})),
created: chunk.created,
model: chunk.model,
object: 'chat.completion.chunk',
}
}
}

252
src/core/llm/infio.ts Normal file
View File

@@ -0,0 +1,252 @@
import OpenAI from 'openai'
import {
ChatCompletion,
ChatCompletionChunk,
} from 'openai/resources/chat/completions'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
export interface RangeFilter {
gte?: number;
lte?: number;
}
export interface ChunkFilter {
field: string;
match_all?: string[];
range?: RangeFilter;
}
/**
* Interface for making requests to the Infio API
*/
export interface InfioRequest {
/** Required: The content of the user message to attach to the topic and then generate an assistant message in response to */
messages: RequestMessage[];
// /** Required: The ID of the topic to attach the message to */
// topic_id: string;
/** Optional: URLs to include */
links?: string[];
/** Optional: Files to include */
files?: string[];
/** Optional: Whether to highlight results in chunk_html. Default is true */
highlight_results?: boolean;
/** Optional: Delimiters for highlighting citations. Default is [".", "!", "?", "\n", "\t", ","] */
highlight_delimiters?: string[];
/** Optional: Search type - "semantic", "fulltext", or "hybrid". Default is "hybrid" */
search_type?: string;
/** Optional: Filters for chunk filtering */
filters?: ChunkFilter;
/** Optional: Whether to use web search API. Default is false */
use_web_search?: boolean;
/** Optional: LLM model to use */
llm_model?: string;
/** Optional: Force source */
force_source?: string;
/** Optional: Whether completion should come before chunks in stream. Default is false */
completion_first?: boolean;
/** Optional: Whether to stream the response. Default is true */
stream_response?: boolean;
/** Optional: Sampling temperature between 0 and 2. Default is 0.5 */
temperature?: number;
/** Optional: Frequency penalty between -2.0 and 2.0. Default is 0.7 */
frequency_penalty?: number;
/** Optional: Presence penalty between -2.0 and 2.0. Default is 0.7 */
presence_penalty?: number;
/** Optional: Maximum tokens to generate */
max_tokens?: number;
/** Optional: Stop tokens (up to 4 sequences) */
stop_tokens?: string[];
}
export class InfioProvider implements BaseLLMProvider {
// private adapter: OpenAIMessageAdapter
// private client: OpenAI
private apiKey: string
private baseUrl: string
constructor(apiKey: string) {
// this.client = new OpenAI({ apiKey, dangerouslyAllowBrowser: true })
// this.adapter = new OpenAIMessageAdapter()
this.apiKey = apiKey
this.baseUrl = 'https://api.infio.com/api/raw_message'
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
try {
const req: InfioRequest = {
messages: request.messages,
stream_response: false,
temperature: request.temperature,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
max_tokens: request.max_tokens,
}
const options = {
method: 'POST',
headers: {
Authorization: this.apiKey,
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
'Content-Type': 'application/json'
},
body: JSON.stringify(req)
};
const response = await fetch(this.baseUrl, options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json() as ChatCompletion;
return InfioProvider.parseNonStreamingResponse(data);
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
try {
const req: InfioRequest = {
llm_model: request.model,
messages: request.messages,
stream_response: true,
temperature: request.temperature,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
max_tokens: request.max_tokens,
}
const options = {
method: 'POST',
headers: {
Authorization: this.apiKey,
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
"Content-Type": "application/json"
},
body: JSON.stringify(req)
};
const response = await fetch(this.baseUrl, options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
return {
[Symbol.asyncIterator]: async function* () {
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n').filter(line => line.trim());
for (const line of lines) {
if (line.startsWith('data: ')) {
const jsonData = JSON.parse(line.slice(6)) as ChatCompletionChunk;
if (!jsonData || typeof jsonData !== 'object' || !('choices' in jsonData)) {
throw new Error('Invalid chunk format received');
}
yield InfioProvider.parseStreamingResponseChunk(jsonData);
}
}
}
} finally {
reader.releaseLock();
}
}
};
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
static parseNonStreamingResponse(
response: ChatCompletion,
): LLMResponseNonStreaming {
return {
id: response.id,
choices: response.choices.map((choice) => ({
finish_reason: choice.finish_reason,
message: {
content: choice.message.content,
role: choice.message.role,
},
})),
created: response.created,
model: response.model,
object: 'chat.completion',
system_fingerprint: response.system_fingerprint,
usage: response.usage,
}
}
static parseStreamingResponseChunk(
chunk: ChatCompletionChunk,
): LLMResponseStreaming {
return {
id: chunk.id,
choices: chunk.choices.map((choice) => ({
finish_reason: choice.finish_reason ?? null,
delta: {
content: choice.delta.content ?? null,
role: choice.delta.role,
},
})),
created: chunk.created,
model: chunk.model,
object: 'chat.completion.chunk',
system_fingerprint: chunk.system_fingerprint,
usage: chunk.usage ?? undefined,
}
}
}

142
src/core/llm/manager.ts Normal file
View File

@@ -0,0 +1,142 @@
import { DEEPSEEK_BASE_URL } from '../../constants'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { AnthropicProvider } from './anthropic'
import { GeminiProvider } from './gemini'
import { GroqProvider } from './groq'
import { InfioProvider } from './infio'
import { OllamaProvider } from './ollama'
import { OpenAIAuthenticatedProvider } from './openai'
import { OpenAICompatibleProvider } from './openai-compatible-provider'
export type LLMManagerInterface = {
generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming>
streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>>
}
class LLMManager implements LLMManagerInterface {
private openaiProvider: OpenAIAuthenticatedProvider
private deepseekProvider: OpenAICompatibleProvider
private anthropicProvider: AnthropicProvider
private googleProvider: GeminiProvider
private groqProvider: GroqProvider
private infioProvider: InfioProvider
private ollamaProvider: OllamaProvider
private isInfioEnabled: boolean
constructor(apiKeys: {
deepseek?: string
openai?: string
anthropic?: string
gemini?: string
groq?: string
infio?: string
}) {
this.deepseekProvider = new OpenAICompatibleProvider(apiKeys.deepseek ?? '', DEEPSEEK_BASE_URL)
this.openaiProvider = new OpenAIAuthenticatedProvider(apiKeys.openai ?? '')
this.anthropicProvider = new AnthropicProvider(apiKeys.anthropic ?? '')
this.googleProvider = new GeminiProvider(apiKeys.gemini ?? '')
this.groqProvider = new GroqProvider(apiKeys.groq ?? '')
this.infioProvider = new InfioProvider(apiKeys.infio ?? '')
this.ollamaProvider = new OllamaProvider()
this.isInfioEnabled = !!apiKeys.infio
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (this.isInfioEnabled) {
return await this.infioProvider.generateResponse(
model,
request,
options,
)
}
// use custom provider
switch (model.provider) {
case 'deepseek':
return await this.deepseekProvider.generateResponse(
model,
request,
options,
)
case 'openai':
return await this.openaiProvider.generateResponse(
model,
request,
options,
)
case 'anthropic':
return await this.anthropicProvider.generateResponse(
model,
request,
options,
)
case 'google':
return await this.googleProvider.generateResponse(
model,
request,
options,
)
case 'groq':
return await this.groqProvider.generateResponse(model, request, options)
case 'ollama':
return await this.ollamaProvider.generateResponse(
model,
request,
options,
)
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (this.isInfioEnabled) {
return await this.infioProvider.streamResponse(model, request, options)
}
// use custom provider
switch (model.provider) {
case 'deepseek':
return await this.deepseekProvider.streamResponse(model, request, options)
case 'openai':
return await this.openaiProvider.streamResponse(model, request, options)
case 'anthropic':
return await this.anthropicProvider.streamResponse(
model,
request,
options,
)
case 'google':
return await this.googleProvider.streamResponse(model, request, options)
case 'groq':
return await this.groqProvider.streamResponse(model, request, options)
case 'ollama':
return await this.ollamaProvider.streamResponse(model, request, options)
}
}
}
export default LLMManager

104
src/core/llm/ollama.ts Normal file
View File

@@ -0,0 +1,104 @@
/**
* This provider is nearly identical to OpenAICompatibleProvider, but uses a custom OpenAI client
* (NoStainlessOpenAI) to work around CORS issues specific to Ollama.
*/
import OpenAI from 'openai'
import { FinalRequestOptions } from 'openai/core'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import { LLMBaseUrlNotSetException, LLMModelNotSetException } from './exception'
import { OpenAIMessageAdapter } from './openai-message-adapter'
export class NoStainlessOpenAI extends OpenAI {
defaultHeaders() {
return {
Accept: 'application/json',
'Content-Type': 'application/json',
}
}
buildRequest<Req>(
options: FinalRequestOptions<Req>,
{ retryCount = 0 }: { retryCount?: number } = {},
): { req: RequestInit; url: string; timeout: number } {
const req = super.buildRequest(options, { retryCount })
const headers = req.req.headers as Record<string, string>
Object.keys(headers).forEach((k) => {
if (k.startsWith('x-stainless')) {
// eslint-disable-next-line @typescript-eslint/no-dynamic-delete
delete headers[k]
}
})
return req
}
}
export class OllamaProvider implements BaseLLMProvider {
private adapter: OpenAIMessageAdapter
constructor() {
this.adapter = new OpenAIMessageAdapter()
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!model.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama base URL is missing. Please set it in settings menu.',
)
}
if (!model.name) {
throw new LLMModelNotSetException(
'Ollama model is missing. Please set it in settings menu.',
)
}
const client = new NoStainlessOpenAI({
baseURL: `${model.baseUrl}/v1`,
apiKey: '',
dangerouslyAllowBrowser: true,
})
return this.adapter.generateResponse(client, request, options)
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!model.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama base URL is missing. Please set it in settings menu.',
)
}
if (!model.name) {
throw new LLMModelNotSetException(
'Ollama model is missing. Please set it in settings menu.',
)
}
const client = new NoStainlessOpenAI({
baseURL: `${model.baseUrl}/v1`,
apiKey: '',
dangerouslyAllowBrowser: true,
})
return this.adapter.streamResponse(client, request, options)
}
}

View File

@@ -0,0 +1,62 @@
import OpenAI from 'openai'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import { LLMBaseUrlNotSetException } from './exception'
import { OpenAIMessageAdapter } from './openai-message-adapter'
export class OpenAICompatibleProvider implements BaseLLMProvider {
private adapter: OpenAIMessageAdapter
private client: OpenAI
private apiKey: string
private baseURL: string
constructor(apiKey: string, baseURL: string) {
this.adapter = new OpenAIMessageAdapter()
this.client = new OpenAI({
apiKey: apiKey,
baseURL: baseURL,
dangerouslyAllowBrowser: true,
})
this.apiKey = apiKey
this.baseURL = baseURL
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.baseURL || !this.apiKey) {
throw new LLMBaseUrlNotSetException(
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
)
}
return this.adapter.generateResponse(this.client, request, options)
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.baseURL || !this.apiKey) {
throw new LLMBaseUrlNotSetException(
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
)
}
return this.adapter.streamResponse(this.client, request, options)
}
}

View File

@@ -0,0 +1,155 @@
import OpenAI from 'openai'
import {
ChatCompletion,
ChatCompletionChunk,
ChatCompletionContentPart,
ChatCompletionMessageParam,
} from 'openai/resources/chat/completions'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
export class OpenAIMessageAdapter {
async generateResponse(
client: OpenAI,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
const response = await client.chat.completions.create(
{
model: request.model,
messages: request.messages.map((m) =>
OpenAIMessageAdapter.parseRequestMessage(m),
),
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
logit_bias: request.logit_bias,
prediction: request.prediction,
},
{
signal: options?.signal,
},
)
return OpenAIMessageAdapter.parseNonStreamingResponse(response)
}
async streamResponse(
client: OpenAI,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
const stream = await client.chat.completions.create(
{
model: request.model,
messages: request.messages.map((m) =>
OpenAIMessageAdapter.parseRequestMessage(m),
),
max_completion_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
logit_bias: request.logit_bias,
stream: true,
stream_options: {
include_usage: true,
},
},
{
signal: options?.signal,
},
)
// eslint-disable-next-line no-inner-declarations
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream) {
yield OpenAIMessageAdapter.parseStreamingResponseChunk(chunk)
}
}
return streamResponse()
}
static parseRequestMessage(
message: RequestMessage,
): ChatCompletionMessageParam {
switch (message.role) {
case 'user': {
const content = Array.isArray(message.content)
? message.content.map((part): ChatCompletionContentPart => {
switch (part.type) {
case 'text':
return { type: 'text', text: part.text }
case 'image_url':
return { type: 'image_url', image_url: part.image_url }
}
})
: message.content
return { role: 'user', content }
}
case 'assistant': {
if (Array.isArray(message.content)) {
throw new Error('Assistant message should be a string')
}
return { role: 'assistant', content: message.content }
}
case 'system': {
if (Array.isArray(message.content)) {
throw new Error('System message should be a string')
}
return { role: 'system', content: message.content }
}
}
}
static parseNonStreamingResponse(
response: ChatCompletion,
): LLMResponseNonStreaming {
return {
id: response.id,
choices: response.choices.map((choice) => ({
finish_reason: choice.finish_reason,
message: {
content: choice.message.content,
role: choice.message.role,
},
})),
created: response.created,
model: response.model,
object: 'chat.completion',
system_fingerprint: response.system_fingerprint,
usage: response.usage,
}
}
static parseStreamingResponseChunk(
chunk: ChatCompletionChunk,
): LLMResponseStreaming {
return {
id: chunk.id,
choices: chunk.choices.map((choice) => ({
finish_reason: choice.finish_reason ?? null,
delta: {
content: choice.delta.content ?? null,
role: choice.delta.role,
},
})),
created: chunk.created,
model: chunk.model,
object: 'chat.completion.chunk',
system_fingerprint: chunk.system_fingerprint,
usage: chunk.usage ?? undefined,
}
}
}

91
src/core/llm/openai.ts Normal file
View File

@@ -0,0 +1,91 @@
import OpenAI from 'openai'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
import { OpenAIMessageAdapter } from './openai-message-adapter'
export class OpenAIAuthenticatedProvider implements BaseLLMProvider {
private adapter: OpenAIMessageAdapter
private client: OpenAI
constructor(apiKey: string) {
this.client = new OpenAI({
apiKey,
dangerouslyAllowBrowser: true,
})
this.adapter = new OpenAIMessageAdapter()
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.baseUrl) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
this.client = new OpenAI({
apiKey: model.apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
})
}
try {
return this.adapter.generateResponse(this.client, request, options)
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.baseUrl) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
this.client = new OpenAI({
apiKey: model.apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
})
}
try {
return this.adapter.streamResponse(this.client, request, options)
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
}

151
src/core/rag/embedding.ts Normal file
View File

@@ -0,0 +1,151 @@
import { GoogleGenerativeAI } from '@google/generative-ai'
import { OpenAI } from 'openai'
import { EmbeddingModel } from '../../types/embedding'
import {
LLMAPIKeyNotSetException,
LLMBaseUrlNotSetException,
LLMRateLimitExceededException,
} from '../llm/exception'
import { NoStainlessOpenAI } from '../llm/ollama'
export const getEmbeddingModel = (
embeddingModelId: string,
apiKeys: {
openAIApiKey: string
geminiApiKey: string
},
ollamaBaseUrl: string,
): EmbeddingModel => {
switch (embeddingModelId) {
case 'text-embedding-3-small': {
const openai = new OpenAI({
apiKey: apiKeys.openAIApiKey,
dangerouslyAllowBrowser: true,
})
return {
id: 'text-embedding-3-small',
dimension: 1536,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'text-embedding-3-small',
input: text,
})
return embedding.data[0].embedding
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case 'text-embedding-004': {
const client = new GoogleGenerativeAI(apiKeys.geminiApiKey)
const model = client.getGenerativeModel({ model: 'text-embedding-004' })
return {
id: 'text-embedding-004',
dimension: 768,
getEmbedding: async (text: string) => {
try {
const response = await model.embedContent(text)
return response.embedding.values
} catch (error) {
if (
error.status === 429 &&
error.message.includes('RATE_LIMIT_EXCEEDED')
) {
throw new LLMRateLimitExceededException(
'Gemini API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case 'nomic-embed-text': {
const openai = new NoStainlessOpenAI({
apiKey: '',
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
})
return {
id: 'nomic-embed-text',
dimension: 768,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'nomic-embed-text',
input: text,
})
return embedding.data[0].embedding
},
}
}
case 'mxbai-embed-large': {
const openai = new NoStainlessOpenAI({
apiKey: '',
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
})
return {
id: 'mxbai-embed-large',
dimension: 1024,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'mxbai-embed-large',
input: text,
})
return embedding.data[0].embedding
},
}
}
case 'bge-m3': {
const openai = new NoStainlessOpenAI({
apiKey: '',
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
})
return {
id: 'bge-m3',
dimension: 1024,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'bge-m3',
input: text,
})
return embedding.data[0].embedding
},
}
}
default:
throw new Error('Invalid embedding model')
}
}

124
src/core/rag/rag-engine.ts Normal file
View File

@@ -0,0 +1,124 @@
import { App } from 'obsidian'
import { QueryProgressState } from '../../components/chat-view/QueryProgress'
import { DBManager } from '../../database/database-manager'
import { VectorManager } from '../../database/modules/vector/vector-manager'
import { SelectVector } from '../../database/schema'
import { EmbeddingModel } from '../../types/embedding'
import { InfioSettings } from '../../types/settings'
import { getEmbeddingModel } from './embedding'
export class RAGEngine {
private app: App
private settings: InfioSettings
private vectorManager: VectorManager
private embeddingModel: EmbeddingModel | null = null
constructor(
app: App,
settings: InfioSettings,
dbManager: DBManager,
) {
this.app = app
this.settings = settings
this.vectorManager = dbManager.getVectorManager()
this.embeddingModel = getEmbeddingModel(
settings.embeddingModelId,
{
openAIApiKey: settings.openAIApiKey,
geminiApiKey: settings.geminiApiKey,
},
settings.ollamaEmbeddingModel.baseUrl,
)
}
setSettings(settings: InfioSettings) {
this.settings = settings
this.embeddingModel = getEmbeddingModel(
settings.embeddingModelId,
{
openAIApiKey: settings.openAIApiKey,
geminiApiKey: settings.geminiApiKey,
},
settings.ollamaEmbeddingModel.baseUrl,
)
}
// TODO: Implement automatic vault re-indexing when settings are changed.
// Currently, users must manually re-index the vault.
async updateVaultIndex(
options: { reindexAll: boolean } = {
reindexAll: false,
},
onQueryProgressChange?: (queryProgress: QueryProgressState) => void,
): Promise<void> {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
await this.vectorManager.updateVaultIndex(
this.embeddingModel,
{
chunkSize: this.settings.ragOptions.chunkSize,
excludePatterns: this.settings.ragOptions.excludePatterns,
includePatterns: this.settings.ragOptions.includePatterns,
reindexAll: options.reindexAll,
},
(indexProgress) => {
onQueryProgressChange?.({
type: 'indexing',
indexProgress,
})
},
)
}
async processQuery({
query,
scope,
onQueryProgressChange,
}: {
query: string
scope?: {
files: string[]
folders: string[]
}
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
}): Promise<
(Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
> {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
// TODO: Decide the vault index update strategy.
// Current approach: Update on every query.
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
const queryEmbedding = await this.getQueryEmbedding(query)
onQueryProgressChange?.({
type: 'querying',
})
const queryResult = await this.vectorManager.performSimilaritySearch(
queryEmbedding,
this.embeddingModel,
{
minSimilarity: this.settings.ragOptions.minSimilarity,
limit: this.settings.ragOptions.limit,
scope,
},
)
onQueryProgressChange?.({
type: 'querying-done',
queryResult,
})
return queryResult
}
private async getQueryEmbedding(query: string): Promise<number[]> {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
return this.embeddingModel.getEmbedding(query)
}
}