mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-05-09 16:38:19 +00:00
init
This commit is contained in:
95
src/core/autocomplete/context-detection.ts
Normal file
95
src/core/autocomplete/context-detection.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
import { generateRandomString } from "./utils";
|
||||
|
||||
const UNIQUE_CURSOR = `${generateRandomString(16)}`;
|
||||
const HEADER_REGEX = `^#+\\s.*${UNIQUE_CURSOR}.*$`;
|
||||
const UNORDERED_LIST_REGEX = `^\\s*(-|\\*)\\s.*${UNIQUE_CURSOR}.*$`;
|
||||
const TASK_LIST_REGEX = `^\\s*(-|[0-9]+\\.) +\\[.\\]\\s.*${UNIQUE_CURSOR}.*$`;
|
||||
const BLOCK_QUOTES_REGEX = `^\\s*>.*${UNIQUE_CURSOR}.*$`;
|
||||
const NUMBERED_LIST_REGEX = `^\\s*\\d+\\.\\s.*${UNIQUE_CURSOR}.*$`
|
||||
const MATH_BLOCK_REGEX = /\$\$[\s\S]*?\$\$/g;
|
||||
const INLINE_MATH_BLOCK_REGEX = /\$[\s\S]*?\$/g;
|
||||
const CODE_BLOCK_REGEX = /```[\s\S]*?```/g;
|
||||
const INLINE_CODE_BLOCK_REGEX = /`.*`/g;
|
||||
|
||||
enum Context {
|
||||
Text = "Text",
|
||||
Heading = "Heading",
|
||||
BlockQuotes = "BlockQuotes",
|
||||
UnorderedList = "UnorderedList",
|
||||
NumberedList = "NumberedList",
|
||||
CodeBlock = "CodeBlock",
|
||||
MathBlock = "MathBlock",
|
||||
TaskList = "TaskList",
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-namespace
|
||||
namespace Context {
|
||||
export function values(): Array<Context> {
|
||||
return Object.values(Context).filter(
|
||||
(value) => typeof value === "string"
|
||||
) as Array<Context>;
|
||||
}
|
||||
|
||||
export function getContext(prefix: string, suffix: string): Context {
|
||||
if (new RegExp(HEADER_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.Heading;
|
||||
}
|
||||
if (new RegExp(BLOCK_QUOTES_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.BlockQuotes;
|
||||
}
|
||||
|
||||
if (new RegExp(TASK_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.TaskList;
|
||||
}
|
||||
|
||||
if (
|
||||
isCursorInRegexBlock(prefix, suffix, MATH_BLOCK_REGEX) ||
|
||||
isCursorInRegexBlock(prefix, suffix, INLINE_MATH_BLOCK_REGEX)
|
||||
) {
|
||||
return Context.MathBlock;
|
||||
}
|
||||
|
||||
if (isCursorInRegexBlock(prefix, suffix, CODE_BLOCK_REGEX) || isCursorInRegexBlock(prefix, suffix, INLINE_CODE_BLOCK_REGEX)) {
|
||||
return Context.CodeBlock;
|
||||
}
|
||||
if (new RegExp(NUMBERED_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.NumberedList;
|
||||
}
|
||||
if (new RegExp(UNORDERED_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.UnorderedList;
|
||||
}
|
||||
|
||||
return Context.Text;
|
||||
}
|
||||
|
||||
export function get(value: string) {
|
||||
for (const context of Context.values()) {
|
||||
if (value === context) {
|
||||
return context;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
function isCursorInRegexBlock(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
regex: RegExp
|
||||
): boolean {
|
||||
const text = prefix + UNIQUE_CURSOR + suffix;
|
||||
const codeBlocks = extractBlocks(text, regex);
|
||||
for (const block of codeBlocks) {
|
||||
if (block.includes(UNIQUE_CURSOR)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function extractBlocks(text: string, regex: RegExp) {
|
||||
const codeBlocks = text.match(regex);
|
||||
return codeBlocks ? codeBlocks.map((block) => block.trim()) : [];
|
||||
}
|
||||
|
||||
export default Context;
|
||||
273
src/core/autocomplete/index.ts
Normal file
273
src/core/autocomplete/index.ts
Normal file
@@ -0,0 +1,273 @@
|
||||
import * as Handlebars from "handlebars";
|
||||
import { err, ok, Result } from "neverthrow";
|
||||
|
||||
import { FewShotExample } from "../../settings/versions";
|
||||
import { CustomLLMModel } from "../../types/llm/model";
|
||||
import { RequestMessage } from '../../types/llm/request';
|
||||
import { InfioSettings } from "../../types/settings";
|
||||
import LLMManager from '../llm/manager';
|
||||
|
||||
import Context from "./context-detection";
|
||||
import RemoveCodeIndicators from "./post-processors/remove-code-indicators";
|
||||
import RemoveMathIndicators from "./post-processors/remove-math-indicators";
|
||||
import RemoveOverlap from "./post-processors/remove-overlap";
|
||||
import RemoveWhitespace from "./post-processors/remove-whitespace";
|
||||
import DataViewRemover from "./pre-processors/data-view-remover";
|
||||
import LengthLimiter from "./pre-processors/length-limiter";
|
||||
import {
|
||||
AutocompleteService,
|
||||
ChatMessage,
|
||||
PostProcessor,
|
||||
PreProcessor,
|
||||
UserMessageFormatter,
|
||||
UserMessageFormattingInputs
|
||||
} from "./types";
|
||||
|
||||
class LLMClient {
|
||||
private llm: LLMManager;
|
||||
private model: CustomLLMModel;
|
||||
|
||||
constructor(llm: LLMManager, model: CustomLLMModel) {
|
||||
this.llm = llm;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
async queryChatModel(messages: RequestMessage[]): Promise<Result<string, Error>> {
|
||||
const data = await this.llm.generateResponse(this.model, {
|
||||
model: this.model.name,
|
||||
messages: messages,
|
||||
stream: false,
|
||||
})
|
||||
return ok(data.choices[0].message.content);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AutoComplete implements AutocompleteService {
|
||||
private readonly client: LLMClient;
|
||||
private readonly systemMessage: string;
|
||||
private readonly userMessageFormatter: UserMessageFormatter;
|
||||
private readonly removePreAnswerGenerationRegex: string;
|
||||
private readonly preProcessors: PreProcessor[];
|
||||
private readonly postProcessors: PostProcessor[];
|
||||
private readonly fewShotExamples: FewShotExample[];
|
||||
private debugMode: boolean;
|
||||
|
||||
private constructor(
|
||||
client: LLMClient,
|
||||
systemMessage: string,
|
||||
userMessageFormatter: UserMessageFormatter,
|
||||
removePreAnswerGenerationRegex: string,
|
||||
preProcessors: PreProcessor[],
|
||||
postProcessors: PostProcessor[],
|
||||
fewShotExamples: FewShotExample[],
|
||||
debugMode: boolean,
|
||||
) {
|
||||
this.client = client;
|
||||
this.systemMessage = systemMessage;
|
||||
this.userMessageFormatter = userMessageFormatter;
|
||||
this.removePreAnswerGenerationRegex = removePreAnswerGenerationRegex;
|
||||
this.preProcessors = preProcessors;
|
||||
this.postProcessors = postProcessors;
|
||||
this.fewShotExamples = fewShotExamples;
|
||||
this.debugMode = debugMode;
|
||||
}
|
||||
|
||||
public static fromSettings(settings: InfioSettings): AutocompleteService {
|
||||
const formatter = Handlebars.compile<UserMessageFormattingInputs>(
|
||||
settings.userMessageTemplate,
|
||||
{ noEscape: true, strict: true }
|
||||
);
|
||||
const preProcessors: PreProcessor[] = [];
|
||||
if (settings.dontIncludeDataviews) {
|
||||
preProcessors.push(new DataViewRemover());
|
||||
}
|
||||
preProcessors.push(
|
||||
new LengthLimiter(
|
||||
settings.maxPrefixCharLimit,
|
||||
settings.maxSuffixCharLimit
|
||||
)
|
||||
);
|
||||
|
||||
const postProcessors: PostProcessor[] = [];
|
||||
if (settings.removeDuplicateMathBlockIndicator) {
|
||||
postProcessors.push(new RemoveMathIndicators());
|
||||
}
|
||||
if (settings.removeDuplicateCodeBlockIndicator) {
|
||||
postProcessors.push(new RemoveCodeIndicators());
|
||||
}
|
||||
|
||||
postProcessors.push(new RemoveOverlap());
|
||||
postProcessors.push(new RemoveWhitespace());
|
||||
|
||||
const llm_manager = new LLMManager({
|
||||
deepseek: settings.deepseekApiKey,
|
||||
openai: settings.openAIApiKey,
|
||||
anthropic: settings.anthropicApiKey,
|
||||
gemini: settings.geminiApiKey,
|
||||
groq: settings.groqApiKey,
|
||||
infio: settings.infioApiKey,
|
||||
})
|
||||
const model: CustomLLMModel = settings.activeModels.find(
|
||||
(option) => option.name === settings.chatModelId,
|
||||
)
|
||||
const llm = new LLMClient(llm_manager, model);
|
||||
|
||||
return new AutoComplete(
|
||||
llm,
|
||||
settings.systemMessage,
|
||||
formatter,
|
||||
settings.chainOfThoughRemovalRegex,
|
||||
preProcessors,
|
||||
postProcessors,
|
||||
settings.fewShotExamples,
|
||||
settings.debugMode,
|
||||
);
|
||||
}
|
||||
|
||||
async fetchPredictions(
|
||||
prefix: string,
|
||||
suffix: string
|
||||
): Promise<Result<string, Error>> {
|
||||
const context: Context = Context.getContext(prefix, suffix);
|
||||
|
||||
for (const preProcessor of this.preProcessors) {
|
||||
if (preProcessor.removesCursor(prefix, suffix)) {
|
||||
return ok("");
|
||||
}
|
||||
|
||||
({ prefix, suffix } = preProcessor.process(
|
||||
prefix,
|
||||
suffix,
|
||||
context
|
||||
));
|
||||
}
|
||||
|
||||
const examples = this.fewShotExamples.filter(
|
||||
(example) => example.context === context
|
||||
);
|
||||
const fewShotExamplesChatMessages =
|
||||
fewShotExamplesToChatMessages(examples);
|
||||
|
||||
const messages: RequestMessage[] = [
|
||||
{
|
||||
content: this.getSystemMessageFor(context),
|
||||
role: "system"
|
||||
},
|
||||
...fewShotExamplesChatMessages,
|
||||
{
|
||||
role: "user",
|
||||
content: this.userMessageFormatter({
|
||||
suffix,
|
||||
prefix,
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
if (this.debugMode) {
|
||||
console.log("Copilot messages send:\n", messages);
|
||||
}
|
||||
|
||||
let result = await this.client.queryChatModel(messages);
|
||||
if (this.debugMode && result.isOk()) {
|
||||
console.log("Copilot response:\n", result.value);
|
||||
}
|
||||
|
||||
result = this.extractAnswerFromChainOfThoughts(result);
|
||||
|
||||
for (const postProcessor of this.postProcessors) {
|
||||
result = result.map((r) => postProcessor.process(prefix, suffix, r, context));
|
||||
}
|
||||
|
||||
result = this.checkAgainstGuardRails(result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private getSystemMessageFor(context: Context): string {
|
||||
if (context === Context.Text) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a paragraph. Your answer must complete this paragraph or sentence in a way that fits the surrounding text without overlapping with it. It must be in the same language as the paragraph.";
|
||||
}
|
||||
if (context === Context.Heading) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in the Markdown heading. Your answer must complete this title in a way that fits the content of this paragraph and be in the same language as the paragraph.";
|
||||
}
|
||||
|
||||
if (context === Context.BlockQuotes) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located within a quote. Your answer must complete this quote in a way that fits the context of the paragraph.";
|
||||
}
|
||||
if (context === Context.UnorderedList) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in an unordered list. Your answer must include one or more list items that fit with the surrounding list without overlapping with it.";
|
||||
}
|
||||
|
||||
if (context === Context.NumberedList) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a numbered list. Your answer must include one or more list items that fit the sequence and context of the surrounding list without overlapping with it.";
|
||||
}
|
||||
|
||||
if (context === Context.CodeBlock) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a code block. Your answer must complete this code block in the same programming language and support the surrounding code and text outside of the code block.";
|
||||
}
|
||||
if (context === Context.MathBlock) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a math block. Your answer must only contain LaTeX code that captures the math discussed in the surrounding text. No text or explaination only LaTex math code.";
|
||||
}
|
||||
if (context === Context.TaskList) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a task list. Your answer must include one or more (sub)tasks that are logical given the other tasks and the surrounding text.";
|
||||
}
|
||||
|
||||
|
||||
return this.systemMessage;
|
||||
|
||||
}
|
||||
|
||||
private extractAnswerFromChainOfThoughts(
|
||||
result: Result<string, Error>
|
||||
): Result<string, Error> {
|
||||
if (result.isErr()) {
|
||||
return result;
|
||||
}
|
||||
const chainOfThoughts = result.value;
|
||||
|
||||
const regex = new RegExp(this.removePreAnswerGenerationRegex, "gm");
|
||||
const match = regex.exec(chainOfThoughts);
|
||||
if (match === null) {
|
||||
return err(new Error("No match found"));
|
||||
}
|
||||
return ok(chainOfThoughts.replace(regex, ""));
|
||||
}
|
||||
|
||||
private checkAgainstGuardRails(
|
||||
result: Result<string, Error>
|
||||
): Result<string, Error> {
|
||||
if (result.isErr()) {
|
||||
return result;
|
||||
}
|
||||
if (result.value.length === 0) {
|
||||
return err(new Error("Empty result"));
|
||||
}
|
||||
if (result.value.contains("<mask/>")) {
|
||||
return err(new Error("Mask in result"));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
function fewShotExamplesToChatMessages(
|
||||
examples: FewShotExample[]
|
||||
): ChatMessage[] {
|
||||
return examples
|
||||
.map((example): ChatMessage[] => {
|
||||
return [
|
||||
{
|
||||
role: "user",
|
||||
content: example.input,
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: example.answer,
|
||||
},
|
||||
];
|
||||
})
|
||||
.flat();
|
||||
}
|
||||
|
||||
export default AutoComplete;
|
||||
@@ -0,0 +1,22 @@
|
||||
import Context from "../context-detection";
|
||||
import { PostProcessor } from "../types";
|
||||
|
||||
class RemoveCodeIndicators implements PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string {
|
||||
if (context === Context.CodeBlock) {
|
||||
completion = completion.replace(/```[a-zA-z]+[ \t]*\n?/g, "");
|
||||
completion = completion.replace(/\n?```[ \t]*\n?/g, "");
|
||||
completion = completion.replace(/`/g, "");
|
||||
}
|
||||
|
||||
|
||||
return completion;
|
||||
}
|
||||
}
|
||||
|
||||
export default RemoveCodeIndicators;
|
||||
@@ -0,0 +1,20 @@
|
||||
import Context from "../context-detection";
|
||||
import { PostProcessor } from "../types";
|
||||
|
||||
class RemoveMathIndicators implements PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string {
|
||||
if (context === Context.MathBlock) {
|
||||
completion = completion.replace(/\n?\$\$\n?/g, "");
|
||||
completion = completion.replace(/\$/g, "");
|
||||
}
|
||||
|
||||
return completion;
|
||||
}
|
||||
}
|
||||
|
||||
export default RemoveMathIndicators;
|
||||
96
src/core/autocomplete/post-processors/remove-overlap.ts
Normal file
96
src/core/autocomplete/post-processors/remove-overlap.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import Context from "../context-detection";
|
||||
import { PostProcessor } from "../types";
|
||||
|
||||
|
||||
class RemoveOverlap implements PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string {
|
||||
completion = removeWordOverlapPrefix(prefix, completion);
|
||||
completion = removeWordOverlapSuffix(completion, suffix);
|
||||
completion = removeWhiteSpaceOverlapPrefix(suffix, completion);
|
||||
completion = removeWhiteSpaceOverlapSuffix(completion, suffix);
|
||||
|
||||
return completion;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function removeWhiteSpaceOverlapPrefix(prefix: string, completion: string): string {
|
||||
let prefixIdx = prefix.length - 1;
|
||||
while (completion.length > 0 && completion[0] === prefix[prefixIdx]) {
|
||||
completion = completion.slice(1);
|
||||
prefixIdx--;
|
||||
}
|
||||
return completion;
|
||||
}
|
||||
|
||||
|
||||
function removeWhiteSpaceOverlapSuffix(completion: string, suffix: string): string {
|
||||
let suffixIdx = 0;
|
||||
while (completion.length > 0 && completion[completion.length - 1] === suffix[suffixIdx]) {
|
||||
completion = completion.slice(0, -1);
|
||||
suffixIdx++;
|
||||
}
|
||||
return completion;
|
||||
}
|
||||
|
||||
function removeWordOverlapPrefix(prefix: string, completion: string): string {
|
||||
const rightTrimmed = completion.trimStart();
|
||||
|
||||
const startIdxOfEachWord = startLocationOfEachWord(prefix);
|
||||
|
||||
while (startIdxOfEachWord.length > 0) {
|
||||
const idx = startIdxOfEachWord.pop();
|
||||
const leftSubstring = prefix.slice(idx);
|
||||
if (rightTrimmed.startsWith(leftSubstring)) {
|
||||
return rightTrimmed.replace(leftSubstring, "");
|
||||
}
|
||||
}
|
||||
|
||||
return completion;
|
||||
}
|
||||
|
||||
function removeWordOverlapSuffix(completion: string, suffix: string): string {
|
||||
const suffixTrimmed = removeLeadingWhiteSpace(suffix);
|
||||
|
||||
const startIdxOfEachWord = startLocationOfEachWord(completion);
|
||||
|
||||
while (startIdxOfEachWord.length > 0) {
|
||||
const idx = startIdxOfEachWord.pop();
|
||||
const suffixSubstring = completion.slice(idx);
|
||||
if (suffixTrimmed.startsWith(suffixSubstring)) {
|
||||
return completion.replace(suffixSubstring, "");
|
||||
}
|
||||
}
|
||||
|
||||
return completion;
|
||||
}
|
||||
|
||||
function removeLeadingWhiteSpace(completion: string): string {
|
||||
return completion.replace(/^[ \t\f\r\v]+/, "");
|
||||
}
|
||||
|
||||
function startLocationOfEachWord(text: string): number[] {
|
||||
const locations: number[] = [];
|
||||
if (text.length > 0 && !isWhiteSpaceChar(text[0])) {
|
||||
locations.push(0);
|
||||
}
|
||||
|
||||
for (let i = 1; i < text.length; i++) {
|
||||
if (isWhiteSpaceChar(text[i - 1]) && !isWhiteSpaceChar(text[i])) {
|
||||
locations.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
return locations;
|
||||
}
|
||||
|
||||
function isWhiteSpaceChar(char: string | undefined): boolean {
|
||||
return char !== undefined && char.match(/\s/) !== null;
|
||||
}
|
||||
|
||||
export default RemoveOverlap;
|
||||
24
src/core/autocomplete/post-processors/remove-whitespace.ts
Normal file
24
src/core/autocomplete/post-processors/remove-whitespace.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import Context from "../context-detection";
|
||||
import { PostProcessor } from "../types";
|
||||
|
||||
class RemoveWhitespace implements PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string {
|
||||
if (context === Context.Text || context === Context.Heading || context === Context.MathBlock || context === Context.TaskList || context === Context.NumberedList || context === Context.UnorderedList) {
|
||||
if (prefix.endsWith(" ") || suffix.endsWith("\n")) {
|
||||
completion = completion.trimStart();
|
||||
}
|
||||
if (suffix.startsWith(" ")) {
|
||||
completion = completion.trimEnd();
|
||||
}
|
||||
}
|
||||
|
||||
return completion;
|
||||
}
|
||||
}
|
||||
|
||||
export default RemoveWhitespace;
|
||||
34
src/core/autocomplete/pre-processors/data-view-remover.ts
Normal file
34
src/core/autocomplete/pre-processors/data-view-remover.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { generateRandomString } from "../utils";
|
||||
import Context from "../context-detection";
|
||||
import { PrefixAndSuffix, PreProcessor } from "../types";
|
||||
|
||||
const DATA_VIEW_REGEX = /```dataview(js){0,1}(.|\n)*?```/gm;
|
||||
const UNIQUE_CURSOR = `${generateRandomString(16)}`;
|
||||
|
||||
class DataViewRemover implements PreProcessor {
|
||||
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix {
|
||||
let text = prefix + UNIQUE_CURSOR + suffix;
|
||||
text = text.replace(DATA_VIEW_REGEX, "");
|
||||
const [prefixNew, suffixNew] = text.split(UNIQUE_CURSOR);
|
||||
|
||||
return { prefix: prefixNew, suffix: suffixNew };
|
||||
}
|
||||
|
||||
removesCursor(prefix: string, suffix: string): boolean {
|
||||
const text = prefix + UNIQUE_CURSOR + suffix;
|
||||
const dataviewAreasWithCursor = text
|
||||
.match(DATA_VIEW_REGEX)
|
||||
?.filter((dataviewArea) => dataviewArea.includes(UNIQUE_CURSOR));
|
||||
|
||||
if (
|
||||
dataviewAreasWithCursor !== undefined &&
|
||||
dataviewAreasWithCursor.length > 0
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export default DataViewRemover;
|
||||
24
src/core/autocomplete/pre-processors/length-limiter.ts
Normal file
24
src/core/autocomplete/pre-processors/length-limiter.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import Context from "../context-detection";
|
||||
import { PrefixAndSuffix, PreProcessor } from "../types";
|
||||
|
||||
class LengthLimiter implements PreProcessor {
|
||||
private readonly maxPrefixChars: number;
|
||||
private readonly maxSuffixChars: number;
|
||||
|
||||
constructor(maxPrefixChars: number, maxSuffixChars: number) {
|
||||
this.maxPrefixChars = maxPrefixChars;
|
||||
this.maxSuffixChars = maxSuffixChars;
|
||||
}
|
||||
|
||||
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix {
|
||||
prefix = prefix.slice(-this.maxPrefixChars);
|
||||
suffix = suffix.slice(0, this.maxSuffixChars);
|
||||
return { prefix, suffix };
|
||||
}
|
||||
|
||||
removesCursor(prefix: string, suffix: string): boolean {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export default LengthLimiter;
|
||||
35
src/core/autocomplete/states/disabled-file-specific-state.ts
Normal file
35
src/core/autocomplete/states/disabled-file-specific-state.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { TFile } from "obsidian";
|
||||
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
|
||||
class DisabledFileSpecificState extends State {
|
||||
getStatusBarText(): string {
|
||||
return "Disabled for this file";
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings) {
|
||||
if (!this.context.settings.autocompleteEnabled) {
|
||||
this.context.transitionToDisabledManualState();
|
||||
}
|
||||
if (!this.context.isCurrentFilePathIgnored() || !this.context.currentFileContainsIgnoredTag()) {
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
}
|
||||
|
||||
handleFileChange(file: TFile): void {
|
||||
if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.context.settings.autocompleteEnabled) {
|
||||
this.context.transitionToIdleState();
|
||||
} else {
|
||||
this.context.transitionToDisabledManualState();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default DisabledFileSpecificState;
|
||||
@@ -0,0 +1,25 @@
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
import { checkForErrors } from "../../../utils/auto-complete";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
class DisabledInvalidSettingsState extends State {
|
||||
getStatusBarText(): string {
|
||||
return "Disabled invalid settings";
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings) {
|
||||
const settingErrors = checkForErrors(settings);
|
||||
if (settingErrors.size > 0) {
|
||||
return
|
||||
}
|
||||
if (this.context.settings.autocompleteEnabled) {
|
||||
this.context.transitionToIdleState();
|
||||
} else {
|
||||
this.context.transitionToDisabledManualState();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export default DisabledInvalidSettingsState;
|
||||
21
src/core/autocomplete/states/disabled-manual-state.ts
Normal file
21
src/core/autocomplete/states/disabled-manual-state.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import { TFile } from "obsidian";
|
||||
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
class DisabledManualState extends State {
|
||||
getStatusBarText(): string {
|
||||
return "Disabled";
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings): void {
|
||||
if (this.context.settings.autocompleteEnabled) {
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
}
|
||||
|
||||
handleFileChange(file: TFile): void { }
|
||||
}
|
||||
|
||||
export default DisabledManualState;
|
||||
46
src/core/autocomplete/states/idle-state.ts
Normal file
46
src/core/autocomplete/states/idle-state.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
|
||||
class IdleState extends State {
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
if (
|
||||
!documentChanges.isDocInFocus()
|
||||
|| !documentChanges.hasDocChanged()
|
||||
|| documentChanges.hasUserDeleted()
|
||||
|| documentChanges.hasMultipleCursors()
|
||||
|| documentChanges.hasSelection()
|
||||
|| documentChanges.hasUserUndone()
|
||||
|| documentChanges.hasUserRedone()
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const cachedSuggestion = this.context.getCachedSuggestionFor(documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
const isThereCachedSuggestion = cachedSuggestion !== undefined && cachedSuggestion.trim().length > 0;
|
||||
|
||||
if (this.context.settings.cacheSuggestions && isThereCachedSuggestion) {
|
||||
this.context.transitionToSuggestingState(cachedSuggestion, documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
if (this.context.containsTriggerCharacters(documentChanges)) {
|
||||
this.context.transitionToQueuedState(documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
}
|
||||
}
|
||||
|
||||
handlePredictCommand(prefix: string, suffix: string): void {
|
||||
this.context.transitionToPredictingState(prefix, suffix);
|
||||
}
|
||||
|
||||
getStatusBarText(): string {
|
||||
return "Idle";
|
||||
}
|
||||
}
|
||||
|
||||
export default IdleState;
|
||||
38
src/core/autocomplete/states/init-state.ts
Normal file
38
src/core/autocomplete/states/init-state.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { TFile } from "obsidian";
|
||||
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import { EventHandler } from "./types";
|
||||
|
||||
class InitState implements EventHandler {
|
||||
async handleDocumentChange(documentChanges: DocumentChanges): Promise<void> { }
|
||||
|
||||
handleSettingChanged(settings: InfioSettings): void { }
|
||||
|
||||
handleAcceptKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handlePredictCommand(): void { }
|
||||
|
||||
handleAcceptCommand(): void { }
|
||||
|
||||
getStatusBarText(): string {
|
||||
return "Initializing...";
|
||||
}
|
||||
|
||||
handleFileChange(file: TFile): void {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
export default InitState;
|
||||
94
src/core/autocomplete/states/predicting-state.ts
Normal file
94
src/core/autocomplete/states/predicting-state.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { Notice } from "obsidian";
|
||||
|
||||
import Context from "../context-detection";
|
||||
import EventListener from "../../../event-listener";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
class PredictingState extends State {
|
||||
private predictionPromise: Promise<void> | null = null;
|
||||
private isStillNeeded = true;
|
||||
private readonly prefix: string;
|
||||
private readonly suffix: string;
|
||||
|
||||
constructor(context: EventListener, prefix: string, suffix: string) {
|
||||
super(context);
|
||||
this.prefix = prefix;
|
||||
this.suffix = suffix;
|
||||
}
|
||||
|
||||
static createAndStartPredicting(
|
||||
context: EventListener,
|
||||
prefix: string,
|
||||
suffix: string
|
||||
): PredictingState {
|
||||
const predictingState = new PredictingState(context, prefix, suffix);
|
||||
predictingState.startPredicting();
|
||||
context.setContext(Context.getContext(prefix, suffix));
|
||||
return predictingState;
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
this.cancelPrediction();
|
||||
return true;
|
||||
}
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
if (
|
||||
documentChanges.hasCursorMoved() ||
|
||||
documentChanges.hasUserTyped() ||
|
||||
documentChanges.hasUserDeleted() ||
|
||||
documentChanges.isTextAdded()
|
||||
) {
|
||||
this.cancelPrediction();
|
||||
}
|
||||
}
|
||||
|
||||
private cancelPrediction(): void {
|
||||
this.isStillNeeded = false;
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
|
||||
startPredicting(): void {
|
||||
this.predictionPromise = this.predict();
|
||||
}
|
||||
|
||||
private async predict(): Promise<void> {
|
||||
|
||||
const result =
|
||||
await this.context.autocomplete?.fetchPredictions(
|
||||
this.prefix,
|
||||
this.suffix
|
||||
);
|
||||
|
||||
if (!this.isStillNeeded) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (result.isErr()) {
|
||||
new Notice(
|
||||
`Copilot: Something went wrong cannot make a prediction. Full error is available in the dev console. Please check your settings. `
|
||||
);
|
||||
console.error(result.error);
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
|
||||
const prediction = result.unwrapOr("");
|
||||
|
||||
if (prediction === "") {
|
||||
this.context.transitionToIdleState();
|
||||
return;
|
||||
}
|
||||
this.context.transitionToSuggestingState(prediction, this.prefix, this.suffix);
|
||||
}
|
||||
|
||||
|
||||
getStatusBarText(): string {
|
||||
return `Predicting for ${this.context.context}`;
|
||||
}
|
||||
}
|
||||
|
||||
export default PredictingState;
|
||||
84
src/core/autocomplete/states/queued-state.ts
Normal file
84
src/core/autocomplete/states/queued-state.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import Context from "../context-detection";
|
||||
import EventListener from "../../../event-listener";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
|
||||
class QueuedState extends State {
|
||||
private timer: ReturnType<typeof setTimeout> | null = null;
|
||||
private readonly prefix: string;
|
||||
private readonly suffix: string;
|
||||
|
||||
|
||||
private constructor(
|
||||
context: EventListener,
|
||||
prefix: string,
|
||||
suffix: string
|
||||
) {
|
||||
super(context);
|
||||
this.prefix = prefix;
|
||||
this.suffix = suffix;
|
||||
}
|
||||
|
||||
static createAndStartTimer(
|
||||
context: EventListener,
|
||||
prefix: string,
|
||||
suffix: string
|
||||
): QueuedState {
|
||||
const state = new QueuedState(context, prefix, suffix);
|
||||
state.startTimer();
|
||||
context.setContext(Context.getContext(prefix, suffix));
|
||||
return state;
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
this.cancelTimer();
|
||||
this.context.transitionToIdleState();
|
||||
return true;
|
||||
}
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
if (
|
||||
documentChanges.isDocInFocus() &&
|
||||
documentChanges.isTextAdded() &&
|
||||
this.context.containsTriggerCharacters(documentChanges)
|
||||
) {
|
||||
this.cancelTimer();
|
||||
this.context.transitionToQueuedState(documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
return
|
||||
}
|
||||
if (
|
||||
(documentChanges.hasCursorMoved() ||
|
||||
documentChanges.hasUserTyped() ||
|
||||
documentChanges.hasUserDeleted() ||
|
||||
documentChanges.isTextAdded() ||
|
||||
!documentChanges.isDocInFocus())
|
||||
) {
|
||||
this.cancelTimer();
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
}
|
||||
|
||||
startTimer(): void {
|
||||
this.cancelTimer();
|
||||
this.timer = setTimeout(() => {
|
||||
this.context.transitionToPredictingState(this.prefix, this.suffix);
|
||||
}, this.context.settings.delay);
|
||||
}
|
||||
|
||||
private cancelTimer(): void {
|
||||
if (this.timer !== null) {
|
||||
clearTimeout(this.timer);
|
||||
this.timer = null;
|
||||
}
|
||||
}
|
||||
|
||||
getStatusBarText(): string {
|
||||
return `Queued (${this.context.settings.delay} ms)`;
|
||||
}
|
||||
}
|
||||
|
||||
export default QueuedState;
|
||||
67
src/core/autocomplete/states/state.ts
Normal file
67
src/core/autocomplete/states/state.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import { Notice, TFile } from "obsidian";
|
||||
|
||||
import EventListener from "../../../event-listener";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
// import { Settings } from "../settings/versions";
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
import { checkForErrors } from "../../../utils/auto-complete";
|
||||
|
||||
import { EventHandler } from "./types";
|
||||
|
||||
abstract class State implements EventHandler {
|
||||
protected readonly context: EventListener;
|
||||
|
||||
constructor(context: EventListener) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings): void {
|
||||
const settingErrors = checkForErrors(settings);
|
||||
if (!settings.autocompleteEnabled) {
|
||||
new Notice("Copilot is now disabled.");
|
||||
this.context.transitionToDisabledManualState()
|
||||
} else if (settingErrors.size > 0) {
|
||||
new Notice(
|
||||
`Copilot: There are ${settingErrors.size} errors in your settings. The plugin will be disabled until they are fixed.`
|
||||
);
|
||||
this.context.transitionToDisabledInvalidSettingsState();
|
||||
} else if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
|
||||
this.context.transitionToDisabledFileSpecificState();
|
||||
}
|
||||
}
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
}
|
||||
|
||||
handleAcceptKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handlePredictCommand(prefix: string, suffix: string): void {
|
||||
}
|
||||
|
||||
handleAcceptCommand(): void {
|
||||
}
|
||||
|
||||
abstract getStatusBarText(): string;
|
||||
|
||||
handleFileChange(file: TFile): void {
|
||||
if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
|
||||
this.context.transitionToDisabledFileSpecificState();
|
||||
} else if (this.context.isDisabled()) {
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default State;
|
||||
177
src/core/autocomplete/states/suggesting-state.ts
Normal file
177
src/core/autocomplete/states/suggesting-state.ts
Normal file
@@ -0,0 +1,177 @@
|
||||
|
||||
import { Settings } from "../../../settings/versions";
|
||||
import { extractNextWordAndRemaining } from "../utils";
|
||||
import EventListener from "../../../event-listener";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
class SuggestingState extends State {
|
||||
private readonly suggestion: string;
|
||||
private readonly prefix: string;
|
||||
private readonly suffix: string;
|
||||
|
||||
|
||||
constructor(context: EventListener, suggestion: string, prefix: string, suffix: string) {
|
||||
super(context);
|
||||
this.suggestion = suggestion;
|
||||
this.prefix = prefix;
|
||||
this.suffix = suffix;
|
||||
}
|
||||
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
|
||||
if (
|
||||
documentChanges.hasCursorMoved()
|
||||
|| documentChanges.hasUserUndone()
|
||||
|| documentChanges.hasUserDeleted()
|
||||
|| documentChanges.hasUserRedone()
|
||||
|| !documentChanges.isDocInFocus()
|
||||
|| documentChanges.hasSelection()
|
||||
|| documentChanges.hasMultipleCursors()
|
||||
) {
|
||||
this.clearPrediction();
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
documentChanges.noUserEvents()
|
||||
|| !documentChanges.hasDocChanged()
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.hasUserAddedPartOfSuggestion(documentChanges)) {
|
||||
this.acceptPartialAddedText(documentChanges);
|
||||
return
|
||||
}
|
||||
|
||||
const currentPrefix = documentChanges.getPrefix();
|
||||
const currentSuffix = documentChanges.getSuffix();
|
||||
const suggestion = this.context.getCachedSuggestionFor(currentPrefix, currentSuffix);
|
||||
const isThereCachedSuggestion = suggestion !== undefined;
|
||||
const isCachedSuggestionDifferent = suggestion !== this.suggestion;
|
||||
|
||||
if (!isCachedSuggestionDifferent) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isThereCachedSuggestion) {
|
||||
this.context.transitionToSuggestingState(suggestion, currentPrefix, currentSuffix);
|
||||
return;
|
||||
}
|
||||
this.clearPrediction();
|
||||
}
|
||||
|
||||
|
||||
hasUserAddedPartOfSuggestion(documentChanges: DocumentChanges): boolean {
|
||||
const addedPrefixText = documentChanges.getAddedPrefixText();
|
||||
const addedSuffixText = documentChanges.getAddedSuffixText();
|
||||
|
||||
return addedPrefixText !== undefined
|
||||
&& addedSuffixText !== undefined
|
||||
&& this.suggestion.toLowerCase().startsWith(addedPrefixText.toLowerCase())
|
||||
&& this.suggestion.toLowerCase().endsWith(addedSuffixText.toLowerCase());
|
||||
}
|
||||
|
||||
acceptPartialAddedText(documentChanges: DocumentChanges): void {
|
||||
const addedPrefixText = documentChanges.getAddedPrefixText();
|
||||
const addedSuffixText = documentChanges.getAddedSuffixText();
|
||||
if (addedSuffixText === undefined || addedPrefixText === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
const startIdx = addedPrefixText.length;
|
||||
const endIdx = this.suggestion.length - addedSuffixText.length
|
||||
const remainingSuggestion = this.suggestion.substring(startIdx, endIdx);
|
||||
|
||||
if (remainingSuggestion.trim() === "") {
|
||||
this.clearPrediction();
|
||||
} else {
|
||||
this.context.transitionToSuggestingState(remainingSuggestion, documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
}
|
||||
}
|
||||
|
||||
private clearPrediction(): void {
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
|
||||
handleAcceptKeyPressed(): boolean {
|
||||
this.accept();
|
||||
return true;
|
||||
}
|
||||
|
||||
private accept() {
|
||||
this.addPartialSuggestionCaches(this.suggestion);
|
||||
this.context.insertCurrentSuggestion(this.suggestion);
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean {
|
||||
this.acceptNextWord();
|
||||
return true;
|
||||
}
|
||||
|
||||
private acceptNextWord() {
|
||||
const [nextWord, remaining] = extractNextWordAndRemaining(this.suggestion);
|
||||
|
||||
if (nextWord !== undefined && remaining !== undefined) {
|
||||
const updatedPrefix = this.prefix + nextWord;
|
||||
|
||||
this.addPartialSuggestionCaches(nextWord, remaining);
|
||||
this.context.insertCurrentSuggestion(nextWord);
|
||||
this.context.transitionToSuggestingState(remaining, updatedPrefix, this.suffix, false);
|
||||
} else {
|
||||
this.accept();
|
||||
}
|
||||
}
|
||||
|
||||
private addPartialSuggestionCaches(acceptSuggestion: string, remainingSuggestion = "") {
|
||||
// store the sub-suggestions in the cache
|
||||
// so that we can have partial suggestions if the user edits a part
|
||||
for (let i = 0; i < acceptSuggestion.length; i++) {
|
||||
const prefix = this.prefix + acceptSuggestion.substring(0, i);
|
||||
const suggestion = acceptSuggestion.substring(i) + remainingSuggestion;
|
||||
this.context.addSuggestionToCache(prefix, this.suffix, suggestion);
|
||||
}
|
||||
}
|
||||
|
||||
private getNextWordAndRemaining(): [string | undefined, string | undefined] {
|
||||
const words = this.suggestion.split(" ");
|
||||
if (words.length === 0) {
|
||||
return ["", ""];
|
||||
}
|
||||
|
||||
if (words.length === 1) {
|
||||
return [words[0] + " ", ""];
|
||||
}
|
||||
|
||||
return [words[0] + " ", words.slice(1).join(" ")];
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
this.context.clearSuggestionsCache();
|
||||
this.clearPrediction();
|
||||
return true;
|
||||
}
|
||||
|
||||
handleAcceptCommand() {
|
||||
this.accept();
|
||||
}
|
||||
|
||||
getStatusBarText(): string {
|
||||
return `Suggesting for ${this.context.context}`;
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: Settings): void {
|
||||
if (!settings.cacheSuggestions) {
|
||||
this.clearPrediction();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export default SuggestingState;
|
||||
25
src/core/autocomplete/states/types.ts
Normal file
25
src/core/autocomplete/states/types.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
import { TFile } from "obsidian";
|
||||
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
export interface EventHandler {
|
||||
handleSettingChanged(settings: InfioSettings): void;
|
||||
|
||||
handleDocumentChange(documentChanges: DocumentChanges): Promise<void>;
|
||||
|
||||
handleAcceptKeyPressed(): boolean;
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean;
|
||||
|
||||
handleCancelKeyPressed(): boolean;
|
||||
|
||||
handlePredictCommand(prefix: string, suffix: string): void;
|
||||
handleAcceptCommand(): void;
|
||||
|
||||
getStatusBarText(): string;
|
||||
|
||||
handleFileChange(file: TFile): void;
|
||||
|
||||
}
|
||||
57
src/core/autocomplete/types.ts
Normal file
57
src/core/autocomplete/types.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { Result } from "neverthrow";
|
||||
|
||||
import Context from "./context-detection";
|
||||
|
||||
export interface AutocompleteService {
|
||||
fetchPredictions(
|
||||
prefix: string,
|
||||
suffix: string
|
||||
): Promise<Result<string, Error>>;
|
||||
}
|
||||
|
||||
export interface PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string;
|
||||
}
|
||||
|
||||
export interface PreProcessor {
|
||||
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix;
|
||||
removesCursor(prefix: string, suffix: string): boolean;
|
||||
}
|
||||
|
||||
export interface PrefixAndSuffix {
|
||||
prefix: string;
|
||||
suffix: string;
|
||||
}
|
||||
|
||||
export interface ChatMessage {
|
||||
content: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
}
|
||||
|
||||
|
||||
export interface UserMessageFormattingInputs {
|
||||
prefix: string;
|
||||
suffix: string;
|
||||
}
|
||||
|
||||
export type UserMessageFormatter = (
|
||||
inputs: UserMessageFormattingInputs
|
||||
) => string;
|
||||
|
||||
export interface ApiClient {
|
||||
queryChatModel(messages: ChatMessage[]): Promise<Result<string, Error>>;
|
||||
checkIfConfiguredCorrectly?(): Promise<string[]>;
|
||||
}
|
||||
|
||||
export interface ModelOptions {
|
||||
temperature: number;
|
||||
top_p: number;
|
||||
frequency_penalty: number;
|
||||
presence_penalty: number;
|
||||
max_tokens: number;
|
||||
}
|
||||
68
src/core/autocomplete/utils.ts
Normal file
68
src/core/autocomplete/utils.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import * as mm from "micromatch";
|
||||
|
||||
export function sleep(ms: number) {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
export function enumKeys<O extends object, K extends keyof O = keyof O>(
|
||||
obj: O
|
||||
): K[] {
|
||||
return Object.keys(obj).filter((k) => Number.isNaN(+k)) as K[];
|
||||
}
|
||||
|
||||
export function generateRandomString(n: number): string {
|
||||
let result = '';
|
||||
const characters = '0123456789abcdef';
|
||||
|
||||
for (let i = 0; i < n; i++) {
|
||||
const randomIndex = Math.floor(Math.random() * characters.length);
|
||||
result += characters[randomIndex];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export function isMatchBetweenPathAndPatterns(
|
||||
path: string,
|
||||
patterns: string[],
|
||||
): boolean {
|
||||
patterns = patterns
|
||||
.map(p => p.trim())
|
||||
.filter((p) => p.length > 0);
|
||||
if (patterns.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const exclusionPatterns = patterns.filter((p) => p.startsWith('!')).map(p => p.slice(1));
|
||||
const inclusionPatterns = patterns.filter((p) => !p.startsWith('!'));
|
||||
|
||||
return mm.some(path, inclusionPatterns) && !mm.some(path, exclusionPatterns);
|
||||
}
|
||||
|
||||
export function extractNextWordAndRemaining(suggestion: string): [string | undefined, string | undefined] {
|
||||
const leadingWhitespacesMatch = suggestion.match(/^(\s*)/);
|
||||
const leadingWhitespaces = leadingWhitespacesMatch ? leadingWhitespacesMatch[0] : '';
|
||||
const trimmedSuggestion = suggestion.slice(leadingWhitespaces.length);
|
||||
|
||||
|
||||
let nextWord: string | undefined;
|
||||
let remaining: string | undefined = undefined;
|
||||
|
||||
const whitespaceAfterNextWordMatch = trimmedSuggestion.match(/\s+/);
|
||||
if (!whitespaceAfterNextWordMatch) {
|
||||
nextWord = trimmedSuggestion || undefined;
|
||||
} else {
|
||||
const whitespaceAfterNextWordStartingIndex = whitespaceAfterNextWordMatch.index!;
|
||||
const whitespaceAfterNextWord = whitespaceAfterNextWordMatch[0];
|
||||
const whitespaceLength = whitespaceAfterNextWord.length;
|
||||
const startOfWhitespaceAfterNextWordIndex = whitespaceAfterNextWordStartingIndex + whitespaceLength;
|
||||
|
||||
nextWord = trimmedSuggestion.substring(0, whitespaceAfterNextWordStartingIndex);
|
||||
if (startOfWhitespaceAfterNextWordIndex < trimmedSuggestion.length) {
|
||||
remaining = trimmedSuggestion.slice(startOfWhitespaceAfterNextWordIndex);
|
||||
nextWord += whitespaceAfterNextWord;
|
||||
}
|
||||
}
|
||||
|
||||
return [nextWord ? leadingWhitespaces + nextWord : undefined, remaining];
|
||||
}
|
||||
44
src/core/edit/inline-edit-processor.ts
Normal file
44
src/core/edit/inline-edit-processor.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import { MarkdownPostProcessorContext, Plugin } from "obsidian";
|
||||
import React from "react";
|
||||
import { createRoot } from "react-dom/client";
|
||||
|
||||
import { InlineEdit as InlineEditComponent } from "../../components/inline-edit/InlineEdit";
|
||||
import { InfioSettings } from "../../types/settings";
|
||||
|
||||
export class InlineEdit {
|
||||
plugin: Plugin;
|
||||
settings: InfioSettings;
|
||||
|
||||
constructor(plugin: Plugin, settings: InfioSettings) {
|
||||
this.plugin = plugin;
|
||||
this.settings = settings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Markdown 处理器入口函数
|
||||
*/
|
||||
Processor(source: string, el: HTMLElement, ctx: MarkdownPostProcessorContext) {
|
||||
const sec = ctx.getSectionInfo(el);
|
||||
if (!sec) return;
|
||||
|
||||
const container = createDiv();
|
||||
const root = createRoot(container);
|
||||
|
||||
root.render(
|
||||
React.createElement(InlineEditComponent, {
|
||||
source: source,
|
||||
secStartLine: sec.lineStart,
|
||||
secEndLine: sec.lineEnd,
|
||||
plugin: this.plugin,
|
||||
settings: this.settings
|
||||
})
|
||||
);
|
||||
|
||||
// 移除父元素的代码块样式
|
||||
const parent = el.parentElement;
|
||||
if (parent) {
|
||||
parent.addClass("infio-ai-block");
|
||||
}
|
||||
el.replaceWith(container);
|
||||
}
|
||||
}
|
||||
323
src/core/llm/anthropic.ts
Normal file
323
src/core/llm/anthropic.ts
Normal file
@@ -0,0 +1,323 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
ImageBlockParam,
|
||||
MessageParam,
|
||||
MessageStreamEvent,
|
||||
TextBlockParam,
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
ResponseUsage,
|
||||
} from '../../types/llm/response'
|
||||
import { parseImageDataUrl } from '../../utils/image'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export class AnthropicProvider implements BaseLLMProvider {
|
||||
private client: Anthropic
|
||||
|
||||
private static readonly DEFAULT_MAX_TOKENS = 8192
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new Anthropic({ apiKey, dangerouslyAllowBrowser: true })
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Anthropic API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Anthropic({
|
||||
baseURL: model.baseUrl,
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true
|
||||
})
|
||||
}
|
||||
|
||||
const systemMessage = AnthropicProvider.validateSystemMessages(
|
||||
request.messages,
|
||||
)
|
||||
|
||||
try {
|
||||
const response = await this.client.messages.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages
|
||||
.filter((m) => m.role !== 'system')
|
||||
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
|
||||
.map((m) => AnthropicProvider.parseRequestMessage(m)),
|
||||
system: systemMessage,
|
||||
max_tokens:
|
||||
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
return AnthropicProvider.parseNonStreamingResponse(response)
|
||||
} catch (error) {
|
||||
if (error instanceof Anthropic.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Anthropic API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Anthropic API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Anthropic({
|
||||
baseURL: model.baseUrl,
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true
|
||||
})
|
||||
}
|
||||
|
||||
const systemMessage = AnthropicProvider.validateSystemMessages(
|
||||
request.messages,
|
||||
)
|
||||
|
||||
try {
|
||||
const stream = await this.client.messages.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages
|
||||
.filter((m) => m.role !== 'system')
|
||||
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
|
||||
.map((m) => AnthropicProvider.parseRequestMessage(m)),
|
||||
system: systemMessage,
|
||||
max_tokens:
|
||||
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
let messageId = ''
|
||||
let model = ''
|
||||
let usage: ResponseUsage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.type === 'message_start') {
|
||||
messageId = chunk.message.id
|
||||
model = chunk.message.model
|
||||
usage = {
|
||||
prompt_tokens: chunk.message.usage.input_tokens,
|
||||
completion_tokens: chunk.message.usage.output_tokens,
|
||||
total_tokens:
|
||||
chunk.message.usage.input_tokens +
|
||||
chunk.message.usage.output_tokens,
|
||||
}
|
||||
} else if (chunk.type === 'content_block_delta') {
|
||||
yield AnthropicProvider.parseStreamingResponseChunk(
|
||||
chunk,
|
||||
messageId,
|
||||
model,
|
||||
)
|
||||
} else if (chunk.type === 'message_delta') {
|
||||
usage = {
|
||||
prompt_tokens: usage.prompt_tokens,
|
||||
completion_tokens:
|
||||
usage.completion_tokens + chunk.usage.output_tokens,
|
||||
total_tokens: usage.total_tokens + chunk.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After the stream is complete, yield the final usage
|
||||
yield {
|
||||
id: messageId,
|
||||
choices: [],
|
||||
object: 'chat.completion.chunk',
|
||||
model: model,
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
} catch (error) {
|
||||
if (error instanceof Anthropic.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Anthropic API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(message: RequestMessage): MessageParam {
|
||||
if (message.role !== 'user' && message.role !== 'assistant') {
|
||||
throw new Error(`Anthropic does not support role: ${message.role}`)
|
||||
}
|
||||
|
||||
if (message.role === 'user' && Array.isArray(message.content)) {
|
||||
const content = message.content.map(
|
||||
(part): TextBlockParam | ImageBlockParam => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url': {
|
||||
const { mimeType, base64Data } = parseImageDataUrl(
|
||||
part.image_url.url,
|
||||
)
|
||||
AnthropicProvider.validateImageType(mimeType)
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data,
|
||||
media_type:
|
||||
mimeType as ImageBlockParam['source']['media_type'],
|
||||
type: 'base64',
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
return { role: 'user', content }
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content as string,
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: Anthropic.Message,
|
||||
): LLMResponseNonStreaming {
|
||||
if (response.content[0].type === 'tool_use') {
|
||||
throw new Error('Unsupported content type: tool_use')
|
||||
}
|
||||
return {
|
||||
id: response.id,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: response.stop_reason,
|
||||
message: {
|
||||
content: response.content[0].text,
|
||||
role: response.role,
|
||||
},
|
||||
},
|
||||
],
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
usage: {
|
||||
prompt_tokens: response.usage.input_tokens,
|
||||
completion_tokens: response.usage.output_tokens,
|
||||
total_tokens:
|
||||
response.usage.input_tokens + response.usage.output_tokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: MessageStreamEvent,
|
||||
messageId: string,
|
||||
model: string,
|
||||
): LLMResponseStreaming {
|
||||
if (chunk.type !== 'content_block_delta') {
|
||||
throw new Error('Unsupported chunk type')
|
||||
}
|
||||
if (chunk.delta.type === 'input_json_delta') {
|
||||
throw new Error('Unsupported content type: input_json_delta')
|
||||
}
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: null,
|
||||
delta: {
|
||||
content: chunk.delta.text,
|
||||
},
|
||||
},
|
||||
],
|
||||
object: 'chat.completion.chunk',
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
private static validateSystemMessages(
|
||||
messages: RequestMessage[],
|
||||
): string | undefined {
|
||||
const systemMessages = messages.filter((m) => m.role === 'system')
|
||||
if (systemMessages.length > 1) {
|
||||
throw new Error(`Anthropic does not support more than one system message`)
|
||||
}
|
||||
const systemMessage =
|
||||
systemMessages.length > 0 ? systemMessages[0].content : undefined
|
||||
if (systemMessage && typeof systemMessage !== 'string') {
|
||||
throw new Error(
|
||||
`Anthropic only supports string content for system messages`,
|
||||
)
|
||||
}
|
||||
return systemMessage
|
||||
}
|
||||
|
||||
private static isMessageEmpty(message: RequestMessage) {
|
||||
if (typeof message.content === 'string') {
|
||||
return message.content.trim() === ''
|
||||
}
|
||||
return message.content.length === 0
|
||||
}
|
||||
|
||||
private static validateImageType(mimeType: string) {
|
||||
const SUPPORTED_IMAGE_TYPES = [
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/webp',
|
||||
]
|
||||
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
|
||||
throw new Error(
|
||||
`Anthropic does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
|
||||
', ',
|
||||
)}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
23
src/core/llm/base.ts
Normal file
23
src/core/llm/base.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
export type BaseLLMProvider = {
|
||||
generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming>
|
||||
streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>>
|
||||
}
|
||||
34
src/core/llm/exception.ts
Normal file
34
src/core/llm/exception.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
export class LLMAPIKeyNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMAPIKeyNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMAPIKeyInvalidException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMAPIKeyInvalidException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMBaseUrlNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMBaseUrlNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMModelNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMModelNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMRateLimitExceededException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMRateLimitExceededException'
|
||||
}
|
||||
}
|
||||
299
src/core/llm/gemini.ts
Normal file
299
src/core/llm/gemini.ts
Normal file
@@ -0,0 +1,299 @@
|
||||
import {
|
||||
Content,
|
||||
EnhancedGenerateContentResponse,
|
||||
GenerateContentResult,
|
||||
GenerateContentStreamResult,
|
||||
GoogleGenerativeAI,
|
||||
} from '@google/generative-ai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
import { parseImageDataUrl } from '../../utils/image'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
/**
|
||||
* Note on OpenAI Compatibility API:
|
||||
* Gemini provides an OpenAI-compatible endpoint (https://ai.google.dev/gemini-api/docs/openai)
|
||||
* which allows using the OpenAI SDK with Gemini models. However, there are currently CORS issues
|
||||
* preventing its use in Obsidian. Consider switching to this endpoint in the future once these
|
||||
* issues are resolved.
|
||||
*/
|
||||
export class GeminiProvider implements BaseLLMProvider {
|
||||
private client: GoogleGenerativeAI
|
||||
private apiKey: string
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.apiKey = apiKey
|
||||
this.client = new GoogleGenerativeAI(apiKey)
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
`Gemini API key is missing. Please set it in settings menu.`,
|
||||
)
|
||||
}
|
||||
this.apiKey = model.apiKey
|
||||
this.client = new GoogleGenerativeAI(model.apiKey)
|
||||
}
|
||||
|
||||
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
||||
const systemInstruction: string | undefined =
|
||||
systemMessages.length > 0
|
||||
? systemMessages.map((m) => m.content).join('\n')
|
||||
: undefined
|
||||
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: request.model,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
topP: request.top_p,
|
||||
presencePenalty: request.presence_penalty,
|
||||
frequencyPenalty: request.frequency_penalty,
|
||||
},
|
||||
systemInstruction: systemInstruction,
|
||||
})
|
||||
|
||||
const result = await model.generateContent(
|
||||
{
|
||||
systemInstruction: systemInstruction,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
||||
return GeminiProvider.parseNonStreamingResponse(
|
||||
result,
|
||||
request.model,
|
||||
messageId,
|
||||
)
|
||||
} catch (error) {
|
||||
const isInvalidApiKey =
|
||||
error.message?.includes('API_KEY_INVALID') ||
|
||||
error.message?.includes('API key not valid')
|
||||
|
||||
if (isInvalidApiKey) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
`Gemini API key is invalid. Please update it in settings menu.`,
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
`Gemini API key is missing. Please set it in settings menu.`,
|
||||
)
|
||||
}
|
||||
this.apiKey = model.apiKey
|
||||
this.client = new GoogleGenerativeAI(model.apiKey)
|
||||
}
|
||||
|
||||
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
||||
const systemInstruction: string | undefined =
|
||||
systemMessages.length > 0
|
||||
? systemMessages.map((m) => m.content).join('\n')
|
||||
: undefined
|
||||
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: request.model,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
topP: request.top_p,
|
||||
presencePenalty: request.presence_penalty,
|
||||
frequencyPenalty: request.frequency_penalty,
|
||||
},
|
||||
systemInstruction: systemInstruction,
|
||||
})
|
||||
|
||||
const stream = await model.generateContentStream(
|
||||
{
|
||||
systemInstruction: systemInstruction,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
||||
return this.streamResponseGenerator(stream, request.model, messageId)
|
||||
} catch (error) {
|
||||
const isInvalidApiKey =
|
||||
error.message?.includes('API_KEY_INVALID') ||
|
||||
error.message?.includes('API key not valid')
|
||||
|
||||
if (isInvalidApiKey) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
`Gemini API key is invalid. Please update it in settings menu.`,
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private async *streamResponseGenerator(
|
||||
stream: GenerateContentStreamResult,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream.stream) {
|
||||
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(message: RequestMessage): Content | null {
|
||||
if (message.role === 'system') {
|
||||
return null
|
||||
}
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
return {
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: message.content.map((part) => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { text: part.text }
|
||||
case 'image_url': {
|
||||
const { mimeType, base64Data } = parseImageDataUrl(
|
||||
part.image_url.url,
|
||||
)
|
||||
GeminiProvider.validateImageType(mimeType)
|
||||
|
||||
return {
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: [
|
||||
{
|
||||
text: message.content,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: GenerateContentResult,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason:
|
||||
response.response.candidates?.[0]?.finishReason ?? null,
|
||||
message: {
|
||||
content: response.response.text(),
|
||||
role: 'assistant',
|
||||
},
|
||||
},
|
||||
],
|
||||
created: Date.now(),
|
||||
model: model,
|
||||
object: 'chat.completion',
|
||||
usage: response.response.usageMetadata
|
||||
? {
|
||||
prompt_tokens: response.response.usageMetadata.promptTokenCount,
|
||||
completion_tokens:
|
||||
response.response.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: response.response.usageMetadata.totalTokenCount,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: EnhancedGenerateContentResponse,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
|
||||
delta: {
|
||||
content: chunk.text(),
|
||||
},
|
||||
},
|
||||
],
|
||||
created: Date.now(),
|
||||
model: model,
|
||||
object: 'chat.completion.chunk',
|
||||
usage: chunk.usageMetadata
|
||||
? {
|
||||
prompt_tokens: chunk.usageMetadata.promptTokenCount,
|
||||
completion_tokens: chunk.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: chunk.usageMetadata.totalTokenCount,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
private static validateImageType(mimeType: string) {
|
||||
const SUPPORTED_IMAGE_TYPES = [
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/webp',
|
||||
'image/heic',
|
||||
'image/heif',
|
||||
]
|
||||
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
|
||||
throw new Error(
|
||||
`Gemini does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
|
||||
', ',
|
||||
)}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
200
src/core/llm/groq.ts
Normal file
200
src/core/llm/groq.ts
Normal file
@@ -0,0 +1,200 @@
|
||||
import Groq from 'groq-sdk'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionMessageParam,
|
||||
} from 'groq-sdk/resources/chat/completions'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export class GroqProvider implements BaseLLMProvider {
|
||||
private client: Groq
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new Groq({
|
||||
apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Groq API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Groq({
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await this.client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
GroqProvider.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
return GroqProvider.parseNonStreamingResponse(response)
|
||||
} catch (error) {
|
||||
if (error instanceof Groq.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Groq API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Groq API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Groq({
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
const stream = await this.client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
GroqProvider.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream) {
|
||||
yield GroqProvider.parseStreamingResponseChunk(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
} catch (error) {
|
||||
if (error instanceof Groq.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Groq API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(
|
||||
message: RequestMessage,
|
||||
): ChatCompletionMessageParam {
|
||||
switch (message.role) {
|
||||
case 'user': {
|
||||
const content = Array.isArray(message.content)
|
||||
? message.content.map((part): ChatCompletionContentPart => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url':
|
||||
return { type: 'image_url', image_url: part.image_url }
|
||||
}
|
||||
})
|
||||
: message.content
|
||||
return { role: 'user', content }
|
||||
}
|
||||
case 'assistant': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('Assistant message should be a string')
|
||||
}
|
||||
return { role: 'assistant', content: message.content }
|
||||
}
|
||||
case 'system': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('System message should be a string')
|
||||
}
|
||||
return { role: 'system', content: message.content }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
}
|
||||
}
|
||||
}
|
||||
252
src/core/llm/infio.ts
Normal file
252
src/core/llm/infio.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import OpenAI from 'openai'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
} from 'openai/resources/chat/completions'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export interface RangeFilter {
|
||||
gte?: number;
|
||||
lte?: number;
|
||||
}
|
||||
|
||||
export interface ChunkFilter {
|
||||
field: string;
|
||||
match_all?: string[];
|
||||
range?: RangeFilter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for making requests to the Infio API
|
||||
*/
|
||||
export interface InfioRequest {
|
||||
/** Required: The content of the user message to attach to the topic and then generate an assistant message in response to */
|
||||
messages: RequestMessage[];
|
||||
// /** Required: The ID of the topic to attach the message to */
|
||||
// topic_id: string;
|
||||
/** Optional: URLs to include */
|
||||
links?: string[];
|
||||
/** Optional: Files to include */
|
||||
files?: string[];
|
||||
/** Optional: Whether to highlight results in chunk_html. Default is true */
|
||||
highlight_results?: boolean;
|
||||
/** Optional: Delimiters for highlighting citations. Default is [".", "!", "?", "\n", "\t", ","] */
|
||||
highlight_delimiters?: string[];
|
||||
/** Optional: Search type - "semantic", "fulltext", or "hybrid". Default is "hybrid" */
|
||||
search_type?: string;
|
||||
/** Optional: Filters for chunk filtering */
|
||||
filters?: ChunkFilter;
|
||||
/** Optional: Whether to use web search API. Default is false */
|
||||
use_web_search?: boolean;
|
||||
/** Optional: LLM model to use */
|
||||
llm_model?: string;
|
||||
/** Optional: Force source */
|
||||
force_source?: string;
|
||||
/** Optional: Whether completion should come before chunks in stream. Default is false */
|
||||
completion_first?: boolean;
|
||||
/** Optional: Whether to stream the response. Default is true */
|
||||
stream_response?: boolean;
|
||||
/** Optional: Sampling temperature between 0 and 2. Default is 0.5 */
|
||||
temperature?: number;
|
||||
/** Optional: Frequency penalty between -2.0 and 2.0. Default is 0.7 */
|
||||
frequency_penalty?: number;
|
||||
/** Optional: Presence penalty between -2.0 and 2.0. Default is 0.7 */
|
||||
presence_penalty?: number;
|
||||
/** Optional: Maximum tokens to generate */
|
||||
max_tokens?: number;
|
||||
/** Optional: Stop tokens (up to 4 sequences) */
|
||||
stop_tokens?: string[];
|
||||
}
|
||||
|
||||
export class InfioProvider implements BaseLLMProvider {
|
||||
// private adapter: OpenAIMessageAdapter
|
||||
// private client: OpenAI
|
||||
private apiKey: string
|
||||
private baseUrl: string
|
||||
|
||||
constructor(apiKey: string) {
|
||||
// this.client = new OpenAI({ apiKey, dangerouslyAllowBrowser: true })
|
||||
// this.adapter = new OpenAIMessageAdapter()
|
||||
this.apiKey = apiKey
|
||||
this.baseUrl = 'https://api.infio.com/api/raw_message'
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
try {
|
||||
const req: InfioRequest = {
|
||||
messages: request.messages,
|
||||
stream_response: false,
|
||||
temperature: request.temperature,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
max_tokens: request.max_tokens,
|
||||
}
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: this.apiKey,
|
||||
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(req)
|
||||
};
|
||||
|
||||
const response = await fetch(this.baseUrl, options);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data = await response.json() as ChatCompletion;
|
||||
return InfioProvider.parseNonStreamingResponse(data);
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
try {
|
||||
const req: InfioRequest = {
|
||||
llm_model: request.model,
|
||||
messages: request.messages,
|
||||
stream_response: true,
|
||||
temperature: request.temperature,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
max_tokens: request.max_tokens,
|
||||
}
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: this.apiKey,
|
||||
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify(req)
|
||||
};
|
||||
|
||||
const response = await fetch(this.baseUrl, options);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split('\n').filter(line => line.trim());
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const jsonData = JSON.parse(line.slice(6)) as ChatCompletionChunk;
|
||||
if (!jsonData || typeof jsonData !== 'object' || !('choices' in jsonData)) {
|
||||
throw new Error('Invalid chunk format received');
|
||||
}
|
||||
yield InfioProvider.parseStreamingResponseChunk(jsonData);
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
system_fingerprint: chunk.system_fingerprint,
|
||||
usage: chunk.usage ?? undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
142
src/core/llm/manager.ts
Normal file
142
src/core/llm/manager.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import { DEEPSEEK_BASE_URL } from '../../constants'
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { AnthropicProvider } from './anthropic'
|
||||
import { GeminiProvider } from './gemini'
|
||||
import { GroqProvider } from './groq'
|
||||
import { InfioProvider } from './infio'
|
||||
import { OllamaProvider } from './ollama'
|
||||
import { OpenAIAuthenticatedProvider } from './openai'
|
||||
import { OpenAICompatibleProvider } from './openai-compatible-provider'
|
||||
|
||||
|
||||
export type LLMManagerInterface = {
|
||||
generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming>
|
||||
streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>>
|
||||
}
|
||||
|
||||
class LLMManager implements LLMManagerInterface {
|
||||
private openaiProvider: OpenAIAuthenticatedProvider
|
||||
private deepseekProvider: OpenAICompatibleProvider
|
||||
private anthropicProvider: AnthropicProvider
|
||||
private googleProvider: GeminiProvider
|
||||
private groqProvider: GroqProvider
|
||||
private infioProvider: InfioProvider
|
||||
private ollamaProvider: OllamaProvider
|
||||
private isInfioEnabled: boolean
|
||||
|
||||
constructor(apiKeys: {
|
||||
deepseek?: string
|
||||
openai?: string
|
||||
anthropic?: string
|
||||
gemini?: string
|
||||
groq?: string
|
||||
infio?: string
|
||||
}) {
|
||||
this.deepseekProvider = new OpenAICompatibleProvider(apiKeys.deepseek ?? '', DEEPSEEK_BASE_URL)
|
||||
this.openaiProvider = new OpenAIAuthenticatedProvider(apiKeys.openai ?? '')
|
||||
this.anthropicProvider = new AnthropicProvider(apiKeys.anthropic ?? '')
|
||||
this.googleProvider = new GeminiProvider(apiKeys.gemini ?? '')
|
||||
this.groqProvider = new GroqProvider(apiKeys.groq ?? '')
|
||||
this.infioProvider = new InfioProvider(apiKeys.infio ?? '')
|
||||
this.ollamaProvider = new OllamaProvider()
|
||||
this.isInfioEnabled = !!apiKeys.infio
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (this.isInfioEnabled) {
|
||||
return await this.infioProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
}
|
||||
// use custom provider
|
||||
switch (model.provider) {
|
||||
case 'deepseek':
|
||||
return await this.deepseekProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'openai':
|
||||
return await this.openaiProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'anthropic':
|
||||
return await this.anthropicProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'google':
|
||||
return await this.googleProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'groq':
|
||||
return await this.groqProvider.generateResponse(model, request, options)
|
||||
case 'ollama':
|
||||
return await this.ollamaProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (this.isInfioEnabled) {
|
||||
return await this.infioProvider.streamResponse(model, request, options)
|
||||
}
|
||||
// use custom provider
|
||||
switch (model.provider) {
|
||||
case 'deepseek':
|
||||
return await this.deepseekProvider.streamResponse(model, request, options)
|
||||
case 'openai':
|
||||
return await this.openaiProvider.streamResponse(model, request, options)
|
||||
case 'anthropic':
|
||||
return await this.anthropicProvider.streamResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'google':
|
||||
return await this.googleProvider.streamResponse(model, request, options)
|
||||
case 'groq':
|
||||
return await this.groqProvider.streamResponse(model, request, options)
|
||||
case 'ollama':
|
||||
return await this.ollamaProvider.streamResponse(model, request, options)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default LLMManager
|
||||
104
src/core/llm/ollama.ts
Normal file
104
src/core/llm/ollama.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
/**
|
||||
* This provider is nearly identical to OpenAICompatibleProvider, but uses a custom OpenAI client
|
||||
* (NoStainlessOpenAI) to work around CORS issues specific to Ollama.
|
||||
*/
|
||||
|
||||
import OpenAI from 'openai'
|
||||
import { FinalRequestOptions } from 'openai/core'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import { LLMBaseUrlNotSetException, LLMModelNotSetException } from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class NoStainlessOpenAI extends OpenAI {
|
||||
defaultHeaders() {
|
||||
return {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
}
|
||||
|
||||
buildRequest<Req>(
|
||||
options: FinalRequestOptions<Req>,
|
||||
{ retryCount = 0 }: { retryCount?: number } = {},
|
||||
): { req: RequestInit; url: string; timeout: number } {
|
||||
const req = super.buildRequest(options, { retryCount })
|
||||
const headers = req.req.headers as Record<string, string>
|
||||
Object.keys(headers).forEach((k) => {
|
||||
if (k.startsWith('x-stainless')) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-dynamic-delete
|
||||
delete headers[k]
|
||||
}
|
||||
})
|
||||
return req
|
||||
}
|
||||
}
|
||||
|
||||
export class OllamaProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
|
||||
constructor() {
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama base URL is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
if (!model.name) {
|
||||
throw new LLMModelNotSetException(
|
||||
'Ollama model is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
const client = new NoStainlessOpenAI({
|
||||
baseURL: `${model.baseUrl}/v1`,
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return this.adapter.generateResponse(client, request, options)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama base URL is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
if (!model.name) {
|
||||
throw new LLMModelNotSetException(
|
||||
'Ollama model is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
const client = new NoStainlessOpenAI({
|
||||
baseURL: `${model.baseUrl}/v1`,
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return this.adapter.streamResponse(client, request, options)
|
||||
}
|
||||
}
|
||||
62
src/core/llm/openai-compatible-provider.ts
Normal file
62
src/core/llm/openai-compatible-provider.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import { LLMBaseUrlNotSetException } from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class OpenAICompatibleProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
private client: OpenAI
|
||||
private apiKey: string
|
||||
private baseURL: string
|
||||
|
||||
constructor(apiKey: string, baseURL: string) {
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
this.client = new OpenAI({
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
this.apiKey = apiKey
|
||||
this.baseURL = baseURL
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.baseURL || !this.apiKey) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
return this.adapter.generateResponse(this.client, request, options)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.baseURL || !this.apiKey) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
return this.adapter.streamResponse(this.client, request, options)
|
||||
}
|
||||
}
|
||||
155
src/core/llm/openai-message-adapter.ts
Normal file
155
src/core/llm/openai-message-adapter.ts
Normal file
@@ -0,0 +1,155 @@
|
||||
import OpenAI from 'openai'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionMessageParam,
|
||||
} from 'openai/resources/chat/completions'
|
||||
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
export class OpenAIMessageAdapter {
|
||||
async generateResponse(
|
||||
client: OpenAI,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
const response = await client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
OpenAIMessageAdapter.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
logit_bias: request.logit_bias,
|
||||
prediction: request.prediction,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
return OpenAIMessageAdapter.parseNonStreamingResponse(response)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
client: OpenAI,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
const stream = await client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
OpenAIMessageAdapter.parseRequestMessage(m),
|
||||
),
|
||||
max_completion_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
logit_bias: request.logit_bias,
|
||||
stream: true,
|
||||
stream_options: {
|
||||
include_usage: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream) {
|
||||
yield OpenAIMessageAdapter.parseStreamingResponseChunk(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
}
|
||||
|
||||
static parseRequestMessage(
|
||||
message: RequestMessage,
|
||||
): ChatCompletionMessageParam {
|
||||
switch (message.role) {
|
||||
case 'user': {
|
||||
const content = Array.isArray(message.content)
|
||||
? message.content.map((part): ChatCompletionContentPart => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url':
|
||||
return { type: 'image_url', image_url: part.image_url }
|
||||
}
|
||||
})
|
||||
: message.content
|
||||
return { role: 'user', content }
|
||||
}
|
||||
case 'assistant': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('Assistant message should be a string')
|
||||
}
|
||||
return { role: 'assistant', content: message.content }
|
||||
}
|
||||
case 'system': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('System message should be a string')
|
||||
}
|
||||
return { role: 'system', content: message.content }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
system_fingerprint: chunk.system_fingerprint,
|
||||
usage: chunk.usage ?? undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
91
src/core/llm/openai.ts
Normal file
91
src/core/llm/openai.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class OpenAIAuthenticatedProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
private client: OpenAI
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new OpenAI({
|
||||
apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new OpenAI({
|
||||
apiKey: model.apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
try {
|
||||
return this.adapter.generateResponse(this.client, request, options)
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new OpenAI({
|
||||
apiKey: model.apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
return this.adapter.streamResponse(this.client, request, options)
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
151
src/core/rag/embedding.ts
Normal file
151
src/core/rag/embedding.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
import { EmbeddingModel } from '../../types/embedding'
|
||||
import {
|
||||
LLMAPIKeyNotSetException,
|
||||
LLMBaseUrlNotSetException,
|
||||
LLMRateLimitExceededException,
|
||||
} from '../llm/exception'
|
||||
import { NoStainlessOpenAI } from '../llm/ollama'
|
||||
|
||||
export const getEmbeddingModel = (
|
||||
embeddingModelId: string,
|
||||
apiKeys: {
|
||||
openAIApiKey: string
|
||||
geminiApiKey: string
|
||||
},
|
||||
ollamaBaseUrl: string,
|
||||
): EmbeddingModel => {
|
||||
switch (embeddingModelId) {
|
||||
case 'text-embedding-3-small': {
|
||||
const openai = new OpenAI({
|
||||
apiKey: apiKeys.openAIApiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return {
|
||||
id: 'text-embedding-3-small',
|
||||
dimension: 1536,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
if (!openai.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'text-embedding-3-small',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
} catch (error) {
|
||||
if (
|
||||
error.status === 429 &&
|
||||
error.message.toLowerCase().includes('rate limit')
|
||||
) {
|
||||
throw new LLMRateLimitExceededException(
|
||||
'OpenAI API rate limit exceeded. Please try again later.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'text-embedding-004': {
|
||||
const client = new GoogleGenerativeAI(apiKeys.geminiApiKey)
|
||||
const model = client.getGenerativeModel({ model: 'text-embedding-004' })
|
||||
return {
|
||||
id: 'text-embedding-004',
|
||||
dimension: 768,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
const response = await model.embedContent(text)
|
||||
return response.embedding.values
|
||||
} catch (error) {
|
||||
if (
|
||||
error.status === 429 &&
|
||||
error.message.includes('RATE_LIMIT_EXCEEDED')
|
||||
) {
|
||||
throw new LLMRateLimitExceededException(
|
||||
'Gemini API rate limit exceeded. Please try again later.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'nomic-embed-text': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'nomic-embed-text',
|
||||
dimension: 768,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'nomic-embed-text',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'mxbai-embed-large': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'mxbai-embed-large',
|
||||
dimension: 1024,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'mxbai-embed-large',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'bge-m3': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'bge-m3',
|
||||
dimension: 1024,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'bge-m3',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
default:
|
||||
throw new Error('Invalid embedding model')
|
||||
}
|
||||
}
|
||||
124
src/core/rag/rag-engine.ts
Normal file
124
src/core/rag/rag-engine.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
import { App } from 'obsidian'
|
||||
|
||||
import { QueryProgressState } from '../../components/chat-view/QueryProgress'
|
||||
import { DBManager } from '../../database/database-manager'
|
||||
import { VectorManager } from '../../database/modules/vector/vector-manager'
|
||||
import { SelectVector } from '../../database/schema'
|
||||
import { EmbeddingModel } from '../../types/embedding'
|
||||
import { InfioSettings } from '../../types/settings'
|
||||
|
||||
import { getEmbeddingModel } from './embedding'
|
||||
|
||||
export class RAGEngine {
|
||||
private app: App
|
||||
private settings: InfioSettings
|
||||
private vectorManager: VectorManager
|
||||
private embeddingModel: EmbeddingModel | null = null
|
||||
|
||||
constructor(
|
||||
app: App,
|
||||
settings: InfioSettings,
|
||||
dbManager: DBManager,
|
||||
) {
|
||||
this.app = app
|
||||
this.settings = settings
|
||||
this.vectorManager = dbManager.getVectorManager()
|
||||
this.embeddingModel = getEmbeddingModel(
|
||||
settings.embeddingModelId,
|
||||
{
|
||||
openAIApiKey: settings.openAIApiKey,
|
||||
geminiApiKey: settings.geminiApiKey,
|
||||
},
|
||||
settings.ollamaEmbeddingModel.baseUrl,
|
||||
)
|
||||
}
|
||||
|
||||
setSettings(settings: InfioSettings) {
|
||||
this.settings = settings
|
||||
this.embeddingModel = getEmbeddingModel(
|
||||
settings.embeddingModelId,
|
||||
{
|
||||
openAIApiKey: settings.openAIApiKey,
|
||||
geminiApiKey: settings.geminiApiKey,
|
||||
},
|
||||
settings.ollamaEmbeddingModel.baseUrl,
|
||||
)
|
||||
}
|
||||
|
||||
// TODO: Implement automatic vault re-indexing when settings are changed.
|
||||
// Currently, users must manually re-index the vault.
|
||||
async updateVaultIndex(
|
||||
options: { reindexAll: boolean } = {
|
||||
reindexAll: false,
|
||||
},
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void,
|
||||
): Promise<void> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
await this.vectorManager.updateVaultIndex(
|
||||
this.embeddingModel,
|
||||
{
|
||||
chunkSize: this.settings.ragOptions.chunkSize,
|
||||
excludePatterns: this.settings.ragOptions.excludePatterns,
|
||||
includePatterns: this.settings.ragOptions.includePatterns,
|
||||
reindexAll: options.reindexAll,
|
||||
},
|
||||
(indexProgress) => {
|
||||
onQueryProgressChange?.({
|
||||
type: 'indexing',
|
||||
indexProgress,
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
async processQuery({
|
||||
query,
|
||||
scope,
|
||||
onQueryProgressChange,
|
||||
}: {
|
||||
query: string
|
||||
scope?: {
|
||||
files: string[]
|
||||
folders: string[]
|
||||
}
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
||||
}): Promise<
|
||||
(Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
})[]
|
||||
> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
// TODO: Decide the vault index update strategy.
|
||||
// Current approach: Update on every query.
|
||||
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
||||
const queryEmbedding = await this.getQueryEmbedding(query)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying',
|
||||
})
|
||||
const queryResult = await this.vectorManager.performSimilaritySearch(
|
||||
queryEmbedding,
|
||||
this.embeddingModel,
|
||||
{
|
||||
minSimilarity: this.settings.ragOptions.minSimilarity,
|
||||
limit: this.settings.ragOptions.limit,
|
||||
scope,
|
||||
},
|
||||
)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying-done',
|
||||
queryResult,
|
||||
})
|
||||
return queryResult
|
||||
}
|
||||
|
||||
private async getQueryEmbedding(query: string): Promise<number[]> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
return this.embeddingModel.getEmbedding(query)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user