This commit is contained in:
duanfuxiang
2025-01-05 11:51:39 +08:00
commit 0c7ee142cb
215 changed files with 20611 additions and 0 deletions

53
src/ApplyView.tsx Normal file
View File

@@ -0,0 +1,53 @@
import { TFile, View, WorkspaceLeaf } from 'obsidian'
import { Root, createRoot } from 'react-dom/client'
import ApplyViewRoot from './components/apply-view/ApplyViewRoot'
import { APPLY_VIEW_TYPE } from './constants'
import { AppProvider } from './contexts/AppContext'
export type ApplyViewState = {
file: TFile
originalContent: string
newContent: string
}
export class ApplyView extends View {
private root: Root | null = null
private state: ApplyViewState | null = null
constructor(leaf: WorkspaceLeaf) {
super(leaf)
}
getViewType() {
return APPLY_VIEW_TYPE
}
getDisplayText() {
return `Applying: ${this.state?.file?.name ?? ''}`
}
async setState(state: ApplyViewState) {
this.state = state
// Should render here because onOpen is called before setState
this.render()
}
async onOpen() {
this.root = createRoot(this.containerEl)
}
async onClose() {
this.root?.unmount()
}
async render() {
if (!this.root || !this.state) return
this.root.render(
<AppProvider app={this.app}>
<ApplyViewRoot state={this.state} close={() => this.leaf.detach()} />
</AppProvider>,
)
}
}

117
src/ChatView.tsx Normal file
View File

@@ -0,0 +1,117 @@
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { ItemView, WorkspaceLeaf } from 'obsidian'
import React from 'react'
import { Root, createRoot } from 'react-dom/client'
import Chat, { ChatProps, ChatRef } from './components/chat-view/Chat'
import { CHAT_VIEW_TYPE } from './constants'
import { AppProvider } from './contexts/AppContext'
import { DarkModeProvider } from './contexts/DarkModeContext'
import { DatabaseProvider } from './contexts/DatabaseContext'
import { DialogProvider } from './contexts/DialogContext'
import { LLMProvider } from './contexts/LLMContext'
import { RAGProvider } from './contexts/RAGContext'
import { SettingsProvider } from './contexts/SettingsContext'
import InfioPlugin from './main'
import { MentionableBlockData } from './types/mentionable'
import { InfioSettings } from './types/settings'
export class ChatView extends ItemView {
private root: Root | null = null
private settings: InfioSettings
private initialChatProps?: ChatProps
private chatRef: React.RefObject<ChatRef> = React.createRef()
constructor(
leaf: WorkspaceLeaf,
private plugin: InfioPlugin,
) {
super(leaf)
this.settings = plugin.settings
this.initialChatProps = plugin.initChatProps
}
getViewType() {
return CHAT_VIEW_TYPE
}
getIcon() {
return 'wand-sparkles'
}
getDisplayText() {
return 'Smart composer chat'
}
async onOpen() {
await this.render()
// Consume chatProps
this.initialChatProps = undefined
}
async onClose() {
this.root?.unmount()
}
async render() {
if (!this.root) {
this.root = createRoot(this.containerEl.children[1])
}
const queryClient = new QueryClient({
defaultOptions: {
queries: {
gcTime: 0, // Immediately garbage collect queries. It prevents memory leak on ChatView close.
},
mutations: {
gcTime: 0, // Immediately garbage collect mutations. It prevents memory leak on ChatView close.
},
},
})
this.root.render(
<AppProvider app={this.app}>
<SettingsProvider
settings={this.settings}
setSettings={(newSettings) => this.plugin.setSettings(newSettings)}
addSettingsChangeListener={(listener) =>
this.plugin.addSettingsListener(listener)
}
>
<DarkModeProvider>
<LLMProvider>
<DatabaseProvider
getDatabaseManager={() => this.plugin.getDbManager()}
>
<RAGProvider getRAGEngine={() => this.plugin.getRAGEngine()}>
<QueryClientProvider client={queryClient}>
<React.StrictMode>
<DialogProvider
container={this.containerEl.children[1] as HTMLElement}
>
<Chat ref={this.chatRef} {...this.initialChatProps} />
</DialogProvider>
</React.StrictMode>
</QueryClientProvider>
</RAGProvider>
</DatabaseProvider>
</LLMProvider>
</DarkModeProvider>
</SettingsProvider>
</AppProvider>,
)
}
openNewChat(selectedBlock?: MentionableBlockData) {
this.chatRef.current?.openNewChat(selectedBlock)
}
addSelectionToChat(selectedBlock: MentionableBlockData) {
this.chatRef.current?.addSelectionToChat(selectedBlock)
}
focusMessage() {
this.chatRef.current?.focusMessage()
}
}

View File

@@ -0,0 +1,142 @@
import { Change, diffLines } from 'diff'
import { CheckIcon, X } from 'lucide-react'
import { getIcon } from 'obsidian'
import { useState } from 'react'
import { ApplyViewState } from '../../ApplyView'
import { useApp } from '../../contexts/AppContext'
export default function ApplyViewRoot({
state,
close,
}: {
state: ApplyViewState
close: () => void
}) {
const acceptIcon = getIcon('check')
const rejectIcon = getIcon('x')
const excludeIcon = getIcon('x')
const app = useApp()
const [diff, setDiff] = useState<Change[]>(
diffLines(state.originalContent, state.newContent),
)
const handleAccept = async () => {
const newContent = diff
.filter((change) => !change.removed)
.map((change) => change.value)
.join('')
await app.vault.modify(state.file, newContent)
close()
}
const handleReject = async () => {
close()
}
const excludeDiffLine = (index: number) => {
setDiff((prevDiff) => {
const newDiff = [...prevDiff]
const change = newDiff[index]
if (change.added) {
// Remove the entry if it's an added line
return newDiff.filter((_, i) => i !== index)
} else if (change.removed) {
change.removed = false
}
return newDiff
})
}
const acceptDiffLine = (index: number) => {
setDiff((prevDiff) => {
const newDiff = [...prevDiff]
const change = newDiff[index]
if (change.added) {
change.added = false
} else if (change.removed) {
// Remove the entry if it's a removed line
return newDiff.filter((_, i) => i !== index)
}
return newDiff
})
}
return (
<div id="infio-apply-view">
<div className="view-header">
<div className="view-header-left">
<div className="view-header-nav-buttons"></div>
</div>
<div className="view-header-title-container mod-at-start">
<div className="view-header-title">
Applying: {state?.file?.name ?? ''}
</div>
<div className="view-actions">
<button
className="clickable-icon view-action infio-approve-button"
aria-label="Accept changes"
onClick={handleAccept}
>
{acceptIcon && <CheckIcon size={14} />}
Accept
</button>
<button
className="clickable-icon view-action infio-reject-button"
aria-label="Reject changes"
onClick={handleReject}
>
{rejectIcon && <X size={14} />}
Reject
</button>
</div>
</div>
</div>
<div className="view-content">
<div className="markdown-source-view cm-s-obsidian mod-cm6 node-insert-event is-readable-line-width is-live-preview is-folding show-properties">
<div className="cm-editor">
<div className="cm-scroller">
<div className="cm-sizer">
<div className="infio-inline-title">
{state?.file?.name
? state.file.name.replace(/\.[^/.]+$/, '')
: ''}
</div>
{diff.map((part, index) => (
<div
key={index}
className={`infio-diff-line ${part.added ? 'added' : part.removed ? 'removed' : ''}`}
>
<div style={{ width: '100%' }}>{part.value}</div>
{(part.added || part.removed) && (
<div className="infio-diff-line-actions">
<button
aria-label="Accept line"
onClick={() => acceptDiffLine(index)}
className="infio-accept"
>
{acceptIcon && 'Y'}
</button>
<button
aria-label="Exclude line"
onClick={() => excludeDiffLine(index)}
className="infio-exclude"
>
{excludeIcon && 'N'}
</button>
</div>
)}
</div>
))}
</div>
</div>
</div>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,90 @@
import * as Tooltip from '@radix-ui/react-tooltip'
import { Check, CopyIcon } from 'lucide-react'
import { useMemo, useState } from 'react'
import { ChatAssistantMessage } from '../../types/chat'
import { calculateLLMCost } from '../../utils/price-calculator'
import LLMResponseInfoPopover from './LLMResponseInfoPopover'
function CopyButton({ message }: { message: ChatAssistantMessage }) {
const [copied, setCopied] = useState(false)
const handleCopy = async () => {
await navigator.clipboard.writeText(message.content)
setCopied(true)
setTimeout(() => {
setCopied(false)
}, 1500)
}
return (
<Tooltip.Provider delayDuration={0}>
<Tooltip.Root>
<Tooltip.Trigger asChild>
<button>
{copied ? (
<Check
size={12}
className="infio-chat-message-actions-icon--copied"
/>
) : (
<CopyIcon onClick={handleCopy} size={12} />
)}
</button>
</Tooltip.Trigger>
<Tooltip.Portal>
<Tooltip.Content className="infio-tooltip-content">
Copy message
</Tooltip.Content>
</Tooltip.Portal>
</Tooltip.Root>
</Tooltip.Provider>
)
}
function LLMResponesInfoButton({ message }: { message: ChatAssistantMessage }) {
const cost = useMemo<number | null>(() => {
if (!message.metadata?.model || !message.metadata?.usage) {
return 0
}
return calculateLLMCost({
model: message.metadata.model,
usage: message.metadata.usage,
})
}, [message])
return (
<Tooltip.Provider delayDuration={0}>
<Tooltip.Root>
<Tooltip.Trigger asChild>
<div>
<LLMResponseInfoPopover
usage={message.metadata?.usage}
estimatedPrice={cost}
model={message.metadata?.model?.name}
/>
</div>
</Tooltip.Trigger>
<Tooltip.Portal>
<Tooltip.Content className="infio-tooltip-content">
View details
</Tooltip.Content>
</Tooltip.Portal>
</Tooltip.Root>
</Tooltip.Provider>
)
}
export default function AssistantMessageActions({
message,
}: {
message: ChatAssistantMessage
}) {
return (
<div className="infio-chat-message-actions">
<LLMResponesInfoButton message={message} />
<CopyButton message={message} />
</div>
)
}

View File

@@ -0,0 +1,737 @@
import { useMutation } from '@tanstack/react-query'
import { CircleStop, History, Plus } from 'lucide-react'
import { App, Notice } from 'obsidian'
import {
forwardRef,
useCallback,
useEffect,
useImperativeHandle,
useMemo,
useRef,
useState,
} from 'react'
import { v4 as uuidv4 } from 'uuid'
import { ApplyViewState } from '../../ApplyView'
import { APPLY_VIEW_TYPE } from '../../constants'
import { useApp } from '../../contexts/AppContext'
import { useLLM } from '../../contexts/LLMContext'
import { useRAG } from '../../contexts/RAGContext'
import { useSettings } from '../../contexts/SettingsContext'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
LLMBaseUrlNotSetException,
LLMModelNotSetException,
} from '../../core/llm/exception'
import { useChatHistory } from '../../hooks/use-chat-history'
import { ChatMessage, ChatUserMessage } from '../../types/chat'
import {
MentionableBlock,
MentionableBlockData,
MentionableCurrentFile,
} from '../../types/mentionable'
import { manualApplyChangesToFile } from '../../utils/apply'
import {
getMentionableKey,
serializeMentionable,
} from '../../utils/mentionable'
import { readTFileContent } from '../../utils/obsidian'
import { openSettingsModalWithError } from '../../utils/open-settings-modal'
import { PromptGenerator } from '../../utils/prompt-generator'
import AssistantMessageActions from './AssistantMessageActions'
import ChatUserInput, { ChatUserInputRef } from './chat-input/ChatUserInput'
import { editorStateToPlainText } from './chat-input/utils/editor-state-to-plain-text'
import { ChatListDropdown } from './ChatListDropdown'
import QueryProgress, { QueryProgressState } from './QueryProgress'
import ReactMarkdown from './ReactMarkdown'
import ShortcutInfo from './ShortcutInfo'
import SimilaritySearchResults from './SimilaritySearchResults'
// Add an empty line here
const getNewInputMessage = (app: App): ChatUserMessage => {
return {
role: 'user',
content: null,
promptContent: null,
id: uuidv4(),
mentionables: [
{
type: 'current-file',
file: app.workspace.getActiveFile(),
},
],
}
}
export type ChatRef = {
openNewChat: (selectedBlock?: MentionableBlockData) => void
addSelectionToChat: (selectedBlock: MentionableBlockData) => void
focusMessage: () => void
}
export type ChatProps = {
selectedBlock?: MentionableBlockData
}
const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
const app = useApp()
const { settings } = useSettings()
const { getRAGEngine } = useRAG()
const {
createOrUpdateConversation,
deleteConversation,
getChatMessagesById,
updateConversationTitle,
chatList,
} = useChatHistory()
const { streamResponse, chatModel } = useLLM()
const promptGenerator = useMemo(() => {
return new PromptGenerator(getRAGEngine, app, settings)
}, [getRAGEngine, app, settings])
const [inputMessage, setInputMessage] = useState<ChatUserMessage>(() => {
const newMessage = getNewInputMessage(app)
if (props.selectedBlock) {
newMessage.mentionables = [
...newMessage.mentionables,
{
type: 'block',
...props.selectedBlock,
},
]
}
return newMessage
})
const [addedBlockKey, setAddedBlockKey] = useState<string | null>(
props.selectedBlock
? getMentionableKey(
serializeMentionable({
type: 'block',
...props.selectedBlock,
}),
)
: null,
)
const [chatMessages, setChatMessages] = useState<ChatMessage[]>([])
const [focusedMessageId, setFocusedMessageId] = useState<string | null>(null)
const [currentConversationId, setCurrentConversationId] =
useState<string>(uuidv4())
const [queryProgress, setQueryProgress] = useState<QueryProgressState>({
type: 'idle',
})
const preventAutoScrollRef = useRef(false)
const lastProgrammaticScrollRef = useRef<number>(0)
const activeStreamAbortControllersRef = useRef<AbortController[]>([])
const chatUserInputRefs = useRef<Map<string, ChatUserInputRef>>(new Map())
const chatMessagesRef = useRef<HTMLDivElement>(null)
const registerChatUserInputRef = (
id: string,
ref: ChatUserInputRef | null,
) => {
if (ref) {
chatUserInputRefs.current.set(id, ref)
} else {
chatUserInputRefs.current.delete(id)
}
}
useEffect(() => {
const scrollContainer = chatMessagesRef.current
if (!scrollContainer) return
const handleScroll = () => {
// If the scroll event happened very close to our programmatic scroll, ignore it
if (Date.now() - lastProgrammaticScrollRef.current < 50) {
return
}
preventAutoScrollRef.current =
scrollContainer.scrollHeight -
scrollContainer.scrollTop -
scrollContainer.clientHeight >
20
}
scrollContainer.addEventListener('scroll', handleScroll)
return () => scrollContainer.removeEventListener('scroll', handleScroll)
}, [chatMessages])
const handleScrollToBottom = () => {
if (chatMessagesRef.current) {
const scrollContainer = chatMessagesRef.current
if (scrollContainer.scrollTop !== scrollContainer.scrollHeight) {
lastProgrammaticScrollRef.current = Date.now()
scrollContainer.scrollTop = scrollContainer.scrollHeight
}
}
}
const abortActiveStreams = () => {
for (const abortController of activeStreamAbortControllersRef.current) {
abortController.abort()
}
activeStreamAbortControllersRef.current = []
}
const handleLoadConversation = async (conversationId: string) => {
try {
abortActiveStreams()
const conversation = await getChatMessagesById(conversationId)
if (!conversation) {
throw new Error('Conversation not found')
}
setCurrentConversationId(conversationId)
setChatMessages(conversation)
const newInputMessage = getNewInputMessage(app)
setInputMessage(newInputMessage)
setFocusedMessageId(newInputMessage.id)
setQueryProgress({
type: 'idle',
})
} catch (error) {
new Notice('Failed to load conversation')
console.error('Failed to load conversation', error)
}
}
const handleNewChat = (selectedBlock?: MentionableBlockData) => {
setCurrentConversationId(uuidv4())
setChatMessages([])
const newInputMessage = getNewInputMessage(app)
if (selectedBlock) {
const mentionableBlock: MentionableBlock = {
type: 'block',
...selectedBlock,
}
newInputMessage.mentionables = [
...newInputMessage.mentionables,
mentionableBlock,
]
setAddedBlockKey(
getMentionableKey(serializeMentionable(mentionableBlock)),
)
}
setInputMessage(newInputMessage)
setFocusedMessageId(newInputMessage.id)
setQueryProgress({
type: 'idle',
})
abortActiveStreams()
}
const submitMutation = useMutation({
mutationFn: async ({
newChatHistory,
useVaultSearch,
}: {
newChatHistory: ChatMessage[]
useVaultSearch?: boolean
}) => {
abortActiveStreams()
setQueryProgress({
type: 'idle',
})
const responseMessageId = uuidv4()
setChatMessages([
...newChatHistory,
{
role: 'assistant',
content: '',
id: responseMessageId,
metadata: {
usage: undefined,
model: undefined,
},
},
])
try {
const abortController = new AbortController()
activeStreamAbortControllersRef.current.push(abortController)
const { requestMessages, compiledMessages } =
await promptGenerator.generateRequestMessages({
messages: newChatHistory,
useVaultSearch,
onQueryProgressChange: setQueryProgress,
})
setQueryProgress({
type: 'idle',
})
setChatMessages([
...compiledMessages,
{
role: 'assistant',
content: '',
id: responseMessageId,
metadata: {
usage: undefined,
model: undefined,
},
},
])
const stream = await streamResponse(
chatModel,
{
model: chatModel.name,
messages: requestMessages,
stream: true,
},
{
signal: abortController.signal,
},
)
for await (const chunk of stream) {
const content = chunk.choices[0]?.delta?.content ?? ''
setChatMessages((prevChatHistory) =>
prevChatHistory.map((message) =>
message.role === 'assistant' && message.id === responseMessageId
? {
...message,
content: message.content + content,
metadata: {
...message.metadata,
usage: chunk.usage ?? message.metadata?.usage, // Keep existing usage if chunk has no usage data
model: chatModel,
},
}
: message,
),
)
if (!preventAutoScrollRef.current) {
handleScrollToBottom()
}
}
// for debugging
setChatMessages((prevChatHistory) => {
const lastMessage = prevChatHistory[prevChatHistory.length - 1];
console.log("Last complete message:", lastMessage?.content);
return prevChatHistory;
});
} catch (error) {
if (error.name === 'AbortError') {
return
} else {
throw error
}
}
},
onError: (error) => {
setQueryProgress({
type: 'idle',
})
if (
error instanceof LLMAPIKeyNotSetException ||
error instanceof LLMAPIKeyInvalidException ||
error instanceof LLMBaseUrlNotSetException ||
error instanceof LLMModelNotSetException
) {
openSettingsModalWithError(app, error.message)
} else {
new Notice(error.message)
console.error('Failed to generate response', error)
}
},
})
const handleSubmit = (
newChatHistory: ChatMessage[],
useVaultSearch?: boolean,
) => {
submitMutation.mutate({ newChatHistory, useVaultSearch })
}
const applyMutation = useMutation({
mutationFn: async ({
blockInfo,
}: {
blockInfo: {
content: string
filename?: string
startLine?: number
endLine?: number
}
}) => {
const activeFile = app.workspace.getActiveFile()
if (!activeFile) {
throw new Error(
'No file is currently open to apply changes. Please open a file and try again.',
)
}
const activeFileContent = await readTFileContent(activeFile, app.vault)
const updatedFileContent = await manualApplyChangesToFile(
blockInfo.content,
activeFile,
activeFileContent,
blockInfo.startLine,
blockInfo.endLine
)
if (!updatedFileContent) {
throw new Error('Failed to apply changes')
}
await app.workspace.getLeaf(true).setViewState({
type: APPLY_VIEW_TYPE,
active: true,
state: {
file: activeFile,
originalContent: activeFileContent,
newContent: updatedFileContent,
} satisfies ApplyViewState,
})
},
onError: (error) => {
if (
error instanceof LLMAPIKeyNotSetException ||
error instanceof LLMAPIKeyInvalidException ||
error instanceof LLMBaseUrlNotSetException ||
error instanceof LLMModelNotSetException
) {
openSettingsModalWithError(app, error.message)
} else {
new Notice(error.message)
console.error('Failed to apply changes', error)
}
},
})
const handleApply = useCallback(
(blockInfo: {
content: string
filename?: string
startLine?: number
endLine?: number
}) => {
applyMutation.mutate({ blockInfo })
},
[applyMutation],
)
useEffect(() => {
setFocusedMessageId(inputMessage.id)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
//
useEffect(() => {
const updateConversationAsync = async () => {
try {
if (chatMessages.length > 0) {
createOrUpdateConversation(currentConversationId, chatMessages)
}
} catch (error) {
new Notice('Failed to save chat history')
console.error('Failed to save chat history', error)
}
}
updateConversationAsync()
}, [currentConversationId, chatMessages, createOrUpdateConversation])
// Updates the currentFile of the focused message (input or chat history)
// This happens when active file changes or focused message changes
const handleActiveLeafChange = useCallback(() => {
const activeFile = app.workspace.getActiveFile()
if (!activeFile) return
const mentionable: Omit<MentionableCurrentFile, 'id'> = {
type: 'current-file',
file: activeFile,
}
if (!focusedMessageId) return
if (inputMessage.id === focusedMessageId) {
setInputMessage((prevInputMessage) => ({
...prevInputMessage,
mentionables: [
mentionable,
...prevInputMessage.mentionables.filter(
(mentionable) => mentionable.type !== 'current-file',
),
],
}))
} else {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((message) =>
message.id === focusedMessageId && message.role === 'user'
? {
...message,
mentionables: [
mentionable,
...message.mentionables.filter(
(mentionable) => mentionable.type !== 'current-file',
),
],
}
: message,
),
)
}
}, [app.workspace, focusedMessageId, inputMessage.id])
useEffect(() => {
app.workspace.on('active-leaf-change', handleActiveLeafChange)
return () => {
app.workspace.off('active-leaf-change', handleActiveLeafChange)
}
}, [app.workspace, handleActiveLeafChange])
useImperativeHandle(ref, () => ({
openNewChat: (selectedBlock?: MentionableBlockData) =>
handleNewChat(selectedBlock),
addSelectionToChat: (selectedBlock: MentionableBlockData) => {
const mentionable: Omit<MentionableBlock, 'id'> = {
type: 'block',
...selectedBlock,
}
setAddedBlockKey(getMentionableKey(serializeMentionable(mentionable)))
if (focusedMessageId === inputMessage.id) {
setInputMessage((prevInputMessage) => {
const mentionableKey = getMentionableKey(
serializeMentionable(mentionable),
)
// Check if mentionable already exists
if (
prevInputMessage.mentionables.some(
(m) =>
getMentionableKey(serializeMentionable(m)) === mentionableKey,
)
) {
return prevInputMessage
}
return {
...prevInputMessage,
mentionables: [...prevInputMessage.mentionables, mentionable],
}
})
} else {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((message) => {
if (message.id === focusedMessageId && message.role === 'user') {
const mentionableKey = getMentionableKey(
serializeMentionable(mentionable),
)
// Check if mentionable already exists
if (
message.mentionables.some(
(m) =>
getMentionableKey(serializeMentionable(m)) ===
mentionableKey,
)
) {
return message
}
return {
...message,
mentionables: [...message.mentionables, mentionable],
}
}
return message
}),
)
}
},
focusMessage: () => {
if (!focusedMessageId) return
chatUserInputRefs.current.get(focusedMessageId)?.focus()
},
}))
return (
<div className="infio-chat-container">
<div className="infio-chat-header">
<h1 className="infio-chat-header-title"> CHAT </h1>
<div className="infio-chat-header-buttons">
<button
onClick={() => handleNewChat()}
className="infio-chat-list-dropdown"
>
<Plus size={18} />
</button>
<ChatListDropdown
chatList={chatList}
currentConversationId={currentConversationId}
onSelect={async (conversationId) => {
if (conversationId === currentConversationId) return
await handleLoadConversation(conversationId)
}}
onDelete={async (conversationId) => {
await deleteConversation(conversationId)
if (conversationId === currentConversationId) {
const nextConversation = chatList.find(
(chat) => chat.id !== conversationId,
)
if (nextConversation) {
void handleLoadConversation(nextConversation.id)
} else {
handleNewChat()
}
}
}}
onUpdateTitle={async (conversationId, newTitle) => {
await updateConversationTitle(conversationId, newTitle)
}}
className="infio-chat-list-dropdown"
>
<History size={18} />
</ChatListDropdown>
</div>
</div>
<div className="infio-chat-messages" ref={chatMessagesRef}>
{
// If the chat is empty, show a message to start a new chat
chatMessages.length === 0 && (
<div style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
height: '100%',
width: '100%'
}}>
<ShortcutInfo />
</div>
)
}
{chatMessages.map((message, index) =>
message.role === 'user' ? (
<div key={message.id} className="infio-chat-messages-user">
<ChatUserInput
ref={(ref) => registerChatUserInputRef(message.id, ref)}
initialSerializedEditorState={message.content}
onChange={(content) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.role === 'user' && msg.id === message.id
? {
...msg,
content,
}
: msg,
),
)
}}
onSubmit={(content, useVaultSearch) => {
if (editorStateToPlainText(content).trim() === '') return
handleSubmit(
[
...chatMessages.slice(0, index),
{
role: 'user',
content: content,
promptContent: null,
id: message.id,
mentionables: message.mentionables,
},
],
useVaultSearch,
)
chatUserInputRefs.current.get(inputMessage.id)?.focus()
}}
onFocus={() => {
setFocusedMessageId(message.id)
}}
mentionables={message.mentionables}
setMentionables={(mentionables) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.id === message.id ? { ...msg, mentionables } : msg,
),
)
}}
/>
{message.similaritySearchResults && (
<SimilaritySearchResults
similaritySearchResults={message.similaritySearchResults}
/>
)}
</div>
) : (
<div key={message.id} className="infio-chat-messages-assistant">
<ReactMarkdownItem
handleApply={handleApply}
isApplying={applyMutation.isPending}
>
{message.content}
</ReactMarkdownItem>
{message.content && <AssistantMessageActions message={message} />}
</div>
),
)}
<QueryProgress state={queryProgress} />
{submitMutation.isPending && (
<button onClick={abortActiveStreams} className="infio-stop-gen-btn">
<CircleStop size={16} />
<div>Stop Generation</div>
</button>
)}
</div>
<ChatUserInput
key={inputMessage.id} // this is needed to clear the editor when the user submits a new message
ref={(ref) => registerChatUserInputRef(inputMessage.id, ref)}
initialSerializedEditorState={inputMessage.content}
onChange={(content) => {
setInputMessage((prevInputMessage) => ({
...prevInputMessage,
content,
}))
}}
onSubmit={(content, useVaultSearch) => {
if (editorStateToPlainText(content).trim() === '') return
handleSubmit(
[...chatMessages, { ...inputMessage, content }],
useVaultSearch,
)
setInputMessage(getNewInputMessage(app))
preventAutoScrollRef.current = false
handleScrollToBottom()
}}
onFocus={() => {
setFocusedMessageId(inputMessage.id)
}}
mentionables={inputMessage.mentionables}
setMentionables={(mentionables) => {
setInputMessage((prevInputMessage) => ({
...prevInputMessage,
mentionables,
}))
}}
autoFocus
addedBlockKey={addedBlockKey}
/>
</div>
)
})
function ReactMarkdownItem({
handleApply,
isApplying,
children,
}: {
handleApply: (blockInfo: {
content: string
filename?: string
startLine?: number
endLine?: number
}) => void
isApplying: boolean
children: string
}) {
return (
<ReactMarkdown onApply={handleApply} isApplying={isApplying}>
{children}
</ReactMarkdown>
)
}
Chat.displayName = 'Chat'
export default Chat

View File

@@ -0,0 +1,202 @@
import * as Popover from '@radix-ui/react-popover'
import { Pencil, Trash2 } from 'lucide-react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { ChatConversationMeta } from '../../types/chat'
function TitleInput({
title,
onSubmit,
}: {
title: string
onSubmit: (title: string) => Promise<void>
}) {
const [value, setValue] = useState(title)
const inputRef = useRef<HTMLInputElement>(null)
useEffect(() => {
if (inputRef.current) {
inputRef.current.select()
inputRef.current.scrollLeft = 0
}
}, [])
return (
<input
ref={inputRef}
type="text"
value={value}
className="infio-chat-list-dropdown-item-title-input"
onClick={(e) => e.stopPropagation()}
onChange={(e) => setValue(e.target.value)}
onKeyDown={(e) => {
e.stopPropagation()
if (e.key === 'Enter') {
onSubmit(value)
}
}}
autoFocus
maxLength={100}
/>
)
}
function ChatListItem({
title,
isFocused,
isEditing,
onMouseEnter,
onSelect,
onDelete,
onStartEdit,
onFinishEdit,
}: {
title: string
isFocused: boolean
isEditing: boolean
onMouseEnter: () => void
onSelect: () => Promise<void>
onDelete: () => Promise<void>
onStartEdit: () => void
onFinishEdit: (title: string) => Promise<void>
}) {
const itemRef = useRef<HTMLLIElement>(null)
useEffect(() => {
if (isFocused && itemRef.current) {
itemRef.current.scrollIntoView({
block: 'nearest',
})
}
}, [isFocused])
return (
<li
ref={itemRef}
onClick={onSelect}
onMouseEnter={onMouseEnter}
className={isFocused ? 'selected' : ''}
>
{isEditing ? (
<TitleInput title={title} onSubmit={onFinishEdit} />
) : (
<div className="infio-chat-list-dropdown-item-title">{title}</div>
)}
<div className="infio-chat-list-dropdown-item-actions">
<div
onClick={(e) => {
e.stopPropagation()
onStartEdit()
}}
className="infio-chat-list-dropdown-item-icon"
>
<Pencil size={14} />
</div>
<div
onClick={async (e) => {
e.stopPropagation()
await onDelete()
}}
className="infio-chat-list-dropdown-item-icon"
>
<Trash2 size={14} />
</div>
</div>
</li>
)
}
export function ChatListDropdown({
chatList,
currentConversationId,
onSelect,
onDelete,
onUpdateTitle,
className,
children,
}: {
chatList: ChatConversationMeta[]
currentConversationId: string
onSelect: (conversationId: string) => Promise<void>
onDelete: (conversationId: string) => Promise<void>
onUpdateTitle: (conversationId: string, newTitle: string) => Promise<void>
className?: string
children: React.ReactNode
}) {
const [open, setOpen] = useState(false)
const [focusedIndex, setFocusedIndex] = useState<number>(0)
const [editingId, setEditingId] = useState<string | null>(null)
useEffect(() => {
if (open) {
const currentIndex = chatList.findIndex(
(chat) => chat.id === currentConversationId,
)
setFocusedIndex(currentIndex === -1 ? 0 : currentIndex)
setEditingId(null)
}
}, [open])
const handleKeyDown = useCallback(
(e: React.KeyboardEvent) => {
if (e.key === 'ArrowUp') {
setFocusedIndex(Math.max(0, focusedIndex - 1))
} else if (e.key === 'ArrowDown') {
setFocusedIndex(Math.min(chatList.length - 1, focusedIndex + 1))
} else if (e.key === 'Enter') {
onSelect(chatList[focusedIndex].id)
setOpen(false)
}
},
[chatList, focusedIndex, setFocusedIndex, onSelect],
)
return (
<Popover.Root open={open} onOpenChange={setOpen}>
<Popover.Trigger asChild>
<button className={className}>{children}</button>
</Popover.Trigger>
<Popover.Portal>
<Popover.Content
className="infio-popover infio-chat-list-dropdown-content"
onKeyDown={handleKeyDown}
>
<ul>
{chatList.length === 0 ? (
<li className="infio-chat-list-dropdown-empty">
No conversations
</li>
) : (
chatList.map((chat, index) => (
<ChatListItem
key={chat.id}
title={chat.title}
isFocused={focusedIndex === index}
isEditing={editingId === chat.id}
onMouseEnter={() => {
setFocusedIndex(index)
}}
onSelect={async () => {
await onSelect(chat.id)
setOpen(false)
}}
onDelete={async () => {
await onDelete(chat.id)
}}
onStartEdit={() => {
setEditingId(chat.id)
}}
onFinishEdit={async (title) => {
await onUpdateTitle(chat.id, title)
setEditingId(null)
}}
/>
))
)}
</ul>
</Popover.Content>
</Popover.Portal>
</Popover.Root>
)
}

View File

@@ -0,0 +1,127 @@
import { $generateNodesFromSerializedNodes } from '@lexical/clipboard'
import { BaseSerializedNode } from '@lexical/clipboard/clipboard'
import { InitialEditorStateType } from '@lexical/react/LexicalComposer'
import * as Dialog from '@radix-ui/react-dialog'
import { $insertNodes, LexicalEditor } from 'lexical'
import { X } from 'lucide-react'
import { Notice } from 'obsidian'
import { useRef, useState } from 'react'
import { useDatabase } from '../../contexts/DatabaseContext'
import { useDialogContainer } from '../../contexts/DialogContext'
import { DuplicateTemplateException } from '../../database/exception'
import LexicalContentEditable from './chat-input/LexicalContentEditable'
/*
* This component must be used inside <Dialog.Root modal={false}>
* The modal={false} prop is required because modal mode blocks pointer events across the entire page,
* which would conflict with lexical editor popovers
*/
export default function CreateTemplateDialogContent({
selectedSerializedNodes,
onClose,
}: {
selectedSerializedNodes?: BaseSerializedNode[] | null
onClose: () => void
}) {
const container = useDialogContainer()
const { getTemplateManager } = useDatabase()
const [templateName, setTemplateName] = useState('')
const editorRef = useRef<LexicalEditor | null>(null)
const contentEditableRef = useRef<HTMLDivElement>(null)
const initialEditorState: InitialEditorStateType = (
editor: LexicalEditor,
) => {
if (!selectedSerializedNodes) return
editor.update(() => {
const parsedNodes = $generateNodesFromSerializedNodes(
selectedSerializedNodes,
)
$insertNodes(parsedNodes)
})
}
const onSubmit = async () => {
try {
if (!editorRef.current) return
const serializedEditorState = editorRef.current.toJSON()
const nodes = serializedEditorState.editorState.root.children
if (nodes.length === 0) {
new Notice('Please enter a content for your template')
return
}
if (templateName.trim().length === 0) {
new Notice('Please enter a name for your template')
return
}
await (
await getTemplateManager()
).createTemplate({
name: templateName,
content: { nodes },
})
new Notice(`Template created: ${templateName}`)
setTemplateName('')
onClose()
} catch (error) {
if (error instanceof DuplicateTemplateException) {
new Notice('A template with this name already exists')
} else {
console.error(error)
new Notice('Failed to create template')
}
}
}
return (
<Dialog.Portal container={container}>
<Dialog.Content className="infio-chat-dialog-content">
<div className="infio-dialog-header">
<Dialog.Title className="infio-dialog-title">
Create template
</Dialog.Title>
<Dialog.Description className="infio-dialog-description">
Create template from selected content
</Dialog.Description>
</div>
<div className="infio-dialog-input">
<label>Name</label>
<input
type="text"
value={templateName}
onChange={(e) => setTemplateName(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter') {
e.stopPropagation()
e.preventDefault()
onSubmit()
}
}}
/>
</div>
<div className="infio-chat-user-input-container">
<LexicalContentEditable
initialEditorState={initialEditorState}
editorRef={editorRef}
contentEditableRef={contentEditableRef}
onEnter={onSubmit}
/>
</div>
<div className="infio-dialog-footer">
<button onClick={onSubmit}>Create template</button>
</div>
<Dialog.Close className="infio-dialog-close" asChild>
<X size={16} />
</Dialog.Close>
</Dialog.Content>
</Dialog.Portal>
)
}

View File

@@ -0,0 +1,84 @@
import * as Popover from '@radix-ui/react-popover'
import {
ArrowDown,
ArrowRightLeft,
ArrowUp,
Coins,
Cpu,
Info,
} from 'lucide-react'
import { ResponseUsage } from '../../types/llm/response'
type LLMResponseInfoProps = {
usage?: ResponseUsage
estimatedPrice: number | null
model?: string
}
export default function LLMResponseInfoPopover({
usage,
estimatedPrice,
model,
}: LLMResponseInfoProps) {
return (
<Popover.Root>
<Popover.Trigger asChild>
<button>
<Info className="infio-llm-info-icon--trigger" size={12} />
</button>
</Popover.Trigger>
{usage ? (
<Popover.Content className="infio-chat-popover-content infio-llm-info-content">
<div className="infio-llm-info-header">LLM Response Information</div>
<div className="infio-llm-info-tokens">
<div className="infio-llm-info-tokens-header">Token Count</div>
<div className="infio-llm-info-tokens-grid">
<div className="infio-llm-info-token-row">
<ArrowUp className="infio-llm-info-icon--input" />
<span>Input:</span>
<span className="infio-llm-info-token-value">
{usage.prompt_tokens}
</span>
</div>
<div className="infio-llm-info-token-row">
<ArrowDown className="infio-llm-info-icon--output" />
<span>Output:</span>
<span className="infio-llm-info-token-value">
{usage.completion_tokens}
</span>
</div>
<div className="infio-llm-info-token-row infio-llm-info-token-total">
<ArrowRightLeft className="infio-llm-info-icon--total" />
<span>Total:</span>
<span className="infio-llm-info-token-value">
{usage.total_tokens}
</span>
</div>
</div>
</div>
<div className="infio-llm-info-footer-row">
<Coins className="infio-llm-info-icon--footer" />
<span>Estimated Price:</span>
<span className="infio-llm-info-footer-value">
{estimatedPrice === null
? 'Not available'
: `$${estimatedPrice.toFixed(4)}`}
</span>
</div>
<div className="infio-llm-info-footer-row">
<Cpu className="infio-llm-info-icon--footer" />
<span>Model:</span>
<span className="infio-llm-info-footer-value infio-llm-info-model">
{model ?? 'Not available'}
</span>
</div>
</Popover.Content>
) : (
<Popover.Content className="infio-chat-popover-content">
<div>Usage statistics are not available for this model</div>
</Popover.Content>
)}
</Popover.Root>
)
}

View File

@@ -0,0 +1,122 @@
import { Check, CopyIcon, Loader2 } from 'lucide-react'
import { PropsWithChildren, useMemo, useState } from 'react'
import { useDarkModeContext } from '../../contexts/DarkModeContext'
import { MemoizedSyntaxHighlighterWrapper } from './SyntaxHighlighterWrapper'
export default function MarkdownCodeComponent({
onApply,
isApplying,
language,
filename,
startLine,
endLine,
action,
children,
}: PropsWithChildren<{
onApply: (blockInfo: {
content: string
filename?: string
startLine?: number
endLine?: number
}) => void
isApplying: boolean
language?: string
filename?: string
startLine?: number
endLine?: number
action?: 'edit' | 'new' | 'reference'
}>) {
const [copied, setCopied] = useState(false)
const { isDarkMode } = useDarkModeContext()
const wrapLines = useMemo(() => {
return !language || ['markdown'].includes(language)
}, [language])
const handleCopy = async () => {
try {
await navigator.clipboard.writeText(String(children))
setCopied(true)
setTimeout(() => setCopied(false), 2000)
} catch (err) {
console.error('Failed to copy text: ', err)
}
}
return (
<div className={`infio-chat-code-block ${filename ? 'has-filename' : ''} ${action ? `type-${action}` : ''}`}>
<div className={'infio-chat-code-block-header'}>
{filename && (
<div className={'infio-chat-code-block-header-filename'}>{filename}</div>
)}
<div className={'infio-chat-code-block-header-button'}>
<button
onClick={() => {
handleCopy()
}}
>
{copied ? (
<>
<Check size={10} /> Copied
</>
) : (
<>
<CopyIcon size={10} /> Copy
</>
)}
</button>
{action === 'edit' && (
<button
onClick={() => {
onApply({
content: String(children),
filename,
startLine,
endLine
})
}}
disabled={isApplying}
>
{isApplying ? (
<>
<Loader2 className="spinner" size={14} /> Applying...
</>
) : (
'Apply'
)}
</button>
)}
{action === 'new' && (
<button
onClick={() => {
onApply({
content: String(children),
filename
})
}}
disabled={isApplying}
>
{isApplying ? (
<>
<Loader2 className="spinner" size={14} /> Inserting...
</>
) : (
'Insert'
)}
</button>
)}
</div>
</div>
<MemoizedSyntaxHighlighterWrapper
isDarkMode={isDarkMode}
language={language}
hasFilename={!!filename}
wrapLines={wrapLines}
>
{String(children)}
</MemoizedSyntaxHighlighterWrapper>
</div>
)
}

View File

@@ -0,0 +1,75 @@
import { PropsWithChildren, useEffect, useMemo, useState } from 'react'
import { useApp } from '../../contexts/AppContext'
import { useDarkModeContext } from '../../contexts/DarkModeContext'
import { openMarkdownFile, readTFileContent } from '../../utils/obsidian'
import { MemoizedSyntaxHighlighterWrapper } from './SyntaxHighlighterWrapper'
export default function MarkdownReferenceBlock({
filename,
startLine,
endLine,
language,
}: PropsWithChildren<{
filename: string
startLine: number
endLine: number
language?: string
}>) {
const app = useApp()
const { isDarkMode } = useDarkModeContext()
const [blockContent, setBlockContent] = useState<string | null>(null)
const wrapLines = useMemo(() => {
return !language || ['markdown'].includes(language)
}, [language])
useEffect(() => {
async function fetchBlockContent() {
const file = app.vault.getFileByPath(filename)
if (!file) {
setBlockContent(null)
return
}
const fileContent = await readTFileContent(file, app.vault)
const content = fileContent
.split('\n')
.slice(startLine - 1, endLine)
.join('\n')
setBlockContent(content)
}
fetchBlockContent()
}, [filename, startLine, endLine, app.vault])
const handleClick = () => {
openMarkdownFile(app, filename, startLine)
}
// TODO: Update styles
return (
blockContent && (
<div
className={`infio-chat-code-block ${filename ? 'has-filename' : ''}`}
onClick={handleClick}
>
<div className={'infio-chat-code-block-header'}>
{filename && (
<div className={'infio-chat-code-block-header-filename'}>
{filename}
</div>
)}
</div>
<MemoizedSyntaxHighlighterWrapper
isDarkMode={isDarkMode}
language={language}
hasFilename={!!filename}
wrapLines={wrapLines}
>
{blockContent}
</MemoizedSyntaxHighlighterWrapper>
</div>
)
)
}

View File

@@ -0,0 +1,85 @@
import { SelectVector } from '../../database/schema'
export type QueryProgressState =
| {
type: 'reading-mentionables'
}
| {
type: 'indexing'
indexProgress: IndexProgress
}
| {
type: 'querying'
}
| {
type: 'querying-done'
queryResult: (Omit<SelectVector, 'embedding'> & { similarity: number })[]
}
| {
type: 'idle'
}
export type IndexProgress = {
completedChunks: number
totalChunks: number
totalFiles: number
}
// TODO: Update style
export default function QueryProgress({
state,
}: {
state: QueryProgressState
}) {
switch (state.type) {
case 'idle':
return null
case 'reading-mentionables':
return (
<div className="infio-query-progress">
<p>
Reading mentioned files
<DotLoader />
</p>
</div>
)
case 'indexing':
return (
<div className="infio-query-progress">
<p>
{`Indexing ${state.indexProgress.totalFiles} file`}
<DotLoader />
</p>
<p className="infio-query-progress-detail">{`${state.indexProgress.completedChunks}/${state.indexProgress.totalChunks} chunks indexed`}</p>
</div>
)
case 'querying':
return (
<div className="infio-query-progress">
<p>
Querying the vault
<DotLoader />
</p>
</div>
)
case 'querying-done':
return (
<div className="infio-query-progress">
<p>
Reading related files
<DotLoader />
</p>
{state.queryResult.map((result) => (
<div key={result.path}>
<p>{result.path}</p>
<p>{result.similarity}</p>
</div>
))}
</div>
)
}
}
function DotLoader() {
return <span className="infio-dot-loader" aria-label="Loading"></span>
}

View File

@@ -0,0 +1,64 @@
import React, { useMemo } from 'react'
import Markdown from 'react-markdown'
import {
ParsedinfioBlock,
parseinfioBlocks,
} from '../../utils/parse-infio-block'
import MarkdownCodeComponent from './MarkdownCodeComponent'
import MarkdownReferenceBlock from './MarkdownReferenceBlock'
function ReactMarkdown({
onApply,
isApplying,
children,
}: {
onApply: (blockInfo: {
content: string
filename?: string
startLine?: number
endLine?: number
}) => void
children: string
isApplying: boolean
}) {
const blocks: ParsedinfioBlock[] = useMemo(
() => parseinfioBlocks(children),
[children],
)
return (
<>
{blocks.map((block, index) =>
block.type === 'string' ? (
<Markdown key={index} className="infio-markdown">
{block.content}
</Markdown>
) : block.startLine && block.endLine && block.filename && block.action === 'reference' ? (
<MarkdownReferenceBlock
key={index}
filename={block.filename}
startLine={block.startLine}
endLine={block.endLine}
/>
) : (
<MarkdownCodeComponent
key={index}
onApply={onApply}
isApplying={isApplying}
language={block.language}
filename={block.filename}
startLine={block.startLine}
endLine={block.endLine}
action={block.action}
>
{block.content}
</MarkdownCodeComponent>
),
)}
</>
)
}
export default React.memo(ReactMarkdown)

View File

@@ -0,0 +1,38 @@
import { Platform } from 'obsidian';
import React from 'react';
const ShortcutInfo: React.FC = () => {
const modKey = Platform.isMacOS ? 'Cmd' : 'Ctrl';
const shortcuts = [
{
label: 'Edit inline',
shortcut: `${modKey}+Shift+K`,
},
{
label: 'Chat with select',
shortcut: `${modKey}+Shift+L`,
},
{
label: 'Submit with vault',
shortcut: `${modKey}+Shift+Enter`,
}
];
return (
<div className="infio-shortcut-info">
<table className="infio-shortcut-table">
<tbody>
{shortcuts.map((item, index) => (
<tr key={index} className="infio-shortcut-item">
<td className="infio-shortcut-label">{item.label}</td>
<td className="infio-shortcut-key"><kbd>{item.shortcut}</kbd></td>
</tr>
))}
</tbody>
</table>
</div>
);
};
export default ShortcutInfo;

View File

@@ -0,0 +1,71 @@
import path from 'path'
import { ChevronDown, ChevronRight } from 'lucide-react'
import { useState } from 'react'
import { useApp } from '../../contexts/AppContext'
import { SelectVector } from '../../database/schema'
import { openMarkdownFile } from '../../utils/obsidian'
function SimiliartySearchItem({
chunk,
}: {
chunk: Omit<SelectVector, 'embedding'> & {
similarity: number
}
}) {
const app = useApp()
const handleClick = () => {
openMarkdownFile(app, chunk.path, chunk.metadata.startLine)
}
return (
<div onClick={handleClick} className="infio-similarity-search-item">
<div className="infio-similarity-search-item__similarity">
{chunk.similarity.toFixed(3)}
</div>
<div className="infio-similarity-search-item__path">
{path.basename(chunk.path)}
</div>
<div className="infio-similarity-search-item__line-numbers">
{`${chunk.metadata.startLine} - ${chunk.metadata.endLine}`}
</div>
</div>
)
}
export default function SimilaritySearchResults({
similaritySearchResults,
}: {
similaritySearchResults: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}) {
const [isOpen, setIsOpen] = useState(false)
return (
<div className="infio-similarity-search-results">
<div
onClick={() => {
setIsOpen(!isOpen)
}}
className="infio-similarity-search-results__trigger"
>
{isOpen ? <ChevronDown size={16} /> : <ChevronRight size={16} />}
<div>Show Referenced Documents ({similaritySearchResults.length})</div>
</div>
{isOpen && (
<div
style={{
display: 'flex',
flexDirection: 'column',
}}
>
{similaritySearchResults.map((chunk) => (
<SimiliartySearchItem key={chunk.id} chunk={chunk} />
))}
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,51 @@
import { memo } from 'react'
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'
import {
oneDark,
oneLight,
} from 'react-syntax-highlighter/dist/esm/styles/prism'
function SyntaxHighlighterWrapper({
isDarkMode,
language,
hasFilename,
wrapLines,
children,
}: {
isDarkMode: boolean
language: string | undefined
hasFilename: boolean
wrapLines: boolean
children: string
}) {
return (
<SyntaxHighlighter
language={language}
style={isDarkMode ? oneDark : oneLight}
customStyle={{
borderRadius: hasFilename
? '0 0 var(--radius-s) var(--radius-s)'
: 'var(--radius-s)',
margin: 0,
padding: 'var(--size-4-2)',
fontSize: 'var(--font-ui-small)',
fontFamily:
language === 'markdown' ? 'var(--font-interface)' : 'inherit',
}}
wrapLines={wrapLines}
lineProps={
// Wrapping should work without lineProps, but Obsidian's default CSS seems to override SyntaxHighlighter's styles.
// We manually override the white-space property to ensure proper wrapping.
wrapLines
? {
style: { whiteSpace: 'pre-wrap' },
}
: undefined
}
>
{children}
</SyntaxHighlighter>
)
}
export const MemoizedSyntaxHighlighterWrapper = memo(SyntaxHighlighterWrapper)

View File

@@ -0,0 +1,374 @@
import { useQuery } from '@tanstack/react-query'
import { $nodesOfType, LexicalEditor, SerializedEditorState } from 'lexical'
import {
forwardRef,
useCallback,
useEffect,
useImperativeHandle,
useMemo,
useRef,
useState,
} from 'react'
import { useApp } from '../../../contexts/AppContext'
import { useDarkModeContext } from '../../../contexts/DarkModeContext'
import {
Mentionable,
MentionableImage,
SerializedMentionable,
} from '../../../types/mentionable'
import { fileToMentionableImage } from '../../../utils/image'
import {
deserializeMentionable,
getMentionableKey,
serializeMentionable,
} from '../../../utils/mentionable'
import { openMarkdownFile, readTFileContent } from '../../../utils/obsidian'
import { MemoizedSyntaxHighlighterWrapper } from '../SyntaxHighlighterWrapper'
import { ImageUploadButton } from './ImageUploadButton'
import LexicalContentEditable from './LexicalContentEditable'
import MentionableBadge from './MentionableBadge'
import { ModelSelect } from './ModelSelect'
import { MentionNode } from './plugins/mention/MentionNode'
import { NodeMutations } from './plugins/on-mutation/OnMutationPlugin'
import { SubmitButton } from './SubmitButton'
import { VaultChatButton } from './VaultChatButton'
export type ChatUserInputRef = {
focus: () => void
}
export type ChatUserInputProps = {
initialSerializedEditorState: SerializedEditorState | null
onChange: (content: SerializedEditorState) => void
onSubmit: (content: SerializedEditorState, useVaultSearch?: boolean) => void
onFocus: () => void
mentionables: Mentionable[]
setMentionables: (mentionables: Mentionable[]) => void
autoFocus?: boolean
addedBlockKey?: string | null
}
const ChatUserInput = forwardRef<ChatUserInputRef, ChatUserInputProps>(
(
{
initialSerializedEditorState,
onChange,
onSubmit,
onFocus,
mentionables,
setMentionables,
autoFocus = false,
addedBlockKey,
},
ref,
) => {
const app = useApp()
const editorRef = useRef<LexicalEditor | null>(null)
const contentEditableRef = useRef<HTMLDivElement>(null)
const containerRef = useRef<HTMLDivElement>(null)
const [displayedMentionableKey, setDisplayedMentionableKey] = useState<
string | null
>(addedBlockKey ?? null)
useEffect(() => {
if (addedBlockKey) {
setDisplayedMentionableKey(addedBlockKey)
}
}, [addedBlockKey])
useImperativeHandle(ref, () => ({
focus: () => {
contentEditableRef.current?.focus()
},
}))
const handleMentionNodeMutation = (
mutations: NodeMutations<MentionNode>,
) => {
const destroyedMentionableKeys: string[] = []
const addedMentionables: SerializedMentionable[] = []
mutations.forEach((mutation) => {
const mentionable = mutation.node.getMentionable()
const mentionableKey = getMentionableKey(mentionable)
if (mutation.mutation === 'destroyed') {
const nodeWithSameMentionable = editorRef.current?.read(() =>
$nodesOfType(MentionNode).find(
(node) =>
getMentionableKey(node.getMentionable()) === mentionableKey,
),
)
if (!nodeWithSameMentionable) {
// remove mentionable only if it's not present in the editor state
destroyedMentionableKeys.push(mentionableKey)
}
} else if (mutation.mutation === 'created') {
if (
mentionables.some(
(m) =>
getMentionableKey(serializeMentionable(m)) === mentionableKey,
) ||
addedMentionables.some(
(m) => getMentionableKey(m) === mentionableKey,
)
) {
// do nothing if mentionable is already added
return
}
addedMentionables.push(mentionable)
}
})
setMentionables(
mentionables
.filter(
(m) =>
!destroyedMentionableKeys.includes(
getMentionableKey(serializeMentionable(m)),
),
)
.concat(
addedMentionables
.map((m) => deserializeMentionable(m, app))
.filter((v) => !!v),
),
)
if (addedMentionables.length > 0) {
setDisplayedMentionableKey(
getMentionableKey(addedMentionables[addedMentionables.length - 1]),
)
}
}
const handleCreateImageMentionables = useCallback(
(mentionableImages: MentionableImage[]) => {
const newMentionableImages = mentionableImages.filter(
(m) =>
!mentionables.some(
(mentionable) =>
getMentionableKey(serializeMentionable(mentionable)) ===
getMentionableKey(serializeMentionable(m)),
),
)
if (newMentionableImages.length === 0) return
setMentionables([...mentionables, ...newMentionableImages])
setDisplayedMentionableKey(
getMentionableKey(
serializeMentionable(
newMentionableImages[newMentionableImages.length - 1],
),
),
)
},
[mentionables, setMentionables],
)
const handleMentionableDelete = (mentionable: Mentionable) => {
const mentionableKey = getMentionableKey(
serializeMentionable(mentionable),
)
setMentionables(
mentionables.filter(
(m) => getMentionableKey(serializeMentionable(m)) !== mentionableKey,
),
)
editorRef.current?.update(() => {
$nodesOfType(MentionNode).forEach((node) => {
if (getMentionableKey(node.getMentionable()) === mentionableKey) {
node.remove()
}
})
})
}
const handleUploadImages = async (images: File[]) => {
const mentionableImages = await Promise.all(
images.map((image) => fileToMentionableImage(image)),
)
handleCreateImageMentionables(mentionableImages)
}
const handleSubmit = (options: { useVaultSearch?: boolean } = {}) => {
const content = editorRef.current?.getEditorState()?.toJSON()
content && onSubmit(content, options.useVaultSearch)
}
return (
<div className="infio-chat-user-input-container" ref={containerRef}>
{mentionables.length > 0 && (
<div className="infio-chat-user-input-files">
{mentionables.map((m) => (
<MentionableBadge
key={getMentionableKey(serializeMentionable(m))}
mentionable={m}
onDelete={() => handleMentionableDelete(m)}
onClick={() => {
const mentionableKey = getMentionableKey(
serializeMentionable(m),
)
if (
(m.type === 'current-file' ||
m.type === 'file' ||
m.type === 'block') &&
m.file &&
mentionableKey === displayedMentionableKey
) {
// open file on click again
openMarkdownFile(
app,
m.file.path,
m.type === 'block' ? m.startLine : undefined,
)
} else {
setDisplayedMentionableKey(mentionableKey)
}
}}
isFocused={
getMentionableKey(serializeMentionable(m)) ===
displayedMentionableKey
}
/>
))}
</div>
)}
<MentionableContentPreview
displayedMentionableKey={displayedMentionableKey}
mentionables={mentionables}
/>
<LexicalContentEditable
initialEditorState={(editor) => {
if (initialSerializedEditorState) {
editor.setEditorState(
editor.parseEditorState(initialSerializedEditorState),
)
}
}}
editorRef={editorRef}
contentEditableRef={contentEditableRef}
onChange={onChange}
onEnter={() => handleSubmit({ useVaultSearch: false })}
onFocus={onFocus}
onMentionNodeMutation={handleMentionNodeMutation}
onCreateImageMentionables={handleCreateImageMentionables}
autoFocus={autoFocus}
plugins={{
onEnter: {
onVaultChat: () => {
handleSubmit({ useVaultSearch: true })
},
},
templatePopover: {
anchorElement: containerRef.current,
},
}}
/>
<div className="infio-chat-user-input-controls">
<div className="infio-chat-user-input-controls__model-select-container">
<ModelSelect />
<ImageUploadButton onUpload={handleUploadImages} />
</div>
<div className="infio-chat-user-input-controls__buttons">
<SubmitButton onClick={() => handleSubmit()} />
{/* <VaultChatButton
onClick={() => {
handleSubmit({ useVaultSearch: true })
}}
/> */}
</div>
</div>
</div>
)
},
)
function MentionableContentPreview({
displayedMentionableKey,
mentionables,
}: {
displayedMentionableKey: string | null
mentionables: Mentionable[]
}) {
const app = useApp()
const { isDarkMode } = useDarkModeContext()
const displayedMentionable: Mentionable | null = useMemo(() => {
return (
mentionables.find(
(m) =>
getMentionableKey(serializeMentionable(m)) ===
displayedMentionableKey,
) ?? null
)
}, [displayedMentionableKey, mentionables])
const { data: displayFileContent } = useQuery({
enabled:
!!displayedMentionable &&
['file', 'current-file', 'block'].includes(displayedMentionable.type),
queryKey: [
'file',
displayedMentionableKey,
mentionables.map((m) => getMentionableKey(serializeMentionable(m))), // should be updated when mentionables change (especially on delete)
],
queryFn: async () => {
if (!displayedMentionable) return null
if (
displayedMentionable.type === 'file' ||
displayedMentionable.type === 'current-file'
) {
if (!displayedMentionable.file) return null
return await readTFileContent(displayedMentionable.file, app.vault)
} else if (displayedMentionable.type === 'block') {
const fileContent = await readTFileContent(
displayedMentionable.file,
app.vault,
)
return fileContent
.split('\n')
.slice(
displayedMentionable.startLine - 1,
displayedMentionable.endLine,
)
.join('\n')
}
return null
},
})
const displayImage: MentionableImage | null = useMemo(() => {
return displayedMentionable?.type === 'image' ? displayedMentionable : null
}, [displayedMentionable])
return displayFileContent ? (
<div className="infio-chat-user-input-file-content-preview">
<MemoizedSyntaxHighlighterWrapper
isDarkMode={isDarkMode}
language="markdown"
hasFilename={false}
wrapLines={false}
>
{displayFileContent}
</MemoizedSyntaxHighlighterWrapper>
</div>
) : displayImage ? (
<div className="infio-chat-user-input-file-content-preview">
<img src={displayImage.data} alt={displayImage.name} />
</div>
) : null
}
ChatUserInput.displayName = 'ChatUserInput'
export default ChatUserInput

View File

@@ -0,0 +1,30 @@
import { ImageIcon } from 'lucide-react'
export function ImageUploadButton({
onUpload,
}: {
onUpload: (files: File[]) => void
}) {
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
const files = Array.from(event.target.files ?? [])
if (files.length > 0) {
onUpload(files)
}
}
return (
<label className="infio-chat-user-input-submit-button">
<input
type="file"
accept="image/*"
multiple
onChange={handleFileChange}
style={{ display: 'none' }}
/>
<div className="infio-chat-user-input-submit-button-icons">
<ImageIcon size={12} />
</div>
<div>Image</div>
</label>
)
}

View File

@@ -0,0 +1,153 @@
import {
InitialConfigType,
InitialEditorStateType,
LexicalComposer,
} from '@lexical/react/LexicalComposer'
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
import { EditorRefPlugin } from '@lexical/react/LexicalEditorRefPlugin'
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin'
import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin'
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
import { LexicalEditor, SerializedEditorState } from 'lexical'
import { RefObject, useCallback, useEffect } from 'react'
import { useApp } from '../../../contexts/AppContext'
import { MentionableImage } from '../../../types/mentionable'
import { fuzzySearch } from '../../../utils/fuzzy-search'
import DragDropPaste from './plugins/image/DragDropPastePlugin'
import ImagePastePlugin from './plugins/image/ImagePastePlugin'
import AutoLinkMentionPlugin from './plugins/mention/AutoLinkMentionPlugin'
import { MentionNode } from './plugins/mention/MentionNode'
import MentionPlugin from './plugins/mention/MentionPlugin'
import NoFormatPlugin from './plugins/no-format/NoFormatPlugin'
import OnEnterPlugin from './plugins/on-enter/OnEnterPlugin'
import OnMutationPlugin, {
NodeMutations,
} from './plugins/on-mutation/OnMutationPlugin'
import CreateTemplatePopoverPlugin from './plugins/template/CreateTemplatePopoverPlugin'
import TemplatePlugin from './plugins/template/TemplatePlugin'
export type LexicalContentEditableProps = {
editorRef: RefObject<LexicalEditor>
contentEditableRef: RefObject<HTMLDivElement>
onChange?: (content: SerializedEditorState) => void
onEnter?: (evt: KeyboardEvent) => void
onFocus?: () => void
onMentionNodeMutation?: (mutations: NodeMutations<MentionNode>) => void
onCreateImageMentionables?: (mentionables: MentionableImage[]) => void
initialEditorState?: InitialEditorStateType
autoFocus?: boolean
plugins?: {
onEnter?: {
onVaultChat: () => void
}
templatePopover?: {
anchorElement: HTMLElement | null
}
}
}
export default function LexicalContentEditable({
editorRef,
contentEditableRef,
onChange,
onEnter,
onFocus,
onMentionNodeMutation,
onCreateImageMentionables,
initialEditorState,
autoFocus = false,
plugins,
}: LexicalContentEditableProps) {
const app = useApp()
const initialConfig: InitialConfigType = {
namespace: 'LexicalContentEditable',
theme: {
root: 'infio-chat-lexical-content-editable-root',
paragraph: 'infio-chat-lexical-content-editable-paragraph',
},
nodes: [MentionNode],
editorState: initialEditorState,
onError: (error) => {
console.error(error)
},
}
const searchResultByQuery = useCallback(
(query: string) => fuzzySearch(app, query),
[app],
)
/*
* Using requestAnimationFrame for autoFocus instead of using editor.focus()
* due to known issues with editor.focus() when initialConfig.editorState is set
* See: https://github.com/facebook/lexical/issues/4460
*/
useEffect(() => {
if (autoFocus) {
requestAnimationFrame(() => {
contentEditableRef.current?.focus()
})
}
}, [autoFocus, contentEditableRef])
return (
<LexicalComposer initialConfig={initialConfig}>
{/*
There was two approach to make mentionable node copy and pasteable.
1. use RichTextPlugin and reset text format when paste
- so I implemented NoFormatPlugin to reset text format when paste
2. use PlainTextPlugin and override paste command
- PlainTextPlugin only pastes text, so we need to implement custom paste handler.
- https://github.com/facebook/lexical/discussions/5112
*/}
<RichTextPlugin
contentEditable={
<ContentEditable
className="obsidian-default-textarea"
style={{
background: 'transparent',
}}
onFocus={onFocus}
ref={contentEditableRef}
/>
}
ErrorBoundary={LexicalErrorBoundary}
/>
<HistoryPlugin />
<MentionPlugin searchResultByQuery={searchResultByQuery} />
<OnChangePlugin
onChange={(editorState) => {
onChange?.(editorState.toJSON())
}}
/>
{onEnter && (
<OnEnterPlugin
onEnter={onEnter}
onVaultChat={plugins?.onEnter?.onVaultChat}
/>
)}
<OnMutationPlugin
nodeClass={MentionNode}
onMutation={(mutations) => {
onMentionNodeMutation?.(mutations)
}}
/>
<EditorRefPlugin editorRef={editorRef} />
<NoFormatPlugin />
<AutoLinkMentionPlugin />
<ImagePastePlugin onCreateImageMentionables={onCreateImageMentionables} />
<DragDropPaste onCreateImageMentionables={onCreateImageMentionables} />
<TemplatePlugin />
{plugins?.templatePopover && (
<CreateTemplatePopoverPlugin
anchorElement={plugins.templatePopover.anchorElement}
contentEditableElement={contentEditableRef.current}
/>
)}
</LexicalComposer>
)
}

View File

@@ -0,0 +1,319 @@
import { X } from 'lucide-react'
import { PropsWithChildren } from 'react'
import {
Mentionable,
MentionableBlock,
MentionableCurrentFile,
MentionableFile,
MentionableFolder,
MentionableImage,
MentionableUrl,
MentionableVault,
} from '../../../types/mentionable'
import { getMentionableIcon } from './utils/get-metionable-icon'
function BadgeBase({
children,
onDelete,
onClick,
isFocused,
}: PropsWithChildren<{
onDelete: () => void
onClick: () => void
isFocused: boolean
}>) {
return (
<div
className={`infio-chat-user-input-file-badge ${isFocused ? 'infio-chat-user-input-file-badge-focused' : ''}`}
onClick={onClick}
>
{children}
<div
className="infio-chat-user-input-file-badge-delete"
onClick={(evt) => {
evt.stopPropagation()
onDelete()
}}
>
<X size={10} />
</div>
</div>
)
}
function FileBadge({
mentionable,
onDelete,
onClick,
isFocused,
}: {
mentionable: MentionableFile
onDelete: () => void
onClick: () => void
isFocused: boolean
}) {
const Icon = getMentionableIcon(mentionable)
return (
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
<div className="infio-chat-user-input-file-badge-name">
{Icon && (
<Icon
size={10}
className="infio-chat-user-input-file-badge-name-icon"
/>
)}
<span>{mentionable.file.name}</span>
</div>
</BadgeBase>
)
}
function FolderBadge({
mentionable,
onDelete,
onClick,
isFocused,
}: {
mentionable: MentionableFolder
onDelete: () => void
onClick: () => void
isFocused: boolean
}) {
const Icon = getMentionableIcon(mentionable)
return (
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
<div className="infio-chat-user-input-file-badge-name">
{Icon && (
<Icon
size={10}
className="infio-chat-user-input-file-badge-name-icon"
/>
)}
<span>{mentionable.folder.name}</span>
</div>
</BadgeBase>
)
}
function VaultBadge({
// eslint-disable-next-line @typescript-eslint/no-unused-vars
mentionable,
onDelete,
onClick,
isFocused,
}: {
mentionable: MentionableVault
onDelete: () => void
onClick: () => void
isFocused: boolean
}) {
const Icon = getMentionableIcon(mentionable)
return (
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
{/* TODO: Update style */}
<div className="infio-chat-user-input-file-badge-name">
{Icon && (
<Icon
size={10}
className="infio-chat-user-input-file-badge-name-icon"
/>
)}
<span>Vault</span>
</div>
</BadgeBase>
)
}
function CurrentFileBadge({
mentionable,
onDelete,
onClick,
isFocused,
}: {
mentionable: MentionableCurrentFile
onDelete: () => void
onClick: () => void
isFocused: boolean
}) {
const Icon = getMentionableIcon(mentionable)
return mentionable.file ? (
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
<div className="infio-chat-user-input-file-badge-name">
{Icon && (
<Icon
size={10}
className="infio-chat-user-input-file-badge-name-icon"
/>
)}
<span>{mentionable.file.name}</span>
</div>
<div className="infio-chat-user-input-file-badge-name-block-suffix">
{' (Current File)'}
</div>
</BadgeBase>
) : null
}
function BlockBadge({
mentionable,
onDelete,
onClick,
isFocused,
}: {
mentionable: MentionableBlock
onDelete: () => void
onClick: () => void
isFocused: boolean
}) {
const Icon = getMentionableIcon(mentionable)
return (
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
<div className="infio-chat-user-input-file-badge-name-block-name">
{Icon && (
<Icon
size={10}
className="infio-chat-user-input-file-badge-name-block-name-icon"
/>
)}
<span>{mentionable.file.name}</span>
</div>
<div className="infio-chat-user-input-file-badge-name-block-suffix">
{` (${mentionable.startLine}:${mentionable.endLine})`}
</div>
</BadgeBase>
)
}
function UrlBadge({
mentionable,
onDelete,
onClick,
isFocused,
}: {
mentionable: MentionableUrl
onDelete: () => void
onClick: () => void
isFocused: boolean
}) {
const Icon = getMentionableIcon(mentionable)
return (
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
<div className="infio-chat-user-input-file-badge-name">
{Icon && (
<Icon
size={10}
className="infio-chat-user-input-file-badge-name-icon"
/>
)}
<span>{mentionable.url}</span>
</div>
</BadgeBase>
)
}
function ImageBadge({
mentionable,
onDelete,
onClick,
isFocused,
}: {
mentionable: MentionableImage
onDelete: () => void
onClick: () => void
isFocused: boolean
}) {
const Icon = getMentionableIcon(mentionable)
return (
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
<div className="infio-chat-user-input-file-badge-name">
{Icon && (
<Icon
size={10}
className="infio-chat-user-input-file-badge-name-icon"
/>
)}
<span>{mentionable.name}</span>
</div>
</BadgeBase>
)
}
export default function MentionableBadge({
mentionable,
onDelete,
onClick,
isFocused = false,
}: {
mentionable: Mentionable
onDelete: () => void
onClick: () => void
isFocused?: boolean
}) {
switch (mentionable.type) {
case 'file':
return (
<FileBadge
mentionable={mentionable}
onDelete={onDelete}
onClick={onClick}
isFocused={isFocused}
/>
)
case 'folder':
return (
<FolderBadge
mentionable={mentionable}
onDelete={onDelete}
onClick={onClick}
isFocused={isFocused}
/>
)
case 'vault':
return (
<VaultBadge
mentionable={mentionable}
onDelete={onDelete}
onClick={onClick}
isFocused={isFocused}
/>
)
case 'current-file':
return (
<CurrentFileBadge
mentionable={mentionable}
onDelete={onDelete}
onClick={onClick}
isFocused={isFocused}
/>
)
case 'block':
return (
<BlockBadge
mentionable={mentionable}
onDelete={onDelete}
onClick={onClick}
isFocused={isFocused}
/>
)
case 'url':
return (
<UrlBadge
mentionable={mentionable}
onDelete={onDelete}
onClick={onClick}
isFocused={isFocused}
/>
)
case 'image':
return (
<ImageBadge
mentionable={mentionable}
onDelete={onDelete}
onClick={onClick}
isFocused={isFocused}
/>
)
}
}

View File

@@ -0,0 +1,51 @@
import * as DropdownMenu from '@radix-ui/react-dropdown-menu'
import { ChevronDown, ChevronUp } from 'lucide-react'
import { useState } from 'react'
import { useSettings } from '../../../contexts/SettingsContext'
export function ModelSelect() {
const { settings, setSettings } = useSettings()
const [isOpen, setIsOpen] = useState(false)
const activeModels = settings.activeModels.filter((model) => model.enabled)
return (
<DropdownMenu.Root open={isOpen} onOpenChange={setIsOpen}>
<DropdownMenu.Trigger className="infio-chat-input-model-select">
<div className="infio-chat-input-model-select__icon">
{isOpen ? <ChevronUp size={12} /> : <ChevronDown size={12} />}
</div>
<div className="infio-chat-input-model-select__model-name">
{
activeModels.find(
(option) => option.name === settings.chatModelId,
)?.name
}
</div>
</DropdownMenu.Trigger>
<DropdownMenu.Portal>
<DropdownMenu.Content
className="infio-popover">
<ul>
{activeModels.map((model) => (
<DropdownMenu.Item
key={model.name}
onSelect={() => {
setSettings({
...settings,
chatModelId: model.name,
})
}}
asChild
>
<li>{model.name}</li>
</DropdownMenu.Item>
))}
</ul>
</DropdownMenu.Content>
</DropdownMenu.Portal>
</DropdownMenu.Root>
)
}

View File

@@ -0,0 +1,12 @@
import { CornerDownLeftIcon } from 'lucide-react'
export function SubmitButton({ onClick }: { onClick: () => void }) {
return (
<button className="infio-chat-user-input-submit-button" onClick={onClick}>
<div>submit</div>
<div className="infio-chat-user-input-submit-button-icons">
<CornerDownLeftIcon size={12} />
</div>
</button>
)
}

View File

@@ -0,0 +1,42 @@
import * as Tooltip from '@radix-ui/react-tooltip'
import {
ArrowBigUp,
ChevronUp,
Command,
CornerDownLeftIcon,
} from 'lucide-react'
import { Platform } from 'obsidian'
export function VaultChatButton({ onClick }: { onClick: () => void }) {
return (
<>
<Tooltip.Provider delayDuration={0}>
<Tooltip.Root>
<Tooltip.Trigger asChild>
<button
className="infio-chat-user-input-vault-button"
onClick={onClick}
>
<div>vault</div>
<div className="infio-chat-user-input-vault-button-icons">
{Platform.isMacOS ? (
<Command size={10} />
) : (
<ChevronUp size={12} />
)}
{/* TODO: Replace with a custom icon */}
{/* <ArrowBigUp size={12} /> */}
<CornerDownLeftIcon size={12} />
</div>
</button>
</Tooltip.Trigger>
<Tooltip.Portal>
<Tooltip.Content className="infio-tooltip-content" sideOffset={5}>
Chat with your entire vault
</Tooltip.Content>
</Tooltip.Portal>
</Tooltip.Root>
</Tooltip.Provider>
</>
)
}

View File

@@ -0,0 +1,34 @@
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { DRAG_DROP_PASTE } from '@lexical/rich-text'
import { COMMAND_PRIORITY_LOW } from 'lexical'
import { useEffect } from 'react'
import { MentionableImage } from '../../../../../types/mentionable'
import { fileToMentionableImage } from '../../../../../utils/image'
export default function DragDropPaste({
onCreateImageMentionables,
}: {
onCreateImageMentionables?: (mentionables: MentionableImage[]) => void
}): null {
const [editor] = useLexicalComposerContext()
useEffect(() => {
return editor.registerCommand(
DRAG_DROP_PASTE, // dispatched in RichTextPlugin
(files) => {
; (async () => {
const images = files.filter((file) => file.type.startsWith('image/'))
const mentionableImages = await Promise.all(
images.map(async (image) => await fileToMentionableImage(image)),
)
onCreateImageMentionables?.(mentionableImages)
})()
return true
},
COMMAND_PRIORITY_LOW,
)
}, [editor, onCreateImageMentionables])
return null
}

View File

@@ -0,0 +1,42 @@
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { COMMAND_PRIORITY_LOW, PASTE_COMMAND, PasteCommandType } from 'lexical'
import { useEffect } from 'react'
import { MentionableImage } from '../../../../../types/mentionable'
import { fileToMentionableImage } from '../../../../../utils/image'
export default function ImagePastePlugin({
onCreateImageMentionables,
}: {
onCreateImageMentionables?: (mentionables: MentionableImage[]) => void
}) {
const [editor] = useLexicalComposerContext()
useEffect(() => {
const handlePaste = (event: PasteCommandType) => {
const clipboardData =
event instanceof ClipboardEvent ? event.clipboardData : null
if (!clipboardData) return false
const images = Array.from(clipboardData.files).filter((file) =>
file.type.startsWith('image/'),
)
if (images.length === 0) return false
Promise.all(images.map((image) => fileToMentionableImage(image))).then(
(mentionableImages) => {
onCreateImageMentionables?.(mentionableImages)
},
)
return true
}
return editor.registerCommand(
PASTE_COMMAND,
handlePaste,
COMMAND_PRIORITY_LOW,
)
}, [editor, onCreateImageMentionables])
return null
}

View File

@@ -0,0 +1,178 @@
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import {
$createTextNode,
$getSelection,
$isRangeSelection,
COMMAND_PRIORITY_LOW,
PASTE_COMMAND,
PasteCommandType,
TextNode,
} from 'lexical'
import { useEffect } from 'react'
import { Mentionable, MentionableUrl } from '../../../../../types/mentionable'
import {
getMentionableName,
serializeMentionable,
} from '../../../../../utils/mentionable'
import { $createMentionNode } from './MentionNode'
const URL_MATCHER =
/^((https?:\/\/(www\.)?)|(www\.))[-a-zA-Z0-9@:%._+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_+.~#?&//=]*)$/
type URLMatch = {
index: number
length: number
text: string
url: string
}
function findURLs(text: string): URLMatch[] {
const urls: URLMatch[] = []
let lastIndex = 0
for (const word of text.split(' ')) {
if (URL_MATCHER.test(word)) {
urls.push({
index: lastIndex,
length: word.length,
text: word,
url: word.startsWith('http') ? word : `https://${word}`,
// attributes: { rel: 'noreferrer', target: '_blank' }, // Optional link attributes
})
}
lastIndex += word.length + 1 // +1 for space
}
return urls
}
function $textNodeTransform(node: TextNode) {
if (!node.isSimpleText()) {
return
}
const text = node.getTextContent()
// Find only 1st occurrence as transform will be re-run anyway for the rest
// because newly inserted nodes are considered to be dirty
const urlMatches = findURLs(text)
if (urlMatches.length === 0) {
return
}
const urlMatch = urlMatches[0]
// Get the current selection
const selection = $getSelection()
// Check if the selection is a RangeSelection and the cursor is at the end of the URL
if (
$isRangeSelection(selection) &&
selection.anchor.key === node.getKey() &&
selection.focus.key === node.getKey() &&
selection.anchor.offset === urlMatch.index + urlMatch.length &&
selection.focus.offset === urlMatch.index + urlMatch.length
) {
// If the cursor is at the end of the URL, don't transform
return
}
let targetNode
if (urlMatch.index === 0) {
// First text chunk within string, splitting into 2 parts
;[targetNode] = node.splitText(urlMatch.index + urlMatch.length)
} else {
// In the middle of a string
;[, targetNode] = node.splitText(
urlMatch.index,
urlMatch.index + urlMatch.length,
)
}
const mentionable: MentionableUrl = {
type: 'url',
url: urlMatch.url,
}
const mentionNode = $createMentionNode(
getMentionableName(mentionable),
serializeMentionable(mentionable),
)
targetNode.replace(mentionNode)
const spaceNode = $createTextNode(' ')
mentionNode.insertAfter(spaceNode)
spaceNode.select()
}
function $handlePaste(event: PasteCommandType) {
const clipboardData =
event instanceof ClipboardEvent ? event.clipboardData : null
if (!clipboardData) return false
const text = clipboardData.getData('text/plain')
const urlMatches = findURLs(text)
if (urlMatches.length === 0) {
return false
}
const selection = $getSelection()
if (!$isRangeSelection(selection)) {
return false
}
const nodes = []
const addedMentionables: Mentionable[] = []
let lastIndex = 0
urlMatches.forEach((urlMatch) => {
// Add text node for unmatched part
if (urlMatch.index > lastIndex) {
nodes.push($createTextNode(text.slice(lastIndex, urlMatch.index)))
}
const mentionable: MentionableUrl = {
type: 'url',
url: urlMatch.url,
}
// Add mention node
nodes.push(
$createMentionNode(urlMatch.text, serializeMentionable(mentionable)),
)
addedMentionables.push(mentionable)
lastIndex = urlMatch.index + urlMatch.length
// Add space node after mention if next character is not space or end of string
if (lastIndex >= text.length || text[lastIndex] !== ' ') {
nodes.push($createTextNode(' '))
}
})
// Add remaining text if any
if (lastIndex < text.length) {
nodes.push($createTextNode(text.slice(lastIndex)))
}
selection.insertNodes(nodes)
return true
}
export default function AutoLinkMentionPlugin() {
const [editor] = useLexicalComposerContext()
useEffect(() => {
editor.registerCommand(PASTE_COMMAND, $handlePaste, COMMAND_PRIORITY_LOW)
editor.registerNodeTransform(TextNode, $textNodeTransform)
}, [editor])
return null
}

View File

@@ -0,0 +1,176 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license.
* Original source: https://github.com/facebook/lexical
*
* Modified from the original code
*/
import {
$applyNodeReplacement,
DOMConversionMap,
DOMConversionOutput,
DOMExportOutput,
type EditorConfig,
type LexicalNode,
type NodeKey,
type SerializedTextNode,
type Spread,
TextNode,
} from 'lexical'
import { SerializedMentionable } from '../../../../../types/mentionable'
export const MENTION_NODE_TYPE = 'mention'
export const MENTION_NODE_ATTRIBUTE = 'data-lexical-mention'
export const MENTION_NODE_MENTION_NAME_ATTRIBUTE = 'data-lexical-mention-name'
export const MENTION_NODE_MENTIONABLE_ATTRIBUTE = 'data-lexical-mentionable'
export type SerializedMentionNode = Spread<
{
mentionName: string
mentionable: SerializedMentionable
},
SerializedTextNode
>
function $convertMentionElement(
domNode: HTMLElement,
): DOMConversionOutput | null {
const textContent = domNode.textContent
const mentionName =
domNode.getAttribute(MENTION_NODE_MENTION_NAME_ATTRIBUTE) ??
domNode.textContent ??
''
const mentionable = JSON.parse(
domNode.getAttribute(MENTION_NODE_MENTIONABLE_ATTRIBUTE) ?? '{}',
)
if (textContent !== null) {
const node = $createMentionNode(
mentionName,
mentionable as SerializedMentionable,
)
return {
node,
}
}
return null
}
export class MentionNode extends TextNode {
__mentionName: string
__mentionable: SerializedMentionable
static getType(): string {
return MENTION_NODE_TYPE
}
static clone(node: MentionNode): MentionNode {
return new MentionNode(node.__mentionName, node.__mentionable, node.__key)
}
static importJSON(serializedNode: SerializedMentionNode): MentionNode {
const node = $createMentionNode(
serializedNode.mentionName,
serializedNode.mentionable,
)
node.setTextContent(serializedNode.text)
node.setFormat(serializedNode.format)
node.setDetail(serializedNode.detail)
node.setMode(serializedNode.mode)
node.setStyle(serializedNode.style)
return node
}
constructor(
mentionName: string,
mentionable: SerializedMentionable,
key?: NodeKey,
) {
super(`@${mentionName}`, key)
this.__mentionName = mentionName
this.__mentionable = mentionable
}
exportJSON(): SerializedMentionNode {
return {
...super.exportJSON(),
mentionName: this.__mentionName,
mentionable: this.__mentionable,
type: MENTION_NODE_TYPE,
version: 1,
}
}
createDOM(config: EditorConfig): HTMLElement {
const dom = super.createDOM(config)
dom.className = MENTION_NODE_TYPE
return dom
}
exportDOM(): DOMExportOutput {
const element = document.createElement('span')
element.setAttribute(MENTION_NODE_ATTRIBUTE, 'true')
element.setAttribute(
MENTION_NODE_MENTION_NAME_ATTRIBUTE,
this.__mentionName,
)
element.setAttribute(
MENTION_NODE_MENTIONABLE_ATTRIBUTE,
JSON.stringify(this.__mentionable),
)
element.textContent = this.__text
return { element }
}
static importDOM(): DOMConversionMap | null {
return {
span: (domNode: HTMLElement) => {
if (
!domNode.hasAttribute(MENTION_NODE_ATTRIBUTE) ||
!domNode.hasAttribute(MENTION_NODE_MENTION_NAME_ATTRIBUTE) ||
!domNode.hasAttribute(MENTION_NODE_MENTIONABLE_ATTRIBUTE)
) {
return null
}
return {
conversion: $convertMentionElement,
priority: 1,
}
},
}
}
isTextEntity(): true {
return true
}
canInsertTextBefore(): boolean {
return false
}
canInsertTextAfter(): boolean {
return false
}
getMentionable(): SerializedMentionable {
return this.__mentionable
}
}
export function $createMentionNode(
mentionName: string,
mentionable: SerializedMentionable,
): MentionNode {
const mentionNode = new MentionNode(mentionName, mentionable)
mentionNode.setMode('token').toggleDirectionless()
return $applyNodeReplacement(mentionNode)
}
export function $isMentionNode(
node: LexicalNode | null | undefined,
): node is MentionNode {
return node instanceof MentionNode
}

View File

@@ -0,0 +1,273 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license.
* Original source: https://github.com/facebook/lexical
*
* Modified from the original code
*/
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { $createTextNode, COMMAND_PRIORITY_NORMAL, TextNode } from 'lexical'
import { useCallback, useMemo, useState } from 'react'
import { createPortal } from 'react-dom'
import { Mentionable } from '../../../../../types/mentionable'
import { SearchableMentionable } from '../../../../../utils/fuzzy-search'
import {
getMentionableName,
serializeMentionable,
} from '../../../../../utils/mentionable'
import { getMentionableIcon } from '../../utils/get-metionable-icon'
import { MenuOption, MenuTextMatch } from '../shared/LexicalMenu'
import {
LexicalTypeaheadMenuPlugin,
useBasicTypeaheadTriggerMatch,
} from '../typeahead-menu/LexicalTypeaheadMenuPlugin'
import { $createMentionNode } from './MentionNode'
const PUNCTUATION =
'\\.,\\+\\*\\?\\$\\@\\|#{}\\(\\)\\^\\-\\[\\]\\\\/!%\'"~=<>_:;'
const NAME = '\\b[A-Z][^\\s' + PUNCTUATION + ']'
const DocumentMentionsRegex = {
NAME,
PUNCTUATION,
}
const PUNC = DocumentMentionsRegex.PUNCTUATION
const TRIGGERS = ['@'].join('')
// Chars we expect to see in a mention (non-space, non-punctuation).
const VALID_CHARS = '[^' + TRIGGERS + PUNC + '\\s]'
// Non-standard series of chars. Each series must be preceded and followed by
// a valid char.
const VALID_JOINS =
'(?:' +
'\\.[ |$]|' + // E.g. "r. " in "Mr. Smith"
' |' + // E.g. " " in "Josh Duck"
'[' +
PUNC +
']|' + // E.g. "-' in "Salier-Hellendag"
')'
const LENGTH_LIMIT = 75
const AtSignMentionsRegex = new RegExp(
`(^|\\s|\\()([${TRIGGERS}]((?:${VALID_CHARS}${VALID_JOINS}){0,${LENGTH_LIMIT}}))$`,
)
// 50 is the longest alias length limit.
const ALIAS_LENGTH_LIMIT = 50
// Regex used to match alias.
const AtSignMentionsRegexAliasRegex = new RegExp(
`(^|\\s|\\()([${TRIGGERS}]((?:${VALID_CHARS}){0,${ALIAS_LENGTH_LIMIT}}))$`,
)
// At most, 20 suggestions are shown in the popup.
const SUGGESTION_LIST_LENGTH_LIMIT = 20
function checkForAtSignMentions(
text: string,
minMatchLength: number,
): MenuTextMatch | null {
let match = AtSignMentionsRegex.exec(text)
if (match === null) {
match = AtSignMentionsRegexAliasRegex.exec(text)
}
if (match !== null) {
// The strategy ignores leading whitespace but we need to know it's
// length to add it to the leadOffset
const maybeLeadingWhitespace = match[1]
const matchingString = match[3]
if (matchingString.length >= minMatchLength) {
return {
leadOffset: match.index + maybeLeadingWhitespace.length,
matchingString,
replaceableString: match[2],
}
}
}
return null
}
function getPossibleQueryMatch(text: string): MenuTextMatch | null {
return checkForAtSignMentions(text, 0)
}
class MentionTypeaheadOption extends MenuOption {
name: string
mentionable: Mentionable
icon: React.ReactNode
constructor(result: SearchableMentionable) {
switch (result.type) {
case 'file':
super(result.file.path)
this.name = result.file.name
this.mentionable = result
break
case 'folder':
super(result.folder.path)
this.name = result.folder.name
this.mentionable = result
break
case 'vault':
super('vault')
this.name = 'Vault'
this.mentionable = result
break
}
}
}
function MentionsTypeaheadMenuItem({
index,
isSelected,
onClick,
onMouseEnter,
option,
}: {
index: number
isSelected: boolean
onClick: () => void
onMouseEnter: () => void
option: MentionTypeaheadOption
}) {
let className = 'item'
if (isSelected) {
className += ' selected'
}
const Icon = getMentionableIcon(option.mentionable)
return (
<li
key={option.key}
tabIndex={-1}
className={className}
ref={(el) => option.setRefElement(el)}
role="option"
aria-selected={isSelected}
id={`typeahead-item-${index}`}
onMouseEnter={onMouseEnter}
onClick={onClick}
>
{Icon && <Icon size={14} className="infio-popover-item-icon" />}
<span className="text">{option.name}</span>
</li>
)
}
export default function NewMentionsPlugin({
searchResultByQuery,
}: {
searchResultByQuery: (query: string) => SearchableMentionable[]
}): JSX.Element | null {
const [editor] = useLexicalComposerContext()
const [queryString, setQueryString] = useState<string | null>(null)
const results = useMemo(() => {
if (queryString == null) return []
return searchResultByQuery(queryString)
}, [queryString, searchResultByQuery])
const checkForSlashTriggerMatch = useBasicTypeaheadTriggerMatch('/', {
minLength: 0,
})
const options = useMemo(
() =>
results
.map((result) => new MentionTypeaheadOption(result))
.slice(0, SUGGESTION_LIST_LENGTH_LIMIT),
[results],
)
const onSelectOption = useCallback(
(
selectedOption: MentionTypeaheadOption,
nodeToReplace: TextNode | null,
closeMenu: () => void,
) => {
editor.update(() => {
const mentionNode = $createMentionNode(
getMentionableName(selectedOption.mentionable),
serializeMentionable(selectedOption.mentionable),
)
if (nodeToReplace) {
nodeToReplace.replace(mentionNode)
}
const spaceNode = $createTextNode(' ')
mentionNode.insertAfter(spaceNode)
spaceNode.select()
closeMenu()
})
},
[editor],
)
const checkForMentionMatch = useCallback(
(text: string) => {
const slashMatch = checkForSlashTriggerMatch(text, editor)
if (slashMatch !== null) {
return null
}
return getPossibleQueryMatch(text)
},
[checkForSlashTriggerMatch, editor],
)
return (
<LexicalTypeaheadMenuPlugin<MentionTypeaheadOption>
onQueryChange={setQueryString}
onSelectOption={onSelectOption}
triggerFn={checkForMentionMatch}
options={options}
commandPriority={COMMAND_PRIORITY_NORMAL}
menuRenderFn={(
anchorElementRef,
{ selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
) =>
anchorElementRef.current && results.length
? createPortal(
<div
className="infio-popover"
style={{
position: 'fixed',
}}
>
<ul>
{options.map((option, i: number) => (
<MentionsTypeaheadMenuItem
index={i}
isSelected={selectedIndex === i}
onClick={() => {
setHighlightedIndex(i)
selectOptionAndCleanUp(option)
}}
onMouseEnter={() => {
setHighlightedIndex(i)
}}
key={option.key}
option={option}
/>
))}
</ul>
</div>,
anchorElementRef.current,
)
: null
}
/>
)
}

View File

@@ -0,0 +1,17 @@
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { TextNode } from 'lexical'
import { useEffect } from 'react'
export default function NoFormatPlugin() {
const [editor] = useLexicalComposerContext()
useEffect(() => {
editor.registerNodeTransform(TextNode, (node) => {
if (node.getFormat() !== 0) {
node.setFormat(0)
}
})
}, [editor])
return null
}

View File

@@ -0,0 +1,46 @@
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { COMMAND_PRIORITY_LOW, KEY_ENTER_COMMAND } from 'lexical'
import { Platform } from 'obsidian'
import { useEffect } from 'react'
export default function OnEnterPlugin({
onEnter,
onVaultChat,
}: {
onEnter: (evt: KeyboardEvent) => void
onVaultChat?: () => void
}) {
const [editor] = useLexicalComposerContext()
useEffect(() => {
const removeListener = editor.registerCommand(
KEY_ENTER_COMMAND,
(evt: KeyboardEvent) => {
console.log('onEnter', evt)
if (
onVaultChat &&
(Platform.isMacOS ? evt.metaKey : evt.ctrlKey)
) {
evt.preventDefault()
evt.stopPropagation()
onVaultChat()
return true
}
if (evt.shiftKey) {
return false
}
evt.preventDefault()
evt.stopPropagation()
onEnter(evt)
return true
},
COMMAND_PRIORITY_LOW,
)
return () => {
removeListener()
}
}, [editor, onEnter, onVaultChat])
return null
}

View File

@@ -0,0 +1,46 @@
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { Klass, LexicalNode, NodeKey, NodeMutation } from 'lexical'
import { useEffect } from 'react'
export type NodeMutations<T> = Map<NodeKey, { mutation: NodeMutation; node: T }>
export default function OnMutationPlugin<T extends LexicalNode>({
nodeClass,
onMutation,
}: {
nodeClass: Klass<T>
onMutation: (mutations: NodeMutations<T>) => void
}) {
const [editor] = useLexicalComposerContext()
useEffect(() => {
const removeListener = editor.registerMutationListener(
nodeClass,
(mutatedNodes, payload) => {
const editorState = editor.getEditorState()
const mutations = new Map<
NodeKey,
{ mutation: NodeMutation; node: T }
>()
for (const [key, mutation] of mutatedNodes) {
mutations.set(key, {
mutation,
node:
mutation === 'destroyed'
? (payload.prevEditorState._nodeMap.get(key) as T)
: (editorState._nodeMap.get(key) as T),
})
}
onMutation(mutations)
},
)
return () => {
removeListener()
}
}, [editor, nodeClass, onMutation])
return null
}

View File

@@ -0,0 +1,597 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license.
* Original source: https://github.com/facebook/lexical
*
* Modified from the original code
* - Added custom positioning logic for menu placement
*/
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { mergeRegister } from '@lexical/utils'
import {
$getSelection,
$isRangeSelection,
COMMAND_PRIORITY_LOW,
CommandListenerPriority,
KEY_ARROW_DOWN_COMMAND,
KEY_ARROW_UP_COMMAND,
KEY_ENTER_COMMAND,
KEY_ESCAPE_COMMAND,
KEY_TAB_COMMAND,
LexicalCommand,
LexicalEditor,
TextNode,
createCommand,
} from 'lexical'
import {
MutableRefObject,
ReactPortal,
useCallback,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from 'react'
export type MenuTextMatch = {
leadOffset: number
matchingString: string
replaceableString: string
}
export type MenuResolution = {
match?: MenuTextMatch
getRect: () => DOMRect
}
export const PUNCTUATION =
'\\.,\\+\\*\\?\\$\\@\\|#{}\\(\\)\\^\\-\\[\\]\\\\/!%\'"~=<>_:;'
export class MenuOption {
key: string
ref?: MutableRefObject<HTMLElement | null>
constructor(key: string) {
this.key = key
this.ref = { current: null }
this.setRefElement = this.setRefElement.bind(this)
}
setRefElement(element: HTMLElement | null) {
this.ref = { current: element }
}
}
export type MenuRenderFn<TOption extends MenuOption> = (
anchorElementRef: MutableRefObject<HTMLElement | null>,
itemProps: {
selectedIndex: number | null
selectOptionAndCleanUp: (option: TOption) => void
setHighlightedIndex: (index: number) => void
options: TOption[]
},
matchingString: string | null,
) => ReactPortal | JSX.Element | null
const scrollIntoViewIfNeeded = (target: HTMLElement) => {
const typeaheadContainerNode = document.getElementById('typeahead-menu')
if (!typeaheadContainerNode) {
return
}
const typeaheadRect = typeaheadContainerNode.getBoundingClientRect()
if (typeaheadRect.top + typeaheadRect.height > window.innerHeight) {
typeaheadContainerNode.scrollIntoView({
block: 'center',
})
}
if (typeaheadRect.top < 0) {
typeaheadContainerNode.scrollIntoView({
block: 'center',
})
}
target.scrollIntoView({ block: 'nearest' })
}
/**
* Walk backwards along user input and forward through entity title to try
* and replace more of the user's text with entity.
*/
function getFullMatchOffset(
documentText: string,
entryText: string,
offset: number,
): number {
let triggerOffset = offset
for (let i = triggerOffset; i <= entryText.length; i++) {
if (documentText.substr(-i) === entryText.substr(0, i)) {
triggerOffset = i
}
}
return triggerOffset
}
/**
* Split Lexical TextNode and return a new TextNode only containing matched text.
* Common use cases include: removing the node, replacing with a new node.
*/
function $splitNodeContainingQuery(match: MenuTextMatch): TextNode | null {
const selection = $getSelection()
if (!$isRangeSelection(selection) || !selection.isCollapsed()) {
return null
}
const anchor = selection.anchor
if (anchor.type !== 'text') {
return null
}
const anchorNode = anchor.getNode()
if (!anchorNode.isSimpleText()) {
return null
}
const selectionOffset = anchor.offset
const textContent = anchorNode.getTextContent().slice(0, selectionOffset)
const characterOffset = match.replaceableString.length
const queryOffset = getFullMatchOffset(
textContent,
match.matchingString,
characterOffset,
)
const startOffset = selectionOffset - queryOffset
if (startOffset < 0) {
return null
}
let newNode
if (startOffset === 0) {
;[newNode] = anchorNode.splitText(selectionOffset)
} else {
;[, newNode] = anchorNode.splitText(startOffset, selectionOffset)
}
return newNode
}
// Got from https://stackoverflow.com/a/42543908/2013580
export function getScrollParent(
element: HTMLElement,
includeHidden: boolean,
): HTMLElement | HTMLBodyElement {
let style = getComputedStyle(element)
const excludeStaticParent = style.position === 'absolute'
const overflowRegex = includeHidden ? /(auto|scroll|hidden)/ : /(auto|scroll)/
if (style.position === 'fixed') {
return document.body
}
for (
let parent: HTMLElement | null = element;
(parent = parent.parentElement);
) {
style = getComputedStyle(parent)
if (excludeStaticParent && style.position === 'static') {
continue
}
if (
overflowRegex.test(style.overflow + style.overflowY + style.overflowX)
) {
return parent
}
}
return document.body
}
function isTriggerVisibleInNearestScrollContainer(
targetElement: HTMLElement,
containerElement: HTMLElement,
): boolean {
const tRect = targetElement.getBoundingClientRect()
const cRect = containerElement.getBoundingClientRect()
return tRect.top > cRect.top && tRect.top < cRect.bottom
}
// Reposition the menu on scroll, window resize, and element resize.
export function useDynamicPositioning(
resolution: MenuResolution | null,
targetElement: HTMLElement | null,
onReposition: () => void,
onVisibilityChange?: (isInView: boolean) => void,
) {
const [editor] = useLexicalComposerContext()
useEffect(() => {
if (targetElement != null && resolution != null) {
const rootElement = editor.getRootElement()
const rootScrollParent =
rootElement != null
? getScrollParent(rootElement, false)
: document.body
let ticking = false
let previousIsInView = isTriggerVisibleInNearestScrollContainer(
targetElement,
rootScrollParent,
)
const handleScroll = function () {
if (!ticking) {
window.requestAnimationFrame(function () {
onReposition()
ticking = false
})
ticking = true
}
const isInView = isTriggerVisibleInNearestScrollContainer(
targetElement,
rootScrollParent,
)
if (isInView !== previousIsInView) {
previousIsInView = isInView
if (onVisibilityChange != null) {
onVisibilityChange(isInView)
}
}
}
const resizeObserver = new ResizeObserver(onReposition)
window.addEventListener('resize', onReposition)
document.addEventListener('scroll', handleScroll, {
capture: true,
passive: true,
})
resizeObserver.observe(targetElement)
return () => {
resizeObserver.unobserve(targetElement)
window.removeEventListener('resize', onReposition)
document.removeEventListener('scroll', handleScroll, true)
}
}
}, [targetElement, editor, onVisibilityChange, onReposition, resolution])
}
export const SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND: LexicalCommand<{
index: number
option: MenuOption
}> = createCommand('SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND')
export function LexicalMenu<TOption extends MenuOption>({
close,
editor,
anchorElementRef,
resolution,
options,
menuRenderFn,
onSelectOption,
shouldSplitNodeWithQuery = false,
commandPriority = COMMAND_PRIORITY_LOW,
}: {
close: () => void
editor: LexicalEditor
anchorElementRef: MutableRefObject<HTMLElement>
resolution: MenuResolution
options: TOption[]
shouldSplitNodeWithQuery?: boolean
menuRenderFn: MenuRenderFn<TOption>
onSelectOption: (
option: TOption,
textNodeContainingQuery: TextNode | null,
closeMenu: () => void,
matchingString: string,
) => void
commandPriority?: CommandListenerPriority
}): JSX.Element | null {
const [selectedIndex, setHighlightedIndex] = useState<null | number>(null)
const matchingString = resolution.match?.matchingString
useEffect(() => {
setHighlightedIndex(0)
}, [matchingString])
const selectOptionAndCleanUp = useCallback(
(selectedEntry: TOption) => {
editor.update(() => {
const textNodeContainingQuery =
resolution.match != null && shouldSplitNodeWithQuery
? $splitNodeContainingQuery(resolution.match)
: null
onSelectOption(
selectedEntry,
textNodeContainingQuery,
close,
resolution.match ? resolution.match.matchingString : '',
)
})
},
[editor, shouldSplitNodeWithQuery, resolution.match, onSelectOption, close],
)
const updateSelectedIndex = useCallback(
(index: number) => {
const rootElem = editor.getRootElement()
if (rootElem !== null) {
rootElem.setAttribute(
'aria-activedescendant',
`typeahead-item-${index}`,
)
setHighlightedIndex(index)
}
},
[editor],
)
useEffect(() => {
return () => {
const rootElem = editor.getRootElement()
if (rootElem !== null) {
rootElem.removeAttribute('aria-activedescendant')
}
}
}, [editor])
useLayoutEffect(() => {
if (options === null) {
setHighlightedIndex(null)
} else if (selectedIndex === null) {
updateSelectedIndex(0)
}
}, [options, selectedIndex, updateSelectedIndex])
useEffect(() => {
return mergeRegister(
editor.registerCommand(
SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND,
({ option }) => {
if (option.ref?.current != null) {
scrollIntoViewIfNeeded(option.ref.current)
return true
}
return false
},
commandPriority,
),
)
}, [editor, updateSelectedIndex, commandPriority])
useEffect(() => {
return mergeRegister(
editor.registerCommand<KeyboardEvent>(
KEY_ARROW_DOWN_COMMAND,
(payload) => {
const event = payload
if (options?.length && selectedIndex !== null) {
const newSelectedIndex =
selectedIndex !== options.length - 1 ? selectedIndex + 1 : 0
updateSelectedIndex(newSelectedIndex)
const option = options[newSelectedIndex]
if (option.ref?.current != null) {
editor.dispatchCommand(
SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND,
{
index: newSelectedIndex,
option,
},
)
}
event.preventDefault()
event.stopImmediatePropagation()
}
return true
},
commandPriority,
),
editor.registerCommand<KeyboardEvent>(
KEY_ARROW_UP_COMMAND,
(payload) => {
const event = payload
if (options?.length && selectedIndex !== null) {
const newSelectedIndex =
selectedIndex !== 0 ? selectedIndex - 1 : options.length - 1
updateSelectedIndex(newSelectedIndex)
const option = options[newSelectedIndex]
if (option.ref?.current != null) {
scrollIntoViewIfNeeded(option.ref.current)
}
event.preventDefault()
event.stopImmediatePropagation()
}
return true
},
commandPriority,
),
editor.registerCommand<KeyboardEvent>(
KEY_ESCAPE_COMMAND,
(payload) => {
const event = payload
event.preventDefault()
event.stopImmediatePropagation()
close()
return true
},
commandPriority,
),
editor.registerCommand<KeyboardEvent>(
KEY_TAB_COMMAND,
(payload) => {
const event = payload
if (
options === null ||
selectedIndex === null ||
options[selectedIndex] == null
) {
return false
}
event.preventDefault()
event.stopImmediatePropagation()
selectOptionAndCleanUp(options[selectedIndex])
return true
},
commandPriority,
),
editor.registerCommand(
KEY_ENTER_COMMAND,
(event: KeyboardEvent | null) => {
if (
options === null ||
selectedIndex === null ||
options[selectedIndex] == null
) {
return false
}
if (event !== null) {
event.preventDefault()
event.stopImmediatePropagation()
}
selectOptionAndCleanUp(options[selectedIndex])
return true
},
commandPriority,
),
)
}, [
selectOptionAndCleanUp,
close,
editor,
options,
selectedIndex,
updateSelectedIndex,
commandPriority,
])
const listItemProps = useMemo(
() => ({
options,
selectOptionAndCleanUp,
selectedIndex,
setHighlightedIndex,
}),
[selectOptionAndCleanUp, selectedIndex, options],
)
return menuRenderFn(
anchorElementRef,
listItemProps,
resolution.match ? resolution.match.matchingString : '',
)
}
export function useMenuAnchorRef(
resolution: MenuResolution | null,
setResolution: (r: MenuResolution | null) => void,
className?: string,
parent: HTMLElement = document.body,
shouldIncludePageYOffset__EXPERIMENTAL = true,
): MutableRefObject<HTMLElement> {
const [editor] = useLexicalComposerContext()
const anchorElementRef = useRef<HTMLElement>(document.createElement('div'))
const positionMenu = useCallback(() => {
anchorElementRef.current.style.top = anchorElementRef.current.style.bottom
const rootElement = editor.getRootElement()
const containerDiv = anchorElementRef.current
const menuEle = containerDiv.firstChild as HTMLElement
if (rootElement !== null && resolution !== null) {
const { left, top, width, height } = resolution.getRect()
const anchorHeight = anchorElementRef.current.offsetHeight // use to position under anchor
containerDiv.style.top = `${top +
anchorHeight +
3 +
(shouldIncludePageYOffset__EXPERIMENTAL ? window.pageYOffset : 0)
}px`
containerDiv.style.left = `${left + window.pageXOffset}px`
containerDiv.style.height = `${height}px`
containerDiv.style.width = `${width}px`
if (menuEle !== null) {
menuEle.style.top = `${top}`
const menuRect = menuEle.getBoundingClientRect()
const menuHeight = menuRect.height
const menuWidth = menuRect.width
const rootElementRect = rootElement.getBoundingClientRect()
if (left + menuWidth > rootElementRect.right) {
containerDiv.style.left = `${rootElementRect.right - menuWidth + window.pageXOffset
}px`
}
if (
// If it exceeds the window height, it should always be displayed above, but the original code checks if it doesn't exceed the editor's top as well. So I modified it.
// (top + menuHeight > window.innerHeight ||
// top + menuHeight > rootElementRect.bottom) &&
// top - rootElementRect.top > menuHeight + height
top + menuHeight >
window.innerHeight
) {
containerDiv.style.top = `${top -
menuHeight -
height +
(shouldIncludePageYOffset__EXPERIMENTAL ? window.pageYOffset : 0)
}px`
}
}
if (!containerDiv.isConnected) {
if (className != null) {
containerDiv.className = className
}
containerDiv.setAttribute('aria-label', 'Typeahead menu')
containerDiv.setAttribute('id', 'typeahead-menu')
containerDiv.setAttribute('role', 'listbox')
containerDiv.style.display = 'block'
containerDiv.style.position = 'absolute'
parent.append(containerDiv)
}
anchorElementRef.current = containerDiv
rootElement.setAttribute('aria-controls', 'typeahead-menu')
}
}, [
editor,
resolution,
shouldIncludePageYOffset__EXPERIMENTAL,
className,
parent,
])
useEffect(() => {
const rootElement = editor.getRootElement()
if (resolution !== null) {
positionMenu()
return () => {
if (rootElement !== null) {
rootElement.removeAttribute('aria-controls')
}
const containerDiv = anchorElementRef.current
if (containerDiv?.isConnected) {
containerDiv.remove()
}
}
}
}, [editor, positionMenu, resolution])
const onVisibilityChange = useCallback(
(isInView: boolean) => {
if (resolution !== null) {
if (!isInView) {
setResolution(null)
}
}
},
[resolution, setResolution],
)
useDynamicPositioning(
resolution,
anchorElementRef.current,
positionMenu,
onVisibilityChange,
)
return anchorElementRef
}
export type TriggerFn = (
text: string,
editor: LexicalEditor,
) => MenuTextMatch | null

View File

@@ -0,0 +1,146 @@
import { $generateJSONFromSelectedNodes } from '@lexical/clipboard'
import { BaseSerializedNode } from '@lexical/clipboard/clipboard'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import * as Dialog from '@radix-ui/react-dialog'
import {
$getSelection,
COMMAND_PRIORITY_LOW,
SELECTION_CHANGE_COMMAND,
} from 'lexical'
import { CSSProperties, useCallback, useEffect, useRef, useState } from 'react'
import CreateTemplateDialogContent from '../../../CreateTemplateDialog'
export default function CreateTemplatePopoverPlugin({
anchorElement,
contentEditableElement,
}: {
anchorElement: HTMLElement | null
contentEditableElement: HTMLElement | null
}): JSX.Element | null {
const [editor] = useLexicalComposerContext()
const [popoverStyle, setPopoverStyle] = useState<CSSProperties | null>(null)
const [isPopoverOpen, setIsPopoverOpen] = useState(false)
const [isDialogOpen, setIsDialogOpen] = useState(false)
const [selectedSerializedNodes, setSelectedSerializedNodes] = useState<
BaseSerializedNode[] | null
>(null)
const popoverRef = useRef<HTMLButtonElement>(null)
const getSelectedSerializedNodes = useCallback(():
| BaseSerializedNode[]
| null => {
if (!editor) return null
let selectedNodes: BaseSerializedNode[] | null = null
editor.update(() => {
const selection = $getSelection()
if (!selection) return
selectedNodes = $generateJSONFromSelectedNodes(editor, selection).nodes
if (selectedNodes.length === 0) return null
})
return selectedNodes
}, [editor])
const updatePopoverPosition = useCallback(() => {
if (!anchorElement || !contentEditableElement) return
const nativeSelection = document.getSelection()
const range = nativeSelection?.getRangeAt(0)
if (!range || range.collapsed) {
setIsPopoverOpen(false)
return
}
if (!contentEditableElement.contains(range.commonAncestorContainer)) {
setIsPopoverOpen(false)
return
}
const rects = Array.from(range.getClientRects())
if (rects.length === 0) {
setIsPopoverOpen(false)
return
}
const anchorRect = anchorElement.getBoundingClientRect()
const idealLeft = rects[rects.length - 1].right - anchorRect.left
const paddingX = 8
const paddingY = 4
const minLeft = (popoverRef.current?.offsetWidth ?? 0) + paddingX
const finalLeft = Math.max(minLeft, idealLeft)
setPopoverStyle({
top: rects[rects.length - 1].bottom - anchorRect.top + paddingY,
left: finalLeft,
transform: 'translate(-100%, 0)',
})
setIsPopoverOpen(true)
}, [anchorElement, contentEditableElement])
useEffect(() => {
const removeSelectionChangeListener = editor.registerCommand(
SELECTION_CHANGE_COMMAND,
() => {
updatePopoverPosition()
return false
},
COMMAND_PRIORITY_LOW,
)
return () => {
removeSelectionChangeListener()
}
}, [editor, updatePopoverPosition])
useEffect(() => {
// Update popover position when the content is cleared
// (Selection change event doesn't fire in this case)
if (!isPopoverOpen) return
const removeTextContentChangeListener = editor.registerTextContentListener(
() => {
updatePopoverPosition()
},
)
return () => {
removeTextContentChangeListener()
}
}, [editor, isPopoverOpen, updatePopoverPosition])
useEffect(() => {
if (!contentEditableElement) return
const handleScroll = () => {
updatePopoverPosition()
}
contentEditableElement.addEventListener('scroll', handleScroll)
return () => {
contentEditableElement.removeEventListener('scroll', handleScroll)
}
}, [contentEditableElement, updatePopoverPosition])
return (
<Dialog.Root
modal={false}
open={isDialogOpen}
onOpenChange={(open) => {
if (open) {
setSelectedSerializedNodes(getSelectedSerializedNodes())
}
setIsDialogOpen(open)
setIsPopoverOpen(false)
}}
>
<Dialog.Trigger asChild>
<button
ref={popoverRef}
style={{
position: 'absolute',
visibility: isPopoverOpen ? 'visible' : 'hidden',
...popoverStyle,
}}
>
Create template
</button>
</Dialog.Trigger>
<CreateTemplateDialogContent
selectedSerializedNodes={selectedSerializedNodes}
onClose={() => setIsDialogOpen(false)}
/>
</Dialog.Root>
)
}

View File

@@ -0,0 +1,182 @@
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import clsx from 'clsx'
import {
$parseSerializedNode,
COMMAND_PRIORITY_NORMAL,
TextNode,
} from 'lexical'
import { Trash2 } from 'lucide-react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { createPortal } from 'react-dom'
import { useDatabase } from '../../../../../contexts/DatabaseContext'
import { SelectTemplate } from '../../../../../database/schema'
import { MenuOption } from '../shared/LexicalMenu'
import {
LexicalTypeaheadMenuPlugin,
useBasicTypeaheadTriggerMatch,
} from '../typeahead-menu/LexicalTypeaheadMenuPlugin'
class TemplateTypeaheadOption extends MenuOption {
name: string
template: SelectTemplate
constructor(name: string, template: SelectTemplate) {
super(name)
this.name = name
this.template = template
}
}
function TemplateMenuItem({
index,
isSelected,
onClick,
onDelete,
onMouseEnter,
option,
}: {
index: number
isSelected: boolean
onClick: () => void
onDelete: () => void
onMouseEnter: () => void
option: TemplateTypeaheadOption
}) {
return (
<li
key={option.key}
tabIndex={-1}
className={clsx('item', isSelected && 'selected')}
ref={(el) => option.setRefElement(el)}
role="option"
aria-selected={isSelected}
id={`typeahead-item-${index}`}
onMouseEnter={onMouseEnter}
onClick={onClick}
>
<div className="infio-chat-template-menu-item">
<div className="text">{option.name}</div>
<div
onClick={(evt) => {
evt.stopPropagation()
evt.preventDefault()
onDelete()
}}
className="infio-chat-template-menu-item-delete"
>
<Trash2 size={12} />
</div>
</div>
</li>
)
}
export default function TemplatePlugin() {
const [editor] = useLexicalComposerContext()
const { getTemplateManager } = useDatabase()
const [queryString, setQueryString] = useState<string | null>(null)
const [searchResults, setSearchResults] = useState<SelectTemplate[]>([])
useEffect(() => {
if (queryString == null) return
getTemplateManager().then((templateManager) =>
templateManager.searchTemplates(queryString).then(setSearchResults),
)
}, [queryString, getTemplateManager])
const options = useMemo(
() =>
searchResults.map(
(result) => new TemplateTypeaheadOption(result.name, result),
),
[searchResults],
)
const checkForTriggerMatch = useBasicTypeaheadTriggerMatch('/', {
minLength: 0,
})
const onSelectOption = useCallback(
(
selectedOption: TemplateTypeaheadOption,
nodeToRemove: TextNode | null,
closeMenu: () => void,
) => {
editor.update(() => {
const parsedNodes = selectedOption.template.content.nodes.map((node) =>
$parseSerializedNode(node),
)
if (nodeToRemove) {
const parent = nodeToRemove.getParentOrThrow()
parent.splice(nodeToRemove.getIndexWithinParent(), 1, parsedNodes)
const lastNode = parsedNodes[parsedNodes.length - 1]
lastNode.selectEnd()
}
closeMenu()
})
},
[editor],
)
const handleDelete = useCallback(
async (option: TemplateTypeaheadOption) => {
await (await getTemplateManager()).deleteTemplate(option.template.id)
if (queryString !== null) {
const updatedResults = await (
await getTemplateManager()
).searchTemplates(queryString)
setSearchResults(updatedResults)
}
},
[getTemplateManager, queryString],
)
return (
<LexicalTypeaheadMenuPlugin<TemplateTypeaheadOption>
onQueryChange={setQueryString}
onSelectOption={onSelectOption}
triggerFn={checkForTriggerMatch}
options={options}
commandPriority={COMMAND_PRIORITY_NORMAL}
menuRenderFn={(
anchorElementRef,
{ selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
) =>
anchorElementRef.current && searchResults.length
? createPortal(
<div
className="infio-popover"
style={{
position: 'fixed',
}}
>
<ul>
{options.map((option, i: number) => (
<TemplateMenuItem
index={i}
isSelected={selectedIndex === i}
onClick={() => {
setHighlightedIndex(i)
selectOptionAndCleanUp(option)
}}
onDelete={() => {
handleDelete(option)
}}
onMouseEnter={() => {
setHighlightedIndex(i)
}}
key={option.key}
option={option}
/>
))}
</ul>
</div>,
anchorElementRef.current,
)
: null
}
/>
)
}

View File

@@ -0,0 +1,297 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license.
* Original source: https://github.com/facebook/lexical
*
* Modified from the original code
*/
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import {
$getSelection,
$isRangeSelection,
$isTextNode,
COMMAND_PRIORITY_LOW,
CommandListenerPriority,
LexicalCommand,
LexicalEditor,
RangeSelection,
TextNode,
createCommand,
} from 'lexical'
import { startTransition, useCallback, useEffect, useState } from 'react'
import {
LexicalMenu,
MenuOption,
MenuRenderFn,
MenuResolution,
TriggerFn,
useMenuAnchorRef,
} from '../shared/LexicalMenu'
export const PUNCTUATION =
'\\.,\\+\\*\\?\\$\\@\\|#{}\\(\\)\\^\\-\\[\\]\\\\/!%\'"~=<>_:;'
function getTextUpToAnchor(selection: RangeSelection): string | null {
const anchor = selection.anchor
if (anchor.type !== 'text') {
return null
}
const anchorNode = anchor.getNode()
if (!anchorNode.isSimpleText()) {
return null
}
const anchorOffset = anchor.offset
return anchorNode.getTextContent().slice(0, anchorOffset)
}
function tryToPositionRange(
leadOffset: number,
range: Range,
editorWindow: Window,
): boolean {
const domSelection = editorWindow.getSelection()
if (domSelection === null || !domSelection.isCollapsed) {
return false
}
const anchorNode = domSelection.anchorNode
const startOffset = leadOffset
const endOffset = domSelection.anchorOffset
if (anchorNode == null || endOffset == null) {
return false
}
try {
range.setStart(anchorNode, startOffset)
range.setEnd(anchorNode, endOffset)
} catch (error) {
return false
}
return true
}
function getQueryTextForSearch(editor: LexicalEditor): string | null {
let text = null
editor.getEditorState().read(() => {
const selection = $getSelection()
if (!$isRangeSelection(selection)) {
return
}
text = getTextUpToAnchor(selection)
})
return text
}
function isSelectionOnEntityBoundary(
editor: LexicalEditor,
offset: number,
): boolean {
if (offset !== 0) {
return false
}
return editor.getEditorState().read(() => {
const selection = $getSelection()
if ($isRangeSelection(selection)) {
const anchor = selection.anchor
const anchorNode = anchor.getNode()
const prevSibling = anchorNode.getPreviousSibling()
return $isTextNode(prevSibling) && prevSibling.isTextEntity()
}
return false
})
}
// Got from https://stackoverflow.com/a/42543908/2013580
export function getScrollParent(
element: HTMLElement,
includeHidden: boolean,
): HTMLElement | HTMLBodyElement {
let style = getComputedStyle(element)
const excludeStaticParent = style.position === 'absolute'
const overflowRegex = includeHidden ? /(auto|scroll|hidden)/ : /(auto|scroll)/
if (style.position === 'fixed') {
return document.body
}
for (
let parent: HTMLElement | null = element;
(parent = parent.parentElement);
) {
style = getComputedStyle(parent)
if (excludeStaticParent && style.position === 'static') {
continue
}
if (
overflowRegex.test(style.overflow + style.overflowY + style.overflowX)
) {
return parent
}
}
return document.body
}
export const SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND: LexicalCommand<{
index: number
option: MenuOption
}> = createCommand('SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND')
export function useBasicTypeaheadTriggerMatch(
trigger: string,
{ minLength = 1, maxLength = 75 }: { minLength?: number; maxLength?: number },
): TriggerFn {
return useCallback(
(text: string) => {
const validChars = '[^' + trigger + PUNCTUATION + '\\s]'
const TypeaheadTriggerRegex = new RegExp(
`(^|\\s|\\()([${trigger}]((?:${validChars}){0,${maxLength}}))$`,
)
const match = TypeaheadTriggerRegex.exec(text)
if (match !== null) {
const maybeLeadingWhitespace = match[1]
const matchingString = match[3]
if (matchingString.length >= minLength) {
return {
leadOffset: match.index + maybeLeadingWhitespace.length,
matchingString,
replaceableString: match[2],
}
}
}
return null
},
[maxLength, minLength, trigger],
)
}
export type TypeaheadMenuPluginProps<TOption extends MenuOption> = {
onQueryChange: (matchingString: string | null) => void
onSelectOption: (
option: TOption,
textNodeContainingQuery: TextNode | null,
closeMenu: () => void,
matchingString: string,
) => void
options: TOption[]
menuRenderFn: MenuRenderFn<TOption>
triggerFn: TriggerFn
onOpen?: (resolution: MenuResolution) => void
onClose?: () => void
anchorClassName?: string
commandPriority?: CommandListenerPriority
parent?: HTMLElement
}
export function LexicalTypeaheadMenuPlugin<TOption extends MenuOption>({
options,
onQueryChange,
onSelectOption,
onOpen,
onClose,
menuRenderFn,
triggerFn,
anchorClassName,
commandPriority = COMMAND_PRIORITY_LOW,
parent,
}: TypeaheadMenuPluginProps<TOption>): JSX.Element | null {
const [editor] = useLexicalComposerContext()
const [resolution, setResolution] = useState<MenuResolution | null>(null)
const anchorElementRef = useMenuAnchorRef(
resolution,
setResolution,
anchorClassName,
parent,
)
const closeTypeahead = useCallback(() => {
setResolution(null)
if (onClose != null && resolution !== null) {
onClose()
}
}, [onClose, resolution])
const openTypeahead = useCallback(
(res: MenuResolution) => {
setResolution(res)
if (onOpen != null && resolution === null) {
onOpen(res)
}
},
[onOpen, resolution],
)
useEffect(() => {
const updateListener = () => {
editor.getEditorState().read(() => {
const editorWindow = editor._window ?? window
const range = editorWindow.document.createRange()
const selection = $getSelection()
const text = getQueryTextForSearch(editor)
if (
!$isRangeSelection(selection) ||
!selection.isCollapsed() ||
text === null ||
range === null
) {
closeTypeahead()
return
}
const match = triggerFn(text, editor)
onQueryChange(match ? match.matchingString : null)
if (
match !== null &&
!isSelectionOnEntityBoundary(editor, match.leadOffset)
) {
const isRangePositioned = tryToPositionRange(
match.leadOffset,
range,
editorWindow,
)
if (isRangePositioned !== null) {
startTransition(() =>
openTypeahead({
getRect: () => range.getBoundingClientRect(),
match,
}),
)
return
}
}
closeTypeahead()
})
}
const removeUpdateListener = editor.registerUpdateListener(updateListener)
return () => {
removeUpdateListener()
}
}, [
editor,
triggerFn,
onQueryChange,
resolution,
closeTypeahead,
openTypeahead,
])
return resolution === null || editor === null ? null : (
<LexicalMenu
close={closeTypeahead}
resolution={resolution}
editor={editor}
anchorElementRef={anchorElementRef}
options={options}
menuRenderFn={menuRenderFn}
shouldSplitNodeWithQuery={true}
onSelectOption={onSelectOption}
commandPriority={commandPriority}
/>
)
}

View File

@@ -0,0 +1,45 @@
import {
SerializedEditorState,
SerializedElementNode,
SerializedTextNode,
} from 'lexical'
import { editorStateToPlainText } from './editor-state-to-plain-text'
describe('editorStateToPlainText', () => {
it('should convert editor state to plain text', () => {
const editorState: SerializedEditorState = {
root: {
children: [
{
children: [
{
detail: 0,
format: 0,
mode: 'normal',
style: '',
text: 'Hello, world!',
type: 'text',
version: 1,
},
],
direction: 'ltr',
format: '',
indent: 0,
type: 'paragraph',
version: 1,
textFormat: 0,
textStyle: '',
} as SerializedElementNode<SerializedTextNode>,
],
direction: 'ltr',
format: '',
indent: 0,
type: 'root',
version: 1,
},
}
const plainText = editorStateToPlainText(editorState)
expect(plainText).toBe('Hello, world!')
})
})

View File

@@ -0,0 +1,21 @@
import { SerializedEditorState, SerializedLexicalNode } from 'lexical'
export function editorStateToPlainText(
editorState: SerializedEditorState,
): string {
return lexicalNodeToPlainText(editorState.root)
}
function lexicalNodeToPlainText(node: SerializedLexicalNode): string {
if ('children' in node) {
// Process children recursively and join their results
return (node.children as SerializedLexicalNode[])
.map(lexicalNodeToPlainText)
.join('')
} else if (node.type === 'linebreak') {
return '\n'
} else if ('text' in node && typeof node.text === 'string') {
return node.text
}
return ''
}

View File

@@ -0,0 +1,30 @@
import {
FileIcon,
FolderClosedIcon,
FoldersIcon,
ImageIcon,
LinkIcon,
} from 'lucide-react'
import { Mentionable } from '../../../../types/mentionable'
export const getMentionableIcon = (mentionable: Mentionable) => {
switch (mentionable.type) {
case 'file':
return FileIcon
case 'folder':
return FolderClosedIcon
case 'vault':
return FoldersIcon
case 'current-file':
return FileIcon
case 'block':
return FileIcon
case 'url':
return LinkIcon
case 'image':
return ImageIcon
default:
return null
}
}

View File

@@ -0,0 +1,271 @@
import { MarkdownView, Plugin } from "obsidian";
import React, { useEffect, useRef, useState } from "react";
import { APPLY_VIEW_TYPE } from "../../constants";
import LLMManager from "../../core/llm/manager";
import { InfioSettings } from "../../types/settings";
import { manualApplyChangesToFile } from "../../utils/apply";
import { removeAITags } from "../../utils/content-filter";
import { PromptGenerator } from "../../utils/prompt-generator";
interface InlineEditProps {
source: string;
secStartLine: number;
secEndLine: number;
plugin: Plugin;
settings: InfioSettings;
}
interface InputAreaProps {
value: string;
onChange: (value: string) => void;
}
const InputArea: React.FC<InputAreaProps> = ({ value, onChange }) => {
const textareaRef = useRef<HTMLTextAreaElement>(null);
useEffect(() => {
// 组件挂载后自动聚焦到 textarea
textareaRef.current?.focus();
}, []);
return (
<div className="infio-ai-block-input-wrapper">
<textarea
ref={textareaRef}
className="infio-ai-block-content"
placeholder="Enter instruction"
value={value}
onChange={(e) => onChange(e.target.value)}
/>
</div>
);
};
interface ControlAreaProps {
settings: InfioSettings;
onSubmit: () => void;
selectedModel: string;
onModelChange: (model: string) => void;
isSubmitting: boolean;
}
const ControlArea: React.FC<ControlAreaProps> = ({
settings,
onSubmit,
selectedModel,
onModelChange,
isSubmitting,
}) => (
<div className="infio-ai-block-controls">
<select
className="infio-ai-block-model-select"
value={selectedModel}
onChange={(e) => onModelChange(e.target.value)}
disabled={isSubmitting}
>
{settings.activeModels
.filter((model) => !model.isEmbeddingModel && model.enabled)
.map((model) => (
<option key={model.name} value={model.name}>
{model.name}
</option>
))}
</select>
<button
className="infio-ai-block-submit-button"
onClick={onSubmit}
disabled={isSubmitting}
>
{isSubmitting ? "Submitting..." : "Submit"}
</button>
</div>
);
export const InlineEdit: React.FC<InlineEditProps> = ({
source,
secStartLine,
secEndLine,
plugin,
settings,
}) => {
const [instruction, setInstruction] = useState("");
const [selectedModel, setSelectedModel] = useState(settings.chatModelId);
const [isSubmitting, setIsSubmitting] = useState(false);
const llmManager = new LLMManager({
deepseek: settings.deepseekApiKey,
openai: settings.openAIApiKey,
anthropic: settings.anthropicApiKey,
gemini: settings.geminiApiKey,
groq: settings.groqApiKey,
infio: settings.infioApiKey,
});
const promptGenerator = new PromptGenerator(
async () => {
throw new Error("RAG not needed for inline edit");
},
plugin.app,
settings
);
const handleClose = () => {
const activeView = plugin.app.workspace.getActiveViewOfType(MarkdownView);
if (!activeView?.editor) return;
activeView.editor.replaceRange(
"",
{ line: secStartLine, ch: 0 },
{ line: secEndLine + 1, ch: 0 }
);
};
const getActiveContext = async () => {
const activeFile = plugin.app.workspace.getActiveFile();
if (!activeFile) {
console.error("No active file");
return {};
}
const editor = plugin.app.workspace.getActiveViewOfType(MarkdownView)?.editor;
if (!editor) {
console.error("No active editor");
return { activeFile };
}
const selection = editor.getSelection();
if (!selection) {
console.error("No text selected");
return { activeFile, editor };
}
return { activeFile, editor, selection };
};
const parseSmartComposeBlock = (content: string) => {
const match = content.match(/<infio_block[^>]*>([\s\S]*?)<\/infio_block>/);
if (!match) {
return null;
}
const blockContent = match[1].trim();
const attributes = match[0].match(/startLine="(\d+)"/);
const startLine = attributes ? parseInt(attributes[1]) : undefined;
const endLineMatch = match[0].match(/endLine="(\d+)"/);
const endLine = endLineMatch ? parseInt(endLineMatch[1]) : undefined;
return {
startLine,
endLine,
content: blockContent,
};
};
const handleSubmit = async () => {
setIsSubmitting(true);
try {
const { activeFile, editor, selection } = await getActiveContext();
if (!activeFile || !editor || !selection) {
setIsSubmitting(false);
return;
}
const chatModel = settings.activeModels.find(
(model) => model.name === selectedModel
);
if (!chatModel) {
setIsSubmitting(false);
throw new Error("Invalid chat model");
}
const from = editor.getCursor("from");
const to = editor.getCursor("to");
const defaultStartLine = from.line + 1;
const defaultEndLine = to.line + 1;
const requestMessages = await promptGenerator.generateEditMessages({
currentFile: activeFile,
selectedContent: selection,
instruction: instruction,
startLine: defaultStartLine,
endLine: defaultEndLine,
});
const response = await llmManager.generateResponse(chatModel, {
model: chatModel.name,
messages: requestMessages,
stream: false,
});
if (!response.choices[0].message.content) {
setIsSubmitting(false);
throw new Error("Empty response from LLM");
}
const parsedBlock = parseSmartComposeBlock(
response.choices[0].message.content
);
const finalContent = parsedBlock?.content || response.choices[0].message.content;
const startLine = parsedBlock?.startLine || defaultStartLine;
const endLine = parsedBlock?.endLine || defaultEndLine;
const updatedContent = await manualApplyChangesToFile(
finalContent,
activeFile,
await plugin.app.vault.read(activeFile),
startLine,
endLine
);
if (!updatedContent) {
console.error("Failed to apply changes");
setIsSubmitting(false);
return;
}
const originalContent = await plugin.app.vault.read(activeFile);
await plugin.app.workspace.getLeaf(true).setViewState({
type: APPLY_VIEW_TYPE,
active: true,
state: {
file: activeFile,
originalContent: removeAITags(originalContent),
newContent: removeAITags(updatedContent),
},
});
} catch (error) {
console.error("Error in inline edit:", error);
} finally {
setIsSubmitting(false);
}
};
return (
<div className="infio-ai-block-container"
id="infio-ai-block-container"
style={{ backgroundColor: 'var(--background-secondary)' }}
>
<InputArea value={instruction} onChange={setInstruction} />
<button className="infio-ai-block-close-button" onClick={handleClose}>
<svg
width="14"
height="14"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
>
<line x1="18" y1="6" x2="6" y2="18"></line>
<line x1="6" y1="6" x2="18" y2="18"></line>
</svg>
</button>
<ControlArea
settings={settings}
onSubmit={handleSubmit}
selectedModel={selectedModel}
onModelChange={setSelectedModel}
isSubmitting={isSubmitting}
/>
</div>
);
};

157
src/constants.ts Normal file
View File

@@ -0,0 +1,157 @@
import { CustomLLMModel } from './types/llm/model'
export const CHAT_VIEW_TYPE = 'infio-chat-view'
export const APPLY_VIEW_TYPE = 'infio-apply-view'
export const DEFAULT_MODELS: CustomLLMModel[] = [
{
name: 'claude-3.5-sonnet',
provider: 'anthropic',
enabled: true,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'o1-mini',
provider: 'openai',
enabled: true,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'o1-preview',
provider: 'openai',
enabled: false,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'gpt-4o',
provider: 'openai',
enabled: true,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'gpt-4o-mini',
provider: 'openai',
enabled: false,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'deepseek-chat',
provider: 'deepseek',
enabled: true,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'gemini-1.5-pro',
provider: 'google',
enabled: true,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'gemini-2.0-flash-exp',
provider: 'google',
enabled: true,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'gemini-2.0-flash-thinking-exp-1219',
provider: 'google',
enabled: false,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'llama-3.1-70b-versatile',
provider: 'groq',
enabled: true,
isEmbeddingModel: false,
isBuiltIn: true,
},
{
name: 'text-embedding-3-small',
provider: 'openai',
dimension: 1536,
enabled: true,
isEmbeddingModel: true,
isBuiltIn: true,
},
{
name: 'text-embedding-004',
provider: 'google',
dimension: 768,
enabled: true,
isEmbeddingModel: true,
isBuiltIn: true,
},
{
name: 'nomic-embed-text',
provider: 'ollama',
dimension: 768,
enabled: true,
isEmbeddingModel: true,
isBuiltIn: true,
},
{
name: 'mxbai-embed-large',
provider: 'ollama',
dimension: 1024,
enabled: true,
isEmbeddingModel: true,
isBuiltIn: true,
},
{
name: 'bge-m3',
provider: 'ollama',
dimension: 1024,
enabled: true,
isEmbeddingModel: true,
isBuiltIn: true,
}
]
export const SUPPORT_EMBEDDING_SIMENTION: number[] = [
384,
512,
768,
1024,
1536
]
export const DEEPSEEK_BASE_URL = 'https://api.deepseek.com'
// Pricing in dollars per million tokens
type ModelPricing = {
input: number
output: number
}
export const OPENAI_PRICES: Record<string, ModelPricing> = {
'gpt-4o': { input: 2.5, output: 10 },
'gpt-4o-mini': { input: 0.15, output: 0.6 },
'deepseek-chat': { input: 0.16, output: 0.32 },
}
export const ANTHROPIC_PRICES: Record<string, ModelPricing> = {
'claude-3-5-sonnet-latest': { input: 3, output: 15 },
'claude-3-5-haiku-latest': { input: 1, output: 5 },
}
// Gemini is currently free for low rate limits
export const GEMINI_PRICES: Record<string, ModelPricing> = {
'gemini-1.5-pro': { input: 0, output: 0 },
'gemini-1.5-flash': { input: 0, output: 0 },
}
export const GROQ_PRICES: Record<string, ModelPricing> = {
'llama-3.1-70b-versatile': { input: 0.59, output: 0.79 },
'llama-3.1-8b-instant': { input: 0.05, output: 0.08 },
}
export const PGLITE_DB_PATH = '.infio_vector_db.tar.gz'

View File

@@ -0,0 +1,23 @@
import { App } from 'obsidian'
import React from 'react'
// App context
const AppContext = React.createContext<App | undefined>(undefined)
export const AppProvider = ({
children,
app,
}: {
children: React.ReactNode
app: App
}) => {
return <AppContext.Provider value={app}>{children}</AppContext.Provider>
}
export const useApp = () => {
const app = React.useContext(AppContext)
if (!app) {
throw new Error('useApp must be used within an AppProvider')
}
return app
}

View File

@@ -0,0 +1,46 @@
import {
ReactNode,
createContext,
useContext,
useEffect,
useState,
} from 'react'
import { useApp } from './AppContext'
type DarkModeContextType = {
isDarkMode: boolean
}
const DarkModeContext = createContext<DarkModeContextType | undefined>(
undefined,
)
export function DarkModeProvider({ children }: { children: ReactNode }) {
const [isDarkMode, setIsDarkMode] = useState(false)
const app = useApp()
useEffect(() => {
const handleDarkMode = () => {
setIsDarkMode(document.body.classList.contains('theme-dark'))
}
handleDarkMode()
app.workspace.on('css-change', handleDarkMode)
return () => app.workspace.off('css-change', handleDarkMode)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
return (
<DarkModeContext.Provider value={{ isDarkMode }}>
{children}
</DarkModeContext.Provider>
)
}
export function useDarkModeContext() {
const context = useContext(DarkModeContext)
if (context === undefined) {
throw new Error('useDarkModeContext must be used within a DarkModeProvider')
}
return context
}

View File

@@ -0,0 +1,58 @@
import {
createContext,
useCallback,
useContext,
useEffect,
useMemo,
} from 'react'
import { DBManager } from '../database/database-manager'
import { TemplateManager } from '../database/modules/template/template-manager'
import { VectorManager } from '../database/modules/vector/vector-manager'
type DatabaseContextType = {
getDatabaseManager: () => Promise<DBManager>
getVectorManager: () => Promise<VectorManager>
getTemplateManager: () => Promise<TemplateManager>
}
const DatabaseContext = createContext<DatabaseContextType | null>(null)
export function DatabaseProvider({
children,
getDatabaseManager,
}: {
children: React.ReactNode
getDatabaseManager: () => Promise<DBManager>
}) {
const getVectorManager = useCallback(async () => {
return (await getDatabaseManager()).getVectorManager()
}, [getDatabaseManager])
const getTemplateManager = useCallback(async () => {
return (await getDatabaseManager()).getTemplateManager()
}, [getDatabaseManager])
useEffect(() => {
// start initialization of dbManager in the background
void getDatabaseManager()
}, [getDatabaseManager])
const value = useMemo(() => {
return { getDatabaseManager, getVectorManager, getTemplateManager }
}, [getDatabaseManager, getVectorManager, getTemplateManager])
return (
<DatabaseContext.Provider value={value}>
{children}
</DatabaseContext.Provider>
)
}
export function useDatabase(): DatabaseContextType {
const context = useContext(DatabaseContext)
if (!context) {
throw new Error('useDatabase must be used within a DatabaseProvider')
}
return context
}

View File

@@ -0,0 +1,27 @@
import React, { createContext, useContext } from 'react'
const DialogContext = createContext<HTMLElement | null>(null)
export function DialogProvider({
children,
container,
}: {
children: React.ReactNode
container: HTMLElement | null
}) {
return (
<DialogContext.Provider value={container}>
{children}
</DialogContext.Provider>
)
}
export function useDialogContainer() {
const context = useContext(DialogContext)
if (!context) {
throw new Error(
'useDialogContainer must be used within a DialogContainerProvider',
)
}
return context
}

135
src/contexts/LLMContext.tsx Normal file
View File

@@ -0,0 +1,135 @@
import {
PropsWithChildren,
createContext,
useCallback,
useContext,
useEffect,
useMemo,
useState,
} from 'react'
import LLMManager from '../core/llm/manager'
import { CustomLLMModel } from '../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../types/llm/response'
import { useSettings } from './SettingsContext'
export type LLMContextType = {
generateResponse: (
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
) => Promise<LLMResponseNonStreaming>
streamResponse: (
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
) => Promise<AsyncIterable<LLMResponseStreaming>>
chatModel: CustomLLMModel
applyModel: CustomLLMModel
}
const LLMContext = createContext<LLMContextType | null>(null)
export function LLMProvider({ children }: PropsWithChildren) {
const [llmManager, setLLMManager] = useState<LLMManager | null>(null)
const { settings } = useSettings()
const chatModel = useMemo((): CustomLLMModel => {
const model = settings.activeModels.find(
(option) => option.name === settings.chatModelId,
)
if (!model) {
throw new Error('Invalid chat model ID')
}
return model
}, [settings])
const applyModel = useMemo((): CustomLLMModel => {
const model = settings.activeModels.find(
(option) => option.name === settings.applyModelId,
)
if (!model) {
throw new Error('Invalid apply model ID')
}
if (model.provider === 'ollama') {
return {
provider: 'ollama',
baseURL: settings.ollamaApplyModel.baseUrl,
model: settings.ollamaApplyModel.model,
}
}
return model
}, [settings])
useEffect(() => {
const manager = new LLMManager({
deepseek: settings.deepseekApiKey,
openai: settings.openAIApiKey,
anthropic: settings.anthropicApiKey,
gemini: settings.geminiApiKey,
groq: settings.groqApiKey,
infio: settings.infioApiKey,
})
setLLMManager(manager)
}, [
settings.deepseekApiKey,
settings.openAIApiKey,
settings.anthropicApiKey,
settings.geminiApiKey,
settings.groqApiKey,
settings.infioApiKey,
])
const generateResponse = useCallback(
async (
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
) => {
if (!llmManager) {
throw new Error('LLMManager is not initialized')
}
return await llmManager.generateResponse(model, request, options)
},
[llmManager],
)
const streamResponse = useCallback(
async (
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
) => {
if (!llmManager) {
throw new Error('LLMManager is not initialized')
}
return await llmManager.streamResponse(model, request, options)
},
[llmManager],
)
return (
<LLMContext.Provider
value={{ generateResponse, streamResponse, chatModel, applyModel }}
>
{children}
</LLMContext.Provider>
)
}
export function useLLM() {
const context = useContext(LLMContext)
if (!context) {
throw new Error('useLLM must be used within an LLMProvider')
}
return context
}

View File

@@ -0,0 +1,39 @@
import {
PropsWithChildren,
createContext,
useContext,
useEffect,
useMemo,
} from 'react'
import { RAGEngine } from '../core/rag/rag-engine'
export type RAGContextType = {
getRAGEngine: () => Promise<RAGEngine>
}
const RAGContext = createContext<RAGContextType | null>(null)
export function RAGProvider({
getRAGEngine,
children,
}: PropsWithChildren<{ getRAGEngine: () => Promise<RAGEngine> }>) {
useEffect(() => {
// start initialization of ragEngine in the background
void getRAGEngine()
}, [getRAGEngine])
const value = useMemo(() => {
return { getRAGEngine }
}, [getRAGEngine])
return <RAGContext.Provider value={value}>{children}</RAGContext.Provider>
}
export function useRAG() {
const context = useContext(RAGContext)
if (!context) {
throw new Error('useRAG must be used within a RAGProvider')
}
return context
}

View File

@@ -0,0 +1,58 @@
import React, { useEffect, useMemo, useState } from 'react'
import { InfioSettings } from '../types/settings'
type SettingsContextType = {
settings: InfioSettings
setSettings: (newSettings: InfioSettings) => void
}
// Settings context
const SettingsContext = React.createContext<SettingsContextType | undefined>(
undefined,
)
export const SettingsProvider = ({
children,
settings: initialSettings,
setSettings,
addSettingsChangeListener,
}: {
children: React.ReactNode
settings: InfioSettings
setSettings: (newSettings: InfioSettings) => void
addSettingsChangeListener: (
listener: (newSettings: InfioSettings) => void,
) => () => void
}) => {
const [settingsCached, setSettingsCached] = useState(initialSettings)
useEffect(() => {
const removeListener = addSettingsChangeListener((newSettings) => {
setSettingsCached(newSettings)
})
return () => {
removeListener()
}
}, [addSettingsChangeListener, setSettings])
const value = useMemo(
() => ({ settings: settingsCached, setSettings }),
[settingsCached, setSettings],
)
return (
<SettingsContext.Provider value={value}>
{children}
</SettingsContext.Provider>
)
}
export const useSettings = () => {
const settings = React.useContext(SettingsContext)
if (!settings) {
throw new Error('useSettings must be used within a SettingsProvider')
}
return settings
}

View File

@@ -0,0 +1,95 @@
import { generateRandomString } from "./utils";
const UNIQUE_CURSOR = `${generateRandomString(16)}`;
const HEADER_REGEX = `^#+\\s.*${UNIQUE_CURSOR}.*$`;
const UNORDERED_LIST_REGEX = `^\\s*(-|\\*)\\s.*${UNIQUE_CURSOR}.*$`;
const TASK_LIST_REGEX = `^\\s*(-|[0-9]+\\.) +\\[.\\]\\s.*${UNIQUE_CURSOR}.*$`;
const BLOCK_QUOTES_REGEX = `^\\s*>.*${UNIQUE_CURSOR}.*$`;
const NUMBERED_LIST_REGEX = `^\\s*\\d+\\.\\s.*${UNIQUE_CURSOR}.*$`
const MATH_BLOCK_REGEX = /\$\$[\s\S]*?\$\$/g;
const INLINE_MATH_BLOCK_REGEX = /\$[\s\S]*?\$/g;
const CODE_BLOCK_REGEX = /```[\s\S]*?```/g;
const INLINE_CODE_BLOCK_REGEX = /`.*`/g;
enum Context {
Text = "Text",
Heading = "Heading",
BlockQuotes = "BlockQuotes",
UnorderedList = "UnorderedList",
NumberedList = "NumberedList",
CodeBlock = "CodeBlock",
MathBlock = "MathBlock",
TaskList = "TaskList",
}
// eslint-disable-next-line @typescript-eslint/no-namespace
namespace Context {
export function values(): Array<Context> {
return Object.values(Context).filter(
(value) => typeof value === "string"
) as Array<Context>;
}
export function getContext(prefix: string, suffix: string): Context {
if (new RegExp(HEADER_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.Heading;
}
if (new RegExp(BLOCK_QUOTES_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.BlockQuotes;
}
if (new RegExp(TASK_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.TaskList;
}
if (
isCursorInRegexBlock(prefix, suffix, MATH_BLOCK_REGEX) ||
isCursorInRegexBlock(prefix, suffix, INLINE_MATH_BLOCK_REGEX)
) {
return Context.MathBlock;
}
if (isCursorInRegexBlock(prefix, suffix, CODE_BLOCK_REGEX) || isCursorInRegexBlock(prefix, suffix, INLINE_CODE_BLOCK_REGEX)) {
return Context.CodeBlock;
}
if (new RegExp(NUMBERED_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.NumberedList;
}
if (new RegExp(UNORDERED_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
return Context.UnorderedList;
}
return Context.Text;
}
export function get(value: string) {
for (const context of Context.values()) {
if (value === context) {
return context;
}
}
return undefined;
}
}
function isCursorInRegexBlock(
prefix: string,
suffix: string,
regex: RegExp
): boolean {
const text = prefix + UNIQUE_CURSOR + suffix;
const codeBlocks = extractBlocks(text, regex);
for (const block of codeBlocks) {
if (block.includes(UNIQUE_CURSOR)) {
return true;
}
}
return false;
}
function extractBlocks(text: string, regex: RegExp) {
const codeBlocks = text.match(regex);
return codeBlocks ? codeBlocks.map((block) => block.trim()) : [];
}
export default Context;

View File

@@ -0,0 +1,273 @@
import * as Handlebars from "handlebars";
import { err, ok, Result } from "neverthrow";
import { FewShotExample } from "../../settings/versions";
import { CustomLLMModel } from "../../types/llm/model";
import { RequestMessage } from '../../types/llm/request';
import { InfioSettings } from "../../types/settings";
import LLMManager from '../llm/manager';
import Context from "./context-detection";
import RemoveCodeIndicators from "./post-processors/remove-code-indicators";
import RemoveMathIndicators from "./post-processors/remove-math-indicators";
import RemoveOverlap from "./post-processors/remove-overlap";
import RemoveWhitespace from "./post-processors/remove-whitespace";
import DataViewRemover from "./pre-processors/data-view-remover";
import LengthLimiter from "./pre-processors/length-limiter";
import {
AutocompleteService,
ChatMessage,
PostProcessor,
PreProcessor,
UserMessageFormatter,
UserMessageFormattingInputs
} from "./types";
class LLMClient {
private llm: LLMManager;
private model: CustomLLMModel;
constructor(llm: LLMManager, model: CustomLLMModel) {
this.llm = llm;
this.model = model;
}
async queryChatModel(messages: RequestMessage[]): Promise<Result<string, Error>> {
const data = await this.llm.generateResponse(this.model, {
model: this.model.name,
messages: messages,
stream: false,
})
return ok(data.choices[0].message.content);
}
}
class AutoComplete implements AutocompleteService {
private readonly client: LLMClient;
private readonly systemMessage: string;
private readonly userMessageFormatter: UserMessageFormatter;
private readonly removePreAnswerGenerationRegex: string;
private readonly preProcessors: PreProcessor[];
private readonly postProcessors: PostProcessor[];
private readonly fewShotExamples: FewShotExample[];
private debugMode: boolean;
private constructor(
client: LLMClient,
systemMessage: string,
userMessageFormatter: UserMessageFormatter,
removePreAnswerGenerationRegex: string,
preProcessors: PreProcessor[],
postProcessors: PostProcessor[],
fewShotExamples: FewShotExample[],
debugMode: boolean,
) {
this.client = client;
this.systemMessage = systemMessage;
this.userMessageFormatter = userMessageFormatter;
this.removePreAnswerGenerationRegex = removePreAnswerGenerationRegex;
this.preProcessors = preProcessors;
this.postProcessors = postProcessors;
this.fewShotExamples = fewShotExamples;
this.debugMode = debugMode;
}
public static fromSettings(settings: InfioSettings): AutocompleteService {
const formatter = Handlebars.compile<UserMessageFormattingInputs>(
settings.userMessageTemplate,
{ noEscape: true, strict: true }
);
const preProcessors: PreProcessor[] = [];
if (settings.dontIncludeDataviews) {
preProcessors.push(new DataViewRemover());
}
preProcessors.push(
new LengthLimiter(
settings.maxPrefixCharLimit,
settings.maxSuffixCharLimit
)
);
const postProcessors: PostProcessor[] = [];
if (settings.removeDuplicateMathBlockIndicator) {
postProcessors.push(new RemoveMathIndicators());
}
if (settings.removeDuplicateCodeBlockIndicator) {
postProcessors.push(new RemoveCodeIndicators());
}
postProcessors.push(new RemoveOverlap());
postProcessors.push(new RemoveWhitespace());
const llm_manager = new LLMManager({
deepseek: settings.deepseekApiKey,
openai: settings.openAIApiKey,
anthropic: settings.anthropicApiKey,
gemini: settings.geminiApiKey,
groq: settings.groqApiKey,
infio: settings.infioApiKey,
})
const model: CustomLLMModel = settings.activeModels.find(
(option) => option.name === settings.chatModelId,
)
const llm = new LLMClient(llm_manager, model);
return new AutoComplete(
llm,
settings.systemMessage,
formatter,
settings.chainOfThoughRemovalRegex,
preProcessors,
postProcessors,
settings.fewShotExamples,
settings.debugMode,
);
}
async fetchPredictions(
prefix: string,
suffix: string
): Promise<Result<string, Error>> {
const context: Context = Context.getContext(prefix, suffix);
for (const preProcessor of this.preProcessors) {
if (preProcessor.removesCursor(prefix, suffix)) {
return ok("");
}
({ prefix, suffix } = preProcessor.process(
prefix,
suffix,
context
));
}
const examples = this.fewShotExamples.filter(
(example) => example.context === context
);
const fewShotExamplesChatMessages =
fewShotExamplesToChatMessages(examples);
const messages: RequestMessage[] = [
{
content: this.getSystemMessageFor(context),
role: "system"
},
...fewShotExamplesChatMessages,
{
role: "user",
content: this.userMessageFormatter({
suffix,
prefix,
}),
},
];
if (this.debugMode) {
console.log("Copilot messages send:\n", messages);
}
let result = await this.client.queryChatModel(messages);
if (this.debugMode && result.isOk()) {
console.log("Copilot response:\n", result.value);
}
result = this.extractAnswerFromChainOfThoughts(result);
for (const postProcessor of this.postProcessors) {
result = result.map((r) => postProcessor.process(prefix, suffix, r, context));
}
result = this.checkAgainstGuardRails(result);
return result;
}
private getSystemMessageFor(context: Context): string {
if (context === Context.Text) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a paragraph. Your answer must complete this paragraph or sentence in a way that fits the surrounding text without overlapping with it. It must be in the same language as the paragraph.";
}
if (context === Context.Heading) {
return this.systemMessage + "\n\n" + "The <mask/> is located in the Markdown heading. Your answer must complete this title in a way that fits the content of this paragraph and be in the same language as the paragraph.";
}
if (context === Context.BlockQuotes) {
return this.systemMessage + "\n\n" + "The <mask/> is located within a quote. Your answer must complete this quote in a way that fits the context of the paragraph.";
}
if (context === Context.UnorderedList) {
return this.systemMessage + "\n\n" + "The <mask/> is located in an unordered list. Your answer must include one or more list items that fit with the surrounding list without overlapping with it.";
}
if (context === Context.NumberedList) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a numbered list. Your answer must include one or more list items that fit the sequence and context of the surrounding list without overlapping with it.";
}
if (context === Context.CodeBlock) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a code block. Your answer must complete this code block in the same programming language and support the surrounding code and text outside of the code block.";
}
if (context === Context.MathBlock) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a math block. Your answer must only contain LaTeX code that captures the math discussed in the surrounding text. No text or explaination only LaTex math code.";
}
if (context === Context.TaskList) {
return this.systemMessage + "\n\n" + "The <mask/> is located in a task list. Your answer must include one or more (sub)tasks that are logical given the other tasks and the surrounding text.";
}
return this.systemMessage;
}
private extractAnswerFromChainOfThoughts(
result: Result<string, Error>
): Result<string, Error> {
if (result.isErr()) {
return result;
}
const chainOfThoughts = result.value;
const regex = new RegExp(this.removePreAnswerGenerationRegex, "gm");
const match = regex.exec(chainOfThoughts);
if (match === null) {
return err(new Error("No match found"));
}
return ok(chainOfThoughts.replace(regex, ""));
}
private checkAgainstGuardRails(
result: Result<string, Error>
): Result<string, Error> {
if (result.isErr()) {
return result;
}
if (result.value.length === 0) {
return err(new Error("Empty result"));
}
if (result.value.contains("<mask/>")) {
return err(new Error("Mask in result"));
}
return result;
}
}
function fewShotExamplesToChatMessages(
examples: FewShotExample[]
): ChatMessage[] {
return examples
.map((example): ChatMessage[] => {
return [
{
role: "user",
content: example.input,
},
{
role: "assistant",
content: example.answer,
},
];
})
.flat();
}
export default AutoComplete;

View File

@@ -0,0 +1,22 @@
import Context from "../context-detection";
import { PostProcessor } from "../types";
class RemoveCodeIndicators implements PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string {
if (context === Context.CodeBlock) {
completion = completion.replace(/```[a-zA-z]+[ \t]*\n?/g, "");
completion = completion.replace(/\n?```[ \t]*\n?/g, "");
completion = completion.replace(/`/g, "");
}
return completion;
}
}
export default RemoveCodeIndicators;

View File

@@ -0,0 +1,20 @@
import Context from "../context-detection";
import { PostProcessor } from "../types";
class RemoveMathIndicators implements PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string {
if (context === Context.MathBlock) {
completion = completion.replace(/\n?\$\$\n?/g, "");
completion = completion.replace(/\$/g, "");
}
return completion;
}
}
export default RemoveMathIndicators;

View File

@@ -0,0 +1,96 @@
import Context from "../context-detection";
import { PostProcessor } from "../types";
class RemoveOverlap implements PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string {
completion = removeWordOverlapPrefix(prefix, completion);
completion = removeWordOverlapSuffix(completion, suffix);
completion = removeWhiteSpaceOverlapPrefix(suffix, completion);
completion = removeWhiteSpaceOverlapSuffix(completion, suffix);
return completion;
}
}
function removeWhiteSpaceOverlapPrefix(prefix: string, completion: string): string {
let prefixIdx = prefix.length - 1;
while (completion.length > 0 && completion[0] === prefix[prefixIdx]) {
completion = completion.slice(1);
prefixIdx--;
}
return completion;
}
function removeWhiteSpaceOverlapSuffix(completion: string, suffix: string): string {
let suffixIdx = 0;
while (completion.length > 0 && completion[completion.length - 1] === suffix[suffixIdx]) {
completion = completion.slice(0, -1);
suffixIdx++;
}
return completion;
}
function removeWordOverlapPrefix(prefix: string, completion: string): string {
const rightTrimmed = completion.trimStart();
const startIdxOfEachWord = startLocationOfEachWord(prefix);
while (startIdxOfEachWord.length > 0) {
const idx = startIdxOfEachWord.pop();
const leftSubstring = prefix.slice(idx);
if (rightTrimmed.startsWith(leftSubstring)) {
return rightTrimmed.replace(leftSubstring, "");
}
}
return completion;
}
function removeWordOverlapSuffix(completion: string, suffix: string): string {
const suffixTrimmed = removeLeadingWhiteSpace(suffix);
const startIdxOfEachWord = startLocationOfEachWord(completion);
while (startIdxOfEachWord.length > 0) {
const idx = startIdxOfEachWord.pop();
const suffixSubstring = completion.slice(idx);
if (suffixTrimmed.startsWith(suffixSubstring)) {
return completion.replace(suffixSubstring, "");
}
}
return completion;
}
function removeLeadingWhiteSpace(completion: string): string {
return completion.replace(/^[ \t\f\r\v]+/, "");
}
function startLocationOfEachWord(text: string): number[] {
const locations: number[] = [];
if (text.length > 0 && !isWhiteSpaceChar(text[0])) {
locations.push(0);
}
for (let i = 1; i < text.length; i++) {
if (isWhiteSpaceChar(text[i - 1]) && !isWhiteSpaceChar(text[i])) {
locations.push(i);
}
}
return locations;
}
function isWhiteSpaceChar(char: string | undefined): boolean {
return char !== undefined && char.match(/\s/) !== null;
}
export default RemoveOverlap;

View File

@@ -0,0 +1,24 @@
import Context from "../context-detection";
import { PostProcessor } from "../types";
class RemoveWhitespace implements PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string {
if (context === Context.Text || context === Context.Heading || context === Context.MathBlock || context === Context.TaskList || context === Context.NumberedList || context === Context.UnorderedList) {
if (prefix.endsWith(" ") || suffix.endsWith("\n")) {
completion = completion.trimStart();
}
if (suffix.startsWith(" ")) {
completion = completion.trimEnd();
}
}
return completion;
}
}
export default RemoveWhitespace;

View File

@@ -0,0 +1,34 @@
import { generateRandomString } from "../utils";
import Context from "../context-detection";
import { PrefixAndSuffix, PreProcessor } from "../types";
const DATA_VIEW_REGEX = /```dataview(js){0,1}(.|\n)*?```/gm;
const UNIQUE_CURSOR = `${generateRandomString(16)}`;
class DataViewRemover implements PreProcessor {
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix {
let text = prefix + UNIQUE_CURSOR + suffix;
text = text.replace(DATA_VIEW_REGEX, "");
const [prefixNew, suffixNew] = text.split(UNIQUE_CURSOR);
return { prefix: prefixNew, suffix: suffixNew };
}
removesCursor(prefix: string, suffix: string): boolean {
const text = prefix + UNIQUE_CURSOR + suffix;
const dataviewAreasWithCursor = text
.match(DATA_VIEW_REGEX)
?.filter((dataviewArea) => dataviewArea.includes(UNIQUE_CURSOR));
if (
dataviewAreasWithCursor !== undefined &&
dataviewAreasWithCursor.length > 0
) {
return true;
}
return false;
}
}
export default DataViewRemover;

View File

@@ -0,0 +1,24 @@
import Context from "../context-detection";
import { PrefixAndSuffix, PreProcessor } from "../types";
class LengthLimiter implements PreProcessor {
private readonly maxPrefixChars: number;
private readonly maxSuffixChars: number;
constructor(maxPrefixChars: number, maxSuffixChars: number) {
this.maxPrefixChars = maxPrefixChars;
this.maxSuffixChars = maxSuffixChars;
}
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix {
prefix = prefix.slice(-this.maxPrefixChars);
suffix = suffix.slice(0, this.maxSuffixChars);
return { prefix, suffix };
}
removesCursor(prefix: string, suffix: string): boolean {
return false;
}
}
export default LengthLimiter;

View File

@@ -0,0 +1,35 @@
import { TFile } from "obsidian";
import { InfioSettings } from "../../../types/settings";
import State from "./state";
class DisabledFileSpecificState extends State {
getStatusBarText(): string {
return "Disabled for this file";
}
handleSettingChanged(settings: InfioSettings) {
if (!this.context.settings.autocompleteEnabled) {
this.context.transitionToDisabledManualState();
}
if (!this.context.isCurrentFilePathIgnored() || !this.context.currentFileContainsIgnoredTag()) {
this.context.transitionToIdleState();
}
}
handleFileChange(file: TFile): void {
if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
return;
}
if (this.context.settings.autocompleteEnabled) {
this.context.transitionToIdleState();
} else {
this.context.transitionToDisabledManualState();
}
}
}
export default DisabledFileSpecificState;

View File

@@ -0,0 +1,25 @@
import { InfioSettings } from "../../../types/settings";
import { checkForErrors } from "../../../utils/auto-complete";
import State from "./state";
class DisabledInvalidSettingsState extends State {
getStatusBarText(): string {
return "Disabled invalid settings";
}
handleSettingChanged(settings: InfioSettings) {
const settingErrors = checkForErrors(settings);
if (settingErrors.size > 0) {
return
}
if (this.context.settings.autocompleteEnabled) {
this.context.transitionToIdleState();
} else {
this.context.transitionToDisabledManualState();
}
}
}
export default DisabledInvalidSettingsState;

View File

@@ -0,0 +1,21 @@
import { TFile } from "obsidian";
import { InfioSettings } from "../../../types/settings";
import State from "./state";
class DisabledManualState extends State {
getStatusBarText(): string {
return "Disabled";
}
handleSettingChanged(settings: InfioSettings): void {
if (this.context.settings.autocompleteEnabled) {
this.context.transitionToIdleState();
}
}
handleFileChange(file: TFile): void { }
}
export default DisabledManualState;

View File

@@ -0,0 +1,46 @@
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import State from "./state";
class IdleState extends State {
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
if (
!documentChanges.isDocInFocus()
|| !documentChanges.hasDocChanged()
|| documentChanges.hasUserDeleted()
|| documentChanges.hasMultipleCursors()
|| documentChanges.hasSelection()
|| documentChanges.hasUserUndone()
|| documentChanges.hasUserRedone()
) {
return;
}
const cachedSuggestion = this.context.getCachedSuggestionFor(documentChanges.getPrefix(), documentChanges.getSuffix());
const isThereCachedSuggestion = cachedSuggestion !== undefined && cachedSuggestion.trim().length > 0;
if (this.context.settings.cacheSuggestions && isThereCachedSuggestion) {
this.context.transitionToSuggestingState(cachedSuggestion, documentChanges.getPrefix(), documentChanges.getSuffix());
return;
}
if (this.context.containsTriggerCharacters(documentChanges)) {
this.context.transitionToQueuedState(documentChanges.getPrefix(), documentChanges.getSuffix());
}
}
handlePredictCommand(prefix: string, suffix: string): void {
this.context.transitionToPredictingState(prefix, suffix);
}
getStatusBarText(): string {
return "Idle";
}
}
export default IdleState;

View File

@@ -0,0 +1,38 @@
import { TFile } from "obsidian";
import { InfioSettings } from "../../../types/settings";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import { EventHandler } from "./types";
class InitState implements EventHandler {
async handleDocumentChange(documentChanges: DocumentChanges): Promise<void> { }
handleSettingChanged(settings: InfioSettings): void { }
handleAcceptKeyPressed(): boolean {
return false;
}
handlePartialAcceptKeyPressed(): boolean {
return false;
}
handleCancelKeyPressed(): boolean {
return false;
}
handlePredictCommand(): void { }
handleAcceptCommand(): void { }
getStatusBarText(): string {
return "Initializing...";
}
handleFileChange(file: TFile): void {
}
}
export default InitState;

View File

@@ -0,0 +1,94 @@
import { Notice } from "obsidian";
import Context from "../context-detection";
import EventListener from "../../../event-listener";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import State from "./state";
class PredictingState extends State {
private predictionPromise: Promise<void> | null = null;
private isStillNeeded = true;
private readonly prefix: string;
private readonly suffix: string;
constructor(context: EventListener, prefix: string, suffix: string) {
super(context);
this.prefix = prefix;
this.suffix = suffix;
}
static createAndStartPredicting(
context: EventListener,
prefix: string,
suffix: string
): PredictingState {
const predictingState = new PredictingState(context, prefix, suffix);
predictingState.startPredicting();
context.setContext(Context.getContext(prefix, suffix));
return predictingState;
}
handleCancelKeyPressed(): boolean {
this.cancelPrediction();
return true;
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
if (
documentChanges.hasCursorMoved() ||
documentChanges.hasUserTyped() ||
documentChanges.hasUserDeleted() ||
documentChanges.isTextAdded()
) {
this.cancelPrediction();
}
}
private cancelPrediction(): void {
this.isStillNeeded = false;
this.context.transitionToIdleState();
}
startPredicting(): void {
this.predictionPromise = this.predict();
}
private async predict(): Promise<void> {
const result =
await this.context.autocomplete?.fetchPredictions(
this.prefix,
this.suffix
);
if (!this.isStillNeeded) {
return;
}
if (result.isErr()) {
new Notice(
`Copilot: Something went wrong cannot make a prediction. Full error is available in the dev console. Please check your settings. `
);
console.error(result.error);
this.context.transitionToIdleState();
}
const prediction = result.unwrapOr("");
if (prediction === "") {
this.context.transitionToIdleState();
return;
}
this.context.transitionToSuggestingState(prediction, this.prefix, this.suffix);
}
getStatusBarText(): string {
return `Predicting for ${this.context.context}`;
}
}
export default PredictingState;

View File

@@ -0,0 +1,84 @@
import Context from "../context-detection";
import EventListener from "../../../event-listener";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import State from "./state";
class QueuedState extends State {
private timer: ReturnType<typeof setTimeout> | null = null;
private readonly prefix: string;
private readonly suffix: string;
private constructor(
context: EventListener,
prefix: string,
suffix: string
) {
super(context);
this.prefix = prefix;
this.suffix = suffix;
}
static createAndStartTimer(
context: EventListener,
prefix: string,
suffix: string
): QueuedState {
const state = new QueuedState(context, prefix, suffix);
state.startTimer();
context.setContext(Context.getContext(prefix, suffix));
return state;
}
handleCancelKeyPressed(): boolean {
this.cancelTimer();
this.context.transitionToIdleState();
return true;
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
if (
documentChanges.isDocInFocus() &&
documentChanges.isTextAdded() &&
this.context.containsTriggerCharacters(documentChanges)
) {
this.cancelTimer();
this.context.transitionToQueuedState(documentChanges.getPrefix(), documentChanges.getSuffix());
return
}
if (
(documentChanges.hasCursorMoved() ||
documentChanges.hasUserTyped() ||
documentChanges.hasUserDeleted() ||
documentChanges.isTextAdded() ||
!documentChanges.isDocInFocus())
) {
this.cancelTimer();
this.context.transitionToIdleState();
}
}
startTimer(): void {
this.cancelTimer();
this.timer = setTimeout(() => {
this.context.transitionToPredictingState(this.prefix, this.suffix);
}, this.context.settings.delay);
}
private cancelTimer(): void {
if (this.timer !== null) {
clearTimeout(this.timer);
this.timer = null;
}
}
getStatusBarText(): string {
return `Queued (${this.context.settings.delay} ms)`;
}
}
export default QueuedState;

View File

@@ -0,0 +1,67 @@
import { Notice, TFile } from "obsidian";
import EventListener from "../../../event-listener";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
// import { Settings } from "../settings/versions";
import { InfioSettings } from "../../../types/settings";
import { checkForErrors } from "../../../utils/auto-complete";
import { EventHandler } from "./types";
abstract class State implements EventHandler {
protected readonly context: EventListener;
constructor(context: EventListener) {
this.context = context;
}
handleSettingChanged(settings: InfioSettings): void {
const settingErrors = checkForErrors(settings);
if (!settings.autocompleteEnabled) {
new Notice("Copilot is now disabled.");
this.context.transitionToDisabledManualState()
} else if (settingErrors.size > 0) {
new Notice(
`Copilot: There are ${settingErrors.size} errors in your settings. The plugin will be disabled until they are fixed.`
);
this.context.transitionToDisabledInvalidSettingsState();
} else if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
this.context.transitionToDisabledFileSpecificState();
}
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
}
handleAcceptKeyPressed(): boolean {
return false;
}
handlePartialAcceptKeyPressed(): boolean {
return false;
}
handleCancelKeyPressed(): boolean {
return false;
}
handlePredictCommand(prefix: string, suffix: string): void {
}
handleAcceptCommand(): void {
}
abstract getStatusBarText(): string;
handleFileChange(file: TFile): void {
if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
this.context.transitionToDisabledFileSpecificState();
} else if (this.context.isDisabled()) {
this.context.transitionToIdleState();
}
}
}
export default State;

View File

@@ -0,0 +1,177 @@
import { Settings } from "../../../settings/versions";
import { extractNextWordAndRemaining } from "../utils";
import EventListener from "../../../event-listener";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
import State from "./state";
class SuggestingState extends State {
private readonly suggestion: string;
private readonly prefix: string;
private readonly suffix: string;
constructor(context: EventListener, suggestion: string, prefix: string, suffix: string) {
super(context);
this.suggestion = suggestion;
this.prefix = prefix;
this.suffix = suffix;
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
if (
documentChanges.hasCursorMoved()
|| documentChanges.hasUserUndone()
|| documentChanges.hasUserDeleted()
|| documentChanges.hasUserRedone()
|| !documentChanges.isDocInFocus()
|| documentChanges.hasSelection()
|| documentChanges.hasMultipleCursors()
) {
this.clearPrediction();
return;
}
if (
documentChanges.noUserEvents()
|| !documentChanges.hasDocChanged()
) {
return;
}
if (this.hasUserAddedPartOfSuggestion(documentChanges)) {
this.acceptPartialAddedText(documentChanges);
return
}
const currentPrefix = documentChanges.getPrefix();
const currentSuffix = documentChanges.getSuffix();
const suggestion = this.context.getCachedSuggestionFor(currentPrefix, currentSuffix);
const isThereCachedSuggestion = suggestion !== undefined;
const isCachedSuggestionDifferent = suggestion !== this.suggestion;
if (!isCachedSuggestionDifferent) {
return;
}
if (isThereCachedSuggestion) {
this.context.transitionToSuggestingState(suggestion, currentPrefix, currentSuffix);
return;
}
this.clearPrediction();
}
hasUserAddedPartOfSuggestion(documentChanges: DocumentChanges): boolean {
const addedPrefixText = documentChanges.getAddedPrefixText();
const addedSuffixText = documentChanges.getAddedSuffixText();
return addedPrefixText !== undefined
&& addedSuffixText !== undefined
&& this.suggestion.toLowerCase().startsWith(addedPrefixText.toLowerCase())
&& this.suggestion.toLowerCase().endsWith(addedSuffixText.toLowerCase());
}
acceptPartialAddedText(documentChanges: DocumentChanges): void {
const addedPrefixText = documentChanges.getAddedPrefixText();
const addedSuffixText = documentChanges.getAddedSuffixText();
if (addedSuffixText === undefined || addedPrefixText === undefined) {
return;
}
const startIdx = addedPrefixText.length;
const endIdx = this.suggestion.length - addedSuffixText.length
const remainingSuggestion = this.suggestion.substring(startIdx, endIdx);
if (remainingSuggestion.trim() === "") {
this.clearPrediction();
} else {
this.context.transitionToSuggestingState(remainingSuggestion, documentChanges.getPrefix(), documentChanges.getSuffix());
}
}
private clearPrediction(): void {
this.context.transitionToIdleState();
}
handleAcceptKeyPressed(): boolean {
this.accept();
return true;
}
private accept() {
this.addPartialSuggestionCaches(this.suggestion);
this.context.insertCurrentSuggestion(this.suggestion);
this.context.transitionToIdleState();
}
handlePartialAcceptKeyPressed(): boolean {
this.acceptNextWord();
return true;
}
private acceptNextWord() {
const [nextWord, remaining] = extractNextWordAndRemaining(this.suggestion);
if (nextWord !== undefined && remaining !== undefined) {
const updatedPrefix = this.prefix + nextWord;
this.addPartialSuggestionCaches(nextWord, remaining);
this.context.insertCurrentSuggestion(nextWord);
this.context.transitionToSuggestingState(remaining, updatedPrefix, this.suffix, false);
} else {
this.accept();
}
}
private addPartialSuggestionCaches(acceptSuggestion: string, remainingSuggestion = "") {
// store the sub-suggestions in the cache
// so that we can have partial suggestions if the user edits a part
for (let i = 0; i < acceptSuggestion.length; i++) {
const prefix = this.prefix + acceptSuggestion.substring(0, i);
const suggestion = acceptSuggestion.substring(i) + remainingSuggestion;
this.context.addSuggestionToCache(prefix, this.suffix, suggestion);
}
}
private getNextWordAndRemaining(): [string | undefined, string | undefined] {
const words = this.suggestion.split(" ");
if (words.length === 0) {
return ["", ""];
}
if (words.length === 1) {
return [words[0] + " ", ""];
}
return [words[0] + " ", words.slice(1).join(" ")];
}
handleCancelKeyPressed(): boolean {
this.context.clearSuggestionsCache();
this.clearPrediction();
return true;
}
handleAcceptCommand() {
this.accept();
}
getStatusBarText(): string {
return `Suggesting for ${this.context.context}`;
}
handleSettingChanged(settings: Settings): void {
if (!settings.cacheSuggestions) {
this.clearPrediction();
}
}
}
export default SuggestingState;

View File

@@ -0,0 +1,25 @@
import { TFile } from "obsidian";
import { InfioSettings } from "../../../types/settings";
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
export interface EventHandler {
handleSettingChanged(settings: InfioSettings): void;
handleDocumentChange(documentChanges: DocumentChanges): Promise<void>;
handleAcceptKeyPressed(): boolean;
handlePartialAcceptKeyPressed(): boolean;
handleCancelKeyPressed(): boolean;
handlePredictCommand(prefix: string, suffix: string): void;
handleAcceptCommand(): void;
getStatusBarText(): string;
handleFileChange(file: TFile): void;
}

View File

@@ -0,0 +1,57 @@
import { Result } from "neverthrow";
import Context from "./context-detection";
export interface AutocompleteService {
fetchPredictions(
prefix: string,
suffix: string
): Promise<Result<string, Error>>;
}
export interface PostProcessor {
process(
prefix: string,
suffix: string,
completion: string,
context: Context
): string;
}
export interface PreProcessor {
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix;
removesCursor(prefix: string, suffix: string): boolean;
}
export interface PrefixAndSuffix {
prefix: string;
suffix: string;
}
export interface ChatMessage {
content: string;
role: "user" | "assistant" | "system";
}
export interface UserMessageFormattingInputs {
prefix: string;
suffix: string;
}
export type UserMessageFormatter = (
inputs: UserMessageFormattingInputs
) => string;
export interface ApiClient {
queryChatModel(messages: ChatMessage[]): Promise<Result<string, Error>>;
checkIfConfiguredCorrectly?(): Promise<string[]>;
}
export interface ModelOptions {
temperature: number;
top_p: number;
frequency_penalty: number;
presence_penalty: number;
max_tokens: number;
}

View File

@@ -0,0 +1,68 @@
import * as mm from "micromatch";
export function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
export function enumKeys<O extends object, K extends keyof O = keyof O>(
obj: O
): K[] {
return Object.keys(obj).filter((k) => Number.isNaN(+k)) as K[];
}
export function generateRandomString(n: number): string {
let result = '';
const characters = '0123456789abcdef';
for (let i = 0; i < n; i++) {
const randomIndex = Math.floor(Math.random() * characters.length);
result += characters[randomIndex];
}
return result;
}
export function isMatchBetweenPathAndPatterns(
path: string,
patterns: string[],
): boolean {
patterns = patterns
.map(p => p.trim())
.filter((p) => p.length > 0);
if (patterns.length === 0) {
return false;
}
const exclusionPatterns = patterns.filter((p) => p.startsWith('!')).map(p => p.slice(1));
const inclusionPatterns = patterns.filter((p) => !p.startsWith('!'));
return mm.some(path, inclusionPatterns) && !mm.some(path, exclusionPatterns);
}
export function extractNextWordAndRemaining(suggestion: string): [string | undefined, string | undefined] {
const leadingWhitespacesMatch = suggestion.match(/^(\s*)/);
const leadingWhitespaces = leadingWhitespacesMatch ? leadingWhitespacesMatch[0] : '';
const trimmedSuggestion = suggestion.slice(leadingWhitespaces.length);
let nextWord: string | undefined;
let remaining: string | undefined = undefined;
const whitespaceAfterNextWordMatch = trimmedSuggestion.match(/\s+/);
if (!whitespaceAfterNextWordMatch) {
nextWord = trimmedSuggestion || undefined;
} else {
const whitespaceAfterNextWordStartingIndex = whitespaceAfterNextWordMatch.index!;
const whitespaceAfterNextWord = whitespaceAfterNextWordMatch[0];
const whitespaceLength = whitespaceAfterNextWord.length;
const startOfWhitespaceAfterNextWordIndex = whitespaceAfterNextWordStartingIndex + whitespaceLength;
nextWord = trimmedSuggestion.substring(0, whitespaceAfterNextWordStartingIndex);
if (startOfWhitespaceAfterNextWordIndex < trimmedSuggestion.length) {
remaining = trimmedSuggestion.slice(startOfWhitespaceAfterNextWordIndex);
nextWord += whitespaceAfterNextWord;
}
}
return [nextWord ? leadingWhitespaces + nextWord : undefined, remaining];
}

View File

@@ -0,0 +1,44 @@
import { MarkdownPostProcessorContext, Plugin } from "obsidian";
import React from "react";
import { createRoot } from "react-dom/client";
import { InlineEdit as InlineEditComponent } from "../../components/inline-edit/InlineEdit";
import { InfioSettings } from "../../types/settings";
export class InlineEdit {
plugin: Plugin;
settings: InfioSettings;
constructor(plugin: Plugin, settings: InfioSettings) {
this.plugin = plugin;
this.settings = settings;
}
/**
* Markdown 处理器入口函数
*/
Processor(source: string, el: HTMLElement, ctx: MarkdownPostProcessorContext) {
const sec = ctx.getSectionInfo(el);
if (!sec) return;
const container = createDiv();
const root = createRoot(container);
root.render(
React.createElement(InlineEditComponent, {
source: source,
secStartLine: sec.lineStart,
secEndLine: sec.lineEnd,
plugin: this.plugin,
settings: this.settings
})
);
// 移除父元素的代码块样式
const parent = el.parentElement;
if (parent) {
parent.addClass("infio-ai-block");
}
el.replaceWith(container);
}
}

323
src/core/llm/anthropic.ts Normal file
View File

@@ -0,0 +1,323 @@
import Anthropic from '@anthropic-ai/sdk'
import {
ImageBlockParam,
MessageParam,
MessageStreamEvent,
TextBlockParam,
} from '@anthropic-ai/sdk/resources/messages'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
ResponseUsage,
} from '../../types/llm/response'
import { parseImageDataUrl } from '../../utils/image'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
export class AnthropicProvider implements BaseLLMProvider {
private client: Anthropic
private static readonly DEFAULT_MAX_TOKENS = 8192
constructor(apiKey: string) {
this.client = new Anthropic({ apiKey, dangerouslyAllowBrowser: true })
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Anthropic API key is missing. Please set it in settings menu.',
)
}
this.client = new Anthropic({
baseURL: model.baseUrl,
apiKey: model.apiKey,
dangerouslyAllowBrowser: true
})
}
const systemMessage = AnthropicProvider.validateSystemMessages(
request.messages,
)
try {
const response = await this.client.messages.create(
{
model: request.model,
messages: request.messages
.filter((m) => m.role !== 'system')
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
.map((m) => AnthropicProvider.parseRequestMessage(m)),
system: systemMessage,
max_tokens:
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
temperature: request.temperature,
top_p: request.top_p,
},
{
signal: options?.signal,
},
)
return AnthropicProvider.parseNonStreamingResponse(response)
} catch (error) {
if (error instanceof Anthropic.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'Anthropic API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Anthropic API key is missing. Please set it in settings menu.',
)
}
this.client = new Anthropic({
baseURL: model.baseUrl,
apiKey: model.apiKey,
dangerouslyAllowBrowser: true
})
}
const systemMessage = AnthropicProvider.validateSystemMessages(
request.messages,
)
try {
const stream = await this.client.messages.create(
{
model: request.model,
messages: request.messages
.filter((m) => m.role !== 'system')
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
.map((m) => AnthropicProvider.parseRequestMessage(m)),
system: systemMessage,
max_tokens:
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
temperature: request.temperature,
top_p: request.top_p,
stream: true,
},
{
signal: options?.signal,
},
)
// eslint-disable-next-line no-inner-declarations
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
let messageId = ''
let model = ''
let usage: ResponseUsage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
}
for await (const chunk of stream) {
if (chunk.type === 'message_start') {
messageId = chunk.message.id
model = chunk.message.model
usage = {
prompt_tokens: chunk.message.usage.input_tokens,
completion_tokens: chunk.message.usage.output_tokens,
total_tokens:
chunk.message.usage.input_tokens +
chunk.message.usage.output_tokens,
}
} else if (chunk.type === 'content_block_delta') {
yield AnthropicProvider.parseStreamingResponseChunk(
chunk,
messageId,
model,
)
} else if (chunk.type === 'message_delta') {
usage = {
prompt_tokens: usage.prompt_tokens,
completion_tokens:
usage.completion_tokens + chunk.usage.output_tokens,
total_tokens: usage.total_tokens + chunk.usage.output_tokens,
}
}
}
// After the stream is complete, yield the final usage
yield {
id: messageId,
choices: [],
object: 'chat.completion.chunk',
model: model,
usage: usage,
}
}
return streamResponse()
} catch (error) {
if (error instanceof Anthropic.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'Anthropic API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
static parseRequestMessage(message: RequestMessage): MessageParam {
if (message.role !== 'user' && message.role !== 'assistant') {
throw new Error(`Anthropic does not support role: ${message.role}`)
}
if (message.role === 'user' && Array.isArray(message.content)) {
const content = message.content.map(
(part): TextBlockParam | ImageBlockParam => {
switch (part.type) {
case 'text':
return { type: 'text', text: part.text }
case 'image_url': {
const { mimeType, base64Data } = parseImageDataUrl(
part.image_url.url,
)
AnthropicProvider.validateImageType(mimeType)
return {
type: 'image',
source: {
data: base64Data,
media_type:
mimeType as ImageBlockParam['source']['media_type'],
type: 'base64',
},
}
}
}
},
)
return { role: 'user', content }
}
return {
role: message.role,
content: message.content as string,
}
}
static parseNonStreamingResponse(
response: Anthropic.Message,
): LLMResponseNonStreaming {
if (response.content[0].type === 'tool_use') {
throw new Error('Unsupported content type: tool_use')
}
return {
id: response.id,
choices: [
{
finish_reason: response.stop_reason,
message: {
content: response.content[0].text,
role: response.role,
},
},
],
model: response.model,
object: 'chat.completion',
usage: {
prompt_tokens: response.usage.input_tokens,
completion_tokens: response.usage.output_tokens,
total_tokens:
response.usage.input_tokens + response.usage.output_tokens,
},
}
}
static parseStreamingResponseChunk(
chunk: MessageStreamEvent,
messageId: string,
model: string,
): LLMResponseStreaming {
if (chunk.type !== 'content_block_delta') {
throw new Error('Unsupported chunk type')
}
if (chunk.delta.type === 'input_json_delta') {
throw new Error('Unsupported content type: input_json_delta')
}
return {
id: messageId,
choices: [
{
finish_reason: null,
delta: {
content: chunk.delta.text,
},
},
],
object: 'chat.completion.chunk',
model: model,
}
}
private static validateSystemMessages(
messages: RequestMessage[],
): string | undefined {
const systemMessages = messages.filter((m) => m.role === 'system')
if (systemMessages.length > 1) {
throw new Error(`Anthropic does not support more than one system message`)
}
const systemMessage =
systemMessages.length > 0 ? systemMessages[0].content : undefined
if (systemMessage && typeof systemMessage !== 'string') {
throw new Error(
`Anthropic only supports string content for system messages`,
)
}
return systemMessage
}
private static isMessageEmpty(message: RequestMessage) {
if (typeof message.content === 'string') {
return message.content.trim() === ''
}
return message.content.length === 0
}
private static validateImageType(mimeType: string) {
const SUPPORTED_IMAGE_TYPES = [
'image/jpeg',
'image/png',
'image/gif',
'image/webp',
]
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
throw new Error(
`Anthropic does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
', ',
)}`,
)
}
}
}

23
src/core/llm/base.ts Normal file
View File

@@ -0,0 +1,23 @@
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
export type BaseLLMProvider = {
generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming>
streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>>
}

34
src/core/llm/exception.ts Normal file
View File

@@ -0,0 +1,34 @@
export class LLMAPIKeyNotSetException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMAPIKeyNotSetException'
}
}
export class LLMAPIKeyInvalidException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMAPIKeyInvalidException'
}
}
export class LLMBaseUrlNotSetException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMBaseUrlNotSetException'
}
}
export class LLMModelNotSetException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMModelNotSetException'
}
}
export class LLMRateLimitExceededException extends Error {
constructor(message: string) {
super(message)
this.name = 'LLMRateLimitExceededException'
}
}

299
src/core/llm/gemini.ts Normal file
View File

@@ -0,0 +1,299 @@
import {
Content,
EnhancedGenerateContentResponse,
GenerateContentResult,
GenerateContentStreamResult,
GoogleGenerativeAI,
} from '@google/generative-ai'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { parseImageDataUrl } from '../../utils/image'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
/**
* Note on OpenAI Compatibility API:
* Gemini provides an OpenAI-compatible endpoint (https://ai.google.dev/gemini-api/docs/openai)
* which allows using the OpenAI SDK with Gemini models. However, there are currently CORS issues
* preventing its use in Obsidian. Consider switching to this endpoint in the future once these
* issues are resolved.
*/
export class GeminiProvider implements BaseLLMProvider {
private client: GoogleGenerativeAI
private apiKey: string
constructor(apiKey: string) {
this.apiKey = apiKey
this.client = new GoogleGenerativeAI(apiKey)
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
this.apiKey = model.apiKey
this.client = new GoogleGenerativeAI(model.apiKey)
}
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
systemMessages.length > 0
? systemMessages.map((m) => m.content).join('\n')
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const result = await model.generateContent(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return GeminiProvider.parseNonStreamingResponse(
result,
request.model,
messageId,
)
} catch (error) {
const isInvalidApiKey =
error.message?.includes('API_KEY_INVALID') ||
error.message?.includes('API key not valid')
if (isInvalidApiKey) {
throw new LLMAPIKeyInvalidException(
`Gemini API key is invalid. Please update it in settings menu.`,
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
this.apiKey = model.apiKey
this.client = new GoogleGenerativeAI(model.apiKey)
}
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
systemMessages.length > 0
? systemMessages.map((m) => m.content).join('\n')
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const stream = await model.generateContentStream(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return this.streamResponseGenerator(stream, request.model, messageId)
} catch (error) {
const isInvalidApiKey =
error.message?.includes('API_KEY_INVALID') ||
error.message?.includes('API key not valid')
if (isInvalidApiKey) {
throw new LLMAPIKeyInvalidException(
`Gemini API key is invalid. Please update it in settings menu.`,
)
}
throw error
}
}
private async *streamResponseGenerator(
stream: GenerateContentStreamResult,
model: string,
messageId: string,
): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream.stream) {
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
}
}
static parseRequestMessage(message: RequestMessage): Content | null {
if (message.role === 'system') {
return null
}
if (Array.isArray(message.content)) {
return {
role: message.role === 'user' ? 'user' : 'model',
parts: message.content.map((part) => {
switch (part.type) {
case 'text':
return { text: part.text }
case 'image_url': {
const { mimeType, base64Data } = parseImageDataUrl(
part.image_url.url,
)
GeminiProvider.validateImageType(mimeType)
return {
inlineData: {
data: base64Data,
mimeType,
},
}
}
}
}),
}
}
return {
role: message.role === 'user' ? 'user' : 'model',
parts: [
{
text: message.content,
},
],
}
}
static parseNonStreamingResponse(
response: GenerateContentResult,
model: string,
messageId: string,
): LLMResponseNonStreaming {
return {
id: messageId,
choices: [
{
finish_reason:
response.response.candidates?.[0]?.finishReason ?? null,
message: {
content: response.response.text(),
role: 'assistant',
},
},
],
created: Date.now(),
model: model,
object: 'chat.completion',
usage: response.response.usageMetadata
? {
prompt_tokens: response.response.usageMetadata.promptTokenCount,
completion_tokens:
response.response.usageMetadata.candidatesTokenCount,
total_tokens: response.response.usageMetadata.totalTokenCount,
}
: undefined,
}
}
static parseStreamingResponseChunk(
chunk: EnhancedGenerateContentResponse,
model: string,
messageId: string,
): LLMResponseStreaming {
return {
id: messageId,
choices: [
{
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
delta: {
content: chunk.text(),
},
},
],
created: Date.now(),
model: model,
object: 'chat.completion.chunk',
usage: chunk.usageMetadata
? {
prompt_tokens: chunk.usageMetadata.promptTokenCount,
completion_tokens: chunk.usageMetadata.candidatesTokenCount,
total_tokens: chunk.usageMetadata.totalTokenCount,
}
: undefined,
}
}
private static validateImageType(mimeType: string) {
const SUPPORTED_IMAGE_TYPES = [
'image/png',
'image/jpeg',
'image/webp',
'image/heic',
'image/heif',
]
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
throw new Error(
`Gemini does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
', ',
)}`,
)
}
}
}

200
src/core/llm/groq.ts Normal file
View File

@@ -0,0 +1,200 @@
import Groq from 'groq-sdk'
import {
ChatCompletion,
ChatCompletionChunk,
ChatCompletionContentPart,
ChatCompletionMessageParam,
} from 'groq-sdk/resources/chat/completions'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
export class GroqProvider implements BaseLLMProvider {
private client: Groq
constructor(apiKey: string) {
this.client = new Groq({
apiKey,
dangerouslyAllowBrowser: true,
})
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Groq API key is missing. Please set it in settings menu.',
)
}
this.client = new Groq({
apiKey: model.apiKey,
dangerouslyAllowBrowser: true,
})
}
try {
const response = await this.client.chat.completions.create(
{
model: request.model,
messages: request.messages.map((m) =>
GroqProvider.parseRequestMessage(m),
),
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
},
{
signal: options?.signal,
},
)
return GroqProvider.parseNonStreamingResponse(response)
} catch (error) {
if (error instanceof Groq.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'Groq API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
'Groq API key is missing. Please set it in settings menu.',
)
}
this.client = new Groq({
apiKey: model.apiKey,
dangerouslyAllowBrowser: true,
})
}
try {
const stream = await this.client.chat.completions.create(
{
model: request.model,
messages: request.messages.map((m) =>
GroqProvider.parseRequestMessage(m),
),
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
stream: true,
},
{
signal: options?.signal,
},
)
// eslint-disable-next-line no-inner-declarations
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream) {
yield GroqProvider.parseStreamingResponseChunk(chunk)
}
}
return streamResponse()
} catch (error) {
if (error instanceof Groq.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'Groq API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
static parseRequestMessage(
message: RequestMessage,
): ChatCompletionMessageParam {
switch (message.role) {
case 'user': {
const content = Array.isArray(message.content)
? message.content.map((part): ChatCompletionContentPart => {
switch (part.type) {
case 'text':
return { type: 'text', text: part.text }
case 'image_url':
return { type: 'image_url', image_url: part.image_url }
}
})
: message.content
return { role: 'user', content }
}
case 'assistant': {
if (Array.isArray(message.content)) {
throw new Error('Assistant message should be a string')
}
return { role: 'assistant', content: message.content }
}
case 'system': {
if (Array.isArray(message.content)) {
throw new Error('System message should be a string')
}
return { role: 'system', content: message.content }
}
}
}
static parseNonStreamingResponse(
response: ChatCompletion,
): LLMResponseNonStreaming {
return {
id: response.id,
choices: response.choices.map((choice) => ({
finish_reason: choice.finish_reason,
message: {
content: choice.message.content,
role: choice.message.role,
},
})),
created: response.created,
model: response.model,
object: 'chat.completion',
usage: response.usage,
}
}
static parseStreamingResponseChunk(
chunk: ChatCompletionChunk,
): LLMResponseStreaming {
return {
id: chunk.id,
choices: chunk.choices.map((choice) => ({
finish_reason: choice.finish_reason ?? null,
delta: {
content: choice.delta.content ?? null,
role: choice.delta.role,
},
})),
created: chunk.created,
model: chunk.model,
object: 'chat.completion.chunk',
}
}
}

252
src/core/llm/infio.ts Normal file
View File

@@ -0,0 +1,252 @@
import OpenAI from 'openai'
import {
ChatCompletion,
ChatCompletionChunk,
} from 'openai/resources/chat/completions'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
export interface RangeFilter {
gte?: number;
lte?: number;
}
export interface ChunkFilter {
field: string;
match_all?: string[];
range?: RangeFilter;
}
/**
* Interface for making requests to the Infio API
*/
export interface InfioRequest {
/** Required: The content of the user message to attach to the topic and then generate an assistant message in response to */
messages: RequestMessage[];
// /** Required: The ID of the topic to attach the message to */
// topic_id: string;
/** Optional: URLs to include */
links?: string[];
/** Optional: Files to include */
files?: string[];
/** Optional: Whether to highlight results in chunk_html. Default is true */
highlight_results?: boolean;
/** Optional: Delimiters for highlighting citations. Default is [".", "!", "?", "\n", "\t", ","] */
highlight_delimiters?: string[];
/** Optional: Search type - "semantic", "fulltext", or "hybrid". Default is "hybrid" */
search_type?: string;
/** Optional: Filters for chunk filtering */
filters?: ChunkFilter;
/** Optional: Whether to use web search API. Default is false */
use_web_search?: boolean;
/** Optional: LLM model to use */
llm_model?: string;
/** Optional: Force source */
force_source?: string;
/** Optional: Whether completion should come before chunks in stream. Default is false */
completion_first?: boolean;
/** Optional: Whether to stream the response. Default is true */
stream_response?: boolean;
/** Optional: Sampling temperature between 0 and 2. Default is 0.5 */
temperature?: number;
/** Optional: Frequency penalty between -2.0 and 2.0. Default is 0.7 */
frequency_penalty?: number;
/** Optional: Presence penalty between -2.0 and 2.0. Default is 0.7 */
presence_penalty?: number;
/** Optional: Maximum tokens to generate */
max_tokens?: number;
/** Optional: Stop tokens (up to 4 sequences) */
stop_tokens?: string[];
}
export class InfioProvider implements BaseLLMProvider {
// private adapter: OpenAIMessageAdapter
// private client: OpenAI
private apiKey: string
private baseUrl: string
constructor(apiKey: string) {
// this.client = new OpenAI({ apiKey, dangerouslyAllowBrowser: true })
// this.adapter = new OpenAIMessageAdapter()
this.apiKey = apiKey
this.baseUrl = 'https://api.infio.com/api/raw_message'
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
try {
const req: InfioRequest = {
messages: request.messages,
stream_response: false,
temperature: request.temperature,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
max_tokens: request.max_tokens,
}
const options = {
method: 'POST',
headers: {
Authorization: this.apiKey,
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
'Content-Type': 'application/json'
},
body: JSON.stringify(req)
};
const response = await fetch(this.baseUrl, options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json() as ChatCompletion;
return InfioProvider.parseNonStreamingResponse(data);
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
try {
const req: InfioRequest = {
llm_model: request.model,
messages: request.messages,
stream_response: true,
temperature: request.temperature,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
max_tokens: request.max_tokens,
}
const options = {
method: 'POST',
headers: {
Authorization: this.apiKey,
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
"Content-Type": "application/json"
},
body: JSON.stringify(req)
};
const response = await fetch(this.baseUrl, options);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
return {
[Symbol.asyncIterator]: async function* () {
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n').filter(line => line.trim());
for (const line of lines) {
if (line.startsWith('data: ')) {
const jsonData = JSON.parse(line.slice(6)) as ChatCompletionChunk;
if (!jsonData || typeof jsonData !== 'object' || !('choices' in jsonData)) {
throw new Error('Invalid chunk format received');
}
yield InfioProvider.parseStreamingResponseChunk(jsonData);
}
}
}
} finally {
reader.releaseLock();
}
}
};
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
static parseNonStreamingResponse(
response: ChatCompletion,
): LLMResponseNonStreaming {
return {
id: response.id,
choices: response.choices.map((choice) => ({
finish_reason: choice.finish_reason,
message: {
content: choice.message.content,
role: choice.message.role,
},
})),
created: response.created,
model: response.model,
object: 'chat.completion',
system_fingerprint: response.system_fingerprint,
usage: response.usage,
}
}
static parseStreamingResponseChunk(
chunk: ChatCompletionChunk,
): LLMResponseStreaming {
return {
id: chunk.id,
choices: chunk.choices.map((choice) => ({
finish_reason: choice.finish_reason ?? null,
delta: {
content: choice.delta.content ?? null,
role: choice.delta.role,
},
})),
created: chunk.created,
model: chunk.model,
object: 'chat.completion.chunk',
system_fingerprint: chunk.system_fingerprint,
usage: chunk.usage ?? undefined,
}
}
}

142
src/core/llm/manager.ts Normal file
View File

@@ -0,0 +1,142 @@
import { DEEPSEEK_BASE_URL } from '../../constants'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { AnthropicProvider } from './anthropic'
import { GeminiProvider } from './gemini'
import { GroqProvider } from './groq'
import { InfioProvider } from './infio'
import { OllamaProvider } from './ollama'
import { OpenAIAuthenticatedProvider } from './openai'
import { OpenAICompatibleProvider } from './openai-compatible-provider'
export type LLMManagerInterface = {
generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming>
streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>>
}
class LLMManager implements LLMManagerInterface {
private openaiProvider: OpenAIAuthenticatedProvider
private deepseekProvider: OpenAICompatibleProvider
private anthropicProvider: AnthropicProvider
private googleProvider: GeminiProvider
private groqProvider: GroqProvider
private infioProvider: InfioProvider
private ollamaProvider: OllamaProvider
private isInfioEnabled: boolean
constructor(apiKeys: {
deepseek?: string
openai?: string
anthropic?: string
gemini?: string
groq?: string
infio?: string
}) {
this.deepseekProvider = new OpenAICompatibleProvider(apiKeys.deepseek ?? '', DEEPSEEK_BASE_URL)
this.openaiProvider = new OpenAIAuthenticatedProvider(apiKeys.openai ?? '')
this.anthropicProvider = new AnthropicProvider(apiKeys.anthropic ?? '')
this.googleProvider = new GeminiProvider(apiKeys.gemini ?? '')
this.groqProvider = new GroqProvider(apiKeys.groq ?? '')
this.infioProvider = new InfioProvider(apiKeys.infio ?? '')
this.ollamaProvider = new OllamaProvider()
this.isInfioEnabled = !!apiKeys.infio
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (this.isInfioEnabled) {
return await this.infioProvider.generateResponse(
model,
request,
options,
)
}
// use custom provider
switch (model.provider) {
case 'deepseek':
return await this.deepseekProvider.generateResponse(
model,
request,
options,
)
case 'openai':
return await this.openaiProvider.generateResponse(
model,
request,
options,
)
case 'anthropic':
return await this.anthropicProvider.generateResponse(
model,
request,
options,
)
case 'google':
return await this.googleProvider.generateResponse(
model,
request,
options,
)
case 'groq':
return await this.groqProvider.generateResponse(model, request, options)
case 'ollama':
return await this.ollamaProvider.generateResponse(
model,
request,
options,
)
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (this.isInfioEnabled) {
return await this.infioProvider.streamResponse(model, request, options)
}
// use custom provider
switch (model.provider) {
case 'deepseek':
return await this.deepseekProvider.streamResponse(model, request, options)
case 'openai':
return await this.openaiProvider.streamResponse(model, request, options)
case 'anthropic':
return await this.anthropicProvider.streamResponse(
model,
request,
options,
)
case 'google':
return await this.googleProvider.streamResponse(model, request, options)
case 'groq':
return await this.groqProvider.streamResponse(model, request, options)
case 'ollama':
return await this.ollamaProvider.streamResponse(model, request, options)
}
}
}
export default LLMManager

104
src/core/llm/ollama.ts Normal file
View File

@@ -0,0 +1,104 @@
/**
* This provider is nearly identical to OpenAICompatibleProvider, but uses a custom OpenAI client
* (NoStainlessOpenAI) to work around CORS issues specific to Ollama.
*/
import OpenAI from 'openai'
import { FinalRequestOptions } from 'openai/core'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import { LLMBaseUrlNotSetException, LLMModelNotSetException } from './exception'
import { OpenAIMessageAdapter } from './openai-message-adapter'
export class NoStainlessOpenAI extends OpenAI {
defaultHeaders() {
return {
Accept: 'application/json',
'Content-Type': 'application/json',
}
}
buildRequest<Req>(
options: FinalRequestOptions<Req>,
{ retryCount = 0 }: { retryCount?: number } = {},
): { req: RequestInit; url: string; timeout: number } {
const req = super.buildRequest(options, { retryCount })
const headers = req.req.headers as Record<string, string>
Object.keys(headers).forEach((k) => {
if (k.startsWith('x-stainless')) {
// eslint-disable-next-line @typescript-eslint/no-dynamic-delete
delete headers[k]
}
})
return req
}
}
export class OllamaProvider implements BaseLLMProvider {
private adapter: OpenAIMessageAdapter
constructor() {
this.adapter = new OpenAIMessageAdapter()
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!model.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama base URL is missing. Please set it in settings menu.',
)
}
if (!model.name) {
throw new LLMModelNotSetException(
'Ollama model is missing. Please set it in settings menu.',
)
}
const client = new NoStainlessOpenAI({
baseURL: `${model.baseUrl}/v1`,
apiKey: '',
dangerouslyAllowBrowser: true,
})
return this.adapter.generateResponse(client, request, options)
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!model.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama base URL is missing. Please set it in settings menu.',
)
}
if (!model.name) {
throw new LLMModelNotSetException(
'Ollama model is missing. Please set it in settings menu.',
)
}
const client = new NoStainlessOpenAI({
baseURL: `${model.baseUrl}/v1`,
apiKey: '',
dangerouslyAllowBrowser: true,
})
return this.adapter.streamResponse(client, request, options)
}
}

View File

@@ -0,0 +1,62 @@
import OpenAI from 'openai'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import { LLMBaseUrlNotSetException } from './exception'
import { OpenAIMessageAdapter } from './openai-message-adapter'
export class OpenAICompatibleProvider implements BaseLLMProvider {
private adapter: OpenAIMessageAdapter
private client: OpenAI
private apiKey: string
private baseURL: string
constructor(apiKey: string, baseURL: string) {
this.adapter = new OpenAIMessageAdapter()
this.client = new OpenAI({
apiKey: apiKey,
baseURL: baseURL,
dangerouslyAllowBrowser: true,
})
this.apiKey = apiKey
this.baseURL = baseURL
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.baseURL || !this.apiKey) {
throw new LLMBaseUrlNotSetException(
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
)
}
return this.adapter.generateResponse(this.client, request, options)
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.baseURL || !this.apiKey) {
throw new LLMBaseUrlNotSetException(
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
)
}
return this.adapter.streamResponse(this.client, request, options)
}
}

View File

@@ -0,0 +1,155 @@
import OpenAI from 'openai'
import {
ChatCompletion,
ChatCompletionChunk,
ChatCompletionContentPart,
ChatCompletionMessageParam,
} from 'openai/resources/chat/completions'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
export class OpenAIMessageAdapter {
async generateResponse(
client: OpenAI,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
const response = await client.chat.completions.create(
{
model: request.model,
messages: request.messages.map((m) =>
OpenAIMessageAdapter.parseRequestMessage(m),
),
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
logit_bias: request.logit_bias,
prediction: request.prediction,
},
{
signal: options?.signal,
},
)
return OpenAIMessageAdapter.parseNonStreamingResponse(response)
}
async streamResponse(
client: OpenAI,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
const stream = await client.chat.completions.create(
{
model: request.model,
messages: request.messages.map((m) =>
OpenAIMessageAdapter.parseRequestMessage(m),
),
max_completion_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
logit_bias: request.logit_bias,
stream: true,
stream_options: {
include_usage: true,
},
},
{
signal: options?.signal,
},
)
// eslint-disable-next-line no-inner-declarations
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream) {
yield OpenAIMessageAdapter.parseStreamingResponseChunk(chunk)
}
}
return streamResponse()
}
static parseRequestMessage(
message: RequestMessage,
): ChatCompletionMessageParam {
switch (message.role) {
case 'user': {
const content = Array.isArray(message.content)
? message.content.map((part): ChatCompletionContentPart => {
switch (part.type) {
case 'text':
return { type: 'text', text: part.text }
case 'image_url':
return { type: 'image_url', image_url: part.image_url }
}
})
: message.content
return { role: 'user', content }
}
case 'assistant': {
if (Array.isArray(message.content)) {
throw new Error('Assistant message should be a string')
}
return { role: 'assistant', content: message.content }
}
case 'system': {
if (Array.isArray(message.content)) {
throw new Error('System message should be a string')
}
return { role: 'system', content: message.content }
}
}
}
static parseNonStreamingResponse(
response: ChatCompletion,
): LLMResponseNonStreaming {
return {
id: response.id,
choices: response.choices.map((choice) => ({
finish_reason: choice.finish_reason,
message: {
content: choice.message.content,
role: choice.message.role,
},
})),
created: response.created,
model: response.model,
object: 'chat.completion',
system_fingerprint: response.system_fingerprint,
usage: response.usage,
}
}
static parseStreamingResponseChunk(
chunk: ChatCompletionChunk,
): LLMResponseStreaming {
return {
id: chunk.id,
choices: chunk.choices.map((choice) => ({
finish_reason: choice.finish_reason ?? null,
delta: {
content: choice.delta.content ?? null,
role: choice.delta.role,
},
})),
created: chunk.created,
model: chunk.model,
object: 'chat.completion.chunk',
system_fingerprint: chunk.system_fingerprint,
usage: chunk.usage ?? undefined,
}
}
}

91
src/core/llm/openai.ts Normal file
View File

@@ -0,0 +1,91 @@
import OpenAI from 'openai'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
import { OpenAIMessageAdapter } from './openai-message-adapter'
export class OpenAIAuthenticatedProvider implements BaseLLMProvider {
private adapter: OpenAIMessageAdapter
private client: OpenAI
constructor(apiKey: string) {
this.client = new OpenAI({
apiKey,
dangerouslyAllowBrowser: true,
})
this.adapter = new OpenAIMessageAdapter()
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.client.apiKey) {
if (!model.baseUrl) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
this.client = new OpenAI({
apiKey: model.apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
})
}
try {
return this.adapter.generateResponse(this.client, request, options)
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.client.apiKey) {
if (!model.baseUrl) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
this.client = new OpenAI({
apiKey: model.apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
})
}
try {
return this.adapter.streamResponse(this.client, request, options)
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) {
throw new LLMAPIKeyInvalidException(
'OpenAI API key is invalid. Please update it in settings menu.',
)
}
throw error
}
}
}

151
src/core/rag/embedding.ts Normal file
View File

@@ -0,0 +1,151 @@
import { GoogleGenerativeAI } from '@google/generative-ai'
import { OpenAI } from 'openai'
import { EmbeddingModel } from '../../types/embedding'
import {
LLMAPIKeyNotSetException,
LLMBaseUrlNotSetException,
LLMRateLimitExceededException,
} from '../llm/exception'
import { NoStainlessOpenAI } from '../llm/ollama'
export const getEmbeddingModel = (
embeddingModelId: string,
apiKeys: {
openAIApiKey: string
geminiApiKey: string
},
ollamaBaseUrl: string,
): EmbeddingModel => {
switch (embeddingModelId) {
case 'text-embedding-3-small': {
const openai = new OpenAI({
apiKey: apiKeys.openAIApiKey,
dangerouslyAllowBrowser: true,
})
return {
id: 'text-embedding-3-small',
dimension: 1536,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'text-embedding-3-small',
input: text,
})
return embedding.data[0].embedding
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case 'text-embedding-004': {
const client = new GoogleGenerativeAI(apiKeys.geminiApiKey)
const model = client.getGenerativeModel({ model: 'text-embedding-004' })
return {
id: 'text-embedding-004',
dimension: 768,
getEmbedding: async (text: string) => {
try {
const response = await model.embedContent(text)
return response.embedding.values
} catch (error) {
if (
error.status === 429 &&
error.message.includes('RATE_LIMIT_EXCEEDED')
) {
throw new LLMRateLimitExceededException(
'Gemini API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case 'nomic-embed-text': {
const openai = new NoStainlessOpenAI({
apiKey: '',
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
})
return {
id: 'nomic-embed-text',
dimension: 768,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'nomic-embed-text',
input: text,
})
return embedding.data[0].embedding
},
}
}
case 'mxbai-embed-large': {
const openai = new NoStainlessOpenAI({
apiKey: '',
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
})
return {
id: 'mxbai-embed-large',
dimension: 1024,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'mxbai-embed-large',
input: text,
})
return embedding.data[0].embedding
},
}
}
case 'bge-m3': {
const openai = new NoStainlessOpenAI({
apiKey: '',
dangerouslyAllowBrowser: true,
baseURL: `${ollamaBaseUrl}/v1`,
})
return {
id: 'bge-m3',
dimension: 1024,
getEmbedding: async (text: string) => {
if (!ollamaBaseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: 'bge-m3',
input: text,
})
return embedding.data[0].embedding
},
}
}
default:
throw new Error('Invalid embedding model')
}
}

124
src/core/rag/rag-engine.ts Normal file
View File

@@ -0,0 +1,124 @@
import { App } from 'obsidian'
import { QueryProgressState } from '../../components/chat-view/QueryProgress'
import { DBManager } from '../../database/database-manager'
import { VectorManager } from '../../database/modules/vector/vector-manager'
import { SelectVector } from '../../database/schema'
import { EmbeddingModel } from '../../types/embedding'
import { InfioSettings } from '../../types/settings'
import { getEmbeddingModel } from './embedding'
export class RAGEngine {
private app: App
private settings: InfioSettings
private vectorManager: VectorManager
private embeddingModel: EmbeddingModel | null = null
constructor(
app: App,
settings: InfioSettings,
dbManager: DBManager,
) {
this.app = app
this.settings = settings
this.vectorManager = dbManager.getVectorManager()
this.embeddingModel = getEmbeddingModel(
settings.embeddingModelId,
{
openAIApiKey: settings.openAIApiKey,
geminiApiKey: settings.geminiApiKey,
},
settings.ollamaEmbeddingModel.baseUrl,
)
}
setSettings(settings: InfioSettings) {
this.settings = settings
this.embeddingModel = getEmbeddingModel(
settings.embeddingModelId,
{
openAIApiKey: settings.openAIApiKey,
geminiApiKey: settings.geminiApiKey,
},
settings.ollamaEmbeddingModel.baseUrl,
)
}
// TODO: Implement automatic vault re-indexing when settings are changed.
// Currently, users must manually re-index the vault.
async updateVaultIndex(
options: { reindexAll: boolean } = {
reindexAll: false,
},
onQueryProgressChange?: (queryProgress: QueryProgressState) => void,
): Promise<void> {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
await this.vectorManager.updateVaultIndex(
this.embeddingModel,
{
chunkSize: this.settings.ragOptions.chunkSize,
excludePatterns: this.settings.ragOptions.excludePatterns,
includePatterns: this.settings.ragOptions.includePatterns,
reindexAll: options.reindexAll,
},
(indexProgress) => {
onQueryProgressChange?.({
type: 'indexing',
indexProgress,
})
},
)
}
async processQuery({
query,
scope,
onQueryProgressChange,
}: {
query: string
scope?: {
files: string[]
folders: string[]
}
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
}): Promise<
(Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
> {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
// TODO: Decide the vault index update strategy.
// Current approach: Update on every query.
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
const queryEmbedding = await this.getQueryEmbedding(query)
onQueryProgressChange?.({
type: 'querying',
})
const queryResult = await this.vectorManager.performSimilaritySearch(
queryEmbedding,
this.embeddingModel,
{
minSimilarity: this.settings.ragOptions.minSimilarity,
limit: this.settings.ragOptions.limit,
scope,
},
)
onQueryProgressChange?.({
type: 'querying-done',
queryResult,
})
return queryResult
}
private async getQueryEmbedding(query: string): Promise<number[]> {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
return this.embeddingModel.getEmbedding(query)
}
}

View File

@@ -0,0 +1,180 @@
import { PGlite } from '@electric-sql/pglite'
import { type PGliteWithLive, live } from '@electric-sql/pglite/live'
// import { PgliteDatabase, drizzle } from 'drizzle-orm/pglite'
import { App, normalizePath } from 'obsidian'
import { PGLITE_DB_PATH } from '../constants'
import { ConversationManager } from './modules/conversation/conversation-manager'
import { TemplateManager } from './modules/template/template-manager'
import { VectorManager } from './modules/vector/vector-manager'
import { pgliteResources } from './pglite-resources'
import { migrations } from './sql'
export class DBManager {
private app: App
private dbPath: string
private db: PGliteWithLive | null = null
// private db: PgliteDatabase | null = null
private vectorManager: VectorManager
private templateManager: TemplateManager
private conversationManager: ConversationManager
constructor(app: App, dbPath: string) {
this.app = app
this.dbPath = dbPath
}
static async create(app: App): Promise<DBManager> {
const dbManager = new DBManager(app, normalizePath(PGLITE_DB_PATH))
await dbManager.loadExistingDatabase()
if (!dbManager.db) {
await dbManager.createNewDatabase()
}
await dbManager.migrateDatabase()
await dbManager.save()
dbManager.vectorManager = new VectorManager(app, dbManager)
dbManager.templateManager = new TemplateManager(app, dbManager)
dbManager.conversationManager = new ConversationManager(app, dbManager)
console.log('Smart composer database initialized.')
return dbManager
}
// getDb() {
// return this.db
// }
getPgClient() {
return this.db
}
getVectorManager() {
return this.vectorManager
}
getTemplateManager() {
return this.templateManager
}
getConversationManager() {
return this.conversationManager
}
private async createNewDatabase() {
const { fsBundle, wasmModule, vectorExtensionBundlePath } =
await this.loadPGliteResources()
this.db = await PGlite.create({
fsBundle: fsBundle,
wasmModule: wasmModule,
extensions: {
vector: vectorExtensionBundlePath,
live,
},
})
}
private async loadExistingDatabase() {
try {
const databaseFileExists = await this.app.vault.adapter.exists(
this.dbPath,
)
if (!databaseFileExists) {
return null
}
const fileBuffer = await this.app.vault.adapter.readBinary(this.dbPath)
const fileBlob = new Blob([fileBuffer], { type: 'application/x-gzip' })
const { fsBundle, wasmModule, vectorExtensionBundlePath } =
await this.loadPGliteResources()
this.db = await PGlite.create({
loadDataDir: fileBlob,
fsBundle: fsBundle,
wasmModule: wasmModule,
extensions: {
vector: vectorExtensionBundlePath,
live
},
})
// return drizzle(this.pgClient)
} catch (error) {
console.error('Error loading database:', error)
return null
}
}
private async migrateDatabase(): Promise<void> {
if (!this.db) {
throw new Error('Database client not initialized');
}
try {
// Execute SQL migrations
for (const [_key, migration] of Object.entries(migrations)) {
// Split SQL into individual commands and execute them one by one
const commands = migration.sql.split('\n\n').filter(cmd => cmd.trim());
for (const command of commands) {
console.log('Executing SQL migration:', command);
await this.db.query(command);
}
}
} catch (error) {
console.error('Error executing SQL migrations:', error);
throw error;
}
}
async save(): Promise<void> {
if (!this.db) {
return
}
try {
const blob: Blob = await this.db.dumpDataDir('gzip')
await this.app.vault.adapter.writeBinary(
this.dbPath,
Buffer.from(await blob.arrayBuffer()),
)
} catch (error) {
console.error('Error saving database:', error)
}
}
async cleanup() {
this.db?.close()
this.db = null
}
private async loadPGliteResources(): Promise<{
fsBundle: Blob
wasmModule: WebAssembly.Module
vectorExtensionBundlePath: URL
}> {
try {
// Convert base64 to binary data
const wasmBinary = Buffer.from(pgliteResources.wasmBase64, 'base64')
const dataBinary = Buffer.from(pgliteResources.dataBase64, 'base64')
const vectorBinary = Buffer.from(pgliteResources.vectorBase64, 'base64')
// Create blobs from binary data
const fsBundle = new Blob([dataBinary], {
type: 'application/octet-stream',
})
const wasmModule = await WebAssembly.compile(wasmBinary)
// Create a blob URL for the vector extension
const vectorBlob = new Blob([vectorBinary], {
type: 'application/gzip',
})
const vectorExtensionBundlePath = URL.createObjectURL(vectorBlob)
return {
fsBundle,
wasmModule,
vectorExtensionBundlePath: new URL(vectorExtensionBundlePath),
}
} catch (error) {
console.error('Error loading PGlite resources:', error)
throw error
}
}
}

20
src/database/exception.ts Normal file
View File

@@ -0,0 +1,20 @@
export class DatabaseException extends Error {
constructor(message: string) {
super(message)
this.name = 'DatabaseException'
}
}
export class DatabaseNotInitializedException extends DatabaseException {
constructor(message = 'Database not initialized') {
super(message)
this.name = 'DatabaseNotInitializedException'
}
}
export class DuplicateTemplateException extends DatabaseException {
constructor(templateName: string) {
super(`Template with name "${templateName}" already exists`)
this.name = 'DuplicateTemplateException'
}
}

View File

@@ -0,0 +1,162 @@
import { SerializedEditorState } from 'lexical'
import { App } from 'obsidian'
import { editorStateToPlainText } from '../../../components/chat-view/chat-input/utils/editor-state-to-plain-text'
import { ChatAssistantMessage, ChatConversationMeta, ChatMessage, ChatUserMessage } from '../../../types/chat'
import { ContentPart } from '../../../types/llm/request'
import { Mentionable, SerializedMentionable } from '../../../types/mentionable'
import { deserializeMentionable, serializeMentionable } from '../../../utils/mentionable'
import { DBManager } from '../../database-manager'
import { InsertMessage } from '../../schema'
import { ConversationRepository } from './conversation-repository'
export class ConversationManager {
private app: App
private repository: ConversationRepository
private dbManager: DBManager
constructor(app: App, dbManager: DBManager) {
this.app = app
this.dbManager = dbManager
const db = dbManager.getPgClient()
if (!db) throw new Error('Database not initialized')
this.repository = new ConversationRepository(app, db)
}
async createConversation(id: string, title = 'New chat'): Promise<void> {
const conversation = {
id,
title,
createdAt: new Date(),
updatedAt: new Date(),
}
await this.repository.create(conversation)
await this.dbManager.save()
}
async saveConversation(id: string, messages: ChatMessage[]): Promise<void> {
const conversation = await this.repository.findById(id)
if (!conversation) {
let title = 'New chat'
if (messages.length > 0 && messages[0].role === 'user') {
const query = editorStateToPlainText(messages[0].content)
if (query.length > 20) {
title = `${query.slice(0, 20)}...`
} else {
title = query
}
}
await this.createConversation(id, title)
}
// Delete existing messages
await this.repository.deleteAllMessagesFromConversation(id)
// Insert new messages
for (const message of messages) {
const insertMessage = this.serializeMessage(message, id)
await this.repository.createMessage(insertMessage)
}
// Update conversation timestamp
await this.repository.update(id, { updatedAt: new Date() })
await this.dbManager.save()
}
async findConversation(id: string): Promise<ChatMessage[] | null> {
const conversation = await this.repository.findById(id)
if (!conversation) {
return null
}
const messages = await this.repository.findMessagesByConversationId(id)
return messages.map(msg => this.deserializeMessage(msg))
}
async deleteConversation(id: string): Promise<void> {
await this.repository.delete(id)
await this.dbManager.save()
}
getAllConversations(callback: (conversations: ChatConversationMeta[]) => void): void {
const db = this.dbManager.getPgClient()
db?.live.query('SELECT * FROM conversations ORDER BY updated_at', [], (results) => {
callback(results.rows.map(conv => ({
id: conv.id,
title: conv.title,
schemaVersion: 2,
createdAt: conv.createdAt instanceof Date ? conv.createdAt.getTime() : conv.createdAt,
updatedAt: conv.updatedAt instanceof Date ? conv.updatedAt.getTime() : conv.updatedAt,
})))
})
}
async updateConversationTitle(id: string, title: string): Promise<void> {
await this.repository.update(id, { title })
await this.dbManager.save()
}
private serializeMessage(message: ChatMessage, conversationId: string): InsertMessage {
const base = {
id: message.id,
conversationId,
role: message.role,
createdAt: new Date(),
}
if (message.role === 'user') {
const userMessage: ChatUserMessage = message
return {
...base,
content: userMessage.content ? JSON.stringify(userMessage.content) : null,
promptContent: userMessage.promptContent
? typeof userMessage.promptContent === 'string'
? userMessage.promptContent
: JSON.stringify(userMessage.promptContent)
: null,
mentionables: JSON.stringify(userMessage.mentionables.map(serializeMentionable)),
similaritySearchResults: userMessage.similaritySearchResults
? JSON.stringify(userMessage.similaritySearchResults)
: null,
}
} else {
const assistantMessage: ChatAssistantMessage = message
return {
...base,
content: assistantMessage.content,
metadata: assistantMessage.metadata ? JSON.stringify(assistantMessage.metadata) : null,
}
}
}
private deserializeMessage(message: InsertMessage): ChatMessage {
if (message.role === 'user') {
return {
id: message.id,
role: 'user',
content: message.content ? JSON.parse(message.content) as SerializedEditorState : null,
promptContent: message.promptContent
? message.promptContent.startsWith('{')
? JSON.parse(message.promptContent) as ContentPart[]
: message.promptContent
: null,
mentionables: message.mentionables
? (JSON.parse(message.mentionables) as SerializedMentionable[])
.map(m => deserializeMentionable(m, this.app))
.filter((m: Mentionable | null): m is Mentionable => m !== null)
: [],
similaritySearchResults: message.similaritySearchResults
? JSON.parse(message.similaritySearchResults)
: undefined,
}
} else {
return {
id: message.id,
role: 'assistant',
content: message.content || '',
metadata: message.metadata ? JSON.parse(message.metadata) : undefined,
}
}
}
}

View File

@@ -0,0 +1,131 @@
import { PGliteInterface } from '@electric-sql/pglite'
import { App } from 'obsidian'
import {
InsertConversation,
InsertMessage,
SelectConversation,
SelectMessage,
} from '../../schema'
type QueryResult<T> = {
rows: T[]
}
export class ConversationRepository {
private app: App
private db: PGliteInterface
constructor(app: App, db: PGliteInterface) {
this.app = app
this.db = db
}
async create(conversation: InsertConversation): Promise<SelectConversation> {
const result = await this.db.query<SelectConversation>(
`INSERT INTO conversations (id, title, created_at, updated_at)
VALUES ($1, $2, $3, $4)
RETURNING *`,
[
conversation.id,
conversation.title,
conversation.createdAt || new Date(),
conversation.updatedAt || new Date()
]
) as QueryResult<SelectConversation>
return result.rows[0]
}
async createMessage(message: InsertMessage): Promise<SelectMessage> {
const result = await this.db.query<SelectMessage>(
`INSERT INTO messages (
id, conversation_id, role, content,
prompt_content, metadata, mentionables,
similarity_search_results, created_at
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING *`,
[
message.id,
message.conversationId,
message.role,
message.content,
message.promptContent,
message.metadata,
message.mentionables,
message.similaritySearchResults,
message.createdAt || new Date()
]
) as QueryResult<SelectMessage>
return result.rows[0]
}
async findById(id: string): Promise<SelectConversation | undefined> {
const result = await this.db.query<SelectConversation>(
`SELECT * FROM conversations WHERE id = $1 LIMIT 1`,
[id]
) as QueryResult<SelectConversation>
return result.rows[0]
}
async findMessagesByConversationId(conversationId: string): Promise<SelectMessage[]> {
const result = await this.db.query<SelectMessage>(
`SELECT * FROM messages
WHERE conversation_id = $1
ORDER BY created_at`,
[conversationId]
) as QueryResult<SelectMessage>
return result.rows
}
async findAll(): Promise<SelectConversation[]> {
const result = await this.db.query<SelectConversation>(
`SELECT * FROM conversations ORDER BY updated_at DESC`
) as QueryResult<SelectConversation>
return result.rows
}
async update(id: string, data: Partial<InsertConversation>): Promise<SelectConversation> {
const setClauses: string[] = []
const values: any[] = []
let paramIndex = 1
if (data.title !== undefined) {
setClauses.push(`title = $${paramIndex}`)
values.push(data.title)
paramIndex++
}
// Always update updated_at
setClauses.push(`updated_at = $${paramIndex}`)
values.push(new Date())
paramIndex++
// Add id as the last parameter
values.push(id)
const result = await this.db.query<SelectConversation>(
`UPDATE conversations
SET ${setClauses.join(', ')}
WHERE id = $${paramIndex}
RETURNING *`,
values
) as QueryResult<SelectConversation>
return result.rows[0]
}
async delete(id: string): Promise<boolean> {
const result = await this.db.query<SelectConversation>(
`DELETE FROM conversations WHERE id = $1 RETURNING *`,
[id]
) as QueryResult<SelectConversation>
return result.rows.length > 0
}
async deleteAllMessagesFromConversation(conversationId: string): Promise<void> {
await this.db.query(
`DELETE FROM messages WHERE conversation_id = $1`,
[conversationId]
)
}
}

View File

@@ -0,0 +1,51 @@
import fuzzysort from 'fuzzysort'
import { App } from 'obsidian'
import { DBManager } from '../../database-manager'
import { DuplicateTemplateException } from '../../exception'
import { InsertTemplate, SelectTemplate } from '../../schema'
import { TemplateRepository } from './template-repository'
export class TemplateManager {
private app: App
private repository: TemplateRepository
private dbManager: DBManager
constructor(app: App, dbManager: DBManager) {
this.app = app
this.dbManager = dbManager
this.repository = new TemplateRepository(app, dbManager.getPgClient())
}
async createTemplate(template: InsertTemplate): Promise<SelectTemplate> {
const existingTemplate = await this.repository.findByName(template.name)
if (existingTemplate) {
throw new DuplicateTemplateException(template.name)
}
const created = await this.repository.create(template)
await this.dbManager.save()
return created
}
async findAllTemplates(): Promise<SelectTemplate[]> {
return await this.repository.findAll()
}
async searchTemplates(query: string): Promise<SelectTemplate[]> {
const templates = await this.findAllTemplates()
const results = fuzzysort.go(query, templates, {
keys: ['name'],
threshold: 0.2,
limit: 20,
all: true,
})
return results.map((result) => result.obj)
}
async deleteTemplate(id: string): Promise<boolean> {
const deleted = await this.repository.delete(id)
await this.dbManager.save()
return deleted
}
}

View File

@@ -0,0 +1,98 @@
import { PGliteInterface } from '@electric-sql/pglite'
import { App } from 'obsidian'
import { DatabaseNotInitializedException } from '../../exception'
import { type InsertTemplate, type SelectTemplate } from '../../schema'
export class TemplateRepository {
private app: App
private db: PGliteInterface | null
constructor(app: App, pgClient: PGliteInterface | null) {
this.app = app
this.db = pgClient
}
async create(template: InsertTemplate): Promise<SelectTemplate> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const result = await this.db.query<SelectTemplate>(
`INSERT INTO "template" (name, content)
VALUES ($1, $2)
RETURNING *`,
[template.name, template.content]
)
return result.rows[0]
}
async findAll(): Promise<SelectTemplate[]> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const result = await this.db.query<SelectTemplate>(
`SELECT * FROM "template"`
)
return result.rows
}
async findByName(name: string): Promise<SelectTemplate | null> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const result = await this.db.query<SelectTemplate>(
`SELECT * FROM "template" WHERE name = $1`,
[name]
)
return result.rows[0] ?? null
}
async update(
id: string,
template: Partial<InsertTemplate>,
): Promise<SelectTemplate | null> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const setClauses: string[] = []
const params: any[] = []
let paramIndex = 1
if (template.name !== undefined) {
setClauses.push(`name = $${paramIndex}`)
params.push(template.name)
paramIndex++
}
if (template.content !== undefined) {
setClauses.push(`content = $${paramIndex}`)
params.push(template.content)
paramIndex++
}
setClauses.push(`updated_at = now()`)
params.push(id)
const result = await this.db.query<SelectTemplate>(
`UPDATE "template"
SET ${setClauses.join(', ')}
WHERE id = $${paramIndex}
RETURNING *`,
params
)
return result.rows[0] ?? null
}
async delete(id: string): Promise<boolean> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const result = await this.db.query<SelectTemplate>(
`DELETE FROM "template" WHERE id = $1 RETURNING *`,
[id]
)
return result.rows.length > 0
}
}

View File

@@ -0,0 +1,277 @@
import { backOff } from 'exponential-backoff'
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
import { minimatch } from 'minimatch'
import { App, Notice, TFile } from 'obsidian'
import pLimit from 'p-limit'
import { IndexProgress } from '../../../components/chat-view/QueryProgress'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
LLMBaseUrlNotSetException,
LLMRateLimitExceededException,
} from '../../../core/llm/exception'
import { InsertVector, SelectVector } from '../../../database/schema'
import { EmbeddingModel } from '../../../types/embedding'
import { openSettingsModalWithError } from '../../../utils/open-settings-modal'
import { DBManager } from '../../database-manager'
import { VectorRepository } from './vector-repository'
export class VectorManager {
private app: App
private repository: VectorRepository
private dbManager: DBManager
constructor(app: App, dbManager: DBManager) {
this.app = app
this.dbManager = dbManager
this.repository = new VectorRepository(app, dbManager.getPgClient())
}
async performSimilaritySearch(
queryVector: number[],
embeddingModel: EmbeddingModel,
options: {
minSimilarity: number
limit: number
scope?: {
files: string[]
folders: string[]
}
},
): Promise<
(Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
> {
return await this.repository.performSimilaritySearch(
queryVector,
embeddingModel,
options,
)
}
async updateVaultIndex(
embeddingModel: EmbeddingModel,
options: {
chunkSize: number
excludePatterns: string[]
includePatterns: string[]
reindexAll?: boolean
},
updateProgress?: (indexProgress: IndexProgress) => void,
): Promise<void> {
let filesToIndex: TFile[]
if (options.reindexAll) {
filesToIndex = await this.getFilesToIndex({
embeddingModel: embeddingModel,
excludePatterns: options.excludePatterns,
includePatterns: options.includePatterns,
reindexAll: true,
})
await this.repository.clearAllVectors(embeddingModel)
} else {
await this.deleteVectorsForDeletedFiles(embeddingModel)
filesToIndex = await this.getFilesToIndex({
embeddingModel: embeddingModel,
excludePatterns: options.excludePatterns,
includePatterns: options.includePatterns,
})
await this.repository.deleteVectorsForMultipleFiles(
filesToIndex.map((file) => file.path),
embeddingModel,
)
}
if (filesToIndex.length === 0) {
return
}
const textSplitter = RecursiveCharacterTextSplitter.fromLanguage(
'markdown',
{
chunkSize: options.chunkSize,
// TODO: Use token-based chunking after migrating to WebAssembly-based tiktoken
// Current token counting method is too slow for practical use
// lengthFunction: async (text) => {
// return await tokenCount(text)
// },
},
)
const contentChunks: InsertVector[] = (
await Promise.all(
filesToIndex.map(async (file) => {
const fileContent = await this.app.vault.cachedRead(file)
const fileDocuments = await textSplitter.createDocuments([
fileContent,
])
return fileDocuments.map((chunk): InsertVector => {
return {
path: file.path,
mtime: file.stat.mtime,
content: chunk.pageContent,
metadata: {
startLine: chunk.metadata.loc.lines.from as number,
endLine: chunk.metadata.loc.lines.to as number,
},
}
})
}),
)
).flat()
updateProgress?.({
completedChunks: 0,
totalChunks: contentChunks.length,
totalFiles: filesToIndex.length,
})
const embeddingProgress = { completed: 0, inserted: 0 }
const embeddingChunks: InsertVector[] = []
const batchSize = 100
const limit = pLimit(50)
const abortController = new AbortController()
const tasks = contentChunks.map((chunk) =>
limit(async () => {
if (abortController.signal.aborted) {
throw new Error('Operation was aborted')
}
try {
await backOff(
async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content)
const embeddedChunk = {
path: chunk.path,
mtime: chunk.mtime,
content: chunk.content,
embedding,
metadata: chunk.metadata,
}
embeddingChunks.push(embeddedChunk)
embeddingProgress.completed++
updateProgress?.({
completedChunks: embeddingProgress.completed,
totalChunks: contentChunks.length,
totalFiles: filesToIndex.length,
})
// Insert vectors in batches
if (
embeddingChunks.length >=
embeddingProgress.inserted + batchSize ||
embeddingChunks.length === contentChunks.length
) {
await this.repository.insertVectors(
embeddingChunks.slice(
embeddingProgress.inserted,
embeddingProgress.inserted + batchSize,
),
embeddingModel,
)
embeddingProgress.inserted += batchSize
}
},
{
numOfAttempts: 5,
startingDelay: 1000,
timeMultiple: 1.5,
jitter: 'full',
},
)
} catch (error) {
abortController.abort()
throw error
}
}),
)
try {
await Promise.all(tasks)
} catch (error) {
if (
error instanceof LLMAPIKeyNotSetException ||
error instanceof LLMAPIKeyInvalidException ||
error instanceof LLMBaseUrlNotSetException
) {
openSettingsModalWithError(this.app, (error as Error).message)
} else if (error instanceof LLMRateLimitExceededException) {
new Notice(error.message)
} else {
console.error('Error embedding chunks:', error)
throw error
}
} finally {
await this.dbManager.save()
}
}
private async deleteVectorsForDeletedFiles(embeddingModel: EmbeddingModel) {
const indexedFilePaths =
await this.repository.getIndexedFilePaths(embeddingModel)
for (const filePath of indexedFilePaths) {
if (!this.app.vault.getAbstractFileByPath(filePath)) {
await this.repository.deleteVectorsForMultipleFiles(
[filePath],
embeddingModel,
)
}
}
}
private async getFilesToIndex({
embeddingModel,
excludePatterns,
includePatterns,
reindexAll,
}: {
embeddingModel: EmbeddingModel
excludePatterns: string[]
includePatterns: string[]
reindexAll?: boolean
}): Promise<TFile[]> {
let filesToIndex = this.app.vault.getMarkdownFiles()
filesToIndex = filesToIndex.filter((file) => {
return !excludePatterns.some((pattern) => minimatch(file.path, pattern))
})
if (includePatterns.length > 0) {
filesToIndex = filesToIndex.filter((file) => {
return includePatterns.some((pattern) => minimatch(file.path, pattern))
})
}
if (reindexAll) {
return filesToIndex
}
// Check for updated or new files
filesToIndex = await Promise.all(
filesToIndex.map(async (file) => {
const fileChunks = await this.repository.getVectorsByFilePath(
file.path,
embeddingModel,
)
if (fileChunks.length === 0) {
// File is not indexed, so we need to index it
const fileContent = await this.app.vault.cachedRead(file)
if (fileContent.length === 0) {
// Ignore empty files
return null
}
return file
}
const outOfDate = file.stat.mtime > fileChunks[0].mtime
if (outOfDate) {
// File has changed, so we need to re-index it
return file
}
return null
}),
).then((files) => files.filter(Boolean) as TFile[])
return filesToIndex
}
}

View File

@@ -0,0 +1,180 @@
import { PGliteInterface } from '@electric-sql/pglite'
import { App } from 'obsidian'
import { EmbeddingModel } from '../../../types/embedding'
import { DatabaseNotInitializedException } from '../../exception'
import { InsertVector, SelectVector, vectorTables } from '../../schema'
export class VectorRepository {
private app: App
private db: PGliteInterface | null
constructor(app: App, pgClient: PGliteInterface | null) {
this.app = app
this.db = pgClient
}
private getTableName(embeddingModel: EmbeddingModel): string {
const tableDefinition = vectorTables[embeddingModel.dimension]
if (!tableDefinition) {
throw new Error(`No table definition found for model: ${embeddingModel.id}`)
}
return tableDefinition.name
}
async getIndexedFilePaths(embeddingModel: EmbeddingModel): Promise<string[]> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const tableName = this.getTableName(embeddingModel)
const result = await this.db.query<{ path: string }>(
`SELECT DISTINCT path FROM "${tableName}"`
)
return result.rows.map((row: { path: string }) => row.path)
}
async getVectorsByFilePath(
filePath: string,
embeddingModel: EmbeddingModel,
): Promise<SelectVector[]> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const tableName = this.getTableName(embeddingModel)
const result = await this.db.query<SelectVector>(
`SELECT * FROM "${tableName}" WHERE path = $1`,
[filePath]
)
return result.rows
}
async deleteVectorsForSingleFile(
filePath: string,
embeddingModel: EmbeddingModel,
): Promise<void> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const tableName = this.getTableName(embeddingModel)
await this.db.query(
`DELETE FROM "${tableName}" WHERE path = $1`,
[filePath]
)
}
async deleteVectorsForMultipleFiles(
filePaths: string[],
embeddingModel: EmbeddingModel,
): Promise<void> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const tableName = this.getTableName(embeddingModel)
await this.db.query(
`DELETE FROM "${tableName}" WHERE path = ANY($1)`,
[filePaths]
)
}
async clearAllVectors(embeddingModel: EmbeddingModel): Promise<void> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const tableName = this.getTableName(embeddingModel)
await this.db.query(`DELETE FROM "${tableName}"`)
}
async insertVectors(
data: InsertVector[],
embeddingModel: EmbeddingModel,
): Promise<void> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const tableName = this.getTableName(embeddingModel)
// 构建批量插入的 SQL
const values = data.map((vector, index) => {
const offset = index * 5
return `($${offset + 1}, $${offset + 2}, $${offset + 3}, $${offset + 4}, $${offset + 5})`
}).join(',')
const params = data.flatMap(vector => [
vector.path,
vector.mtime,
vector.content,
`[${vector.embedding.join(',')}]`, // 转换为PostgreSQL vector格式
vector.metadata
])
await this.db.query(
`INSERT INTO "${tableName}" (path, mtime, content, embedding, metadata)
VALUES ${values}`,
params
)
}
async performSimilaritySearch(
queryVector: number[],
embeddingModel: EmbeddingModel,
options: {
minSimilarity: number
limit: number
scope?: {
files: string[]
folders: string[]
}
},
): Promise<
(Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const tableName = this.getTableName(embeddingModel)
let scopeCondition = ''
const params: any[] = [`[${queryVector.join(',')}]`, options.minSimilarity, options.limit]
let paramIndex = 4
if (options.scope) {
const conditions: string[] = []
if (options.scope.files.length > 0) {
conditions.push(`path = ANY($${paramIndex})`)
params.push(options.scope.files)
paramIndex++
}
if (options.scope.folders.length > 0) {
const folderConditions = options.scope.folders.map((folder, idx) => {
params.push(`${folder}/%`)
return `path LIKE $${paramIndex + idx}`
})
conditions.push(`(${folderConditions.join(' OR ')})`)
paramIndex += options.scope.folders.length
}
if (conditions.length > 0) {
scopeCondition = `AND (${conditions.join(' OR ')})`
}
}
const query = `
SELECT
id, path, mtime, content, metadata,
1 - (embedding <=> $1::vector) as similarity
FROM "${tableName}"
WHERE 1 - (embedding <=> $1::vector) > $2
${scopeCondition}
ORDER BY similarity DESC
LIMIT $3
`
type SearchResult = Omit<SelectVector, 'embedding'> & { similarity: number }
const result = await this.db.query<SearchResult>(query, params)
return result.rows
}
}

7
src/database/pglite-resources.d.ts vendored Normal file
View File

@@ -0,0 +1,7 @@
export interface PgliteResources {
wasmBase64: string;
dataBase64: string;
vectorBase64: string;
}
export const pgliteResources: PgliteResources;

File diff suppressed because one or more lines are too long

156
src/database/schema.ts Normal file
View File

@@ -0,0 +1,156 @@
import { SerializedLexicalNode } from 'lexical'
import { SUPPORT_EMBEDDING_SIMENTION } from '../constants'
import { EmbeddingModelId } from '../types/embedding'
// PostgreSQL column types
interface ColumnDefinition {
type: string
notNull?: boolean
primaryKey?: boolean
defaultRandom?: boolean
unique?: boolean
defaultNow?: boolean
dimensions?: number
}
interface TableDefinition {
name: string
columns: Record<string, ColumnDefinition>
indices?: Record<string, {
type: string
columns: string[]
options?: string
}>
}
/* Vector Table */
const createVectorTable = (dimension: number): TableDefinition => {
const tableName = `embeddings_${dimension}`
const table: TableDefinition = {
name: tableName,
columns: {
id: { type: 'SERIAL', primaryKey: true },
path: { type: 'TEXT', notNull: true },
mtime: { type: 'BIGINT', notNull: true },
content: { type: 'TEXT', notNull: true },
embedding: { type: 'VECTOR', dimensions: dimension },
metadata: { type: 'JSONB', notNull: true },
}
}
if (dimension <= 2000) {
table.indices = {
[`embeddingIndex_${dimension}`]: {
type: 'HNSW',
columns: ['embedding'],
options: 'vector_cosine_ops'
}
}
}
return table
}
export const vectorTables = SUPPORT_EMBEDDING_SIMENTION.reduce<
Record<number, TableDefinition>
>((acc, dimension) => {
acc[dimension] = createVectorTable(dimension)
return acc
}, {})
// Type definitions for vector table
export interface VectorRecord {
id: number
path: string
mtime: number
content: string
embedding: number[]
metadata: VectorMetaData
}
export type SelectVector = VectorRecord
export type InsertVector = Omit<VectorRecord, 'id'>
export type VectorMetaData = {
startLine: number
endLine: number
}
// // Export individual vector tables for reference
// export const vectorTable0 = vectorTables[EMBEDDING_MODEL_OPTIONS[0].id]
// export const vectorTable1 = vectorTables[EMBEDDING_MODEL_OPTIONS[1].id]
// export const vectorTable2 = vectorTables[EMBEDDING_MODEL_OPTIONS[2].id]
// export const vectorTable3 = vectorTables[EMBEDDING_MODEL_OPTIONS[3].id]
// export const vectorTable4 = vectorTables[EMBEDDING_MODEL_OPTIONS[4].id]
// export const vectorTable5 = vectorTables[EMBEDDING_MODEL_OPTIONS[5].id]
/* Template Table */
export type TemplateContent = {
nodes: SerializedLexicalNode[]
}
export interface TemplateRecord {
id: string
name: string
content: TemplateContent
createdAt: Date
updatedAt: Date
}
export type SelectTemplate = TemplateRecord
export type InsertTemplate = Omit<TemplateRecord, 'id' | 'createdAt' | 'updatedAt'>
export const templateTable: TableDefinition = {
name: 'template',
columns: {
id: { type: 'UUID', primaryKey: true, defaultRandom: true },
name: { type: 'TEXT', notNull: true, unique: true },
content: { type: 'JSONB', notNull: true },
createdAt: { type: 'TIMESTAMP', notNull: true, defaultNow: true },
updatedAt: { type: 'TIMESTAMP', notNull: true, defaultNow: true }
}
}
export interface Conversation {
id: string // uuid
title: string
createdAt: Date
updatedAt: Date
}
export interface Message {
id: string // uuid
conversationId: string // uuid
role: 'user' | 'assistant'
content: string | null
promptContent?: string | null
metadata?: string | null
mentionables?: string | null
similaritySearchResults?: string | null
createdAt: Date
}
export type InsertConversation = {
id: string
title: string
createdAt?: Date
updatedAt?: Date
}
export type SelectConversation = Conversation
export type InsertMessage = {
id: string
conversationId: string
role: 'user' | 'assistant'
content: string | null
promptContent?: string | null
metadata?: string | null
mentionables?: string | null
similaritySearchResults?: string | null
createdAt?: Date
}
export type SelectMessage = Message

118
src/database/sql.ts Normal file
View File

@@ -0,0 +1,118 @@
export interface SqlMigration {
description: string;
sql: string;
}
export const migrations: Record<string, SqlMigration> = {
vector: {
description: "Creates vector tables and indexes for different models",
sql: `
-- Enable required extensions
CREATE EXTENSION IF NOT EXISTS vector;
-- Create vector tables for different models
CREATE TABLE IF NOT EXISTS "embeddings_1536" (
"id" serial PRIMARY KEY NOT NULL,
"path" text NOT NULL,
"mtime" bigint NOT NULL,
"content" text NOT NULL,
"embedding" vector(1536),
"metadata" jsonb NOT NULL
);
CREATE TABLE IF NOT EXISTS "embeddings_1024" (
"id" serial PRIMARY KEY NOT NULL,
"path" text NOT NULL,
"mtime" bigint NOT NULL,
"content" text NOT NULL,
"embedding" vector(1024),
"metadata" jsonb NOT NULL
);
CREATE TABLE IF NOT EXISTS "embeddings_768" (
"id" serial PRIMARY KEY NOT NULL,
"path" text NOT NULL,
"mtime" bigint NOT NULL,
"content" text NOT NULL,
"embedding" vector(768),
"metadata" jsonb NOT NULL
);
CREATE TABLE IF NOT EXISTS "embeddings_512" (
"id" serial PRIMARY KEY NOT NULL,
"path" text NOT NULL,
"mtime" bigint NOT NULL,
"content" text NOT NULL,
"embedding" vector(512),
"metadata" jsonb NOT NULL
);
CREATE TABLE IF NOT EXISTS "embeddings_384" (
"id" serial PRIMARY KEY NOT NULL,
"path" text NOT NULL,
"mtime" bigint NOT NULL,
"content" text NOT NULL,
"embedding" vector(384),
"metadata" jsonb NOT NULL
);
-- Create HNSW indexes for vector similarity search
CREATE INDEX IF NOT EXISTS "embeddingIndex_1536"
ON "embeddings_1536"
USING hnsw ("embedding" vector_cosine_ops);
CREATE INDEX IF NOT EXISTS "embeddingIndex_1024"
ON "embeddings_1024"
USING hnsw ("embedding" vector_cosine_ops);
CREATE INDEX IF NOT EXISTS "embeddingIndex_768"
ON "embeddings_768"
USING hnsw ("embedding" vector_cosine_ops);
CREATE INDEX IF NOT EXISTS "embeddingIndex_512"
ON "embeddings_512"
USING hnsw ("embedding" vector_cosine_ops);
CREATE INDEX IF NOT EXISTS "embeddingIndex_384"
ON "embeddings_384"
USING hnsw ("embedding" vector_cosine_ops);
`
},
template: {
description: "Creates template table with UUID support",
sql: `
-- Create template table
CREATE TABLE IF NOT EXISTS "template" (
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
"name" text NOT NULL,
"content" jsonb NOT NULL,
"created_at" timestamp DEFAULT now() NOT NULL,
"updated_at" timestamp DEFAULT now() NOT NULL,
CONSTRAINT "template_name_unique" UNIQUE("name")
);
`
},
conversation: {
description: "Creates conversations and messages tables",
sql: `
CREATE TABLE IF NOT EXISTS "conversations" (
"id" uuid PRIMARY KEY NOT NULL,
"title" text NOT NULL,
"created_at" timestamp DEFAULT now() NOT NULL,
"updated_at" timestamp DEFAULT now() NOT NULL
);
CREATE TABLE IF NOT EXISTS "messages" (
"id" uuid PRIMARY KEY NOT NULL,
"conversation_id" uuid NOT NULL REFERENCES "conversations"("id") ON DELETE CASCADE,
"role" text NOT NULL,
"content" text,
"prompt_content" text,
"metadata" text,
"mentionables" text,
"similarity_search_results" text,
"created_at" timestamp DEFAULT now() NOT NULL
);
`
}
};

302
src/event-listener.ts Normal file
View File

@@ -0,0 +1,302 @@
import { EditorView } from "@codemirror/view";
import { LRUCache } from "lru-cache";
import { App, TFile } from "obsidian";
import AutoComplete from "./core/autocomplete";
import Context from "./core/autocomplete/context-detection";
import DisabledFileSpecificState from "./core/autocomplete/states/disabled-file-specific-state";
import DisabledInvalidSettingsState from "./core/autocomplete/states/disabled-invalid-settings-state";
import DisabledManualState from "./core/autocomplete/states/disabled-manual-state";
import IdleState from "./core/autocomplete/states/idle-state";
import InitState from "./core/autocomplete/states/init-state";
import PredictingState from "./core/autocomplete/states/predicting-state";
import QueuedState from "./core/autocomplete/states/queued-state";
import State from "./core/autocomplete/states/state";
import SuggestingState from "./core/autocomplete/states/suggesting-state";
import { EventHandler } from "./core/autocomplete/states/types";
import { AutocompleteService } from "./core/autocomplete/types";
import { isMatchBetweenPathAndPatterns } from "./core/autocomplete/utils";
import { DocumentChanges } from "./render-plugin/document-changes-listener";
import { cancelSuggestion, insertSuggestion, updateSuggestion } from "./render-plugin/states";
import StatusBar from "./status-bar";
import { InfioSettings } from './types/settings';
import { checkForErrors } from "./utils/auto-complete";
const FIVE_MINUTES_IN_MS = 1000 * 60 * 5;
const MAX_N_ITEMS_IN_CACHE = 5000;
class EventListener implements EventHandler {
private view: EditorView | null = null;
private state: EventHandler = new InitState();
private statusBar: StatusBar;
private app: App;
context: Context = Context.Text;
autocomplete: AutocompleteService;
settings: InfioSettings;
private currentFile: TFile | null = null;
private suggestionCache = new LRUCache<string, string>({ max: MAX_N_ITEMS_IN_CACHE, ttl: FIVE_MINUTES_IN_MS });
public static fromSettings(
settings: InfioSettings,
statusBar: StatusBar,
app: App
): EventListener {
const autocomplete = createPredictionService(settings);
const eventListener = new EventListener(
settings,
statusBar,
app,
autocomplete
);
const settingErrors = checkForErrors(settings);
if (settings.autocompleteEnabled) {
eventListener.transitionToIdleState()
} else if (settingErrors.size > 0) {
eventListener.transitionToDisabledInvalidSettingsState();
} else if (!settings.autocompleteEnabled) {
eventListener.transitionToDisabledManualState();
}
return eventListener;
}
private constructor(
settings: InfioSettings,
statusBar: StatusBar,
app: App,
autocomplete: AutocompleteService
) {
this.settings = settings;
this.statusBar = statusBar;
this.app = app;
this.autocomplete = autocomplete;
}
public setContext(context: Context): void {
if (context === this.context) {
return;
}
this.context = context;
this.updateStatusBarText();
}
public isSuggesting(): boolean {
return this.state instanceof SuggestingState;
}
public onViewUpdate(view: EditorView): void {
this.view = view;
}
public handleFileChange(file: TFile): void {
this.currentFile = file;
this.state.handleFileChange(file);
}
public isCurrentFilePathIgnored(): boolean {
if (this.currentFile === null) {
return false;
}
const patterns = this.settings.ignoredFilePatterns.split("\n");
return isMatchBetweenPathAndPatterns(this.currentFile.path, patterns);
}
public currentFileContainsIgnoredTag(): boolean {
if (this.currentFile === null) {
return false;
}
const ignoredTags = this.settings.ignoredTags.toLowerCase().split('\n');
const metadata = this.app.metadataCache.getFileCache(this.currentFile);
if (!metadata || !metadata.tags) {
return false;
}
const tags = metadata.tags.map(tag => tag.tag.replace(/#/g, '').toLowerCase());
return tags.some(tag => ignoredTags.includes(tag));
}
insertCurrentSuggestion(suggestion: string): void {
if (this.view === null) {
return;
}
insertSuggestion(this.view, suggestion);
}
cancelSuggestion(): void {
if (this.view === null) {
return;
}
cancelSuggestion(this.view);
}
private transitionTo(state: State): void {
this.state = state;
this.updateStatusBarText();
}
transitionToDisabledFileSpecificState(): void {
this.transitionTo(new DisabledFileSpecificState(this));
}
transitionToDisabledManualState(): void {
this.cancelSuggestion();
this.transitionTo(new DisabledManualState(this));
}
transitionToDisabledInvalidSettingsState(): void {
this.cancelSuggestion();
this.transitionTo(new DisabledInvalidSettingsState(this));
}
transitionToQueuedState(prefix: string, suffix: string): void {
this.transitionTo(
QueuedState.createAndStartTimer(
this,
prefix,
suffix
)
);
}
transitionToPredictingState(prefix: string, suffix: string): void {
this.transitionTo(PredictingState.createAndStartPredicting(
this,
prefix,
suffix
)
);
}
transitionToSuggestingState(
suggestion: string,
prefix: string,
suffix: string,
addToCache = true
): void {
if (this.view === null) {
return;
}
if (suggestion.trim().length === 0) {
this.transitionToIdleState();
return;
}
if (addToCache) {
this.addSuggestionToCache(prefix, suffix, suggestion);
}
this.transitionTo(new SuggestingState(this, suggestion, prefix, suffix));
updateSuggestion(this.view, suggestion);
}
public transitionToIdleState() {
const previousState = this.state;
this.transitionTo(new IdleState(this));
if (previousState instanceof SuggestingState) {
this.cancelSuggestion();
}
}
private updateStatusBarText(): void {
this.statusBar.updateText(this.getStatusBarText());
}
getStatusBarText(): string {
return `Copilot: ${this.state.getStatusBarText()}`;
}
handleSettingChanged(settings: InfioSettings): void {
this.settings = settings;
this.autocomplete = createPredictionService(settings);
if (!this.settings.cacheSuggestions) {
this.clearSuggestionsCache();
}
this.state.handleSettingChanged(settings);
}
async handleDocumentChange(
documentChanges: DocumentChanges
): Promise<void> {
await this.state.handleDocumentChange(documentChanges);
}
handleAcceptKeyPressed(): boolean {
return this.state.handleAcceptKeyPressed();
}
handlePartialAcceptKeyPressed(): boolean {
return this.state.handlePartialAcceptKeyPressed();
}
handleCancelKeyPressed(): boolean {
return this.state.handleCancelKeyPressed();
}
handlePredictCommand(prefix: string, suffix: string): void {
this.state.handlePredictCommand(prefix, suffix);
}
handleAcceptCommand(): void {
this.state.handleAcceptCommand();
}
containsTriggerCharacters(
documentChanges: DocumentChanges
): boolean {
for (const trigger of this.settings.triggers) {
if (trigger.type === "string" && documentChanges.getPrefix().endsWith(trigger.value)) {
return true;
}
if (trigger.type === "regex" && documentChanges.getPrefix().match(trigger.value)) {
return true;
}
}
return false;
}
public isDisabled(): boolean {
return this.state instanceof DisabledManualState || this.state instanceof DisabledInvalidSettingsState || this.state instanceof DisabledFileSpecificState;
}
public isIdle(): boolean {
return this.state instanceof IdleState;
}
public getCachedSuggestionFor(prefix: string, suffix: string): string | undefined {
return this.suggestionCache.get(this.getCacheKey(prefix, suffix));
}
private getCacheKey(prefix: string, suffix: string): string {
const nCharsToKeepPrefix = prefix.length;
const nCharsToKeepSuffix = suffix.length;
return `${prefix.substring(prefix.length - nCharsToKeepPrefix)}<mask/>${suffix.substring(0, nCharsToKeepSuffix)}`
}
public clearSuggestionsCache(): void {
this.suggestionCache.clear();
}
public addSuggestionToCache(prefix: string, suffix: string, suggestion: string): void {
if (!this.settings.cacheSuggestions) {
return;
}
this.suggestionCache.set(this.getCacheKey(prefix, suffix), suggestion);
}
}
function createPredictionService(settings: InfioSettings) {
return AutoComplete.fromSettings(settings);
}
export default EventListener;

View File

@@ -0,0 +1,83 @@
import { useCallback, useEffect, useState } from 'react'
import { useApp } from '../contexts/AppContext'
import { useDatabase } from '../contexts/DatabaseContext'
import { DBManager } from '../database/database-manager'
import { ChatConversationMeta, ChatMessage } from '../types/chat'
type UseChatHistory = {
createOrUpdateConversation: (
id: string,
messages: ChatMessage[],
) => Promise<void>
deleteConversation: (id: string) => Promise<void>
getChatMessagesById: (id: string) => Promise<ChatMessage[] | null>
updateConversationTitle: (id: string, title: string) => Promise<void>
chatList: ChatConversationMeta[]
}
export function useChatHistory(): UseChatHistory {
const app = useApp()
const { getDatabaseManager } = useDatabase()
// 这里更新有点繁琐, 但是能保持 chatList 实时更新
const [chatList, setChatList] = useState<ChatConversationMeta[]>([])
const getManager = useCallback(async (): Promise<DBManager> => {
return await getDatabaseManager()
}, [getDatabaseManager])
const fetchChatList = useCallback(async () => {
const dbManager = await getManager()
dbManager.getConversationManager().getAllConversations(setChatList)
}, [getManager])
useEffect(() => {
void fetchChatList()
}, [fetchChatList])
// 只新增消息
const createConversation = useCallback(
async (id: string, messages: ChatMessage[]): Promise<void> => {
const dbManager = await getManager()
const conversationManager = dbManager.getConversationManager()
await conversationManager.saveConversation(id, messages)
},
[getManager],
)
const deleteConversation = useCallback(
async (id: string): Promise<void> => {
const dbManager = await getManager()
const conversationManager = dbManager.getConversationManager()
await conversationManager.deleteConversation(id)
},
[getManager],
)
const getChatMessagesById = useCallback(
async (id: string): Promise<ChatMessage[] | null> => {
const dbManager = await getManager()
const conversationManager = dbManager.getConversationManager()
return await conversationManager.findConversation(id)
},
[getManager],
)
const updateConversationTitle = useCallback(
async (id: string, title: string): Promise<void> => {
const dbManager = await getManager()
const conversationManager = dbManager.getConversationManager()
await conversationManager.updateConversationTitle(id, title)
},
[getManager],
)
return {
createOrUpdateConversation: createConversation,
deleteConversation,
getChatMessagesById,
updateConversationTitle,
chatList,
}
}

429
src/main.ts Normal file
View File

@@ -0,0 +1,429 @@
import { EditorView } from '@codemirror/view'
import { Editor, MarkdownView, Notice, Plugin, TFile } from 'obsidian'
import { ApplyView } from './ApplyView'
import { ChatView } from './ChatView'
import { ChatProps } from './components/chat-view/Chat'
import { APPLY_VIEW_TYPE, CHAT_VIEW_TYPE } from './constants'
import { InlineEdit } from './core/edit/inline-edit-processor'
import { RAGEngine } from './core/rag/rag-engine'
import { DBManager } from './database/database-manager'
import EventListener from "./event-listener"
import CompletionKeyWatcher from "./render-plugin/completion-key-watcher"
import DocumentChangesListener, {
DocumentChanges,
getPrefix, getSuffix,
hasMultipleCursors,
hasSelection
} from "./render-plugin/document-changes-listener"
import RenderSuggestionPlugin from "./render-plugin/render-surgestion-plugin"
import { InlineSuggestionState } from "./render-plugin/states"
import { InfioSettingTab } from './settings/SettingTab'
import StatusBar from "./status-bar"
import {
InfioSettings,
parseInfioSettings,
} from './types/settings'
import { getMentionableBlockData } from './utils/obsidian'
// Remember to rename these classes and interfaces!
export default class InfioPlugin extends Plugin {
settings: InfioSettings
settingsListeners: ((newSettings: InfioSettings) => void)[] = []
initChatProps?: ChatProps
dbManager: DBManager | null = null
ragEngine: RAGEngine | null = null
inlineEdit: InlineEdit | null = null
private dbManagerInitPromise: Promise<DBManager> | null = null
private ragEngineInitPromise: Promise<RAGEngine> | null = null
async onload() {
await this.loadSettings()
// This creates an icon in the left ribbon.
this.addRibbonIcon('wand-sparkles', 'Open smart composer', () =>
this.openChatView(),
)
this.registerView(CHAT_VIEW_TYPE, (leaf) => new ChatView(leaf, this))
this.registerView(APPLY_VIEW_TYPE, (leaf) => new ApplyView(leaf))
// This adds a settings tab so the user can configure various aspects of the plugin
this.addSettingTab(new InfioSettingTab(this.app, this))
// Register markdown processor for ai blocks
this.inlineEdit = new InlineEdit(this, this.settings);
this.registerMarkdownCodeBlockProcessor("infioedit", (source, el, ctx) => {
this.inlineEdit?.Processor(source, el, ctx);
});
// Update inlineEdit when settings change
this.addSettingsListener((newSettings) => {
this.inlineEdit = new InlineEdit(this, newSettings);
});
// Setup event listener
const statusBar = StatusBar.fromApp(this);
const eventListener = EventListener.fromSettings(
this.settings,
statusBar,
this.app
);
this.addSettingsListener((newSettings) => {
eventListener.handleSettingChanged(newSettings)
});
// Setup render plugin
this.registerEditorExtension([
InlineSuggestionState,
CompletionKeyWatcher(
eventListener.handleAcceptKeyPressed.bind(eventListener) as () => boolean,
eventListener.handlePartialAcceptKeyPressed.bind(eventListener) as () => boolean,
eventListener.handleCancelKeyPressed.bind(eventListener) as () => boolean,
),
DocumentChangesListener(
eventListener.handleDocumentChange.bind(eventListener) as (documentChange: DocumentChanges) => Promise<void>
),
RenderSuggestionPlugin(),
]);
this.app.workspace.onLayoutReady(() => {
const view = this.app.workspace.getActiveViewOfType(MarkdownView);
if (view) {
// @ts-expect-error, not typed
const editorView = view.editor.cm as EditorView;
eventListener.onViewUpdate(editorView);
}
});
this.app.workspace.on("active-leaf-change", (leaf) => {
if (leaf?.view instanceof MarkdownView) {
// @ts-expect-error, not typed
const editorView = leaf.view.editor.cm as EditorView;
eventListener.onViewUpdate(editorView);
if (leaf.view.file) {
eventListener.handleFileChange(leaf.view.file);
}
}
});
this.app.metadataCache.on("changed", (file: TFile) => {
if (file) {
eventListener.handleFileChange(file);
}
});
// This adds a simple command that can be triggered anywhere
this.addCommand({
id: 'infio-open-new-chat',
name: 'Infio open new chat',
callback: () => this.openChatView(true),
})
this.addCommand({
id: 'infio-add-selection-to-chat',
name: 'Infio add selection to chat',
editorCallback: (editor: Editor, view: MarkdownView) => {
this.addSelectionToChat(editor, view)
},
hotkeys: [
{
modifiers: ['Mod', 'Shift'],
key: 'l',
},
],
})
this.addCommand({
id: 'infio-rebuild-vault-index',
name: 'Infio rebuild entire vault index',
callback: async () => {
const notice = new Notice('Rebuilding vault index...', 0)
try {
const ragEngine = await this.getRAGEngine()
await ragEngine.updateVaultIndex(
{ reindexAll: true },
(queryProgress) => {
if (queryProgress.type === 'indexing') {
const { completedChunks, totalChunks } =
queryProgress.indexProgress
notice.setMessage(
`Indexing chunks: ${completedChunks} / ${totalChunks}`,
)
}
},
)
notice.setMessage('Rebuilding vault index complete')
} catch (error) {
console.error(error)
notice.setMessage('Rebuilding vault index failed')
} finally {
setTimeout(() => {
notice.hide()
}, 1000)
}
},
})
this.addCommand({
id: 'infio-update-vault-index',
name: 'Infio update index for modified files',
callback: async () => {
const notice = new Notice('Updating vault index...', 0)
try {
const ragEngine = await this.getRAGEngine()
await ragEngine.updateVaultIndex(
{ reindexAll: false },
(queryProgress) => {
if (queryProgress.type === 'indexing') {
const { completedChunks, totalChunks } =
queryProgress.indexProgress
notice.setMessage(
`Indexing chunks: ${completedChunks} / ${totalChunks}`,
)
}
},
)
notice.setMessage('Vault index updated')
} catch (error) {
console.error(error)
notice.setMessage('Vault index update failed')
} finally {
setTimeout(() => {
notice.hide()
}, 1000)
}
},
})
this.addCommand({
id: 'infio-autocomplete-accept',
name: 'Infio Autocomplete Accept',
editorCheckCallback: (
checking: boolean,
editor: Editor,
view: MarkdownView
) => {
if (checking) {
return (
eventListener.isSuggesting()
);
}
eventListener.handleAcceptCommand();
return true;
},
})
this.addCommand({
id: 'infio-autocomplete-predict',
name: 'Infio Autocomplete Predict',
editorCheckCallback: (
checking: boolean,
editor: Editor,
view: MarkdownView
) => {
// @ts-expect-error, not typed
const editorView = editor.cm as EditorView;
const state = editorView.state;
if (checking) {
return eventListener.isIdle() && !hasMultipleCursors(state) && !hasSelection(state);
}
const prefix = getPrefix(state)
const suffix = getSuffix(state)
eventListener.handlePredictCommand(prefix, suffix);
return true;
},
});
this.addCommand({
id: "infio-autocomplete-toggle",
name: "Infio Autocomplete Toggle",
callback: () => {
const newValue = !this.settings.autocompleteEnabled;
this.setSettings({
...this.settings,
autocompleteEnabled: newValue,
})
},
});
this.addCommand({
id: "infio-autocomplete-enable",
name: "Infio Autocomplete Enable",
checkCallback: (checking) => {
if (checking) {
return !this.settings.autocompleteEnabled;
}
this.setSettings({
...this.settings,
autocompleteEnabled: true,
})
return true;
},
});
this.addCommand({
id: "infio-autocomplete-disable",
name: "Infio Autocomplete Disable",
checkCallback: (checking) => {
if (checking) {
return this.settings.autocompleteEnabled;
}
this.setSettings({
...this.settings,
autocompleteEnabled: false,
})
return true;
},
});
this.addCommand({
id: "infio-ai-inline-edit",
name: "infio Inline Edit",
hotkeys: [
{
modifiers: ['Mod', 'Shift'],
key: "k",
},
],
editorCallback: (editor: Editor) => {
const selection = editor.getSelection();
if (!selection) {
new Notice("Please select some text first");
return;
}
// Get the selection start position
const from = editor.getCursor("from");
// Create the position for inserting the block
const insertPos = { line: from.line, ch: 0 };
// Create the AI block with the selected text
const customBlock = "```infioedit\n```\n";
// Insert the block above the selection
editor.replaceRange(customBlock, insertPos);
},
});
}
onunload() {
this.dbManager?.cleanup()
this.dbManager = null
}
async loadSettings() {
this.settings = parseInfioSettings(await this.loadData())
await this.saveData(this.settings) // Save updated settings
}
async setSettings(newSettings: InfioSettings) {
this.settings = newSettings
await this.saveData(newSettings)
this.ragEngine?.setSettings(newSettings)
this.settingsListeners.forEach((listener) => listener(newSettings))
}
addSettingsListener(
listener: (newSettings: InfioSettings) => void,
) {
this.settingsListeners.push(listener)
return () => {
this.settingsListeners = this.settingsListeners.filter(
(l) => l !== listener,
)
}
}
async openChatView(openNewChat = false) {
const view = this.app.workspace.getActiveViewOfType(MarkdownView)
const editor = view?.editor
if (!view || !editor) {
this.activateChatView(undefined, openNewChat)
return
}
const selectedBlockData = await getMentionableBlockData(editor, view)
this.activateChatView(
{
selectedBlock: selectedBlockData ?? undefined,
},
openNewChat,
)
}
async activateChatView(chatProps?: ChatProps, openNewChat = false) {
// chatProps is consumed in ChatView.tsx
this.initChatProps = chatProps
const leaf = this.app.workspace.getLeavesOfType(CHAT_VIEW_TYPE)[0]
await (leaf ?? this.app.workspace.getRightLeaf(false))?.setViewState({
type: CHAT_VIEW_TYPE,
active: true,
})
if (openNewChat && leaf && leaf.view instanceof ChatView) {
leaf.view.openNewChat(chatProps?.selectedBlock)
}
this.app.workspace.revealLeaf(
this.app.workspace.getLeavesOfType(CHAT_VIEW_TYPE)[0],
)
}
async addSelectionToChat(editor: Editor, view: MarkdownView) {
const data = await getMentionableBlockData(editor, view)
if (!data) return
const leaves = this.app.workspace.getLeavesOfType(CHAT_VIEW_TYPE)
if (leaves.length === 0 || !(leaves[0].view instanceof ChatView)) {
await this.activateChatView({
selectedBlock: data,
})
return
}
// bring leaf to foreground (uncollapse sidebar if it's collapsed)
await this.app.workspace.revealLeaf(leaves[0])
const chatView = leaves[0].view
chatView.addSelectionToChat(data)
chatView.focusMessage()
}
async getDbManager(): Promise<DBManager> {
if (this.dbManager) {
return this.dbManager
}
if (!this.dbManagerInitPromise) {
this.dbManagerInitPromise = (async () => {
this.dbManager = await DBManager.create(this.app)
return this.dbManager
})()
}
// if initialization is running, wait for it to complete instead of creating a new initialization promise
return this.dbManagerInitPromise
}
async getRAGEngine(): Promise<RAGEngine> {
if (this.ragEngine) {
return this.ragEngine
}
if (!this.ragEngineInitPromise) {
this.ragEngineInitPromise = (async () => {
const dbManager = await this.getDbManager()
this.ragEngine = new RAGEngine(this.app, this.settings, dbManager)
return this.ragEngine
})()
}
// if initialization is running, wait for it to complete instead of creating a new initialization promise
return this.ragEngineInitPromise
}
}

View File

@@ -0,0 +1,17 @@
import { App, Modal, Setting } from 'obsidian'
export class OpenSettingsModal extends Modal {
constructor(app: App, title: string, onSubmit: () => void) {
super(app)
this.setTitle(title)
new Setting(this.contentEl).addButton((button) => {
button.setButtonText('Open settings')
button.onClick(() => {
this.close()
onSubmit()
})
})
}
}

View File

@@ -0,0 +1,27 @@
import { Prec } from "@codemirror/state";
import { keymap } from "@codemirror/view";
function CompletionKeyWatcher(
handleAcceptKey: () => boolean,
handlePartialAcceptKey: () => boolean,
handleCancelKey: () => boolean
) {
return Prec.highest(
keymap.of([
{
key: "Tab",
run: handleAcceptKey,
},
{
key: "ArrowRight",
run: handlePartialAcceptKey,
},
{
key: "Escape",
run: handleCancelKey,
},
])
);
}
export default CompletionKeyWatcher;

View File

@@ -0,0 +1,201 @@
import { EditorState } from "@codemirror/state";
import { ViewPlugin, ViewUpdate } from "@codemirror/view";
import UserEvent from "./user-event";
export class DocumentChanges {
private update: ViewUpdate;
constructor(update: ViewUpdate) {
this.update = update;
}
public isDocInFocus(): boolean {
return this.update.view.hasFocus;
}
public noUserEvents(): boolean {
return this.getUserEvents().length === 0;
}
public hasUserTyped(): boolean {
const userEvents = this.getUserEvents();
return userEvents.contains(UserEvent.INPUT_TYPE);
}
public hasUserUndone(): boolean {
const userEvents = this.getUserEvents();
return userEvents.contains(UserEvent.UNDO);
}
public hasUserRedone(): boolean {
const userEvents = this.getUserEvents();
return userEvents.contains(UserEvent.REDO);
}
public hasUserDeleted(): boolean {
const userEvents = this.getUserEvents();
return (
userEvents.filter((event) => UserEvent.isDelete(event)).length > 0
);
}
public hasDocChanged(): boolean {
return (
this.update.docChanged ||
this.hasUserTyped() ||
this.hasUserDeleted()
);
}
public hasCursorMoved(): boolean {
return this.getUserEvents().contains(UserEvent.CURSOR_MOVED);
}
public getUserEvents(): UserEvent[] {
const userEvents: UserEvent[] = [];
for (const transaction of this.update.transactions) {
const event = UserEvent.fromTransaction(transaction);
if (event) {
userEvents.push(event);
}
}
return userEvents;
}
isTextAdded(): boolean {
return this.getAddedText().length > 0;
}
getAddedText(): string {
let addedText = "";
this.update.changes.iterChanges((fromA, toA, fromB, toB, inserted) => {
addedText += inserted;
});
return addedText;
}
getPrefix(): string {
return getPrefix(this.update.state);
}
getSuffix(): string {
return getSuffix(this.update.state);
}
getAddedPrefixText(): string | undefined {
if (!this.isDocInFocus() || this.hasCursorMoved()) {
return undefined;
}
const previousPrefix = this.getPreviousPrefix();
const updatedPrefix = this.getPrefix();
if (updatedPrefix.length > previousPrefix.length) {
return updatedPrefix.substring(previousPrefix.length);
}
return "";
}
getPreviousPrefix(): string {
return getPrefix(this.update.startState);
}
getAddedSuffixText(): string | undefined {
if (!this.isDocInFocus() || this.hasCursorMoved()) {
return undefined;
}
const previousSuffix = this.getPreviousSuffix();
const updatedSuffix = this.getSuffix();
if (updatedSuffix.length > previousSuffix.length) {
return updatedSuffix.substring(0, updatedSuffix.length - previousSuffix.length);
}
return "";
}
getPreviousSuffix(): string {
return getSuffix(this.update.startState);
}
getRemovedPrefixText(): string | undefined {
if (!this.isDocInFocus() || this.hasCursorMoved()) {
return undefined
}
const previousPrefix = this.getPreviousPrefix();
const updatedPrefix = this.getPrefix();
if (updatedPrefix.length < previousPrefix.length) {
return previousPrefix.substring(updatedPrefix.length);
}
return "";
}
getRemovedSuffixText(): string | undefined {
if (!this.isDocInFocus() || this.hasCursorMoved()) {
return undefined
}
const previousSuffix = this.getPreviousSuffix();
const updatedSuffix = this.getSuffix();
if (updatedSuffix.length < previousSuffix.length) {
return previousSuffix.substring(0, previousSuffix.length - updatedSuffix.length);
}
return "";
}
hasSelection(): boolean {
return hasSelection(this.update.state);
}
hasMultipleCursors(): boolean {
return hasMultipleCursors(this.update.state);
}
}
const DocumentChangesListener = (
onDocumentChange: (documentChange: DocumentChanges) => Promise<void>
) =>
ViewPlugin.fromClass(
class FetchPlugin {
async update(update: ViewUpdate) {
await onDocumentChange(new DocumentChanges(update));
}
}
);
export function getPrefix(state: EditorState): string {
return state.doc.sliceString(0, getCursorLocation(state));
}
export function getCursorLocation(state: EditorState): number {
return state.selection.main.head;
}
export function getSuffix(state: EditorState): string {
return state.doc.sliceString(getCursorLocation(state));
}
export function hasMultipleCursors(state: EditorState): boolean {
return state.selection.ranges.length > 1;
}
export function hasSelection(state: EditorState): boolean {
for (const range of state.selection.ranges) {
const { from, to } = range;
if (from !== to) {
return true;
}
}
return false;
}
export default DocumentChangesListener;

View File

@@ -0,0 +1,105 @@
import { Prec } from "@codemirror/state";
import {
Decoration,
DecorationSet,
EditorView,
ViewPlugin,
ViewUpdate,
WidgetType,
} from "@codemirror/view";
import { cancelSuggestion, InlineSuggestionState } from "./states";
import { OptionalSuggestion, Suggestion } from "./types";
const RenderSuggestionPlugin = () =>
Prec.lowest(
// must be lowest else you get infinite loop with state changes by our plugin
ViewPlugin.fromClass(
class RenderPlugin {
decorations: DecorationSet;
suggestion: Suggestion;
constructor(view: EditorView) {
this.decorations = Decoration.none;
this.suggestion = {
value: "",
render: false,
}
}
async update(update: ViewUpdate) {
const suggestion: OptionalSuggestion = update.state.field(
InlineSuggestionState
);
if (suggestion !== null && suggestion !== undefined) {
this.suggestion = suggestion;
}
this.decorations = inlineSuggestionDecoration(
update.view,
this.suggestion
);
}
},
{
decorations: (v) => v.decorations,
}
)
);
function inlineSuggestionDecoration(
view: EditorView,
display_suggestion: Suggestion
) {
const post = view.state.selection.main.head;
if (!display_suggestion.render) {
return Decoration.none;
}
try {
const widget = new InlineSuggestionWidget(display_suggestion.value, view);
const decoration = Decoration.widget({
widget,
side: 1,
});
return Decoration.set([decoration.range(post)]);
} catch (e) {
return Decoration.none;
}
}
class InlineSuggestionWidget extends WidgetType {
constructor(readonly display_suggestion: string, readonly view: EditorView) {
super();
this.display_suggestion = display_suggestion;
this.view = view;
}
eq(other: InlineSuggestionWidget) {
return other.display_suggestion == this.display_suggestion;
}
toDOM() {
const span = document.createElement("span");
span.textContent = this.display_suggestion;
span.style.opacity = "0.4"; // TODO replace with css
span.onclick = () => {
cancelSuggestion(this.view);
}
span.onselect = () => {
cancelSuggestion(this.view);
}
return span;
}
destroy(dom: HTMLElement) {
super.destroy(dom);
}
}
export default RenderSuggestionPlugin;

151
src/render-plugin/states.ts Normal file
View File

@@ -0,0 +1,151 @@
import {
EditorSelection,
EditorState,
SelectionRange,
StateEffect,
StateField,
Transaction,
TransactionSpec,
} from "@codemirror/state";
import { EditorView } from "@codemirror/view";
import { InlineSuggestion, OptionalSuggestion } from "./types";
const InlineSuggestionEffect = StateEffect.define<InlineSuggestion>();
export const InlineSuggestionState = StateField.define<OptionalSuggestion>({
create(): OptionalSuggestion {
return null;
},
update(
value: OptionalSuggestion,
transaction: Transaction
): OptionalSuggestion {
const inlineSuggestion = transaction.effects.find((effect) =>
effect.is(InlineSuggestionEffect)
);
if (
inlineSuggestion?.value?.doc !== undefined
) {
return inlineSuggestion.value.suggestion;
}
return null;
},
});
export const updateSuggestion = (
view: EditorView,
suggestion: string
) => {
const doc = view.state.doc;
sleep(1).then(() => {
view.dispatch({
effects: InlineSuggestionEffect.of({
suggestion: {
value: suggestion,
render: true,
},
doc: doc,
}),
});
});
};
export const cancelSuggestion = (view: EditorView) => {
const doc = view.state.doc;
sleep(1).then(() => {
view.dispatch({
effects: InlineSuggestionEffect.of({
suggestion: {
value: "",
render: false,
},
doc: doc,
}),
});
});
};
export const insertSuggestion = (view: EditorView, suggestion: string) => {
view.dispatch({
...createInsertSuggestionTransaction(
view.state,
suggestion,
view.state.selection.main.from,
view.state.selection.main.to
),
});
};
function createInsertSuggestionTransaction(
state: EditorState,
text: string,
from: number,
to: number
): TransactionSpec {
const docLength = state.doc.length;
if (from < 0 || to > docLength || from > to) {
// If the range is not valid, return an empty transaction spec.
return { changes: [] };
}
const createInsertSuggestionTransactionFromSelectionRange = (
range: SelectionRange
) => {
if (range === state.selection.main) {
return {
changes: { from, to, insert: text },
range: EditorSelection.cursor(to + text.length),
};
}
const length = to - from;
if (hasTextChanged(from, to, state, range)) {
return { range };
}
return {
changes: {
from: range.from - length,
to: range.from,
insert: text,
},
range: EditorSelection.cursor(range.from - length + text.length),
};
};
return {
...state.changeByRange(
createInsertSuggestionTransactionFromSelectionRange
),
userEvent: "input.complete",
};
}
function hasTextChanged(
from: number,
to: number,
state: EditorState,
changeRange: SelectionRange
) {
if (changeRange.empty) {
return false;
}
const length = to - from;
if (length <= 0) {
return false;
}
if (changeRange.to <= from || changeRange.from >= to) {
return false;
}
// check out of bound
if (changeRange.from < 0 || changeRange.to > state.doc.length) {
return false;
}
return (
state.sliceDoc(changeRange.from - length, changeRange.from) !=
state.sliceDoc(from, to)
);
}

Some files were not shown because too many files have changed in this diff Show More