mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-05-08 16:10:09 +00:00
init
This commit is contained in:
53
src/ApplyView.tsx
Normal file
53
src/ApplyView.tsx
Normal file
@@ -0,0 +1,53 @@
|
||||
import { TFile, View, WorkspaceLeaf } from 'obsidian'
|
||||
import { Root, createRoot } from 'react-dom/client'
|
||||
|
||||
import ApplyViewRoot from './components/apply-view/ApplyViewRoot'
|
||||
import { APPLY_VIEW_TYPE } from './constants'
|
||||
import { AppProvider } from './contexts/AppContext'
|
||||
|
||||
export type ApplyViewState = {
|
||||
file: TFile
|
||||
originalContent: string
|
||||
newContent: string
|
||||
}
|
||||
|
||||
export class ApplyView extends View {
|
||||
private root: Root | null = null
|
||||
|
||||
private state: ApplyViewState | null = null
|
||||
|
||||
constructor(leaf: WorkspaceLeaf) {
|
||||
super(leaf)
|
||||
}
|
||||
|
||||
getViewType() {
|
||||
return APPLY_VIEW_TYPE
|
||||
}
|
||||
|
||||
getDisplayText() {
|
||||
return `Applying: ${this.state?.file?.name ?? ''}`
|
||||
}
|
||||
|
||||
async setState(state: ApplyViewState) {
|
||||
this.state = state
|
||||
// Should render here because onOpen is called before setState
|
||||
this.render()
|
||||
}
|
||||
|
||||
async onOpen() {
|
||||
this.root = createRoot(this.containerEl)
|
||||
}
|
||||
|
||||
async onClose() {
|
||||
this.root?.unmount()
|
||||
}
|
||||
|
||||
async render() {
|
||||
if (!this.root || !this.state) return
|
||||
this.root.render(
|
||||
<AppProvider app={this.app}>
|
||||
<ApplyViewRoot state={this.state} close={() => this.leaf.detach()} />
|
||||
</AppProvider>,
|
||||
)
|
||||
}
|
||||
}
|
||||
117
src/ChatView.tsx
Normal file
117
src/ChatView.tsx
Normal file
@@ -0,0 +1,117 @@
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { ItemView, WorkspaceLeaf } from 'obsidian'
|
||||
import React from 'react'
|
||||
import { Root, createRoot } from 'react-dom/client'
|
||||
|
||||
import Chat, { ChatProps, ChatRef } from './components/chat-view/Chat'
|
||||
import { CHAT_VIEW_TYPE } from './constants'
|
||||
import { AppProvider } from './contexts/AppContext'
|
||||
import { DarkModeProvider } from './contexts/DarkModeContext'
|
||||
import { DatabaseProvider } from './contexts/DatabaseContext'
|
||||
import { DialogProvider } from './contexts/DialogContext'
|
||||
import { LLMProvider } from './contexts/LLMContext'
|
||||
import { RAGProvider } from './contexts/RAGContext'
|
||||
import { SettingsProvider } from './contexts/SettingsContext'
|
||||
import InfioPlugin from './main'
|
||||
import { MentionableBlockData } from './types/mentionable'
|
||||
import { InfioSettings } from './types/settings'
|
||||
|
||||
export class ChatView extends ItemView {
|
||||
private root: Root | null = null
|
||||
private settings: InfioSettings
|
||||
private initialChatProps?: ChatProps
|
||||
private chatRef: React.RefObject<ChatRef> = React.createRef()
|
||||
|
||||
constructor(
|
||||
leaf: WorkspaceLeaf,
|
||||
private plugin: InfioPlugin,
|
||||
) {
|
||||
super(leaf)
|
||||
this.settings = plugin.settings
|
||||
this.initialChatProps = plugin.initChatProps
|
||||
}
|
||||
|
||||
getViewType() {
|
||||
return CHAT_VIEW_TYPE
|
||||
}
|
||||
|
||||
getIcon() {
|
||||
return 'wand-sparkles'
|
||||
}
|
||||
|
||||
getDisplayText() {
|
||||
return 'Smart composer chat'
|
||||
}
|
||||
|
||||
async onOpen() {
|
||||
await this.render()
|
||||
|
||||
// Consume chatProps
|
||||
this.initialChatProps = undefined
|
||||
}
|
||||
|
||||
async onClose() {
|
||||
this.root?.unmount()
|
||||
}
|
||||
|
||||
async render() {
|
||||
if (!this.root) {
|
||||
this.root = createRoot(this.containerEl.children[1])
|
||||
}
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
gcTime: 0, // Immediately garbage collect queries. It prevents memory leak on ChatView close.
|
||||
},
|
||||
mutations: {
|
||||
gcTime: 0, // Immediately garbage collect mutations. It prevents memory leak on ChatView close.
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
this.root.render(
|
||||
<AppProvider app={this.app}>
|
||||
<SettingsProvider
|
||||
settings={this.settings}
|
||||
setSettings={(newSettings) => this.plugin.setSettings(newSettings)}
|
||||
addSettingsChangeListener={(listener) =>
|
||||
this.plugin.addSettingsListener(listener)
|
||||
}
|
||||
>
|
||||
<DarkModeProvider>
|
||||
<LLMProvider>
|
||||
<DatabaseProvider
|
||||
getDatabaseManager={() => this.plugin.getDbManager()}
|
||||
>
|
||||
<RAGProvider getRAGEngine={() => this.plugin.getRAGEngine()}>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<React.StrictMode>
|
||||
<DialogProvider
|
||||
container={this.containerEl.children[1] as HTMLElement}
|
||||
>
|
||||
<Chat ref={this.chatRef} {...this.initialChatProps} />
|
||||
</DialogProvider>
|
||||
</React.StrictMode>
|
||||
</QueryClientProvider>
|
||||
</RAGProvider>
|
||||
</DatabaseProvider>
|
||||
</LLMProvider>
|
||||
</DarkModeProvider>
|
||||
</SettingsProvider>
|
||||
</AppProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
openNewChat(selectedBlock?: MentionableBlockData) {
|
||||
this.chatRef.current?.openNewChat(selectedBlock)
|
||||
}
|
||||
|
||||
addSelectionToChat(selectedBlock: MentionableBlockData) {
|
||||
this.chatRef.current?.addSelectionToChat(selectedBlock)
|
||||
}
|
||||
|
||||
focusMessage() {
|
||||
this.chatRef.current?.focusMessage()
|
||||
}
|
||||
}
|
||||
142
src/components/apply-view/ApplyViewRoot.tsx
Normal file
142
src/components/apply-view/ApplyViewRoot.tsx
Normal file
@@ -0,0 +1,142 @@
|
||||
import { Change, diffLines } from 'diff'
|
||||
import { CheckIcon, X } from 'lucide-react'
|
||||
import { getIcon } from 'obsidian'
|
||||
import { useState } from 'react'
|
||||
|
||||
import { ApplyViewState } from '../../ApplyView'
|
||||
import { useApp } from '../../contexts/AppContext'
|
||||
|
||||
export default function ApplyViewRoot({
|
||||
state,
|
||||
close,
|
||||
}: {
|
||||
state: ApplyViewState
|
||||
close: () => void
|
||||
}) {
|
||||
const acceptIcon = getIcon('check')
|
||||
const rejectIcon = getIcon('x')
|
||||
const excludeIcon = getIcon('x')
|
||||
|
||||
const app = useApp()
|
||||
|
||||
const [diff, setDiff] = useState<Change[]>(
|
||||
diffLines(state.originalContent, state.newContent),
|
||||
)
|
||||
|
||||
const handleAccept = async () => {
|
||||
const newContent = diff
|
||||
.filter((change) => !change.removed)
|
||||
.map((change) => change.value)
|
||||
.join('')
|
||||
await app.vault.modify(state.file, newContent)
|
||||
close()
|
||||
}
|
||||
|
||||
const handleReject = async () => {
|
||||
close()
|
||||
}
|
||||
|
||||
const excludeDiffLine = (index: number) => {
|
||||
setDiff((prevDiff) => {
|
||||
const newDiff = [...prevDiff]
|
||||
const change = newDiff[index]
|
||||
if (change.added) {
|
||||
// Remove the entry if it's an added line
|
||||
return newDiff.filter((_, i) => i !== index)
|
||||
} else if (change.removed) {
|
||||
change.removed = false
|
||||
}
|
||||
return newDiff
|
||||
})
|
||||
}
|
||||
|
||||
const acceptDiffLine = (index: number) => {
|
||||
setDiff((prevDiff) => {
|
||||
const newDiff = [...prevDiff]
|
||||
const change = newDiff[index]
|
||||
if (change.added) {
|
||||
change.added = false
|
||||
} else if (change.removed) {
|
||||
// Remove the entry if it's a removed line
|
||||
return newDiff.filter((_, i) => i !== index)
|
||||
}
|
||||
return newDiff
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<div id="infio-apply-view">
|
||||
<div className="view-header">
|
||||
<div className="view-header-left">
|
||||
<div className="view-header-nav-buttons"></div>
|
||||
</div>
|
||||
<div className="view-header-title-container mod-at-start">
|
||||
<div className="view-header-title">
|
||||
Applying: {state?.file?.name ?? ''}
|
||||
</div>
|
||||
<div className="view-actions">
|
||||
<button
|
||||
className="clickable-icon view-action infio-approve-button"
|
||||
aria-label="Accept changes"
|
||||
onClick={handleAccept}
|
||||
>
|
||||
{acceptIcon && <CheckIcon size={14} />}
|
||||
Accept
|
||||
</button>
|
||||
<button
|
||||
className="clickable-icon view-action infio-reject-button"
|
||||
aria-label="Reject changes"
|
||||
onClick={handleReject}
|
||||
>
|
||||
{rejectIcon && <X size={14} />}
|
||||
Reject
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="view-content">
|
||||
<div className="markdown-source-view cm-s-obsidian mod-cm6 node-insert-event is-readable-line-width is-live-preview is-folding show-properties">
|
||||
<div className="cm-editor">
|
||||
<div className="cm-scroller">
|
||||
<div className="cm-sizer">
|
||||
<div className="infio-inline-title">
|
||||
{state?.file?.name
|
||||
? state.file.name.replace(/\.[^/.]+$/, '')
|
||||
: ''}
|
||||
</div>
|
||||
|
||||
{diff.map((part, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className={`infio-diff-line ${part.added ? 'added' : part.removed ? 'removed' : ''}`}
|
||||
>
|
||||
<div style={{ width: '100%' }}>{part.value}</div>
|
||||
{(part.added || part.removed) && (
|
||||
<div className="infio-diff-line-actions">
|
||||
<button
|
||||
aria-label="Accept line"
|
||||
onClick={() => acceptDiffLine(index)}
|
||||
className="infio-accept"
|
||||
>
|
||||
{acceptIcon && 'Y'}
|
||||
</button>
|
||||
<button
|
||||
aria-label="Exclude line"
|
||||
onClick={() => excludeDiffLine(index)}
|
||||
className="infio-exclude"
|
||||
>
|
||||
{excludeIcon && 'N'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
90
src/components/chat-view/AssistantMessageActions.tsx
Normal file
90
src/components/chat-view/AssistantMessageActions.tsx
Normal file
@@ -0,0 +1,90 @@
|
||||
import * as Tooltip from '@radix-ui/react-tooltip'
|
||||
import { Check, CopyIcon } from 'lucide-react'
|
||||
import { useMemo, useState } from 'react'
|
||||
|
||||
import { ChatAssistantMessage } from '../../types/chat'
|
||||
import { calculateLLMCost } from '../../utils/price-calculator'
|
||||
|
||||
import LLMResponseInfoPopover from './LLMResponseInfoPopover'
|
||||
|
||||
function CopyButton({ message }: { message: ChatAssistantMessage }) {
|
||||
const [copied, setCopied] = useState(false)
|
||||
|
||||
const handleCopy = async () => {
|
||||
await navigator.clipboard.writeText(message.content)
|
||||
setCopied(true)
|
||||
setTimeout(() => {
|
||||
setCopied(false)
|
||||
}, 1500)
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip.Provider delayDuration={0}>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger asChild>
|
||||
<button>
|
||||
{copied ? (
|
||||
<Check
|
||||
size={12}
|
||||
className="infio-chat-message-actions-icon--copied"
|
||||
/>
|
||||
) : (
|
||||
<CopyIcon onClick={handleCopy} size={12} />
|
||||
)}
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Portal>
|
||||
<Tooltip.Content className="infio-tooltip-content">
|
||||
Copy message
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Portal>
|
||||
</Tooltip.Root>
|
||||
</Tooltip.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
function LLMResponesInfoButton({ message }: { message: ChatAssistantMessage }) {
|
||||
const cost = useMemo<number | null>(() => {
|
||||
if (!message.metadata?.model || !message.metadata?.usage) {
|
||||
return 0
|
||||
}
|
||||
return calculateLLMCost({
|
||||
model: message.metadata.model,
|
||||
usage: message.metadata.usage,
|
||||
})
|
||||
}, [message])
|
||||
|
||||
return (
|
||||
<Tooltip.Provider delayDuration={0}>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger asChild>
|
||||
<div>
|
||||
<LLMResponseInfoPopover
|
||||
usage={message.metadata?.usage}
|
||||
estimatedPrice={cost}
|
||||
model={message.metadata?.model?.name}
|
||||
/>
|
||||
</div>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Portal>
|
||||
<Tooltip.Content className="infio-tooltip-content">
|
||||
View details
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Portal>
|
||||
</Tooltip.Root>
|
||||
</Tooltip.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export default function AssistantMessageActions({
|
||||
message,
|
||||
}: {
|
||||
message: ChatAssistantMessage
|
||||
}) {
|
||||
return (
|
||||
<div className="infio-chat-message-actions">
|
||||
<LLMResponesInfoButton message={message} />
|
||||
<CopyButton message={message} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
737
src/components/chat-view/Chat.tsx
Normal file
737
src/components/chat-view/Chat.tsx
Normal file
@@ -0,0 +1,737 @@
|
||||
import { useMutation } from '@tanstack/react-query'
|
||||
import { CircleStop, History, Plus } from 'lucide-react'
|
||||
import { App, Notice } from 'obsidian'
|
||||
import {
|
||||
forwardRef,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useImperativeHandle,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import { ApplyViewState } from '../../ApplyView'
|
||||
import { APPLY_VIEW_TYPE } from '../../constants'
|
||||
import { useApp } from '../../contexts/AppContext'
|
||||
import { useLLM } from '../../contexts/LLMContext'
|
||||
import { useRAG } from '../../contexts/RAGContext'
|
||||
import { useSettings } from '../../contexts/SettingsContext'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
LLMBaseUrlNotSetException,
|
||||
LLMModelNotSetException,
|
||||
} from '../../core/llm/exception'
|
||||
import { useChatHistory } from '../../hooks/use-chat-history'
|
||||
import { ChatMessage, ChatUserMessage } from '../../types/chat'
|
||||
import {
|
||||
MentionableBlock,
|
||||
MentionableBlockData,
|
||||
MentionableCurrentFile,
|
||||
} from '../../types/mentionable'
|
||||
import { manualApplyChangesToFile } from '../../utils/apply'
|
||||
import {
|
||||
getMentionableKey,
|
||||
serializeMentionable,
|
||||
} from '../../utils/mentionable'
|
||||
import { readTFileContent } from '../../utils/obsidian'
|
||||
import { openSettingsModalWithError } from '../../utils/open-settings-modal'
|
||||
import { PromptGenerator } from '../../utils/prompt-generator'
|
||||
|
||||
import AssistantMessageActions from './AssistantMessageActions'
|
||||
import ChatUserInput, { ChatUserInputRef } from './chat-input/ChatUserInput'
|
||||
import { editorStateToPlainText } from './chat-input/utils/editor-state-to-plain-text'
|
||||
import { ChatListDropdown } from './ChatListDropdown'
|
||||
import QueryProgress, { QueryProgressState } from './QueryProgress'
|
||||
import ReactMarkdown from './ReactMarkdown'
|
||||
import ShortcutInfo from './ShortcutInfo'
|
||||
import SimilaritySearchResults from './SimilaritySearchResults'
|
||||
// Add an empty line here
|
||||
const getNewInputMessage = (app: App): ChatUserMessage => {
|
||||
return {
|
||||
role: 'user',
|
||||
content: null,
|
||||
promptContent: null,
|
||||
id: uuidv4(),
|
||||
mentionables: [
|
||||
{
|
||||
type: 'current-file',
|
||||
file: app.workspace.getActiveFile(),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
export type ChatRef = {
|
||||
openNewChat: (selectedBlock?: MentionableBlockData) => void
|
||||
addSelectionToChat: (selectedBlock: MentionableBlockData) => void
|
||||
focusMessage: () => void
|
||||
}
|
||||
|
||||
export type ChatProps = {
|
||||
selectedBlock?: MentionableBlockData
|
||||
}
|
||||
|
||||
const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
|
||||
const app = useApp()
|
||||
const { settings } = useSettings()
|
||||
const { getRAGEngine } = useRAG()
|
||||
|
||||
const {
|
||||
createOrUpdateConversation,
|
||||
deleteConversation,
|
||||
getChatMessagesById,
|
||||
updateConversationTitle,
|
||||
chatList,
|
||||
} = useChatHistory()
|
||||
const { streamResponse, chatModel } = useLLM()
|
||||
|
||||
const promptGenerator = useMemo(() => {
|
||||
return new PromptGenerator(getRAGEngine, app, settings)
|
||||
}, [getRAGEngine, app, settings])
|
||||
|
||||
const [inputMessage, setInputMessage] = useState<ChatUserMessage>(() => {
|
||||
const newMessage = getNewInputMessage(app)
|
||||
if (props.selectedBlock) {
|
||||
newMessage.mentionables = [
|
||||
...newMessage.mentionables,
|
||||
{
|
||||
type: 'block',
|
||||
...props.selectedBlock,
|
||||
},
|
||||
]
|
||||
}
|
||||
return newMessage
|
||||
})
|
||||
const [addedBlockKey, setAddedBlockKey] = useState<string | null>(
|
||||
props.selectedBlock
|
||||
? getMentionableKey(
|
||||
serializeMentionable({
|
||||
type: 'block',
|
||||
...props.selectedBlock,
|
||||
}),
|
||||
)
|
||||
: null,
|
||||
)
|
||||
const [chatMessages, setChatMessages] = useState<ChatMessage[]>([])
|
||||
const [focusedMessageId, setFocusedMessageId] = useState<string | null>(null)
|
||||
const [currentConversationId, setCurrentConversationId] =
|
||||
useState<string>(uuidv4())
|
||||
const [queryProgress, setQueryProgress] = useState<QueryProgressState>({
|
||||
type: 'idle',
|
||||
})
|
||||
|
||||
const preventAutoScrollRef = useRef(false)
|
||||
const lastProgrammaticScrollRef = useRef<number>(0)
|
||||
const activeStreamAbortControllersRef = useRef<AbortController[]>([])
|
||||
const chatUserInputRefs = useRef<Map<string, ChatUserInputRef>>(new Map())
|
||||
const chatMessagesRef = useRef<HTMLDivElement>(null)
|
||||
const registerChatUserInputRef = (
|
||||
id: string,
|
||||
ref: ChatUserInputRef | null,
|
||||
) => {
|
||||
if (ref) {
|
||||
chatUserInputRefs.current.set(id, ref)
|
||||
} else {
|
||||
chatUserInputRefs.current.delete(id)
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const scrollContainer = chatMessagesRef.current
|
||||
if (!scrollContainer) return
|
||||
|
||||
const handleScroll = () => {
|
||||
// If the scroll event happened very close to our programmatic scroll, ignore it
|
||||
if (Date.now() - lastProgrammaticScrollRef.current < 50) {
|
||||
return
|
||||
}
|
||||
|
||||
preventAutoScrollRef.current =
|
||||
scrollContainer.scrollHeight -
|
||||
scrollContainer.scrollTop -
|
||||
scrollContainer.clientHeight >
|
||||
20
|
||||
}
|
||||
|
||||
scrollContainer.addEventListener('scroll', handleScroll)
|
||||
return () => scrollContainer.removeEventListener('scroll', handleScroll)
|
||||
}, [chatMessages])
|
||||
|
||||
const handleScrollToBottom = () => {
|
||||
if (chatMessagesRef.current) {
|
||||
const scrollContainer = chatMessagesRef.current
|
||||
if (scrollContainer.scrollTop !== scrollContainer.scrollHeight) {
|
||||
lastProgrammaticScrollRef.current = Date.now()
|
||||
scrollContainer.scrollTop = scrollContainer.scrollHeight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const abortActiveStreams = () => {
|
||||
for (const abortController of activeStreamAbortControllersRef.current) {
|
||||
abortController.abort()
|
||||
}
|
||||
activeStreamAbortControllersRef.current = []
|
||||
}
|
||||
|
||||
const handleLoadConversation = async (conversationId: string) => {
|
||||
try {
|
||||
abortActiveStreams()
|
||||
const conversation = await getChatMessagesById(conversationId)
|
||||
if (!conversation) {
|
||||
throw new Error('Conversation not found')
|
||||
}
|
||||
setCurrentConversationId(conversationId)
|
||||
setChatMessages(conversation)
|
||||
const newInputMessage = getNewInputMessage(app)
|
||||
setInputMessage(newInputMessage)
|
||||
setFocusedMessageId(newInputMessage.id)
|
||||
setQueryProgress({
|
||||
type: 'idle',
|
||||
})
|
||||
} catch (error) {
|
||||
new Notice('Failed to load conversation')
|
||||
console.error('Failed to load conversation', error)
|
||||
}
|
||||
}
|
||||
|
||||
const handleNewChat = (selectedBlock?: MentionableBlockData) => {
|
||||
setCurrentConversationId(uuidv4())
|
||||
setChatMessages([])
|
||||
const newInputMessage = getNewInputMessage(app)
|
||||
if (selectedBlock) {
|
||||
const mentionableBlock: MentionableBlock = {
|
||||
type: 'block',
|
||||
...selectedBlock,
|
||||
}
|
||||
newInputMessage.mentionables = [
|
||||
...newInputMessage.mentionables,
|
||||
mentionableBlock,
|
||||
]
|
||||
setAddedBlockKey(
|
||||
getMentionableKey(serializeMentionable(mentionableBlock)),
|
||||
)
|
||||
}
|
||||
setInputMessage(newInputMessage)
|
||||
setFocusedMessageId(newInputMessage.id)
|
||||
setQueryProgress({
|
||||
type: 'idle',
|
||||
})
|
||||
abortActiveStreams()
|
||||
}
|
||||
|
||||
const submitMutation = useMutation({
|
||||
mutationFn: async ({
|
||||
newChatHistory,
|
||||
useVaultSearch,
|
||||
}: {
|
||||
newChatHistory: ChatMessage[]
|
||||
useVaultSearch?: boolean
|
||||
}) => {
|
||||
abortActiveStreams()
|
||||
setQueryProgress({
|
||||
type: 'idle',
|
||||
})
|
||||
|
||||
const responseMessageId = uuidv4()
|
||||
setChatMessages([
|
||||
...newChatHistory,
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
id: responseMessageId,
|
||||
metadata: {
|
||||
usage: undefined,
|
||||
model: undefined,
|
||||
},
|
||||
},
|
||||
])
|
||||
|
||||
try {
|
||||
const abortController = new AbortController()
|
||||
activeStreamAbortControllersRef.current.push(abortController)
|
||||
|
||||
const { requestMessages, compiledMessages } =
|
||||
await promptGenerator.generateRequestMessages({
|
||||
messages: newChatHistory,
|
||||
useVaultSearch,
|
||||
onQueryProgressChange: setQueryProgress,
|
||||
})
|
||||
setQueryProgress({
|
||||
type: 'idle',
|
||||
})
|
||||
|
||||
setChatMessages([
|
||||
...compiledMessages,
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
id: responseMessageId,
|
||||
metadata: {
|
||||
usage: undefined,
|
||||
model: undefined,
|
||||
},
|
||||
},
|
||||
])
|
||||
const stream = await streamResponse(
|
||||
chatModel,
|
||||
{
|
||||
model: chatModel.name,
|
||||
messages: requestMessages,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: abortController.signal,
|
||||
},
|
||||
)
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const content = chunk.choices[0]?.delta?.content ?? ''
|
||||
setChatMessages((prevChatHistory) =>
|
||||
prevChatHistory.map((message) =>
|
||||
message.role === 'assistant' && message.id === responseMessageId
|
||||
? {
|
||||
...message,
|
||||
content: message.content + content,
|
||||
metadata: {
|
||||
...message.metadata,
|
||||
usage: chunk.usage ?? message.metadata?.usage, // Keep existing usage if chunk has no usage data
|
||||
model: chatModel,
|
||||
},
|
||||
}
|
||||
: message,
|
||||
),
|
||||
)
|
||||
if (!preventAutoScrollRef.current) {
|
||||
handleScrollToBottom()
|
||||
}
|
||||
}
|
||||
// for debugging
|
||||
setChatMessages((prevChatHistory) => {
|
||||
const lastMessage = prevChatHistory[prevChatHistory.length - 1];
|
||||
console.log("Last complete message:", lastMessage?.content);
|
||||
return prevChatHistory;
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.name === 'AbortError') {
|
||||
return
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
},
|
||||
onError: (error) => {
|
||||
setQueryProgress({
|
||||
type: 'idle',
|
||||
})
|
||||
if (
|
||||
error instanceof LLMAPIKeyNotSetException ||
|
||||
error instanceof LLMAPIKeyInvalidException ||
|
||||
error instanceof LLMBaseUrlNotSetException ||
|
||||
error instanceof LLMModelNotSetException
|
||||
) {
|
||||
openSettingsModalWithError(app, error.message)
|
||||
} else {
|
||||
new Notice(error.message)
|
||||
console.error('Failed to generate response', error)
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
const handleSubmit = (
|
||||
newChatHistory: ChatMessage[],
|
||||
useVaultSearch?: boolean,
|
||||
) => {
|
||||
submitMutation.mutate({ newChatHistory, useVaultSearch })
|
||||
}
|
||||
|
||||
const applyMutation = useMutation({
|
||||
mutationFn: async ({
|
||||
blockInfo,
|
||||
}: {
|
||||
blockInfo: {
|
||||
content: string
|
||||
filename?: string
|
||||
startLine?: number
|
||||
endLine?: number
|
||||
}
|
||||
}) => {
|
||||
const activeFile = app.workspace.getActiveFile()
|
||||
if (!activeFile) {
|
||||
throw new Error(
|
||||
'No file is currently open to apply changes. Please open a file and try again.',
|
||||
)
|
||||
}
|
||||
const activeFileContent = await readTFileContent(activeFile, app.vault)
|
||||
|
||||
const updatedFileContent = await manualApplyChangesToFile(
|
||||
blockInfo.content,
|
||||
activeFile,
|
||||
activeFileContent,
|
||||
blockInfo.startLine,
|
||||
blockInfo.endLine
|
||||
)
|
||||
if (!updatedFileContent) {
|
||||
throw new Error('Failed to apply changes')
|
||||
}
|
||||
|
||||
await app.workspace.getLeaf(true).setViewState({
|
||||
type: APPLY_VIEW_TYPE,
|
||||
active: true,
|
||||
state: {
|
||||
file: activeFile,
|
||||
originalContent: activeFileContent,
|
||||
newContent: updatedFileContent,
|
||||
} satisfies ApplyViewState,
|
||||
})
|
||||
},
|
||||
onError: (error) => {
|
||||
if (
|
||||
error instanceof LLMAPIKeyNotSetException ||
|
||||
error instanceof LLMAPIKeyInvalidException ||
|
||||
error instanceof LLMBaseUrlNotSetException ||
|
||||
error instanceof LLMModelNotSetException
|
||||
) {
|
||||
openSettingsModalWithError(app, error.message)
|
||||
} else {
|
||||
new Notice(error.message)
|
||||
console.error('Failed to apply changes', error)
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
const handleApply = useCallback(
|
||||
(blockInfo: {
|
||||
content: string
|
||||
filename?: string
|
||||
startLine?: number
|
||||
endLine?: number
|
||||
}) => {
|
||||
applyMutation.mutate({ blockInfo })
|
||||
},
|
||||
[applyMutation],
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
setFocusedMessageId(inputMessage.id)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
//
|
||||
useEffect(() => {
|
||||
const updateConversationAsync = async () => {
|
||||
try {
|
||||
if (chatMessages.length > 0) {
|
||||
createOrUpdateConversation(currentConversationId, chatMessages)
|
||||
}
|
||||
} catch (error) {
|
||||
new Notice('Failed to save chat history')
|
||||
console.error('Failed to save chat history', error)
|
||||
}
|
||||
}
|
||||
updateConversationAsync()
|
||||
}, [currentConversationId, chatMessages, createOrUpdateConversation])
|
||||
|
||||
// Updates the currentFile of the focused message (input or chat history)
|
||||
// This happens when active file changes or focused message changes
|
||||
const handleActiveLeafChange = useCallback(() => {
|
||||
const activeFile = app.workspace.getActiveFile()
|
||||
if (!activeFile) return
|
||||
|
||||
const mentionable: Omit<MentionableCurrentFile, 'id'> = {
|
||||
type: 'current-file',
|
||||
file: activeFile,
|
||||
}
|
||||
|
||||
if (!focusedMessageId) return
|
||||
if (inputMessage.id === focusedMessageId) {
|
||||
setInputMessage((prevInputMessage) => ({
|
||||
...prevInputMessage,
|
||||
mentionables: [
|
||||
mentionable,
|
||||
...prevInputMessage.mentionables.filter(
|
||||
(mentionable) => mentionable.type !== 'current-file',
|
||||
),
|
||||
],
|
||||
}))
|
||||
} else {
|
||||
setChatMessages((prevChatHistory) =>
|
||||
prevChatHistory.map((message) =>
|
||||
message.id === focusedMessageId && message.role === 'user'
|
||||
? {
|
||||
...message,
|
||||
mentionables: [
|
||||
mentionable,
|
||||
...message.mentionables.filter(
|
||||
(mentionable) => mentionable.type !== 'current-file',
|
||||
),
|
||||
],
|
||||
}
|
||||
: message,
|
||||
),
|
||||
)
|
||||
}
|
||||
}, [app.workspace, focusedMessageId, inputMessage.id])
|
||||
|
||||
useEffect(() => {
|
||||
app.workspace.on('active-leaf-change', handleActiveLeafChange)
|
||||
return () => {
|
||||
app.workspace.off('active-leaf-change', handleActiveLeafChange)
|
||||
}
|
||||
}, [app.workspace, handleActiveLeafChange])
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
openNewChat: (selectedBlock?: MentionableBlockData) =>
|
||||
handleNewChat(selectedBlock),
|
||||
addSelectionToChat: (selectedBlock: MentionableBlockData) => {
|
||||
const mentionable: Omit<MentionableBlock, 'id'> = {
|
||||
type: 'block',
|
||||
...selectedBlock,
|
||||
}
|
||||
|
||||
setAddedBlockKey(getMentionableKey(serializeMentionable(mentionable)))
|
||||
|
||||
if (focusedMessageId === inputMessage.id) {
|
||||
setInputMessage((prevInputMessage) => {
|
||||
const mentionableKey = getMentionableKey(
|
||||
serializeMentionable(mentionable),
|
||||
)
|
||||
// Check if mentionable already exists
|
||||
if (
|
||||
prevInputMessage.mentionables.some(
|
||||
(m) =>
|
||||
getMentionableKey(serializeMentionable(m)) === mentionableKey,
|
||||
)
|
||||
) {
|
||||
return prevInputMessage
|
||||
}
|
||||
return {
|
||||
...prevInputMessage,
|
||||
mentionables: [...prevInputMessage.mentionables, mentionable],
|
||||
}
|
||||
})
|
||||
} else {
|
||||
setChatMessages((prevChatHistory) =>
|
||||
prevChatHistory.map((message) => {
|
||||
if (message.id === focusedMessageId && message.role === 'user') {
|
||||
const mentionableKey = getMentionableKey(
|
||||
serializeMentionable(mentionable),
|
||||
)
|
||||
// Check if mentionable already exists
|
||||
if (
|
||||
message.mentionables.some(
|
||||
(m) =>
|
||||
getMentionableKey(serializeMentionable(m)) ===
|
||||
mentionableKey,
|
||||
)
|
||||
) {
|
||||
return message
|
||||
}
|
||||
return {
|
||||
...message,
|
||||
mentionables: [...message.mentionables, mentionable],
|
||||
}
|
||||
}
|
||||
return message
|
||||
}),
|
||||
)
|
||||
}
|
||||
},
|
||||
focusMessage: () => {
|
||||
if (!focusedMessageId) return
|
||||
chatUserInputRefs.current.get(focusedMessageId)?.focus()
|
||||
},
|
||||
}))
|
||||
|
||||
return (
|
||||
<div className="infio-chat-container">
|
||||
<div className="infio-chat-header">
|
||||
<h1 className="infio-chat-header-title"> CHAT </h1>
|
||||
<div className="infio-chat-header-buttons">
|
||||
<button
|
||||
onClick={() => handleNewChat()}
|
||||
className="infio-chat-list-dropdown"
|
||||
>
|
||||
<Plus size={18} />
|
||||
</button>
|
||||
<ChatListDropdown
|
||||
chatList={chatList}
|
||||
currentConversationId={currentConversationId}
|
||||
onSelect={async (conversationId) => {
|
||||
if (conversationId === currentConversationId) return
|
||||
await handleLoadConversation(conversationId)
|
||||
}}
|
||||
onDelete={async (conversationId) => {
|
||||
await deleteConversation(conversationId)
|
||||
if (conversationId === currentConversationId) {
|
||||
const nextConversation = chatList.find(
|
||||
(chat) => chat.id !== conversationId,
|
||||
)
|
||||
if (nextConversation) {
|
||||
void handleLoadConversation(nextConversation.id)
|
||||
} else {
|
||||
handleNewChat()
|
||||
}
|
||||
}
|
||||
}}
|
||||
onUpdateTitle={async (conversationId, newTitle) => {
|
||||
await updateConversationTitle(conversationId, newTitle)
|
||||
}}
|
||||
className="infio-chat-list-dropdown"
|
||||
>
|
||||
<History size={18} />
|
||||
</ChatListDropdown>
|
||||
</div>
|
||||
</div>
|
||||
<div className="infio-chat-messages" ref={chatMessagesRef}>
|
||||
{
|
||||
// If the chat is empty, show a message to start a new chat
|
||||
chatMessages.length === 0 && (
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
height: '100%',
|
||||
width: '100%'
|
||||
}}>
|
||||
<ShortcutInfo />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{chatMessages.map((message, index) =>
|
||||
message.role === 'user' ? (
|
||||
<div key={message.id} className="infio-chat-messages-user">
|
||||
<ChatUserInput
|
||||
ref={(ref) => registerChatUserInputRef(message.id, ref)}
|
||||
initialSerializedEditorState={message.content}
|
||||
onChange={(content) => {
|
||||
setChatMessages((prevChatHistory) =>
|
||||
prevChatHistory.map((msg) =>
|
||||
msg.role === 'user' && msg.id === message.id
|
||||
? {
|
||||
...msg,
|
||||
content,
|
||||
}
|
||||
: msg,
|
||||
),
|
||||
)
|
||||
}}
|
||||
onSubmit={(content, useVaultSearch) => {
|
||||
if (editorStateToPlainText(content).trim() === '') return
|
||||
handleSubmit(
|
||||
[
|
||||
...chatMessages.slice(0, index),
|
||||
{
|
||||
role: 'user',
|
||||
content: content,
|
||||
promptContent: null,
|
||||
id: message.id,
|
||||
mentionables: message.mentionables,
|
||||
},
|
||||
],
|
||||
useVaultSearch,
|
||||
)
|
||||
chatUserInputRefs.current.get(inputMessage.id)?.focus()
|
||||
}}
|
||||
onFocus={() => {
|
||||
setFocusedMessageId(message.id)
|
||||
}}
|
||||
mentionables={message.mentionables}
|
||||
setMentionables={(mentionables) => {
|
||||
setChatMessages((prevChatHistory) =>
|
||||
prevChatHistory.map((msg) =>
|
||||
msg.id === message.id ? { ...msg, mentionables } : msg,
|
||||
),
|
||||
)
|
||||
}}
|
||||
/>
|
||||
{message.similaritySearchResults && (
|
||||
<SimilaritySearchResults
|
||||
similaritySearchResults={message.similaritySearchResults}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div key={message.id} className="infio-chat-messages-assistant">
|
||||
<ReactMarkdownItem
|
||||
handleApply={handleApply}
|
||||
isApplying={applyMutation.isPending}
|
||||
>
|
||||
{message.content}
|
||||
</ReactMarkdownItem>
|
||||
{message.content && <AssistantMessageActions message={message} />}
|
||||
</div>
|
||||
),
|
||||
)}
|
||||
<QueryProgress state={queryProgress} />
|
||||
{submitMutation.isPending && (
|
||||
<button onClick={abortActiveStreams} className="infio-stop-gen-btn">
|
||||
<CircleStop size={16} />
|
||||
<div>Stop Generation</div>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<ChatUserInput
|
||||
key={inputMessage.id} // this is needed to clear the editor when the user submits a new message
|
||||
ref={(ref) => registerChatUserInputRef(inputMessage.id, ref)}
|
||||
initialSerializedEditorState={inputMessage.content}
|
||||
onChange={(content) => {
|
||||
setInputMessage((prevInputMessage) => ({
|
||||
...prevInputMessage,
|
||||
content,
|
||||
}))
|
||||
}}
|
||||
onSubmit={(content, useVaultSearch) => {
|
||||
if (editorStateToPlainText(content).trim() === '') return
|
||||
handleSubmit(
|
||||
[...chatMessages, { ...inputMessage, content }],
|
||||
useVaultSearch,
|
||||
)
|
||||
setInputMessage(getNewInputMessage(app))
|
||||
preventAutoScrollRef.current = false
|
||||
handleScrollToBottom()
|
||||
}}
|
||||
onFocus={() => {
|
||||
setFocusedMessageId(inputMessage.id)
|
||||
}}
|
||||
mentionables={inputMessage.mentionables}
|
||||
setMentionables={(mentionables) => {
|
||||
setInputMessage((prevInputMessage) => ({
|
||||
...prevInputMessage,
|
||||
mentionables,
|
||||
}))
|
||||
}}
|
||||
autoFocus
|
||||
addedBlockKey={addedBlockKey}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
function ReactMarkdownItem({
|
||||
handleApply,
|
||||
isApplying,
|
||||
children,
|
||||
}: {
|
||||
handleApply: (blockInfo: {
|
||||
content: string
|
||||
filename?: string
|
||||
startLine?: number
|
||||
endLine?: number
|
||||
}) => void
|
||||
isApplying: boolean
|
||||
children: string
|
||||
}) {
|
||||
return (
|
||||
<ReactMarkdown onApply={handleApply} isApplying={isApplying}>
|
||||
{children}
|
||||
</ReactMarkdown>
|
||||
)
|
||||
}
|
||||
|
||||
Chat.displayName = 'Chat'
|
||||
|
||||
export default Chat
|
||||
202
src/components/chat-view/ChatListDropdown.tsx
Normal file
202
src/components/chat-view/ChatListDropdown.tsx
Normal file
@@ -0,0 +1,202 @@
|
||||
import * as Popover from '@radix-ui/react-popover'
|
||||
import { Pencil, Trash2 } from 'lucide-react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
|
||||
import { ChatConversationMeta } from '../../types/chat'
|
||||
|
||||
function TitleInput({
|
||||
title,
|
||||
onSubmit,
|
||||
}: {
|
||||
title: string
|
||||
onSubmit: (title: string) => Promise<void>
|
||||
}) {
|
||||
const [value, setValue] = useState(title)
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (inputRef.current) {
|
||||
inputRef.current.select()
|
||||
inputRef.current.scrollLeft = 0
|
||||
}
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="text"
|
||||
value={value}
|
||||
className="infio-chat-list-dropdown-item-title-input"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
e.stopPropagation()
|
||||
if (e.key === 'Enter') {
|
||||
onSubmit(value)
|
||||
}
|
||||
}}
|
||||
autoFocus
|
||||
maxLength={100}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function ChatListItem({
|
||||
title,
|
||||
isFocused,
|
||||
isEditing,
|
||||
onMouseEnter,
|
||||
onSelect,
|
||||
onDelete,
|
||||
onStartEdit,
|
||||
onFinishEdit,
|
||||
}: {
|
||||
title: string
|
||||
isFocused: boolean
|
||||
isEditing: boolean
|
||||
onMouseEnter: () => void
|
||||
onSelect: () => Promise<void>
|
||||
onDelete: () => Promise<void>
|
||||
onStartEdit: () => void
|
||||
onFinishEdit: (title: string) => Promise<void>
|
||||
}) {
|
||||
const itemRef = useRef<HTMLLIElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (isFocused && itemRef.current) {
|
||||
itemRef.current.scrollIntoView({
|
||||
block: 'nearest',
|
||||
})
|
||||
}
|
||||
}, [isFocused])
|
||||
|
||||
return (
|
||||
<li
|
||||
ref={itemRef}
|
||||
onClick={onSelect}
|
||||
onMouseEnter={onMouseEnter}
|
||||
className={isFocused ? 'selected' : ''}
|
||||
>
|
||||
{isEditing ? (
|
||||
<TitleInput title={title} onSubmit={onFinishEdit} />
|
||||
) : (
|
||||
<div className="infio-chat-list-dropdown-item-title">{title}</div>
|
||||
)}
|
||||
<div className="infio-chat-list-dropdown-item-actions">
|
||||
<div
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
onStartEdit()
|
||||
}}
|
||||
className="infio-chat-list-dropdown-item-icon"
|
||||
>
|
||||
<Pencil size={14} />
|
||||
</div>
|
||||
<div
|
||||
onClick={async (e) => {
|
||||
e.stopPropagation()
|
||||
await onDelete()
|
||||
}}
|
||||
className="infio-chat-list-dropdown-item-icon"
|
||||
>
|
||||
<Trash2 size={14} />
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
)
|
||||
}
|
||||
|
||||
export function ChatListDropdown({
|
||||
chatList,
|
||||
currentConversationId,
|
||||
onSelect,
|
||||
onDelete,
|
||||
onUpdateTitle,
|
||||
className,
|
||||
children,
|
||||
}: {
|
||||
chatList: ChatConversationMeta[]
|
||||
currentConversationId: string
|
||||
onSelect: (conversationId: string) => Promise<void>
|
||||
onDelete: (conversationId: string) => Promise<void>
|
||||
onUpdateTitle: (conversationId: string, newTitle: string) => Promise<void>
|
||||
className?: string
|
||||
children: React.ReactNode
|
||||
}) {
|
||||
const [open, setOpen] = useState(false)
|
||||
const [focusedIndex, setFocusedIndex] = useState<number>(0)
|
||||
const [editingId, setEditingId] = useState<string | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
const currentIndex = chatList.findIndex(
|
||||
(chat) => chat.id === currentConversationId,
|
||||
)
|
||||
setFocusedIndex(currentIndex === -1 ? 0 : currentIndex)
|
||||
setEditingId(null)
|
||||
}
|
||||
}, [open])
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: React.KeyboardEvent) => {
|
||||
if (e.key === 'ArrowUp') {
|
||||
setFocusedIndex(Math.max(0, focusedIndex - 1))
|
||||
} else if (e.key === 'ArrowDown') {
|
||||
setFocusedIndex(Math.min(chatList.length - 1, focusedIndex + 1))
|
||||
} else if (e.key === 'Enter') {
|
||||
onSelect(chatList[focusedIndex].id)
|
||||
setOpen(false)
|
||||
}
|
||||
},
|
||||
[chatList, focusedIndex, setFocusedIndex, onSelect],
|
||||
)
|
||||
|
||||
return (
|
||||
<Popover.Root open={open} onOpenChange={setOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<button className={className}>{children}</button>
|
||||
</Popover.Trigger>
|
||||
|
||||
<Popover.Portal>
|
||||
<Popover.Content
|
||||
className="infio-popover infio-chat-list-dropdown-content"
|
||||
onKeyDown={handleKeyDown}
|
||||
>
|
||||
<ul>
|
||||
{chatList.length === 0 ? (
|
||||
<li className="infio-chat-list-dropdown-empty">
|
||||
No conversations
|
||||
</li>
|
||||
) : (
|
||||
chatList.map((chat, index) => (
|
||||
<ChatListItem
|
||||
key={chat.id}
|
||||
title={chat.title}
|
||||
isFocused={focusedIndex === index}
|
||||
isEditing={editingId === chat.id}
|
||||
onMouseEnter={() => {
|
||||
setFocusedIndex(index)
|
||||
}}
|
||||
onSelect={async () => {
|
||||
await onSelect(chat.id)
|
||||
setOpen(false)
|
||||
}}
|
||||
onDelete={async () => {
|
||||
await onDelete(chat.id)
|
||||
}}
|
||||
onStartEdit={() => {
|
||||
setEditingId(chat.id)
|
||||
}}
|
||||
onFinishEdit={async (title) => {
|
||||
await onUpdateTitle(chat.id, title)
|
||||
setEditingId(null)
|
||||
}}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</ul>
|
||||
</Popover.Content>
|
||||
</Popover.Portal>
|
||||
</Popover.Root>
|
||||
)
|
||||
}
|
||||
127
src/components/chat-view/CreateTemplateDialog.tsx
Normal file
127
src/components/chat-view/CreateTemplateDialog.tsx
Normal file
@@ -0,0 +1,127 @@
|
||||
import { $generateNodesFromSerializedNodes } from '@lexical/clipboard'
|
||||
import { BaseSerializedNode } from '@lexical/clipboard/clipboard'
|
||||
import { InitialEditorStateType } from '@lexical/react/LexicalComposer'
|
||||
import * as Dialog from '@radix-ui/react-dialog'
|
||||
import { $insertNodes, LexicalEditor } from 'lexical'
|
||||
import { X } from 'lucide-react'
|
||||
import { Notice } from 'obsidian'
|
||||
import { useRef, useState } from 'react'
|
||||
|
||||
import { useDatabase } from '../../contexts/DatabaseContext'
|
||||
import { useDialogContainer } from '../../contexts/DialogContext'
|
||||
import { DuplicateTemplateException } from '../../database/exception'
|
||||
|
||||
import LexicalContentEditable from './chat-input/LexicalContentEditable'
|
||||
|
||||
/*
|
||||
* This component must be used inside <Dialog.Root modal={false}>
|
||||
* The modal={false} prop is required because modal mode blocks pointer events across the entire page,
|
||||
* which would conflict with lexical editor popovers
|
||||
*/
|
||||
export default function CreateTemplateDialogContent({
|
||||
selectedSerializedNodes,
|
||||
onClose,
|
||||
}: {
|
||||
selectedSerializedNodes?: BaseSerializedNode[] | null
|
||||
onClose: () => void
|
||||
}) {
|
||||
const container = useDialogContainer()
|
||||
const { getTemplateManager } = useDatabase()
|
||||
|
||||
const [templateName, setTemplateName] = useState('')
|
||||
const editorRef = useRef<LexicalEditor | null>(null)
|
||||
const contentEditableRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const initialEditorState: InitialEditorStateType = (
|
||||
editor: LexicalEditor,
|
||||
) => {
|
||||
if (!selectedSerializedNodes) return
|
||||
editor.update(() => {
|
||||
const parsedNodes = $generateNodesFromSerializedNodes(
|
||||
selectedSerializedNodes,
|
||||
)
|
||||
$insertNodes(parsedNodes)
|
||||
})
|
||||
}
|
||||
|
||||
const onSubmit = async () => {
|
||||
try {
|
||||
if (!editorRef.current) return
|
||||
const serializedEditorState = editorRef.current.toJSON()
|
||||
const nodes = serializedEditorState.editorState.root.children
|
||||
if (nodes.length === 0) {
|
||||
new Notice('Please enter a content for your template')
|
||||
return
|
||||
}
|
||||
if (templateName.trim().length === 0) {
|
||||
new Notice('Please enter a name for your template')
|
||||
return
|
||||
}
|
||||
|
||||
await (
|
||||
await getTemplateManager()
|
||||
).createTemplate({
|
||||
name: templateName,
|
||||
content: { nodes },
|
||||
})
|
||||
new Notice(`Template created: ${templateName}`)
|
||||
setTemplateName('')
|
||||
onClose()
|
||||
} catch (error) {
|
||||
if (error instanceof DuplicateTemplateException) {
|
||||
new Notice('A template with this name already exists')
|
||||
} else {
|
||||
console.error(error)
|
||||
new Notice('Failed to create template')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog.Portal container={container}>
|
||||
<Dialog.Content className="infio-chat-dialog-content">
|
||||
<div className="infio-dialog-header">
|
||||
<Dialog.Title className="infio-dialog-title">
|
||||
Create template
|
||||
</Dialog.Title>
|
||||
<Dialog.Description className="infio-dialog-description">
|
||||
Create template from selected content
|
||||
</Dialog.Description>
|
||||
</div>
|
||||
|
||||
<div className="infio-dialog-input">
|
||||
<label>Name</label>
|
||||
<input
|
||||
type="text"
|
||||
value={templateName}
|
||||
onChange={(e) => setTemplateName(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.stopPropagation()
|
||||
e.preventDefault()
|
||||
onSubmit()
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="infio-chat-user-input-container">
|
||||
<LexicalContentEditable
|
||||
initialEditorState={initialEditorState}
|
||||
editorRef={editorRef}
|
||||
contentEditableRef={contentEditableRef}
|
||||
onEnter={onSubmit}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="infio-dialog-footer">
|
||||
<button onClick={onSubmit}>Create template</button>
|
||||
</div>
|
||||
|
||||
<Dialog.Close className="infio-dialog-close" asChild>
|
||||
<X size={16} />
|
||||
</Dialog.Close>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
)
|
||||
}
|
||||
84
src/components/chat-view/LLMResponseInfoPopover.tsx
Normal file
84
src/components/chat-view/LLMResponseInfoPopover.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
import * as Popover from '@radix-ui/react-popover'
|
||||
import {
|
||||
ArrowDown,
|
||||
ArrowRightLeft,
|
||||
ArrowUp,
|
||||
Coins,
|
||||
Cpu,
|
||||
Info,
|
||||
} from 'lucide-react'
|
||||
|
||||
import { ResponseUsage } from '../../types/llm/response'
|
||||
|
||||
type LLMResponseInfoProps = {
|
||||
usage?: ResponseUsage
|
||||
estimatedPrice: number | null
|
||||
model?: string
|
||||
}
|
||||
|
||||
export default function LLMResponseInfoPopover({
|
||||
usage,
|
||||
estimatedPrice,
|
||||
model,
|
||||
}: LLMResponseInfoProps) {
|
||||
return (
|
||||
<Popover.Root>
|
||||
<Popover.Trigger asChild>
|
||||
<button>
|
||||
<Info className="infio-llm-info-icon--trigger" size={12} />
|
||||
</button>
|
||||
</Popover.Trigger>
|
||||
{usage ? (
|
||||
<Popover.Content className="infio-chat-popover-content infio-llm-info-content">
|
||||
<div className="infio-llm-info-header">LLM Response Information</div>
|
||||
<div className="infio-llm-info-tokens">
|
||||
<div className="infio-llm-info-tokens-header">Token Count</div>
|
||||
<div className="infio-llm-info-tokens-grid">
|
||||
<div className="infio-llm-info-token-row">
|
||||
<ArrowUp className="infio-llm-info-icon--input" />
|
||||
<span>Input:</span>
|
||||
<span className="infio-llm-info-token-value">
|
||||
{usage.prompt_tokens}
|
||||
</span>
|
||||
</div>
|
||||
<div className="infio-llm-info-token-row">
|
||||
<ArrowDown className="infio-llm-info-icon--output" />
|
||||
<span>Output:</span>
|
||||
<span className="infio-llm-info-token-value">
|
||||
{usage.completion_tokens}
|
||||
</span>
|
||||
</div>
|
||||
<div className="infio-llm-info-token-row infio-llm-info-token-total">
|
||||
<ArrowRightLeft className="infio-llm-info-icon--total" />
|
||||
<span>Total:</span>
|
||||
<span className="infio-llm-info-token-value">
|
||||
{usage.total_tokens}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="infio-llm-info-footer-row">
|
||||
<Coins className="infio-llm-info-icon--footer" />
|
||||
<span>Estimated Price:</span>
|
||||
<span className="infio-llm-info-footer-value">
|
||||
{estimatedPrice === null
|
||||
? 'Not available'
|
||||
: `$${estimatedPrice.toFixed(4)}`}
|
||||
</span>
|
||||
</div>
|
||||
<div className="infio-llm-info-footer-row">
|
||||
<Cpu className="infio-llm-info-icon--footer" />
|
||||
<span>Model:</span>
|
||||
<span className="infio-llm-info-footer-value infio-llm-info-model">
|
||||
{model ?? 'Not available'}
|
||||
</span>
|
||||
</div>
|
||||
</Popover.Content>
|
||||
) : (
|
||||
<Popover.Content className="infio-chat-popover-content">
|
||||
<div>Usage statistics are not available for this model</div>
|
||||
</Popover.Content>
|
||||
)}
|
||||
</Popover.Root>
|
||||
)
|
||||
}
|
||||
122
src/components/chat-view/MarkdownCodeComponent.tsx
Normal file
122
src/components/chat-view/MarkdownCodeComponent.tsx
Normal file
@@ -0,0 +1,122 @@
|
||||
import { Check, CopyIcon, Loader2 } from 'lucide-react'
|
||||
import { PropsWithChildren, useMemo, useState } from 'react'
|
||||
|
||||
import { useDarkModeContext } from '../../contexts/DarkModeContext'
|
||||
|
||||
import { MemoizedSyntaxHighlighterWrapper } from './SyntaxHighlighterWrapper'
|
||||
|
||||
export default function MarkdownCodeComponent({
|
||||
onApply,
|
||||
isApplying,
|
||||
language,
|
||||
filename,
|
||||
startLine,
|
||||
endLine,
|
||||
action,
|
||||
children,
|
||||
}: PropsWithChildren<{
|
||||
onApply: (blockInfo: {
|
||||
content: string
|
||||
filename?: string
|
||||
startLine?: number
|
||||
endLine?: number
|
||||
}) => void
|
||||
isApplying: boolean
|
||||
language?: string
|
||||
filename?: string
|
||||
startLine?: number
|
||||
endLine?: number
|
||||
action?: 'edit' | 'new' | 'reference'
|
||||
}>) {
|
||||
const [copied, setCopied] = useState(false)
|
||||
const { isDarkMode } = useDarkModeContext()
|
||||
|
||||
const wrapLines = useMemo(() => {
|
||||
return !language || ['markdown'].includes(language)
|
||||
}, [language])
|
||||
|
||||
const handleCopy = async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(String(children))
|
||||
setCopied(true)
|
||||
setTimeout(() => setCopied(false), 2000)
|
||||
} catch (err) {
|
||||
console.error('Failed to copy text: ', err)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={`infio-chat-code-block ${filename ? 'has-filename' : ''} ${action ? `type-${action}` : ''}`}>
|
||||
<div className={'infio-chat-code-block-header'}>
|
||||
{filename && (
|
||||
<div className={'infio-chat-code-block-header-filename'}>{filename}</div>
|
||||
)}
|
||||
<div className={'infio-chat-code-block-header-button'}>
|
||||
<button
|
||||
onClick={() => {
|
||||
handleCopy()
|
||||
}}
|
||||
>
|
||||
{copied ? (
|
||||
<>
|
||||
<Check size={10} /> Copied
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<CopyIcon size={10} /> Copy
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
{action === 'edit' && (
|
||||
<button
|
||||
onClick={() => {
|
||||
onApply({
|
||||
content: String(children),
|
||||
filename,
|
||||
startLine,
|
||||
endLine
|
||||
})
|
||||
}}
|
||||
disabled={isApplying}
|
||||
>
|
||||
{isApplying ? (
|
||||
<>
|
||||
<Loader2 className="spinner" size={14} /> Applying...
|
||||
</>
|
||||
) : (
|
||||
'Apply'
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
{action === 'new' && (
|
||||
<button
|
||||
onClick={() => {
|
||||
onApply({
|
||||
content: String(children),
|
||||
filename
|
||||
})
|
||||
}}
|
||||
disabled={isApplying}
|
||||
>
|
||||
{isApplying ? (
|
||||
<>
|
||||
<Loader2 className="spinner" size={14} /> Inserting...
|
||||
</>
|
||||
) : (
|
||||
'Insert'
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<MemoizedSyntaxHighlighterWrapper
|
||||
isDarkMode={isDarkMode}
|
||||
language={language}
|
||||
hasFilename={!!filename}
|
||||
wrapLines={wrapLines}
|
||||
>
|
||||
{String(children)}
|
||||
</MemoizedSyntaxHighlighterWrapper>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
75
src/components/chat-view/MarkdownReferenceBlock.tsx
Normal file
75
src/components/chat-view/MarkdownReferenceBlock.tsx
Normal file
@@ -0,0 +1,75 @@
|
||||
import { PropsWithChildren, useEffect, useMemo, useState } from 'react'
|
||||
|
||||
import { useApp } from '../../contexts/AppContext'
|
||||
import { useDarkModeContext } from '../../contexts/DarkModeContext'
|
||||
import { openMarkdownFile, readTFileContent } from '../../utils/obsidian'
|
||||
|
||||
import { MemoizedSyntaxHighlighterWrapper } from './SyntaxHighlighterWrapper'
|
||||
|
||||
export default function MarkdownReferenceBlock({
|
||||
filename,
|
||||
startLine,
|
||||
endLine,
|
||||
language,
|
||||
}: PropsWithChildren<{
|
||||
filename: string
|
||||
startLine: number
|
||||
endLine: number
|
||||
language?: string
|
||||
}>) {
|
||||
const app = useApp()
|
||||
const { isDarkMode } = useDarkModeContext()
|
||||
const [blockContent, setBlockContent] = useState<string | null>(null)
|
||||
|
||||
const wrapLines = useMemo(() => {
|
||||
return !language || ['markdown'].includes(language)
|
||||
}, [language])
|
||||
|
||||
useEffect(() => {
|
||||
async function fetchBlockContent() {
|
||||
const file = app.vault.getFileByPath(filename)
|
||||
if (!file) {
|
||||
setBlockContent(null)
|
||||
return
|
||||
}
|
||||
const fileContent = await readTFileContent(file, app.vault)
|
||||
const content = fileContent
|
||||
.split('\n')
|
||||
.slice(startLine - 1, endLine)
|
||||
.join('\n')
|
||||
setBlockContent(content)
|
||||
}
|
||||
|
||||
fetchBlockContent()
|
||||
}, [filename, startLine, endLine, app.vault])
|
||||
|
||||
const handleClick = () => {
|
||||
openMarkdownFile(app, filename, startLine)
|
||||
}
|
||||
|
||||
// TODO: Update styles
|
||||
return (
|
||||
blockContent && (
|
||||
<div
|
||||
className={`infio-chat-code-block ${filename ? 'has-filename' : ''}`}
|
||||
onClick={handleClick}
|
||||
>
|
||||
<div className={'infio-chat-code-block-header'}>
|
||||
{filename && (
|
||||
<div className={'infio-chat-code-block-header-filename'}>
|
||||
{filename}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<MemoizedSyntaxHighlighterWrapper
|
||||
isDarkMode={isDarkMode}
|
||||
language={language}
|
||||
hasFilename={!!filename}
|
||||
wrapLines={wrapLines}
|
||||
>
|
||||
{blockContent}
|
||||
</MemoizedSyntaxHighlighterWrapper>
|
||||
</div>
|
||||
)
|
||||
)
|
||||
}
|
||||
85
src/components/chat-view/QueryProgress.tsx
Normal file
85
src/components/chat-view/QueryProgress.tsx
Normal file
@@ -0,0 +1,85 @@
|
||||
import { SelectVector } from '../../database/schema'
|
||||
|
||||
export type QueryProgressState =
|
||||
| {
|
||||
type: 'reading-mentionables'
|
||||
}
|
||||
| {
|
||||
type: 'indexing'
|
||||
indexProgress: IndexProgress
|
||||
}
|
||||
| {
|
||||
type: 'querying'
|
||||
}
|
||||
| {
|
||||
type: 'querying-done'
|
||||
queryResult: (Omit<SelectVector, 'embedding'> & { similarity: number })[]
|
||||
}
|
||||
| {
|
||||
type: 'idle'
|
||||
}
|
||||
|
||||
export type IndexProgress = {
|
||||
completedChunks: number
|
||||
totalChunks: number
|
||||
totalFiles: number
|
||||
}
|
||||
|
||||
// TODO: Update style
|
||||
export default function QueryProgress({
|
||||
state,
|
||||
}: {
|
||||
state: QueryProgressState
|
||||
}) {
|
||||
switch (state.type) {
|
||||
case 'idle':
|
||||
return null
|
||||
case 'reading-mentionables':
|
||||
return (
|
||||
<div className="infio-query-progress">
|
||||
<p>
|
||||
Reading mentioned files
|
||||
<DotLoader />
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
case 'indexing':
|
||||
return (
|
||||
<div className="infio-query-progress">
|
||||
<p>
|
||||
{`Indexing ${state.indexProgress.totalFiles} file`}
|
||||
<DotLoader />
|
||||
</p>
|
||||
<p className="infio-query-progress-detail">{`${state.indexProgress.completedChunks}/${state.indexProgress.totalChunks} chunks indexed`}</p>
|
||||
</div>
|
||||
)
|
||||
case 'querying':
|
||||
return (
|
||||
<div className="infio-query-progress">
|
||||
<p>
|
||||
Querying the vault
|
||||
<DotLoader />
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
case 'querying-done':
|
||||
return (
|
||||
<div className="infio-query-progress">
|
||||
<p>
|
||||
Reading related files
|
||||
<DotLoader />
|
||||
</p>
|
||||
{state.queryResult.map((result) => (
|
||||
<div key={result.path}>
|
||||
<p>{result.path}</p>
|
||||
<p>{result.similarity}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
function DotLoader() {
|
||||
return <span className="infio-dot-loader" aria-label="Loading"></span>
|
||||
}
|
||||
64
src/components/chat-view/ReactMarkdown.tsx
Normal file
64
src/components/chat-view/ReactMarkdown.tsx
Normal file
@@ -0,0 +1,64 @@
|
||||
import React, { useMemo } from 'react'
|
||||
import Markdown from 'react-markdown'
|
||||
|
||||
import {
|
||||
ParsedinfioBlock,
|
||||
parseinfioBlocks,
|
||||
} from '../../utils/parse-infio-block'
|
||||
|
||||
import MarkdownCodeComponent from './MarkdownCodeComponent'
|
||||
import MarkdownReferenceBlock from './MarkdownReferenceBlock'
|
||||
|
||||
function ReactMarkdown({
|
||||
onApply,
|
||||
isApplying,
|
||||
children,
|
||||
}: {
|
||||
onApply: (blockInfo: {
|
||||
content: string
|
||||
filename?: string
|
||||
startLine?: number
|
||||
endLine?: number
|
||||
}) => void
|
||||
children: string
|
||||
isApplying: boolean
|
||||
}) {
|
||||
const blocks: ParsedinfioBlock[] = useMemo(
|
||||
() => parseinfioBlocks(children),
|
||||
[children],
|
||||
)
|
||||
|
||||
return (
|
||||
<>
|
||||
{blocks.map((block, index) =>
|
||||
block.type === 'string' ? (
|
||||
<Markdown key={index} className="infio-markdown">
|
||||
{block.content}
|
||||
</Markdown>
|
||||
) : block.startLine && block.endLine && block.filename && block.action === 'reference' ? (
|
||||
<MarkdownReferenceBlock
|
||||
key={index}
|
||||
filename={block.filename}
|
||||
startLine={block.startLine}
|
||||
endLine={block.endLine}
|
||||
/>
|
||||
) : (
|
||||
<MarkdownCodeComponent
|
||||
key={index}
|
||||
onApply={onApply}
|
||||
isApplying={isApplying}
|
||||
language={block.language}
|
||||
filename={block.filename}
|
||||
startLine={block.startLine}
|
||||
endLine={block.endLine}
|
||||
action={block.action}
|
||||
>
|
||||
{block.content}
|
||||
</MarkdownCodeComponent>
|
||||
),
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(ReactMarkdown)
|
||||
38
src/components/chat-view/ShortcutInfo.tsx
Normal file
38
src/components/chat-view/ShortcutInfo.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import { Platform } from 'obsidian';
|
||||
import React from 'react';
|
||||
|
||||
const ShortcutInfo: React.FC = () => {
|
||||
const modKey = Platform.isMacOS ? 'Cmd' : 'Ctrl';
|
||||
|
||||
const shortcuts = [
|
||||
{
|
||||
label: 'Edit inline',
|
||||
shortcut: `${modKey}+Shift+K`,
|
||||
},
|
||||
{
|
||||
label: 'Chat with select',
|
||||
shortcut: `${modKey}+Shift+L`,
|
||||
},
|
||||
{
|
||||
label: 'Submit with vault',
|
||||
shortcut: `${modKey}+Shift+Enter`,
|
||||
}
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="infio-shortcut-info">
|
||||
<table className="infio-shortcut-table">
|
||||
<tbody>
|
||||
{shortcuts.map((item, index) => (
|
||||
<tr key={index} className="infio-shortcut-item">
|
||||
<td className="infio-shortcut-label">{item.label}</td>
|
||||
<td className="infio-shortcut-key"><kbd>{item.shortcut}</kbd></td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ShortcutInfo;
|
||||
71
src/components/chat-view/SimilaritySearchResults.tsx
Normal file
71
src/components/chat-view/SimilaritySearchResults.tsx
Normal file
@@ -0,0 +1,71 @@
|
||||
import path from 'path'
|
||||
|
||||
import { ChevronDown, ChevronRight } from 'lucide-react'
|
||||
import { useState } from 'react'
|
||||
|
||||
import { useApp } from '../../contexts/AppContext'
|
||||
import { SelectVector } from '../../database/schema'
|
||||
import { openMarkdownFile } from '../../utils/obsidian'
|
||||
|
||||
function SimiliartySearchItem({
|
||||
chunk,
|
||||
}: {
|
||||
chunk: Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
}
|
||||
}) {
|
||||
const app = useApp()
|
||||
|
||||
const handleClick = () => {
|
||||
openMarkdownFile(app, chunk.path, chunk.metadata.startLine)
|
||||
}
|
||||
return (
|
||||
<div onClick={handleClick} className="infio-similarity-search-item">
|
||||
<div className="infio-similarity-search-item__similarity">
|
||||
{chunk.similarity.toFixed(3)}
|
||||
</div>
|
||||
<div className="infio-similarity-search-item__path">
|
||||
{path.basename(chunk.path)}
|
||||
</div>
|
||||
<div className="infio-similarity-search-item__line-numbers">
|
||||
{`${chunk.metadata.startLine} - ${chunk.metadata.endLine}`}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default function SimilaritySearchResults({
|
||||
similaritySearchResults,
|
||||
}: {
|
||||
similaritySearchResults: (Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
})[]
|
||||
}) {
|
||||
const [isOpen, setIsOpen] = useState(false)
|
||||
|
||||
return (
|
||||
<div className="infio-similarity-search-results">
|
||||
<div
|
||||
onClick={() => {
|
||||
setIsOpen(!isOpen)
|
||||
}}
|
||||
className="infio-similarity-search-results__trigger"
|
||||
>
|
||||
{isOpen ? <ChevronDown size={16} /> : <ChevronRight size={16} />}
|
||||
<div>Show Referenced Documents ({similaritySearchResults.length})</div>
|
||||
</div>
|
||||
{isOpen && (
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
}}
|
||||
>
|
||||
{similaritySearchResults.map((chunk) => (
|
||||
<SimiliartySearchItem key={chunk.id} chunk={chunk} />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
51
src/components/chat-view/SyntaxHighlighterWrapper.tsx
Normal file
51
src/components/chat-view/SyntaxHighlighterWrapper.tsx
Normal file
@@ -0,0 +1,51 @@
|
||||
import { memo } from 'react'
|
||||
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'
|
||||
import {
|
||||
oneDark,
|
||||
oneLight,
|
||||
} from 'react-syntax-highlighter/dist/esm/styles/prism'
|
||||
|
||||
function SyntaxHighlighterWrapper({
|
||||
isDarkMode,
|
||||
language,
|
||||
hasFilename,
|
||||
wrapLines,
|
||||
children,
|
||||
}: {
|
||||
isDarkMode: boolean
|
||||
language: string | undefined
|
||||
hasFilename: boolean
|
||||
wrapLines: boolean
|
||||
children: string
|
||||
}) {
|
||||
return (
|
||||
<SyntaxHighlighter
|
||||
language={language}
|
||||
style={isDarkMode ? oneDark : oneLight}
|
||||
customStyle={{
|
||||
borderRadius: hasFilename
|
||||
? '0 0 var(--radius-s) var(--radius-s)'
|
||||
: 'var(--radius-s)',
|
||||
margin: 0,
|
||||
padding: 'var(--size-4-2)',
|
||||
fontSize: 'var(--font-ui-small)',
|
||||
fontFamily:
|
||||
language === 'markdown' ? 'var(--font-interface)' : 'inherit',
|
||||
}}
|
||||
wrapLines={wrapLines}
|
||||
lineProps={
|
||||
// Wrapping should work without lineProps, but Obsidian's default CSS seems to override SyntaxHighlighter's styles.
|
||||
// We manually override the white-space property to ensure proper wrapping.
|
||||
wrapLines
|
||||
? {
|
||||
style: { whiteSpace: 'pre-wrap' },
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</SyntaxHighlighter>
|
||||
)
|
||||
}
|
||||
|
||||
export const MemoizedSyntaxHighlighterWrapper = memo(SyntaxHighlighterWrapper)
|
||||
374
src/components/chat-view/chat-input/ChatUserInput.tsx
Normal file
374
src/components/chat-view/chat-input/ChatUserInput.tsx
Normal file
@@ -0,0 +1,374 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { $nodesOfType, LexicalEditor, SerializedEditorState } from 'lexical'
|
||||
import {
|
||||
forwardRef,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useImperativeHandle,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
|
||||
import { useApp } from '../../../contexts/AppContext'
|
||||
import { useDarkModeContext } from '../../../contexts/DarkModeContext'
|
||||
import {
|
||||
Mentionable,
|
||||
MentionableImage,
|
||||
SerializedMentionable,
|
||||
} from '../../../types/mentionable'
|
||||
import { fileToMentionableImage } from '../../../utils/image'
|
||||
import {
|
||||
deserializeMentionable,
|
||||
getMentionableKey,
|
||||
serializeMentionable,
|
||||
} from '../../../utils/mentionable'
|
||||
import { openMarkdownFile, readTFileContent } from '../../../utils/obsidian'
|
||||
import { MemoizedSyntaxHighlighterWrapper } from '../SyntaxHighlighterWrapper'
|
||||
|
||||
import { ImageUploadButton } from './ImageUploadButton'
|
||||
import LexicalContentEditable from './LexicalContentEditable'
|
||||
import MentionableBadge from './MentionableBadge'
|
||||
import { ModelSelect } from './ModelSelect'
|
||||
import { MentionNode } from './plugins/mention/MentionNode'
|
||||
import { NodeMutations } from './plugins/on-mutation/OnMutationPlugin'
|
||||
import { SubmitButton } from './SubmitButton'
|
||||
import { VaultChatButton } from './VaultChatButton'
|
||||
|
||||
export type ChatUserInputRef = {
|
||||
focus: () => void
|
||||
}
|
||||
|
||||
export type ChatUserInputProps = {
|
||||
initialSerializedEditorState: SerializedEditorState | null
|
||||
onChange: (content: SerializedEditorState) => void
|
||||
onSubmit: (content: SerializedEditorState, useVaultSearch?: boolean) => void
|
||||
onFocus: () => void
|
||||
mentionables: Mentionable[]
|
||||
setMentionables: (mentionables: Mentionable[]) => void
|
||||
autoFocus?: boolean
|
||||
addedBlockKey?: string | null
|
||||
}
|
||||
|
||||
const ChatUserInput = forwardRef<ChatUserInputRef, ChatUserInputProps>(
|
||||
(
|
||||
{
|
||||
initialSerializedEditorState,
|
||||
onChange,
|
||||
onSubmit,
|
||||
onFocus,
|
||||
mentionables,
|
||||
setMentionables,
|
||||
autoFocus = false,
|
||||
addedBlockKey,
|
||||
},
|
||||
ref,
|
||||
) => {
|
||||
const app = useApp()
|
||||
|
||||
const editorRef = useRef<LexicalEditor | null>(null)
|
||||
const contentEditableRef = useRef<HTMLDivElement>(null)
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const [displayedMentionableKey, setDisplayedMentionableKey] = useState<
|
||||
string | null
|
||||
>(addedBlockKey ?? null)
|
||||
|
||||
useEffect(() => {
|
||||
if (addedBlockKey) {
|
||||
setDisplayedMentionableKey(addedBlockKey)
|
||||
}
|
||||
}, [addedBlockKey])
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
focus: () => {
|
||||
contentEditableRef.current?.focus()
|
||||
},
|
||||
}))
|
||||
|
||||
const handleMentionNodeMutation = (
|
||||
mutations: NodeMutations<MentionNode>,
|
||||
) => {
|
||||
const destroyedMentionableKeys: string[] = []
|
||||
const addedMentionables: SerializedMentionable[] = []
|
||||
mutations.forEach((mutation) => {
|
||||
const mentionable = mutation.node.getMentionable()
|
||||
const mentionableKey = getMentionableKey(mentionable)
|
||||
|
||||
if (mutation.mutation === 'destroyed') {
|
||||
const nodeWithSameMentionable = editorRef.current?.read(() =>
|
||||
$nodesOfType(MentionNode).find(
|
||||
(node) =>
|
||||
getMentionableKey(node.getMentionable()) === mentionableKey,
|
||||
),
|
||||
)
|
||||
|
||||
if (!nodeWithSameMentionable) {
|
||||
// remove mentionable only if it's not present in the editor state
|
||||
destroyedMentionableKeys.push(mentionableKey)
|
||||
}
|
||||
} else if (mutation.mutation === 'created') {
|
||||
if (
|
||||
mentionables.some(
|
||||
(m) =>
|
||||
getMentionableKey(serializeMentionable(m)) === mentionableKey,
|
||||
) ||
|
||||
addedMentionables.some(
|
||||
(m) => getMentionableKey(m) === mentionableKey,
|
||||
)
|
||||
) {
|
||||
// do nothing if mentionable is already added
|
||||
return
|
||||
}
|
||||
|
||||
addedMentionables.push(mentionable)
|
||||
}
|
||||
})
|
||||
|
||||
setMentionables(
|
||||
mentionables
|
||||
.filter(
|
||||
(m) =>
|
||||
!destroyedMentionableKeys.includes(
|
||||
getMentionableKey(serializeMentionable(m)),
|
||||
),
|
||||
)
|
||||
.concat(
|
||||
addedMentionables
|
||||
.map((m) => deserializeMentionable(m, app))
|
||||
.filter((v) => !!v),
|
||||
),
|
||||
)
|
||||
if (addedMentionables.length > 0) {
|
||||
setDisplayedMentionableKey(
|
||||
getMentionableKey(addedMentionables[addedMentionables.length - 1]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const handleCreateImageMentionables = useCallback(
|
||||
(mentionableImages: MentionableImage[]) => {
|
||||
const newMentionableImages = mentionableImages.filter(
|
||||
(m) =>
|
||||
!mentionables.some(
|
||||
(mentionable) =>
|
||||
getMentionableKey(serializeMentionable(mentionable)) ===
|
||||
getMentionableKey(serializeMentionable(m)),
|
||||
),
|
||||
)
|
||||
if (newMentionableImages.length === 0) return
|
||||
setMentionables([...mentionables, ...newMentionableImages])
|
||||
setDisplayedMentionableKey(
|
||||
getMentionableKey(
|
||||
serializeMentionable(
|
||||
newMentionableImages[newMentionableImages.length - 1],
|
||||
),
|
||||
),
|
||||
)
|
||||
},
|
||||
[mentionables, setMentionables],
|
||||
)
|
||||
|
||||
const handleMentionableDelete = (mentionable: Mentionable) => {
|
||||
const mentionableKey = getMentionableKey(
|
||||
serializeMentionable(mentionable),
|
||||
)
|
||||
setMentionables(
|
||||
mentionables.filter(
|
||||
(m) => getMentionableKey(serializeMentionable(m)) !== mentionableKey,
|
||||
),
|
||||
)
|
||||
|
||||
editorRef.current?.update(() => {
|
||||
$nodesOfType(MentionNode).forEach((node) => {
|
||||
if (getMentionableKey(node.getMentionable()) === mentionableKey) {
|
||||
node.remove()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const handleUploadImages = async (images: File[]) => {
|
||||
const mentionableImages = await Promise.all(
|
||||
images.map((image) => fileToMentionableImage(image)),
|
||||
)
|
||||
handleCreateImageMentionables(mentionableImages)
|
||||
}
|
||||
|
||||
const handleSubmit = (options: { useVaultSearch?: boolean } = {}) => {
|
||||
const content = editorRef.current?.getEditorState()?.toJSON()
|
||||
content && onSubmit(content, options.useVaultSearch)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="infio-chat-user-input-container" ref={containerRef}>
|
||||
{mentionables.length > 0 && (
|
||||
<div className="infio-chat-user-input-files">
|
||||
{mentionables.map((m) => (
|
||||
<MentionableBadge
|
||||
key={getMentionableKey(serializeMentionable(m))}
|
||||
mentionable={m}
|
||||
onDelete={() => handleMentionableDelete(m)}
|
||||
onClick={() => {
|
||||
const mentionableKey = getMentionableKey(
|
||||
serializeMentionable(m),
|
||||
)
|
||||
if (
|
||||
(m.type === 'current-file' ||
|
||||
m.type === 'file' ||
|
||||
m.type === 'block') &&
|
||||
m.file &&
|
||||
mentionableKey === displayedMentionableKey
|
||||
) {
|
||||
// open file on click again
|
||||
openMarkdownFile(
|
||||
app,
|
||||
m.file.path,
|
||||
m.type === 'block' ? m.startLine : undefined,
|
||||
)
|
||||
} else {
|
||||
setDisplayedMentionableKey(mentionableKey)
|
||||
}
|
||||
}}
|
||||
isFocused={
|
||||
getMentionableKey(serializeMentionable(m)) ===
|
||||
displayedMentionableKey
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<MentionableContentPreview
|
||||
displayedMentionableKey={displayedMentionableKey}
|
||||
mentionables={mentionables}
|
||||
/>
|
||||
|
||||
<LexicalContentEditable
|
||||
initialEditorState={(editor) => {
|
||||
if (initialSerializedEditorState) {
|
||||
editor.setEditorState(
|
||||
editor.parseEditorState(initialSerializedEditorState),
|
||||
)
|
||||
}
|
||||
}}
|
||||
editorRef={editorRef}
|
||||
contentEditableRef={contentEditableRef}
|
||||
onChange={onChange}
|
||||
onEnter={() => handleSubmit({ useVaultSearch: false })}
|
||||
onFocus={onFocus}
|
||||
onMentionNodeMutation={handleMentionNodeMutation}
|
||||
onCreateImageMentionables={handleCreateImageMentionables}
|
||||
autoFocus={autoFocus}
|
||||
plugins={{
|
||||
onEnter: {
|
||||
onVaultChat: () => {
|
||||
handleSubmit({ useVaultSearch: true })
|
||||
},
|
||||
},
|
||||
templatePopover: {
|
||||
anchorElement: containerRef.current,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
|
||||
<div className="infio-chat-user-input-controls">
|
||||
<div className="infio-chat-user-input-controls__model-select-container">
|
||||
<ModelSelect />
|
||||
<ImageUploadButton onUpload={handleUploadImages} />
|
||||
</div>
|
||||
<div className="infio-chat-user-input-controls__buttons">
|
||||
<SubmitButton onClick={() => handleSubmit()} />
|
||||
{/* <VaultChatButton
|
||||
onClick={() => {
|
||||
handleSubmit({ useVaultSearch: true })
|
||||
}}
|
||||
/> */}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
function MentionableContentPreview({
|
||||
displayedMentionableKey,
|
||||
mentionables,
|
||||
}: {
|
||||
displayedMentionableKey: string | null
|
||||
mentionables: Mentionable[]
|
||||
}) {
|
||||
const app = useApp()
|
||||
const { isDarkMode } = useDarkModeContext()
|
||||
|
||||
const displayedMentionable: Mentionable | null = useMemo(() => {
|
||||
return (
|
||||
mentionables.find(
|
||||
(m) =>
|
||||
getMentionableKey(serializeMentionable(m)) ===
|
||||
displayedMentionableKey,
|
||||
) ?? null
|
||||
)
|
||||
}, [displayedMentionableKey, mentionables])
|
||||
|
||||
const { data: displayFileContent } = useQuery({
|
||||
enabled:
|
||||
!!displayedMentionable &&
|
||||
['file', 'current-file', 'block'].includes(displayedMentionable.type),
|
||||
queryKey: [
|
||||
'file',
|
||||
displayedMentionableKey,
|
||||
mentionables.map((m) => getMentionableKey(serializeMentionable(m))), // should be updated when mentionables change (especially on delete)
|
||||
],
|
||||
queryFn: async () => {
|
||||
if (!displayedMentionable) return null
|
||||
if (
|
||||
displayedMentionable.type === 'file' ||
|
||||
displayedMentionable.type === 'current-file'
|
||||
) {
|
||||
if (!displayedMentionable.file) return null
|
||||
return await readTFileContent(displayedMentionable.file, app.vault)
|
||||
} else if (displayedMentionable.type === 'block') {
|
||||
const fileContent = await readTFileContent(
|
||||
displayedMentionable.file,
|
||||
app.vault,
|
||||
)
|
||||
|
||||
return fileContent
|
||||
.split('\n')
|
||||
.slice(
|
||||
displayedMentionable.startLine - 1,
|
||||
displayedMentionable.endLine,
|
||||
)
|
||||
.join('\n')
|
||||
}
|
||||
|
||||
return null
|
||||
},
|
||||
})
|
||||
|
||||
const displayImage: MentionableImage | null = useMemo(() => {
|
||||
return displayedMentionable?.type === 'image' ? displayedMentionable : null
|
||||
}, [displayedMentionable])
|
||||
|
||||
return displayFileContent ? (
|
||||
<div className="infio-chat-user-input-file-content-preview">
|
||||
<MemoizedSyntaxHighlighterWrapper
|
||||
isDarkMode={isDarkMode}
|
||||
language="markdown"
|
||||
hasFilename={false}
|
||||
wrapLines={false}
|
||||
>
|
||||
{displayFileContent}
|
||||
</MemoizedSyntaxHighlighterWrapper>
|
||||
</div>
|
||||
) : displayImage ? (
|
||||
<div className="infio-chat-user-input-file-content-preview">
|
||||
<img src={displayImage.data} alt={displayImage.name} />
|
||||
</div>
|
||||
) : null
|
||||
}
|
||||
|
||||
ChatUserInput.displayName = 'ChatUserInput'
|
||||
|
||||
export default ChatUserInput
|
||||
30
src/components/chat-view/chat-input/ImageUploadButton.tsx
Normal file
30
src/components/chat-view/chat-input/ImageUploadButton.tsx
Normal file
@@ -0,0 +1,30 @@
|
||||
import { ImageIcon } from 'lucide-react'
|
||||
|
||||
export function ImageUploadButton({
|
||||
onUpload,
|
||||
}: {
|
||||
onUpload: (files: File[]) => void
|
||||
}) {
|
||||
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const files = Array.from(event.target.files ?? [])
|
||||
if (files.length > 0) {
|
||||
onUpload(files)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<label className="infio-chat-user-input-submit-button">
|
||||
<input
|
||||
type="file"
|
||||
accept="image/*"
|
||||
multiple
|
||||
onChange={handleFileChange}
|
||||
style={{ display: 'none' }}
|
||||
/>
|
||||
<div className="infio-chat-user-input-submit-button-icons">
|
||||
<ImageIcon size={12} />
|
||||
</div>
|
||||
<div>Image</div>
|
||||
</label>
|
||||
)
|
||||
}
|
||||
153
src/components/chat-view/chat-input/LexicalContentEditable.tsx
Normal file
153
src/components/chat-view/chat-input/LexicalContentEditable.tsx
Normal file
@@ -0,0 +1,153 @@
|
||||
import {
|
||||
InitialConfigType,
|
||||
InitialEditorStateType,
|
||||
LexicalComposer,
|
||||
} from '@lexical/react/LexicalComposer'
|
||||
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
|
||||
import { EditorRefPlugin } from '@lexical/react/LexicalEditorRefPlugin'
|
||||
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
|
||||
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin'
|
||||
import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin'
|
||||
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
|
||||
import { LexicalEditor, SerializedEditorState } from 'lexical'
|
||||
import { RefObject, useCallback, useEffect } from 'react'
|
||||
|
||||
import { useApp } from '../../../contexts/AppContext'
|
||||
import { MentionableImage } from '../../../types/mentionable'
|
||||
import { fuzzySearch } from '../../../utils/fuzzy-search'
|
||||
|
||||
import DragDropPaste from './plugins/image/DragDropPastePlugin'
|
||||
import ImagePastePlugin from './plugins/image/ImagePastePlugin'
|
||||
import AutoLinkMentionPlugin from './plugins/mention/AutoLinkMentionPlugin'
|
||||
import { MentionNode } from './plugins/mention/MentionNode'
|
||||
import MentionPlugin from './plugins/mention/MentionPlugin'
|
||||
import NoFormatPlugin from './plugins/no-format/NoFormatPlugin'
|
||||
import OnEnterPlugin from './plugins/on-enter/OnEnterPlugin'
|
||||
import OnMutationPlugin, {
|
||||
NodeMutations,
|
||||
} from './plugins/on-mutation/OnMutationPlugin'
|
||||
import CreateTemplatePopoverPlugin from './plugins/template/CreateTemplatePopoverPlugin'
|
||||
import TemplatePlugin from './plugins/template/TemplatePlugin'
|
||||
|
||||
export type LexicalContentEditableProps = {
|
||||
editorRef: RefObject<LexicalEditor>
|
||||
contentEditableRef: RefObject<HTMLDivElement>
|
||||
onChange?: (content: SerializedEditorState) => void
|
||||
onEnter?: (evt: KeyboardEvent) => void
|
||||
onFocus?: () => void
|
||||
onMentionNodeMutation?: (mutations: NodeMutations<MentionNode>) => void
|
||||
onCreateImageMentionables?: (mentionables: MentionableImage[]) => void
|
||||
initialEditorState?: InitialEditorStateType
|
||||
autoFocus?: boolean
|
||||
plugins?: {
|
||||
onEnter?: {
|
||||
onVaultChat: () => void
|
||||
}
|
||||
templatePopover?: {
|
||||
anchorElement: HTMLElement | null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default function LexicalContentEditable({
|
||||
editorRef,
|
||||
contentEditableRef,
|
||||
onChange,
|
||||
onEnter,
|
||||
onFocus,
|
||||
onMentionNodeMutation,
|
||||
onCreateImageMentionables,
|
||||
initialEditorState,
|
||||
autoFocus = false,
|
||||
plugins,
|
||||
}: LexicalContentEditableProps) {
|
||||
const app = useApp()
|
||||
|
||||
const initialConfig: InitialConfigType = {
|
||||
namespace: 'LexicalContentEditable',
|
||||
theme: {
|
||||
root: 'infio-chat-lexical-content-editable-root',
|
||||
paragraph: 'infio-chat-lexical-content-editable-paragraph',
|
||||
},
|
||||
nodes: [MentionNode],
|
||||
editorState: initialEditorState,
|
||||
onError: (error) => {
|
||||
console.error(error)
|
||||
},
|
||||
}
|
||||
|
||||
const searchResultByQuery = useCallback(
|
||||
(query: string) => fuzzySearch(app, query),
|
||||
[app],
|
||||
)
|
||||
|
||||
/*
|
||||
* Using requestAnimationFrame for autoFocus instead of using editor.focus()
|
||||
* due to known issues with editor.focus() when initialConfig.editorState is set
|
||||
* See: https://github.com/facebook/lexical/issues/4460
|
||||
*/
|
||||
useEffect(() => {
|
||||
if (autoFocus) {
|
||||
requestAnimationFrame(() => {
|
||||
contentEditableRef.current?.focus()
|
||||
})
|
||||
}
|
||||
}, [autoFocus, contentEditableRef])
|
||||
|
||||
return (
|
||||
<LexicalComposer initialConfig={initialConfig}>
|
||||
{/*
|
||||
There was two approach to make mentionable node copy and pasteable.
|
||||
1. use RichTextPlugin and reset text format when paste
|
||||
- so I implemented NoFormatPlugin to reset text format when paste
|
||||
2. use PlainTextPlugin and override paste command
|
||||
- PlainTextPlugin only pastes text, so we need to implement custom paste handler.
|
||||
- https://github.com/facebook/lexical/discussions/5112
|
||||
*/}
|
||||
<RichTextPlugin
|
||||
contentEditable={
|
||||
<ContentEditable
|
||||
className="obsidian-default-textarea"
|
||||
style={{
|
||||
background: 'transparent',
|
||||
}}
|
||||
onFocus={onFocus}
|
||||
ref={contentEditableRef}
|
||||
/>
|
||||
}
|
||||
ErrorBoundary={LexicalErrorBoundary}
|
||||
/>
|
||||
<HistoryPlugin />
|
||||
<MentionPlugin searchResultByQuery={searchResultByQuery} />
|
||||
<OnChangePlugin
|
||||
onChange={(editorState) => {
|
||||
onChange?.(editorState.toJSON())
|
||||
}}
|
||||
/>
|
||||
{onEnter && (
|
||||
<OnEnterPlugin
|
||||
onEnter={onEnter}
|
||||
onVaultChat={plugins?.onEnter?.onVaultChat}
|
||||
/>
|
||||
)}
|
||||
<OnMutationPlugin
|
||||
nodeClass={MentionNode}
|
||||
onMutation={(mutations) => {
|
||||
onMentionNodeMutation?.(mutations)
|
||||
}}
|
||||
/>
|
||||
<EditorRefPlugin editorRef={editorRef} />
|
||||
<NoFormatPlugin />
|
||||
<AutoLinkMentionPlugin />
|
||||
<ImagePastePlugin onCreateImageMentionables={onCreateImageMentionables} />
|
||||
<DragDropPaste onCreateImageMentionables={onCreateImageMentionables} />
|
||||
<TemplatePlugin />
|
||||
{plugins?.templatePopover && (
|
||||
<CreateTemplatePopoverPlugin
|
||||
anchorElement={plugins.templatePopover.anchorElement}
|
||||
contentEditableElement={contentEditableRef.current}
|
||||
/>
|
||||
)}
|
||||
</LexicalComposer>
|
||||
)
|
||||
}
|
||||
319
src/components/chat-view/chat-input/MentionableBadge.tsx
Normal file
319
src/components/chat-view/chat-input/MentionableBadge.tsx
Normal file
@@ -0,0 +1,319 @@
|
||||
import { X } from 'lucide-react'
|
||||
import { PropsWithChildren } from 'react'
|
||||
|
||||
import {
|
||||
Mentionable,
|
||||
MentionableBlock,
|
||||
MentionableCurrentFile,
|
||||
MentionableFile,
|
||||
MentionableFolder,
|
||||
MentionableImage,
|
||||
MentionableUrl,
|
||||
MentionableVault,
|
||||
} from '../../../types/mentionable'
|
||||
|
||||
import { getMentionableIcon } from './utils/get-metionable-icon'
|
||||
|
||||
function BadgeBase({
|
||||
children,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused,
|
||||
}: PropsWithChildren<{
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused: boolean
|
||||
}>) {
|
||||
return (
|
||||
<div
|
||||
className={`infio-chat-user-input-file-badge ${isFocused ? 'infio-chat-user-input-file-badge-focused' : ''}`}
|
||||
onClick={onClick}
|
||||
>
|
||||
{children}
|
||||
<div
|
||||
className="infio-chat-user-input-file-badge-delete"
|
||||
onClick={(evt) => {
|
||||
evt.stopPropagation()
|
||||
onDelete()
|
||||
}}
|
||||
>
|
||||
<X size={10} />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function FileBadge({
|
||||
mentionable,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused,
|
||||
}: {
|
||||
mentionable: MentionableFile
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused: boolean
|
||||
}) {
|
||||
const Icon = getMentionableIcon(mentionable)
|
||||
return (
|
||||
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
|
||||
<div className="infio-chat-user-input-file-badge-name">
|
||||
{Icon && (
|
||||
<Icon
|
||||
size={10}
|
||||
className="infio-chat-user-input-file-badge-name-icon"
|
||||
/>
|
||||
)}
|
||||
<span>{mentionable.file.name}</span>
|
||||
</div>
|
||||
</BadgeBase>
|
||||
)
|
||||
}
|
||||
|
||||
function FolderBadge({
|
||||
mentionable,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused,
|
||||
}: {
|
||||
mentionable: MentionableFolder
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused: boolean
|
||||
}) {
|
||||
const Icon = getMentionableIcon(mentionable)
|
||||
return (
|
||||
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
|
||||
<div className="infio-chat-user-input-file-badge-name">
|
||||
{Icon && (
|
||||
<Icon
|
||||
size={10}
|
||||
className="infio-chat-user-input-file-badge-name-icon"
|
||||
/>
|
||||
)}
|
||||
<span>{mentionable.folder.name}</span>
|
||||
</div>
|
||||
</BadgeBase>
|
||||
)
|
||||
}
|
||||
|
||||
function VaultBadge({
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
mentionable,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused,
|
||||
}: {
|
||||
mentionable: MentionableVault
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused: boolean
|
||||
}) {
|
||||
const Icon = getMentionableIcon(mentionable)
|
||||
return (
|
||||
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
|
||||
{/* TODO: Update style */}
|
||||
<div className="infio-chat-user-input-file-badge-name">
|
||||
{Icon && (
|
||||
<Icon
|
||||
size={10}
|
||||
className="infio-chat-user-input-file-badge-name-icon"
|
||||
/>
|
||||
)}
|
||||
<span>Vault</span>
|
||||
</div>
|
||||
</BadgeBase>
|
||||
)
|
||||
}
|
||||
|
||||
function CurrentFileBadge({
|
||||
mentionable,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused,
|
||||
}: {
|
||||
mentionable: MentionableCurrentFile
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused: boolean
|
||||
}) {
|
||||
const Icon = getMentionableIcon(mentionable)
|
||||
return mentionable.file ? (
|
||||
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
|
||||
<div className="infio-chat-user-input-file-badge-name">
|
||||
{Icon && (
|
||||
<Icon
|
||||
size={10}
|
||||
className="infio-chat-user-input-file-badge-name-icon"
|
||||
/>
|
||||
)}
|
||||
<span>{mentionable.file.name}</span>
|
||||
</div>
|
||||
<div className="infio-chat-user-input-file-badge-name-block-suffix">
|
||||
{' (Current File)'}
|
||||
</div>
|
||||
</BadgeBase>
|
||||
) : null
|
||||
}
|
||||
|
||||
function BlockBadge({
|
||||
mentionable,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused,
|
||||
}: {
|
||||
mentionable: MentionableBlock
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused: boolean
|
||||
}) {
|
||||
const Icon = getMentionableIcon(mentionable)
|
||||
return (
|
||||
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
|
||||
<div className="infio-chat-user-input-file-badge-name-block-name">
|
||||
{Icon && (
|
||||
<Icon
|
||||
size={10}
|
||||
className="infio-chat-user-input-file-badge-name-block-name-icon"
|
||||
/>
|
||||
)}
|
||||
<span>{mentionable.file.name}</span>
|
||||
</div>
|
||||
<div className="infio-chat-user-input-file-badge-name-block-suffix">
|
||||
{` (${mentionable.startLine}:${mentionable.endLine})`}
|
||||
</div>
|
||||
</BadgeBase>
|
||||
)
|
||||
}
|
||||
|
||||
function UrlBadge({
|
||||
mentionable,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused,
|
||||
}: {
|
||||
mentionable: MentionableUrl
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused: boolean
|
||||
}) {
|
||||
const Icon = getMentionableIcon(mentionable)
|
||||
return (
|
||||
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
|
||||
<div className="infio-chat-user-input-file-badge-name">
|
||||
{Icon && (
|
||||
<Icon
|
||||
size={10}
|
||||
className="infio-chat-user-input-file-badge-name-icon"
|
||||
/>
|
||||
)}
|
||||
<span>{mentionable.url}</span>
|
||||
</div>
|
||||
</BadgeBase>
|
||||
)
|
||||
}
|
||||
|
||||
function ImageBadge({
|
||||
mentionable,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused,
|
||||
}: {
|
||||
mentionable: MentionableImage
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused: boolean
|
||||
}) {
|
||||
const Icon = getMentionableIcon(mentionable)
|
||||
return (
|
||||
<BadgeBase onDelete={onDelete} onClick={onClick} isFocused={isFocused}>
|
||||
<div className="infio-chat-user-input-file-badge-name">
|
||||
{Icon && (
|
||||
<Icon
|
||||
size={10}
|
||||
className="infio-chat-user-input-file-badge-name-icon"
|
||||
/>
|
||||
)}
|
||||
<span>{mentionable.name}</span>
|
||||
</div>
|
||||
</BadgeBase>
|
||||
)
|
||||
}
|
||||
|
||||
export default function MentionableBadge({
|
||||
mentionable,
|
||||
onDelete,
|
||||
onClick,
|
||||
isFocused = false,
|
||||
}: {
|
||||
mentionable: Mentionable
|
||||
onDelete: () => void
|
||||
onClick: () => void
|
||||
isFocused?: boolean
|
||||
}) {
|
||||
switch (mentionable.type) {
|
||||
case 'file':
|
||||
return (
|
||||
<FileBadge
|
||||
mentionable={mentionable}
|
||||
onDelete={onDelete}
|
||||
onClick={onClick}
|
||||
isFocused={isFocused}
|
||||
/>
|
||||
)
|
||||
case 'folder':
|
||||
return (
|
||||
<FolderBadge
|
||||
mentionable={mentionable}
|
||||
onDelete={onDelete}
|
||||
onClick={onClick}
|
||||
isFocused={isFocused}
|
||||
/>
|
||||
)
|
||||
case 'vault':
|
||||
return (
|
||||
<VaultBadge
|
||||
mentionable={mentionable}
|
||||
onDelete={onDelete}
|
||||
onClick={onClick}
|
||||
isFocused={isFocused}
|
||||
/>
|
||||
)
|
||||
case 'current-file':
|
||||
return (
|
||||
<CurrentFileBadge
|
||||
mentionable={mentionable}
|
||||
onDelete={onDelete}
|
||||
onClick={onClick}
|
||||
isFocused={isFocused}
|
||||
/>
|
||||
)
|
||||
case 'block':
|
||||
return (
|
||||
<BlockBadge
|
||||
mentionable={mentionable}
|
||||
onDelete={onDelete}
|
||||
onClick={onClick}
|
||||
isFocused={isFocused}
|
||||
/>
|
||||
)
|
||||
case 'url':
|
||||
return (
|
||||
<UrlBadge
|
||||
mentionable={mentionable}
|
||||
onDelete={onDelete}
|
||||
onClick={onClick}
|
||||
isFocused={isFocused}
|
||||
/>
|
||||
)
|
||||
case 'image':
|
||||
return (
|
||||
<ImageBadge
|
||||
mentionable={mentionable}
|
||||
onDelete={onDelete}
|
||||
onClick={onClick}
|
||||
isFocused={isFocused}
|
||||
/>
|
||||
)
|
||||
}
|
||||
}
|
||||
51
src/components/chat-view/chat-input/ModelSelect.tsx
Normal file
51
src/components/chat-view/chat-input/ModelSelect.tsx
Normal file
@@ -0,0 +1,51 @@
|
||||
import * as DropdownMenu from '@radix-ui/react-dropdown-menu'
|
||||
import { ChevronDown, ChevronUp } from 'lucide-react'
|
||||
import { useState } from 'react'
|
||||
|
||||
import { useSettings } from '../../../contexts/SettingsContext'
|
||||
|
||||
export function ModelSelect() {
|
||||
const { settings, setSettings } = useSettings()
|
||||
const [isOpen, setIsOpen] = useState(false)
|
||||
|
||||
const activeModels = settings.activeModels.filter((model) => model.enabled)
|
||||
|
||||
return (
|
||||
<DropdownMenu.Root open={isOpen} onOpenChange={setIsOpen}>
|
||||
<DropdownMenu.Trigger className="infio-chat-input-model-select">
|
||||
<div className="infio-chat-input-model-select__icon">
|
||||
{isOpen ? <ChevronUp size={12} /> : <ChevronDown size={12} />}
|
||||
</div>
|
||||
<div className="infio-chat-input-model-select__model-name">
|
||||
{
|
||||
activeModels.find(
|
||||
(option) => option.name === settings.chatModelId,
|
||||
)?.name
|
||||
}
|
||||
</div>
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Portal>
|
||||
<DropdownMenu.Content
|
||||
className="infio-popover">
|
||||
<ul>
|
||||
{activeModels.map((model) => (
|
||||
<DropdownMenu.Item
|
||||
key={model.name}
|
||||
onSelect={() => {
|
||||
setSettings({
|
||||
...settings,
|
||||
chatModelId: model.name,
|
||||
})
|
||||
}}
|
||||
asChild
|
||||
>
|
||||
<li>{model.name}</li>
|
||||
</DropdownMenu.Item>
|
||||
))}
|
||||
</ul>
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Portal>
|
||||
</DropdownMenu.Root>
|
||||
)
|
||||
}
|
||||
12
src/components/chat-view/chat-input/SubmitButton.tsx
Normal file
12
src/components/chat-view/chat-input/SubmitButton.tsx
Normal file
@@ -0,0 +1,12 @@
|
||||
import { CornerDownLeftIcon } from 'lucide-react'
|
||||
|
||||
export function SubmitButton({ onClick }: { onClick: () => void }) {
|
||||
return (
|
||||
<button className="infio-chat-user-input-submit-button" onClick={onClick}>
|
||||
<div>submit</div>
|
||||
<div className="infio-chat-user-input-submit-button-icons">
|
||||
<CornerDownLeftIcon size={12} />
|
||||
</div>
|
||||
</button>
|
||||
)
|
||||
}
|
||||
42
src/components/chat-view/chat-input/VaultChatButton.tsx
Normal file
42
src/components/chat-view/chat-input/VaultChatButton.tsx
Normal file
@@ -0,0 +1,42 @@
|
||||
import * as Tooltip from '@radix-ui/react-tooltip'
|
||||
import {
|
||||
ArrowBigUp,
|
||||
ChevronUp,
|
||||
Command,
|
||||
CornerDownLeftIcon,
|
||||
} from 'lucide-react'
|
||||
import { Platform } from 'obsidian'
|
||||
|
||||
export function VaultChatButton({ onClick }: { onClick: () => void }) {
|
||||
return (
|
||||
<>
|
||||
<Tooltip.Provider delayDuration={0}>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger asChild>
|
||||
<button
|
||||
className="infio-chat-user-input-vault-button"
|
||||
onClick={onClick}
|
||||
>
|
||||
<div>vault</div>
|
||||
<div className="infio-chat-user-input-vault-button-icons">
|
||||
{Platform.isMacOS ? (
|
||||
<Command size={10} />
|
||||
) : (
|
||||
<ChevronUp size={12} />
|
||||
)}
|
||||
{/* TODO: Replace with a custom icon */}
|
||||
{/* <ArrowBigUp size={12} /> */}
|
||||
<CornerDownLeftIcon size={12} />
|
||||
</div>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Portal>
|
||||
<Tooltip.Content className="infio-tooltip-content" sideOffset={5}>
|
||||
Chat with your entire vault
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Portal>
|
||||
</Tooltip.Root>
|
||||
</Tooltip.Provider>
|
||||
</>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { DRAG_DROP_PASTE } from '@lexical/rich-text'
|
||||
import { COMMAND_PRIORITY_LOW } from 'lexical'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
import { MentionableImage } from '../../../../../types/mentionable'
|
||||
import { fileToMentionableImage } from '../../../../../utils/image'
|
||||
|
||||
export default function DragDropPaste({
|
||||
onCreateImageMentionables,
|
||||
}: {
|
||||
onCreateImageMentionables?: (mentionables: MentionableImage[]) => void
|
||||
}): null {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
useEffect(() => {
|
||||
return editor.registerCommand(
|
||||
DRAG_DROP_PASTE, // dispatched in RichTextPlugin
|
||||
(files) => {
|
||||
; (async () => {
|
||||
const images = files.filter((file) => file.type.startsWith('image/'))
|
||||
const mentionableImages = await Promise.all(
|
||||
images.map(async (image) => await fileToMentionableImage(image)),
|
||||
)
|
||||
onCreateImageMentionables?.(mentionableImages)
|
||||
})()
|
||||
return true
|
||||
},
|
||||
COMMAND_PRIORITY_LOW,
|
||||
)
|
||||
}, [editor, onCreateImageMentionables])
|
||||
|
||||
return null
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { COMMAND_PRIORITY_LOW, PASTE_COMMAND, PasteCommandType } from 'lexical'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
import { MentionableImage } from '../../../../../types/mentionable'
|
||||
import { fileToMentionableImage } from '../../../../../utils/image'
|
||||
|
||||
export default function ImagePastePlugin({
|
||||
onCreateImageMentionables,
|
||||
}: {
|
||||
onCreateImageMentionables?: (mentionables: MentionableImage[]) => void
|
||||
}) {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
useEffect(() => {
|
||||
const handlePaste = (event: PasteCommandType) => {
|
||||
const clipboardData =
|
||||
event instanceof ClipboardEvent ? event.clipboardData : null
|
||||
if (!clipboardData) return false
|
||||
|
||||
const images = Array.from(clipboardData.files).filter((file) =>
|
||||
file.type.startsWith('image/'),
|
||||
)
|
||||
if (images.length === 0) return false
|
||||
|
||||
Promise.all(images.map((image) => fileToMentionableImage(image))).then(
|
||||
(mentionableImages) => {
|
||||
onCreateImageMentionables?.(mentionableImages)
|
||||
},
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
return editor.registerCommand(
|
||||
PASTE_COMMAND,
|
||||
handlePaste,
|
||||
COMMAND_PRIORITY_LOW,
|
||||
)
|
||||
}, [editor, onCreateImageMentionables])
|
||||
|
||||
return null
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import {
|
||||
$createTextNode,
|
||||
$getSelection,
|
||||
$isRangeSelection,
|
||||
COMMAND_PRIORITY_LOW,
|
||||
PASTE_COMMAND,
|
||||
PasteCommandType,
|
||||
TextNode,
|
||||
} from 'lexical'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
import { Mentionable, MentionableUrl } from '../../../../../types/mentionable'
|
||||
import {
|
||||
getMentionableName,
|
||||
serializeMentionable,
|
||||
} from '../../../../../utils/mentionable'
|
||||
|
||||
import { $createMentionNode } from './MentionNode'
|
||||
|
||||
const URL_MATCHER =
|
||||
/^((https?:\/\/(www\.)?)|(www\.))[-a-zA-Z0-9@:%._+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_+.~#?&//=]*)$/
|
||||
|
||||
type URLMatch = {
|
||||
index: number
|
||||
length: number
|
||||
text: string
|
||||
url: string
|
||||
}
|
||||
|
||||
function findURLs(text: string): URLMatch[] {
|
||||
const urls: URLMatch[] = []
|
||||
|
||||
let lastIndex = 0
|
||||
for (const word of text.split(' ')) {
|
||||
if (URL_MATCHER.test(word)) {
|
||||
urls.push({
|
||||
index: lastIndex,
|
||||
length: word.length,
|
||||
text: word,
|
||||
url: word.startsWith('http') ? word : `https://${word}`,
|
||||
// attributes: { rel: 'noreferrer', target: '_blank' }, // Optional link attributes
|
||||
})
|
||||
}
|
||||
|
||||
lastIndex += word.length + 1 // +1 for space
|
||||
}
|
||||
|
||||
return urls
|
||||
}
|
||||
|
||||
function $textNodeTransform(node: TextNode) {
|
||||
if (!node.isSimpleText()) {
|
||||
return
|
||||
}
|
||||
|
||||
const text = node.getTextContent()
|
||||
|
||||
// Find only 1st occurrence as transform will be re-run anyway for the rest
|
||||
// because newly inserted nodes are considered to be dirty
|
||||
const urlMatches = findURLs(text)
|
||||
if (urlMatches.length === 0) {
|
||||
return
|
||||
}
|
||||
const urlMatch = urlMatches[0]
|
||||
|
||||
// Get the current selection
|
||||
const selection = $getSelection()
|
||||
|
||||
// Check if the selection is a RangeSelection and the cursor is at the end of the URL
|
||||
if (
|
||||
$isRangeSelection(selection) &&
|
||||
selection.anchor.key === node.getKey() &&
|
||||
selection.focus.key === node.getKey() &&
|
||||
selection.anchor.offset === urlMatch.index + urlMatch.length &&
|
||||
selection.focus.offset === urlMatch.index + urlMatch.length
|
||||
) {
|
||||
// If the cursor is at the end of the URL, don't transform
|
||||
return
|
||||
}
|
||||
|
||||
let targetNode
|
||||
if (urlMatch.index === 0) {
|
||||
// First text chunk within string, splitting into 2 parts
|
||||
;[targetNode] = node.splitText(urlMatch.index + urlMatch.length)
|
||||
} else {
|
||||
// In the middle of a string
|
||||
;[, targetNode] = node.splitText(
|
||||
urlMatch.index,
|
||||
urlMatch.index + urlMatch.length,
|
||||
)
|
||||
}
|
||||
|
||||
const mentionable: MentionableUrl = {
|
||||
type: 'url',
|
||||
url: urlMatch.url,
|
||||
}
|
||||
|
||||
const mentionNode = $createMentionNode(
|
||||
getMentionableName(mentionable),
|
||||
serializeMentionable(mentionable),
|
||||
)
|
||||
|
||||
targetNode.replace(mentionNode)
|
||||
|
||||
const spaceNode = $createTextNode(' ')
|
||||
mentionNode.insertAfter(spaceNode)
|
||||
|
||||
spaceNode.select()
|
||||
}
|
||||
|
||||
function $handlePaste(event: PasteCommandType) {
|
||||
const clipboardData =
|
||||
event instanceof ClipboardEvent ? event.clipboardData : null
|
||||
|
||||
if (!clipboardData) return false
|
||||
|
||||
const text = clipboardData.getData('text/plain')
|
||||
|
||||
const urlMatches = findURLs(text)
|
||||
if (urlMatches.length === 0) {
|
||||
return false
|
||||
}
|
||||
|
||||
const selection = $getSelection()
|
||||
if (!$isRangeSelection(selection)) {
|
||||
return false
|
||||
}
|
||||
|
||||
const nodes = []
|
||||
const addedMentionables: Mentionable[] = []
|
||||
let lastIndex = 0
|
||||
|
||||
urlMatches.forEach((urlMatch) => {
|
||||
// Add text node for unmatched part
|
||||
if (urlMatch.index > lastIndex) {
|
||||
nodes.push($createTextNode(text.slice(lastIndex, urlMatch.index)))
|
||||
}
|
||||
|
||||
const mentionable: MentionableUrl = {
|
||||
type: 'url',
|
||||
url: urlMatch.url,
|
||||
}
|
||||
|
||||
// Add mention node
|
||||
nodes.push(
|
||||
$createMentionNode(urlMatch.text, serializeMentionable(mentionable)),
|
||||
)
|
||||
addedMentionables.push(mentionable)
|
||||
|
||||
lastIndex = urlMatch.index + urlMatch.length
|
||||
|
||||
// Add space node after mention if next character is not space or end of string
|
||||
if (lastIndex >= text.length || text[lastIndex] !== ' ') {
|
||||
nodes.push($createTextNode(' '))
|
||||
}
|
||||
})
|
||||
|
||||
// Add remaining text if any
|
||||
if (lastIndex < text.length) {
|
||||
nodes.push($createTextNode(text.slice(lastIndex)))
|
||||
}
|
||||
|
||||
selection.insertNodes(nodes)
|
||||
return true
|
||||
}
|
||||
|
||||
export default function AutoLinkMentionPlugin() {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
useEffect(() => {
|
||||
editor.registerCommand(PASTE_COMMAND, $handlePaste, COMMAND_PRIORITY_LOW)
|
||||
|
||||
editor.registerNodeTransform(TextNode, $textNodeTransform)
|
||||
}, [editor])
|
||||
|
||||
return null
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
/**
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license.
|
||||
* Original source: https://github.com/facebook/lexical
|
||||
*
|
||||
* Modified from the original code
|
||||
*/
|
||||
|
||||
import {
|
||||
$applyNodeReplacement,
|
||||
DOMConversionMap,
|
||||
DOMConversionOutput,
|
||||
DOMExportOutput,
|
||||
type EditorConfig,
|
||||
type LexicalNode,
|
||||
type NodeKey,
|
||||
type SerializedTextNode,
|
||||
type Spread,
|
||||
TextNode,
|
||||
} from 'lexical'
|
||||
|
||||
import { SerializedMentionable } from '../../../../../types/mentionable'
|
||||
|
||||
export const MENTION_NODE_TYPE = 'mention'
|
||||
export const MENTION_NODE_ATTRIBUTE = 'data-lexical-mention'
|
||||
export const MENTION_NODE_MENTION_NAME_ATTRIBUTE = 'data-lexical-mention-name'
|
||||
export const MENTION_NODE_MENTIONABLE_ATTRIBUTE = 'data-lexical-mentionable'
|
||||
|
||||
export type SerializedMentionNode = Spread<
|
||||
{
|
||||
mentionName: string
|
||||
mentionable: SerializedMentionable
|
||||
},
|
||||
SerializedTextNode
|
||||
>
|
||||
|
||||
function $convertMentionElement(
|
||||
domNode: HTMLElement,
|
||||
): DOMConversionOutput | null {
|
||||
const textContent = domNode.textContent
|
||||
const mentionName =
|
||||
domNode.getAttribute(MENTION_NODE_MENTION_NAME_ATTRIBUTE) ??
|
||||
domNode.textContent ??
|
||||
''
|
||||
const mentionable = JSON.parse(
|
||||
domNode.getAttribute(MENTION_NODE_MENTIONABLE_ATTRIBUTE) ?? '{}',
|
||||
)
|
||||
|
||||
if (textContent !== null) {
|
||||
const node = $createMentionNode(
|
||||
mentionName,
|
||||
mentionable as SerializedMentionable,
|
||||
)
|
||||
return {
|
||||
node,
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export class MentionNode extends TextNode {
|
||||
__mentionName: string
|
||||
__mentionable: SerializedMentionable
|
||||
|
||||
static getType(): string {
|
||||
return MENTION_NODE_TYPE
|
||||
}
|
||||
|
||||
static clone(node: MentionNode): MentionNode {
|
||||
return new MentionNode(node.__mentionName, node.__mentionable, node.__key)
|
||||
}
|
||||
static importJSON(serializedNode: SerializedMentionNode): MentionNode {
|
||||
const node = $createMentionNode(
|
||||
serializedNode.mentionName,
|
||||
serializedNode.mentionable,
|
||||
)
|
||||
node.setTextContent(serializedNode.text)
|
||||
node.setFormat(serializedNode.format)
|
||||
node.setDetail(serializedNode.detail)
|
||||
node.setMode(serializedNode.mode)
|
||||
node.setStyle(serializedNode.style)
|
||||
return node
|
||||
}
|
||||
|
||||
constructor(
|
||||
mentionName: string,
|
||||
mentionable: SerializedMentionable,
|
||||
key?: NodeKey,
|
||||
) {
|
||||
super(`@${mentionName}`, key)
|
||||
this.__mentionName = mentionName
|
||||
this.__mentionable = mentionable
|
||||
}
|
||||
|
||||
exportJSON(): SerializedMentionNode {
|
||||
return {
|
||||
...super.exportJSON(),
|
||||
mentionName: this.__mentionName,
|
||||
mentionable: this.__mentionable,
|
||||
type: MENTION_NODE_TYPE,
|
||||
version: 1,
|
||||
}
|
||||
}
|
||||
|
||||
createDOM(config: EditorConfig): HTMLElement {
|
||||
const dom = super.createDOM(config)
|
||||
dom.className = MENTION_NODE_TYPE
|
||||
return dom
|
||||
}
|
||||
|
||||
exportDOM(): DOMExportOutput {
|
||||
const element = document.createElement('span')
|
||||
element.setAttribute(MENTION_NODE_ATTRIBUTE, 'true')
|
||||
element.setAttribute(
|
||||
MENTION_NODE_MENTION_NAME_ATTRIBUTE,
|
||||
this.__mentionName,
|
||||
)
|
||||
element.setAttribute(
|
||||
MENTION_NODE_MENTIONABLE_ATTRIBUTE,
|
||||
JSON.stringify(this.__mentionable),
|
||||
)
|
||||
element.textContent = this.__text
|
||||
return { element }
|
||||
}
|
||||
|
||||
static importDOM(): DOMConversionMap | null {
|
||||
return {
|
||||
span: (domNode: HTMLElement) => {
|
||||
if (
|
||||
!domNode.hasAttribute(MENTION_NODE_ATTRIBUTE) ||
|
||||
!domNode.hasAttribute(MENTION_NODE_MENTION_NAME_ATTRIBUTE) ||
|
||||
!domNode.hasAttribute(MENTION_NODE_MENTIONABLE_ATTRIBUTE)
|
||||
) {
|
||||
return null
|
||||
}
|
||||
return {
|
||||
conversion: $convertMentionElement,
|
||||
priority: 1,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
isTextEntity(): true {
|
||||
return true
|
||||
}
|
||||
|
||||
canInsertTextBefore(): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
canInsertTextAfter(): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
getMentionable(): SerializedMentionable {
|
||||
return this.__mentionable
|
||||
}
|
||||
}
|
||||
|
||||
export function $createMentionNode(
|
||||
mentionName: string,
|
||||
mentionable: SerializedMentionable,
|
||||
): MentionNode {
|
||||
const mentionNode = new MentionNode(mentionName, mentionable)
|
||||
mentionNode.setMode('token').toggleDirectionless()
|
||||
return $applyNodeReplacement(mentionNode)
|
||||
}
|
||||
|
||||
export function $isMentionNode(
|
||||
node: LexicalNode | null | undefined,
|
||||
): node is MentionNode {
|
||||
return node instanceof MentionNode
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
/**
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license.
|
||||
* Original source: https://github.com/facebook/lexical
|
||||
*
|
||||
* Modified from the original code
|
||||
*/
|
||||
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { $createTextNode, COMMAND_PRIORITY_NORMAL, TextNode } from 'lexical'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { createPortal } from 'react-dom'
|
||||
|
||||
import { Mentionable } from '../../../../../types/mentionable'
|
||||
import { SearchableMentionable } from '../../../../../utils/fuzzy-search'
|
||||
import {
|
||||
getMentionableName,
|
||||
serializeMentionable,
|
||||
} from '../../../../../utils/mentionable'
|
||||
import { getMentionableIcon } from '../../utils/get-metionable-icon'
|
||||
import { MenuOption, MenuTextMatch } from '../shared/LexicalMenu'
|
||||
import {
|
||||
LexicalTypeaheadMenuPlugin,
|
||||
useBasicTypeaheadTriggerMatch,
|
||||
} from '../typeahead-menu/LexicalTypeaheadMenuPlugin'
|
||||
|
||||
import { $createMentionNode } from './MentionNode'
|
||||
|
||||
const PUNCTUATION =
|
||||
'\\.,\\+\\*\\?\\$\\@\\|#{}\\(\\)\\^\\-\\[\\]\\\\/!%\'"~=<>_:;'
|
||||
const NAME = '\\b[A-Z][^\\s' + PUNCTUATION + ']'
|
||||
|
||||
const DocumentMentionsRegex = {
|
||||
NAME,
|
||||
PUNCTUATION,
|
||||
}
|
||||
|
||||
const PUNC = DocumentMentionsRegex.PUNCTUATION
|
||||
|
||||
const TRIGGERS = ['@'].join('')
|
||||
|
||||
// Chars we expect to see in a mention (non-space, non-punctuation).
|
||||
const VALID_CHARS = '[^' + TRIGGERS + PUNC + '\\s]'
|
||||
|
||||
// Non-standard series of chars. Each series must be preceded and followed by
|
||||
// a valid char.
|
||||
const VALID_JOINS =
|
||||
'(?:' +
|
||||
'\\.[ |$]|' + // E.g. "r. " in "Mr. Smith"
|
||||
' |' + // E.g. " " in "Josh Duck"
|
||||
'[' +
|
||||
PUNC +
|
||||
']|' + // E.g. "-' in "Salier-Hellendag"
|
||||
')'
|
||||
|
||||
const LENGTH_LIMIT = 75
|
||||
|
||||
const AtSignMentionsRegex = new RegExp(
|
||||
`(^|\\s|\\()([${TRIGGERS}]((?:${VALID_CHARS}${VALID_JOINS}){0,${LENGTH_LIMIT}}))$`,
|
||||
)
|
||||
|
||||
// 50 is the longest alias length limit.
|
||||
const ALIAS_LENGTH_LIMIT = 50
|
||||
|
||||
// Regex used to match alias.
|
||||
const AtSignMentionsRegexAliasRegex = new RegExp(
|
||||
`(^|\\s|\\()([${TRIGGERS}]((?:${VALID_CHARS}){0,${ALIAS_LENGTH_LIMIT}}))$`,
|
||||
)
|
||||
|
||||
// At most, 20 suggestions are shown in the popup.
|
||||
const SUGGESTION_LIST_LENGTH_LIMIT = 20
|
||||
|
||||
function checkForAtSignMentions(
|
||||
text: string,
|
||||
minMatchLength: number,
|
||||
): MenuTextMatch | null {
|
||||
let match = AtSignMentionsRegex.exec(text)
|
||||
|
||||
if (match === null) {
|
||||
match = AtSignMentionsRegexAliasRegex.exec(text)
|
||||
}
|
||||
if (match !== null) {
|
||||
// The strategy ignores leading whitespace but we need to know it's
|
||||
// length to add it to the leadOffset
|
||||
const maybeLeadingWhitespace = match[1]
|
||||
|
||||
const matchingString = match[3]
|
||||
if (matchingString.length >= minMatchLength) {
|
||||
return {
|
||||
leadOffset: match.index + maybeLeadingWhitespace.length,
|
||||
matchingString,
|
||||
replaceableString: match[2],
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function getPossibleQueryMatch(text: string): MenuTextMatch | null {
|
||||
return checkForAtSignMentions(text, 0)
|
||||
}
|
||||
|
||||
class MentionTypeaheadOption extends MenuOption {
|
||||
name: string
|
||||
mentionable: Mentionable
|
||||
icon: React.ReactNode
|
||||
|
||||
constructor(result: SearchableMentionable) {
|
||||
switch (result.type) {
|
||||
case 'file':
|
||||
super(result.file.path)
|
||||
this.name = result.file.name
|
||||
this.mentionable = result
|
||||
break
|
||||
case 'folder':
|
||||
super(result.folder.path)
|
||||
this.name = result.folder.name
|
||||
this.mentionable = result
|
||||
break
|
||||
case 'vault':
|
||||
super('vault')
|
||||
this.name = 'Vault'
|
||||
this.mentionable = result
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function MentionsTypeaheadMenuItem({
|
||||
index,
|
||||
isSelected,
|
||||
onClick,
|
||||
onMouseEnter,
|
||||
option,
|
||||
}: {
|
||||
index: number
|
||||
isSelected: boolean
|
||||
onClick: () => void
|
||||
onMouseEnter: () => void
|
||||
option: MentionTypeaheadOption
|
||||
}) {
|
||||
let className = 'item'
|
||||
if (isSelected) {
|
||||
className += ' selected'
|
||||
}
|
||||
|
||||
const Icon = getMentionableIcon(option.mentionable)
|
||||
return (
|
||||
<li
|
||||
key={option.key}
|
||||
tabIndex={-1}
|
||||
className={className}
|
||||
ref={(el) => option.setRefElement(el)}
|
||||
role="option"
|
||||
aria-selected={isSelected}
|
||||
id={`typeahead-item-${index}`}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onClick={onClick}
|
||||
>
|
||||
{Icon && <Icon size={14} className="infio-popover-item-icon" />}
|
||||
<span className="text">{option.name}</span>
|
||||
</li>
|
||||
)
|
||||
}
|
||||
|
||||
export default function NewMentionsPlugin({
|
||||
searchResultByQuery,
|
||||
}: {
|
||||
searchResultByQuery: (query: string) => SearchableMentionable[]
|
||||
}): JSX.Element | null {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
const [queryString, setQueryString] = useState<string | null>(null)
|
||||
|
||||
const results = useMemo(() => {
|
||||
if (queryString == null) return []
|
||||
return searchResultByQuery(queryString)
|
||||
}, [queryString, searchResultByQuery])
|
||||
|
||||
const checkForSlashTriggerMatch = useBasicTypeaheadTriggerMatch('/', {
|
||||
minLength: 0,
|
||||
})
|
||||
|
||||
const options = useMemo(
|
||||
() =>
|
||||
results
|
||||
.map((result) => new MentionTypeaheadOption(result))
|
||||
.slice(0, SUGGESTION_LIST_LENGTH_LIMIT),
|
||||
[results],
|
||||
)
|
||||
|
||||
const onSelectOption = useCallback(
|
||||
(
|
||||
selectedOption: MentionTypeaheadOption,
|
||||
nodeToReplace: TextNode | null,
|
||||
closeMenu: () => void,
|
||||
) => {
|
||||
editor.update(() => {
|
||||
const mentionNode = $createMentionNode(
|
||||
getMentionableName(selectedOption.mentionable),
|
||||
serializeMentionable(selectedOption.mentionable),
|
||||
)
|
||||
if (nodeToReplace) {
|
||||
nodeToReplace.replace(mentionNode)
|
||||
}
|
||||
|
||||
const spaceNode = $createTextNode(' ')
|
||||
mentionNode.insertAfter(spaceNode)
|
||||
|
||||
spaceNode.select()
|
||||
closeMenu()
|
||||
})
|
||||
},
|
||||
[editor],
|
||||
)
|
||||
|
||||
const checkForMentionMatch = useCallback(
|
||||
(text: string) => {
|
||||
const slashMatch = checkForSlashTriggerMatch(text, editor)
|
||||
|
||||
if (slashMatch !== null) {
|
||||
return null
|
||||
}
|
||||
return getPossibleQueryMatch(text)
|
||||
},
|
||||
[checkForSlashTriggerMatch, editor],
|
||||
)
|
||||
|
||||
return (
|
||||
<LexicalTypeaheadMenuPlugin<MentionTypeaheadOption>
|
||||
onQueryChange={setQueryString}
|
||||
onSelectOption={onSelectOption}
|
||||
triggerFn={checkForMentionMatch}
|
||||
options={options}
|
||||
commandPriority={COMMAND_PRIORITY_NORMAL}
|
||||
menuRenderFn={(
|
||||
anchorElementRef,
|
||||
{ selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
|
||||
) =>
|
||||
anchorElementRef.current && results.length
|
||||
? createPortal(
|
||||
<div
|
||||
className="infio-popover"
|
||||
style={{
|
||||
position: 'fixed',
|
||||
}}
|
||||
>
|
||||
<ul>
|
||||
{options.map((option, i: number) => (
|
||||
<MentionsTypeaheadMenuItem
|
||||
index={i}
|
||||
isSelected={selectedIndex === i}
|
||||
onClick={() => {
|
||||
setHighlightedIndex(i)
|
||||
selectOptionAndCleanUp(option)
|
||||
}}
|
||||
onMouseEnter={() => {
|
||||
setHighlightedIndex(i)
|
||||
}}
|
||||
key={option.key}
|
||||
option={option}
|
||||
/>
|
||||
))}
|
||||
</ul>
|
||||
</div>,
|
||||
anchorElementRef.current,
|
||||
)
|
||||
: null
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { TextNode } from 'lexical'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
export default function NoFormatPlugin() {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
useEffect(() => {
|
||||
editor.registerNodeTransform(TextNode, (node) => {
|
||||
if (node.getFormat() !== 0) {
|
||||
node.setFormat(0)
|
||||
}
|
||||
})
|
||||
}, [editor])
|
||||
|
||||
return null
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { COMMAND_PRIORITY_LOW, KEY_ENTER_COMMAND } from 'lexical'
|
||||
import { Platform } from 'obsidian'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
export default function OnEnterPlugin({
|
||||
onEnter,
|
||||
onVaultChat,
|
||||
}: {
|
||||
onEnter: (evt: KeyboardEvent) => void
|
||||
onVaultChat?: () => void
|
||||
}) {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
useEffect(() => {
|
||||
const removeListener = editor.registerCommand(
|
||||
KEY_ENTER_COMMAND,
|
||||
(evt: KeyboardEvent) => {
|
||||
console.log('onEnter', evt)
|
||||
if (
|
||||
onVaultChat &&
|
||||
(Platform.isMacOS ? evt.metaKey : evt.ctrlKey)
|
||||
) {
|
||||
evt.preventDefault()
|
||||
evt.stopPropagation()
|
||||
onVaultChat()
|
||||
return true
|
||||
}
|
||||
if (evt.shiftKey) {
|
||||
return false
|
||||
}
|
||||
evt.preventDefault()
|
||||
evt.stopPropagation()
|
||||
onEnter(evt)
|
||||
return true
|
||||
},
|
||||
COMMAND_PRIORITY_LOW,
|
||||
)
|
||||
|
||||
return () => {
|
||||
removeListener()
|
||||
}
|
||||
}, [editor, onEnter, onVaultChat])
|
||||
|
||||
return null
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { Klass, LexicalNode, NodeKey, NodeMutation } from 'lexical'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
export type NodeMutations<T> = Map<NodeKey, { mutation: NodeMutation; node: T }>
|
||||
|
||||
export default function OnMutationPlugin<T extends LexicalNode>({
|
||||
nodeClass,
|
||||
onMutation,
|
||||
}: {
|
||||
nodeClass: Klass<T>
|
||||
onMutation: (mutations: NodeMutations<T>) => void
|
||||
}) {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
useEffect(() => {
|
||||
const removeListener = editor.registerMutationListener(
|
||||
nodeClass,
|
||||
(mutatedNodes, payload) => {
|
||||
const editorState = editor.getEditorState()
|
||||
|
||||
const mutations = new Map<
|
||||
NodeKey,
|
||||
{ mutation: NodeMutation; node: T }
|
||||
>()
|
||||
for (const [key, mutation] of mutatedNodes) {
|
||||
mutations.set(key, {
|
||||
mutation,
|
||||
node:
|
||||
mutation === 'destroyed'
|
||||
? (payload.prevEditorState._nodeMap.get(key) as T)
|
||||
: (editorState._nodeMap.get(key) as T),
|
||||
})
|
||||
}
|
||||
|
||||
onMutation(mutations)
|
||||
},
|
||||
)
|
||||
|
||||
return () => {
|
||||
removeListener()
|
||||
}
|
||||
}, [editor, nodeClass, onMutation])
|
||||
|
||||
return null
|
||||
}
|
||||
@@ -0,0 +1,597 @@
|
||||
/**
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license.
|
||||
* Original source: https://github.com/facebook/lexical
|
||||
*
|
||||
* Modified from the original code
|
||||
* - Added custom positioning logic for menu placement
|
||||
*/
|
||||
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { mergeRegister } from '@lexical/utils'
|
||||
import {
|
||||
$getSelection,
|
||||
$isRangeSelection,
|
||||
COMMAND_PRIORITY_LOW,
|
||||
CommandListenerPriority,
|
||||
KEY_ARROW_DOWN_COMMAND,
|
||||
KEY_ARROW_UP_COMMAND,
|
||||
KEY_ENTER_COMMAND,
|
||||
KEY_ESCAPE_COMMAND,
|
||||
KEY_TAB_COMMAND,
|
||||
LexicalCommand,
|
||||
LexicalEditor,
|
||||
TextNode,
|
||||
createCommand,
|
||||
} from 'lexical'
|
||||
import {
|
||||
MutableRefObject,
|
||||
ReactPortal,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
|
||||
export type MenuTextMatch = {
|
||||
leadOffset: number
|
||||
matchingString: string
|
||||
replaceableString: string
|
||||
}
|
||||
|
||||
export type MenuResolution = {
|
||||
match?: MenuTextMatch
|
||||
getRect: () => DOMRect
|
||||
}
|
||||
|
||||
export const PUNCTUATION =
|
||||
'\\.,\\+\\*\\?\\$\\@\\|#{}\\(\\)\\^\\-\\[\\]\\\\/!%\'"~=<>_:;'
|
||||
|
||||
export class MenuOption {
|
||||
key: string
|
||||
ref?: MutableRefObject<HTMLElement | null>
|
||||
|
||||
constructor(key: string) {
|
||||
this.key = key
|
||||
this.ref = { current: null }
|
||||
this.setRefElement = this.setRefElement.bind(this)
|
||||
}
|
||||
|
||||
setRefElement(element: HTMLElement | null) {
|
||||
this.ref = { current: element }
|
||||
}
|
||||
}
|
||||
|
||||
export type MenuRenderFn<TOption extends MenuOption> = (
|
||||
anchorElementRef: MutableRefObject<HTMLElement | null>,
|
||||
itemProps: {
|
||||
selectedIndex: number | null
|
||||
selectOptionAndCleanUp: (option: TOption) => void
|
||||
setHighlightedIndex: (index: number) => void
|
||||
options: TOption[]
|
||||
},
|
||||
matchingString: string | null,
|
||||
) => ReactPortal | JSX.Element | null
|
||||
|
||||
const scrollIntoViewIfNeeded = (target: HTMLElement) => {
|
||||
const typeaheadContainerNode = document.getElementById('typeahead-menu')
|
||||
if (!typeaheadContainerNode) {
|
||||
return
|
||||
}
|
||||
|
||||
const typeaheadRect = typeaheadContainerNode.getBoundingClientRect()
|
||||
|
||||
if (typeaheadRect.top + typeaheadRect.height > window.innerHeight) {
|
||||
typeaheadContainerNode.scrollIntoView({
|
||||
block: 'center',
|
||||
})
|
||||
}
|
||||
|
||||
if (typeaheadRect.top < 0) {
|
||||
typeaheadContainerNode.scrollIntoView({
|
||||
block: 'center',
|
||||
})
|
||||
}
|
||||
|
||||
target.scrollIntoView({ block: 'nearest' })
|
||||
}
|
||||
|
||||
/**
|
||||
* Walk backwards along user input and forward through entity title to try
|
||||
* and replace more of the user's text with entity.
|
||||
*/
|
||||
function getFullMatchOffset(
|
||||
documentText: string,
|
||||
entryText: string,
|
||||
offset: number,
|
||||
): number {
|
||||
let triggerOffset = offset
|
||||
for (let i = triggerOffset; i <= entryText.length; i++) {
|
||||
if (documentText.substr(-i) === entryText.substr(0, i)) {
|
||||
triggerOffset = i
|
||||
}
|
||||
}
|
||||
return triggerOffset
|
||||
}
|
||||
|
||||
/**
|
||||
* Split Lexical TextNode and return a new TextNode only containing matched text.
|
||||
* Common use cases include: removing the node, replacing with a new node.
|
||||
*/
|
||||
function $splitNodeContainingQuery(match: MenuTextMatch): TextNode | null {
|
||||
const selection = $getSelection()
|
||||
if (!$isRangeSelection(selection) || !selection.isCollapsed()) {
|
||||
return null
|
||||
}
|
||||
const anchor = selection.anchor
|
||||
if (anchor.type !== 'text') {
|
||||
return null
|
||||
}
|
||||
const anchorNode = anchor.getNode()
|
||||
if (!anchorNode.isSimpleText()) {
|
||||
return null
|
||||
}
|
||||
const selectionOffset = anchor.offset
|
||||
const textContent = anchorNode.getTextContent().slice(0, selectionOffset)
|
||||
const characterOffset = match.replaceableString.length
|
||||
const queryOffset = getFullMatchOffset(
|
||||
textContent,
|
||||
match.matchingString,
|
||||
characterOffset,
|
||||
)
|
||||
const startOffset = selectionOffset - queryOffset
|
||||
if (startOffset < 0) {
|
||||
return null
|
||||
}
|
||||
let newNode
|
||||
if (startOffset === 0) {
|
||||
;[newNode] = anchorNode.splitText(selectionOffset)
|
||||
} else {
|
||||
;[, newNode] = anchorNode.splitText(startOffset, selectionOffset)
|
||||
}
|
||||
|
||||
return newNode
|
||||
}
|
||||
|
||||
// Got from https://stackoverflow.com/a/42543908/2013580
|
||||
export function getScrollParent(
|
||||
element: HTMLElement,
|
||||
includeHidden: boolean,
|
||||
): HTMLElement | HTMLBodyElement {
|
||||
let style = getComputedStyle(element)
|
||||
const excludeStaticParent = style.position === 'absolute'
|
||||
const overflowRegex = includeHidden ? /(auto|scroll|hidden)/ : /(auto|scroll)/
|
||||
if (style.position === 'fixed') {
|
||||
return document.body
|
||||
}
|
||||
for (
|
||||
let parent: HTMLElement | null = element;
|
||||
(parent = parent.parentElement);
|
||||
|
||||
) {
|
||||
style = getComputedStyle(parent)
|
||||
if (excludeStaticParent && style.position === 'static') {
|
||||
continue
|
||||
}
|
||||
if (
|
||||
overflowRegex.test(style.overflow + style.overflowY + style.overflowX)
|
||||
) {
|
||||
return parent
|
||||
}
|
||||
}
|
||||
return document.body
|
||||
}
|
||||
|
||||
function isTriggerVisibleInNearestScrollContainer(
|
||||
targetElement: HTMLElement,
|
||||
containerElement: HTMLElement,
|
||||
): boolean {
|
||||
const tRect = targetElement.getBoundingClientRect()
|
||||
const cRect = containerElement.getBoundingClientRect()
|
||||
return tRect.top > cRect.top && tRect.top < cRect.bottom
|
||||
}
|
||||
|
||||
// Reposition the menu on scroll, window resize, and element resize.
|
||||
export function useDynamicPositioning(
|
||||
resolution: MenuResolution | null,
|
||||
targetElement: HTMLElement | null,
|
||||
onReposition: () => void,
|
||||
onVisibilityChange?: (isInView: boolean) => void,
|
||||
) {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
useEffect(() => {
|
||||
if (targetElement != null && resolution != null) {
|
||||
const rootElement = editor.getRootElement()
|
||||
const rootScrollParent =
|
||||
rootElement != null
|
||||
? getScrollParent(rootElement, false)
|
||||
: document.body
|
||||
let ticking = false
|
||||
let previousIsInView = isTriggerVisibleInNearestScrollContainer(
|
||||
targetElement,
|
||||
rootScrollParent,
|
||||
)
|
||||
const handleScroll = function () {
|
||||
if (!ticking) {
|
||||
window.requestAnimationFrame(function () {
|
||||
onReposition()
|
||||
ticking = false
|
||||
})
|
||||
ticking = true
|
||||
}
|
||||
const isInView = isTriggerVisibleInNearestScrollContainer(
|
||||
targetElement,
|
||||
rootScrollParent,
|
||||
)
|
||||
if (isInView !== previousIsInView) {
|
||||
previousIsInView = isInView
|
||||
if (onVisibilityChange != null) {
|
||||
onVisibilityChange(isInView)
|
||||
}
|
||||
}
|
||||
}
|
||||
const resizeObserver = new ResizeObserver(onReposition)
|
||||
window.addEventListener('resize', onReposition)
|
||||
document.addEventListener('scroll', handleScroll, {
|
||||
capture: true,
|
||||
passive: true,
|
||||
})
|
||||
resizeObserver.observe(targetElement)
|
||||
return () => {
|
||||
resizeObserver.unobserve(targetElement)
|
||||
window.removeEventListener('resize', onReposition)
|
||||
document.removeEventListener('scroll', handleScroll, true)
|
||||
}
|
||||
}
|
||||
}, [targetElement, editor, onVisibilityChange, onReposition, resolution])
|
||||
}
|
||||
|
||||
export const SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND: LexicalCommand<{
|
||||
index: number
|
||||
option: MenuOption
|
||||
}> = createCommand('SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND')
|
||||
|
||||
export function LexicalMenu<TOption extends MenuOption>({
|
||||
close,
|
||||
editor,
|
||||
anchorElementRef,
|
||||
resolution,
|
||||
options,
|
||||
menuRenderFn,
|
||||
onSelectOption,
|
||||
shouldSplitNodeWithQuery = false,
|
||||
commandPriority = COMMAND_PRIORITY_LOW,
|
||||
}: {
|
||||
close: () => void
|
||||
editor: LexicalEditor
|
||||
anchorElementRef: MutableRefObject<HTMLElement>
|
||||
resolution: MenuResolution
|
||||
options: TOption[]
|
||||
shouldSplitNodeWithQuery?: boolean
|
||||
menuRenderFn: MenuRenderFn<TOption>
|
||||
onSelectOption: (
|
||||
option: TOption,
|
||||
textNodeContainingQuery: TextNode | null,
|
||||
closeMenu: () => void,
|
||||
matchingString: string,
|
||||
) => void
|
||||
commandPriority?: CommandListenerPriority
|
||||
}): JSX.Element | null {
|
||||
const [selectedIndex, setHighlightedIndex] = useState<null | number>(null)
|
||||
|
||||
const matchingString = resolution.match?.matchingString
|
||||
|
||||
useEffect(() => {
|
||||
setHighlightedIndex(0)
|
||||
}, [matchingString])
|
||||
|
||||
const selectOptionAndCleanUp = useCallback(
|
||||
(selectedEntry: TOption) => {
|
||||
editor.update(() => {
|
||||
const textNodeContainingQuery =
|
||||
resolution.match != null && shouldSplitNodeWithQuery
|
||||
? $splitNodeContainingQuery(resolution.match)
|
||||
: null
|
||||
|
||||
onSelectOption(
|
||||
selectedEntry,
|
||||
textNodeContainingQuery,
|
||||
close,
|
||||
resolution.match ? resolution.match.matchingString : '',
|
||||
)
|
||||
})
|
||||
},
|
||||
[editor, shouldSplitNodeWithQuery, resolution.match, onSelectOption, close],
|
||||
)
|
||||
|
||||
const updateSelectedIndex = useCallback(
|
||||
(index: number) => {
|
||||
const rootElem = editor.getRootElement()
|
||||
if (rootElem !== null) {
|
||||
rootElem.setAttribute(
|
||||
'aria-activedescendant',
|
||||
`typeahead-item-${index}`,
|
||||
)
|
||||
setHighlightedIndex(index)
|
||||
}
|
||||
},
|
||||
[editor],
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
const rootElem = editor.getRootElement()
|
||||
if (rootElem !== null) {
|
||||
rootElem.removeAttribute('aria-activedescendant')
|
||||
}
|
||||
}
|
||||
}, [editor])
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (options === null) {
|
||||
setHighlightedIndex(null)
|
||||
} else if (selectedIndex === null) {
|
||||
updateSelectedIndex(0)
|
||||
}
|
||||
}, [options, selectedIndex, updateSelectedIndex])
|
||||
|
||||
useEffect(() => {
|
||||
return mergeRegister(
|
||||
editor.registerCommand(
|
||||
SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND,
|
||||
({ option }) => {
|
||||
if (option.ref?.current != null) {
|
||||
scrollIntoViewIfNeeded(option.ref.current)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
},
|
||||
commandPriority,
|
||||
),
|
||||
)
|
||||
}, [editor, updateSelectedIndex, commandPriority])
|
||||
|
||||
useEffect(() => {
|
||||
return mergeRegister(
|
||||
editor.registerCommand<KeyboardEvent>(
|
||||
KEY_ARROW_DOWN_COMMAND,
|
||||
(payload) => {
|
||||
const event = payload
|
||||
if (options?.length && selectedIndex !== null) {
|
||||
const newSelectedIndex =
|
||||
selectedIndex !== options.length - 1 ? selectedIndex + 1 : 0
|
||||
updateSelectedIndex(newSelectedIndex)
|
||||
const option = options[newSelectedIndex]
|
||||
if (option.ref?.current != null) {
|
||||
editor.dispatchCommand(
|
||||
SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND,
|
||||
{
|
||||
index: newSelectedIndex,
|
||||
option,
|
||||
},
|
||||
)
|
||||
}
|
||||
event.preventDefault()
|
||||
event.stopImmediatePropagation()
|
||||
}
|
||||
return true
|
||||
},
|
||||
commandPriority,
|
||||
),
|
||||
editor.registerCommand<KeyboardEvent>(
|
||||
KEY_ARROW_UP_COMMAND,
|
||||
(payload) => {
|
||||
const event = payload
|
||||
if (options?.length && selectedIndex !== null) {
|
||||
const newSelectedIndex =
|
||||
selectedIndex !== 0 ? selectedIndex - 1 : options.length - 1
|
||||
updateSelectedIndex(newSelectedIndex)
|
||||
const option = options[newSelectedIndex]
|
||||
if (option.ref?.current != null) {
|
||||
scrollIntoViewIfNeeded(option.ref.current)
|
||||
}
|
||||
event.preventDefault()
|
||||
event.stopImmediatePropagation()
|
||||
}
|
||||
return true
|
||||
},
|
||||
commandPriority,
|
||||
),
|
||||
editor.registerCommand<KeyboardEvent>(
|
||||
KEY_ESCAPE_COMMAND,
|
||||
(payload) => {
|
||||
const event = payload
|
||||
event.preventDefault()
|
||||
event.stopImmediatePropagation()
|
||||
close()
|
||||
return true
|
||||
},
|
||||
commandPriority,
|
||||
),
|
||||
editor.registerCommand<KeyboardEvent>(
|
||||
KEY_TAB_COMMAND,
|
||||
(payload) => {
|
||||
const event = payload
|
||||
if (
|
||||
options === null ||
|
||||
selectedIndex === null ||
|
||||
options[selectedIndex] == null
|
||||
) {
|
||||
return false
|
||||
}
|
||||
event.preventDefault()
|
||||
event.stopImmediatePropagation()
|
||||
selectOptionAndCleanUp(options[selectedIndex])
|
||||
return true
|
||||
},
|
||||
commandPriority,
|
||||
),
|
||||
editor.registerCommand(
|
||||
KEY_ENTER_COMMAND,
|
||||
(event: KeyboardEvent | null) => {
|
||||
if (
|
||||
options === null ||
|
||||
selectedIndex === null ||
|
||||
options[selectedIndex] == null
|
||||
) {
|
||||
return false
|
||||
}
|
||||
if (event !== null) {
|
||||
event.preventDefault()
|
||||
event.stopImmediatePropagation()
|
||||
}
|
||||
selectOptionAndCleanUp(options[selectedIndex])
|
||||
return true
|
||||
},
|
||||
commandPriority,
|
||||
),
|
||||
)
|
||||
}, [
|
||||
selectOptionAndCleanUp,
|
||||
close,
|
||||
editor,
|
||||
options,
|
||||
selectedIndex,
|
||||
updateSelectedIndex,
|
||||
commandPriority,
|
||||
])
|
||||
|
||||
const listItemProps = useMemo(
|
||||
() => ({
|
||||
options,
|
||||
selectOptionAndCleanUp,
|
||||
selectedIndex,
|
||||
setHighlightedIndex,
|
||||
}),
|
||||
[selectOptionAndCleanUp, selectedIndex, options],
|
||||
)
|
||||
|
||||
return menuRenderFn(
|
||||
anchorElementRef,
|
||||
listItemProps,
|
||||
resolution.match ? resolution.match.matchingString : '',
|
||||
)
|
||||
}
|
||||
|
||||
export function useMenuAnchorRef(
|
||||
resolution: MenuResolution | null,
|
||||
setResolution: (r: MenuResolution | null) => void,
|
||||
className?: string,
|
||||
parent: HTMLElement = document.body,
|
||||
shouldIncludePageYOffset__EXPERIMENTAL = true,
|
||||
): MutableRefObject<HTMLElement> {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
const anchorElementRef = useRef<HTMLElement>(document.createElement('div'))
|
||||
const positionMenu = useCallback(() => {
|
||||
anchorElementRef.current.style.top = anchorElementRef.current.style.bottom
|
||||
const rootElement = editor.getRootElement()
|
||||
const containerDiv = anchorElementRef.current
|
||||
|
||||
const menuEle = containerDiv.firstChild as HTMLElement
|
||||
if (rootElement !== null && resolution !== null) {
|
||||
const { left, top, width, height } = resolution.getRect()
|
||||
const anchorHeight = anchorElementRef.current.offsetHeight // use to position under anchor
|
||||
containerDiv.style.top = `${top +
|
||||
anchorHeight +
|
||||
3 +
|
||||
(shouldIncludePageYOffset__EXPERIMENTAL ? window.pageYOffset : 0)
|
||||
}px`
|
||||
containerDiv.style.left = `${left + window.pageXOffset}px`
|
||||
containerDiv.style.height = `${height}px`
|
||||
containerDiv.style.width = `${width}px`
|
||||
if (menuEle !== null) {
|
||||
menuEle.style.top = `${top}`
|
||||
const menuRect = menuEle.getBoundingClientRect()
|
||||
const menuHeight = menuRect.height
|
||||
const menuWidth = menuRect.width
|
||||
|
||||
const rootElementRect = rootElement.getBoundingClientRect()
|
||||
|
||||
if (left + menuWidth > rootElementRect.right) {
|
||||
containerDiv.style.left = `${rootElementRect.right - menuWidth + window.pageXOffset
|
||||
}px`
|
||||
}
|
||||
if (
|
||||
// If it exceeds the window height, it should always be displayed above, but the original code checks if it doesn't exceed the editor's top as well. So I modified it.
|
||||
// (top + menuHeight > window.innerHeight ||
|
||||
// top + menuHeight > rootElementRect.bottom) &&
|
||||
// top - rootElementRect.top > menuHeight + height
|
||||
top + menuHeight >
|
||||
window.innerHeight
|
||||
) {
|
||||
containerDiv.style.top = `${top -
|
||||
menuHeight -
|
||||
height +
|
||||
(shouldIncludePageYOffset__EXPERIMENTAL ? window.pageYOffset : 0)
|
||||
}px`
|
||||
}
|
||||
}
|
||||
|
||||
if (!containerDiv.isConnected) {
|
||||
if (className != null) {
|
||||
containerDiv.className = className
|
||||
}
|
||||
containerDiv.setAttribute('aria-label', 'Typeahead menu')
|
||||
containerDiv.setAttribute('id', 'typeahead-menu')
|
||||
containerDiv.setAttribute('role', 'listbox')
|
||||
containerDiv.style.display = 'block'
|
||||
containerDiv.style.position = 'absolute'
|
||||
parent.append(containerDiv)
|
||||
}
|
||||
anchorElementRef.current = containerDiv
|
||||
rootElement.setAttribute('aria-controls', 'typeahead-menu')
|
||||
}
|
||||
}, [
|
||||
editor,
|
||||
resolution,
|
||||
shouldIncludePageYOffset__EXPERIMENTAL,
|
||||
className,
|
||||
parent,
|
||||
])
|
||||
|
||||
useEffect(() => {
|
||||
const rootElement = editor.getRootElement()
|
||||
if (resolution !== null) {
|
||||
positionMenu()
|
||||
return () => {
|
||||
if (rootElement !== null) {
|
||||
rootElement.removeAttribute('aria-controls')
|
||||
}
|
||||
|
||||
const containerDiv = anchorElementRef.current
|
||||
if (containerDiv?.isConnected) {
|
||||
containerDiv.remove()
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [editor, positionMenu, resolution])
|
||||
|
||||
const onVisibilityChange = useCallback(
|
||||
(isInView: boolean) => {
|
||||
if (resolution !== null) {
|
||||
if (!isInView) {
|
||||
setResolution(null)
|
||||
}
|
||||
}
|
||||
},
|
||||
[resolution, setResolution],
|
||||
)
|
||||
|
||||
useDynamicPositioning(
|
||||
resolution,
|
||||
anchorElementRef.current,
|
||||
positionMenu,
|
||||
onVisibilityChange,
|
||||
)
|
||||
|
||||
return anchorElementRef
|
||||
}
|
||||
|
||||
export type TriggerFn = (
|
||||
text: string,
|
||||
editor: LexicalEditor,
|
||||
) => MenuTextMatch | null
|
||||
@@ -0,0 +1,146 @@
|
||||
import { $generateJSONFromSelectedNodes } from '@lexical/clipboard'
|
||||
import { BaseSerializedNode } from '@lexical/clipboard/clipboard'
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import * as Dialog from '@radix-ui/react-dialog'
|
||||
import {
|
||||
$getSelection,
|
||||
COMMAND_PRIORITY_LOW,
|
||||
SELECTION_CHANGE_COMMAND,
|
||||
} from 'lexical'
|
||||
import { CSSProperties, useCallback, useEffect, useRef, useState } from 'react'
|
||||
|
||||
import CreateTemplateDialogContent from '../../../CreateTemplateDialog'
|
||||
|
||||
export default function CreateTemplatePopoverPlugin({
|
||||
anchorElement,
|
||||
contentEditableElement,
|
||||
}: {
|
||||
anchorElement: HTMLElement | null
|
||||
contentEditableElement: HTMLElement | null
|
||||
}): JSX.Element | null {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
const [popoverStyle, setPopoverStyle] = useState<CSSProperties | null>(null)
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false)
|
||||
const [isDialogOpen, setIsDialogOpen] = useState(false)
|
||||
const [selectedSerializedNodes, setSelectedSerializedNodes] = useState<
|
||||
BaseSerializedNode[] | null
|
||||
>(null)
|
||||
|
||||
const popoverRef = useRef<HTMLButtonElement>(null)
|
||||
|
||||
const getSelectedSerializedNodes = useCallback(():
|
||||
| BaseSerializedNode[]
|
||||
| null => {
|
||||
if (!editor) return null
|
||||
let selectedNodes: BaseSerializedNode[] | null = null
|
||||
editor.update(() => {
|
||||
const selection = $getSelection()
|
||||
if (!selection) return
|
||||
selectedNodes = $generateJSONFromSelectedNodes(editor, selection).nodes
|
||||
if (selectedNodes.length === 0) return null
|
||||
})
|
||||
return selectedNodes
|
||||
}, [editor])
|
||||
|
||||
const updatePopoverPosition = useCallback(() => {
|
||||
if (!anchorElement || !contentEditableElement) return
|
||||
const nativeSelection = document.getSelection()
|
||||
const range = nativeSelection?.getRangeAt(0)
|
||||
if (!range || range.collapsed) {
|
||||
setIsPopoverOpen(false)
|
||||
return
|
||||
}
|
||||
if (!contentEditableElement.contains(range.commonAncestorContainer)) {
|
||||
setIsPopoverOpen(false)
|
||||
return
|
||||
}
|
||||
const rects = Array.from(range.getClientRects())
|
||||
if (rects.length === 0) {
|
||||
setIsPopoverOpen(false)
|
||||
return
|
||||
}
|
||||
const anchorRect = anchorElement.getBoundingClientRect()
|
||||
const idealLeft = rects[rects.length - 1].right - anchorRect.left
|
||||
const paddingX = 8
|
||||
const paddingY = 4
|
||||
const minLeft = (popoverRef.current?.offsetWidth ?? 0) + paddingX
|
||||
const finalLeft = Math.max(minLeft, idealLeft)
|
||||
setPopoverStyle({
|
||||
top: rects[rects.length - 1].bottom - anchorRect.top + paddingY,
|
||||
left: finalLeft,
|
||||
transform: 'translate(-100%, 0)',
|
||||
})
|
||||
setIsPopoverOpen(true)
|
||||
}, [anchorElement, contentEditableElement])
|
||||
|
||||
useEffect(() => {
|
||||
const removeSelectionChangeListener = editor.registerCommand(
|
||||
SELECTION_CHANGE_COMMAND,
|
||||
() => {
|
||||
updatePopoverPosition()
|
||||
return false
|
||||
},
|
||||
COMMAND_PRIORITY_LOW,
|
||||
)
|
||||
return () => {
|
||||
removeSelectionChangeListener()
|
||||
}
|
||||
}, [editor, updatePopoverPosition])
|
||||
|
||||
useEffect(() => {
|
||||
// Update popover position when the content is cleared
|
||||
// (Selection change event doesn't fire in this case)
|
||||
if (!isPopoverOpen) return
|
||||
const removeTextContentChangeListener = editor.registerTextContentListener(
|
||||
() => {
|
||||
updatePopoverPosition()
|
||||
},
|
||||
)
|
||||
return () => {
|
||||
removeTextContentChangeListener()
|
||||
}
|
||||
}, [editor, isPopoverOpen, updatePopoverPosition])
|
||||
|
||||
useEffect(() => {
|
||||
if (!contentEditableElement) return
|
||||
const handleScroll = () => {
|
||||
updatePopoverPosition()
|
||||
}
|
||||
contentEditableElement.addEventListener('scroll', handleScroll)
|
||||
return () => {
|
||||
contentEditableElement.removeEventListener('scroll', handleScroll)
|
||||
}
|
||||
}, [contentEditableElement, updatePopoverPosition])
|
||||
|
||||
return (
|
||||
<Dialog.Root
|
||||
modal={false}
|
||||
open={isDialogOpen}
|
||||
onOpenChange={(open) => {
|
||||
if (open) {
|
||||
setSelectedSerializedNodes(getSelectedSerializedNodes())
|
||||
}
|
||||
setIsDialogOpen(open)
|
||||
setIsPopoverOpen(false)
|
||||
}}
|
||||
>
|
||||
<Dialog.Trigger asChild>
|
||||
<button
|
||||
ref={popoverRef}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
visibility: isPopoverOpen ? 'visible' : 'hidden',
|
||||
...popoverStyle,
|
||||
}}
|
||||
>
|
||||
Create template
|
||||
</button>
|
||||
</Dialog.Trigger>
|
||||
<CreateTemplateDialogContent
|
||||
selectedSerializedNodes={selectedSerializedNodes}
|
||||
onClose={() => setIsDialogOpen(false)}
|
||||
/>
|
||||
</Dialog.Root>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import clsx from 'clsx'
|
||||
import {
|
||||
$parseSerializedNode,
|
||||
COMMAND_PRIORITY_NORMAL,
|
||||
TextNode,
|
||||
} from 'lexical'
|
||||
import { Trash2 } from 'lucide-react'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { createPortal } from 'react-dom'
|
||||
|
||||
import { useDatabase } from '../../../../../contexts/DatabaseContext'
|
||||
import { SelectTemplate } from '../../../../../database/schema'
|
||||
import { MenuOption } from '../shared/LexicalMenu'
|
||||
import {
|
||||
LexicalTypeaheadMenuPlugin,
|
||||
useBasicTypeaheadTriggerMatch,
|
||||
} from '../typeahead-menu/LexicalTypeaheadMenuPlugin'
|
||||
|
||||
class TemplateTypeaheadOption extends MenuOption {
|
||||
name: string
|
||||
template: SelectTemplate
|
||||
|
||||
constructor(name: string, template: SelectTemplate) {
|
||||
super(name)
|
||||
this.name = name
|
||||
this.template = template
|
||||
}
|
||||
}
|
||||
|
||||
function TemplateMenuItem({
|
||||
index,
|
||||
isSelected,
|
||||
onClick,
|
||||
onDelete,
|
||||
onMouseEnter,
|
||||
option,
|
||||
}: {
|
||||
index: number
|
||||
isSelected: boolean
|
||||
onClick: () => void
|
||||
onDelete: () => void
|
||||
onMouseEnter: () => void
|
||||
option: TemplateTypeaheadOption
|
||||
}) {
|
||||
return (
|
||||
<li
|
||||
key={option.key}
|
||||
tabIndex={-1}
|
||||
className={clsx('item', isSelected && 'selected')}
|
||||
ref={(el) => option.setRefElement(el)}
|
||||
role="option"
|
||||
aria-selected={isSelected}
|
||||
id={`typeahead-item-${index}`}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onClick={onClick}
|
||||
>
|
||||
<div className="infio-chat-template-menu-item">
|
||||
<div className="text">{option.name}</div>
|
||||
<div
|
||||
onClick={(evt) => {
|
||||
evt.stopPropagation()
|
||||
evt.preventDefault()
|
||||
onDelete()
|
||||
}}
|
||||
className="infio-chat-template-menu-item-delete"
|
||||
>
|
||||
<Trash2 size={12} />
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
)
|
||||
}
|
||||
|
||||
export default function TemplatePlugin() {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
const { getTemplateManager } = useDatabase()
|
||||
|
||||
const [queryString, setQueryString] = useState<string | null>(null)
|
||||
const [searchResults, setSearchResults] = useState<SelectTemplate[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
if (queryString == null) return
|
||||
getTemplateManager().then((templateManager) =>
|
||||
templateManager.searchTemplates(queryString).then(setSearchResults),
|
||||
)
|
||||
}, [queryString, getTemplateManager])
|
||||
|
||||
const options = useMemo(
|
||||
() =>
|
||||
searchResults.map(
|
||||
(result) => new TemplateTypeaheadOption(result.name, result),
|
||||
),
|
||||
[searchResults],
|
||||
)
|
||||
|
||||
const checkForTriggerMatch = useBasicTypeaheadTriggerMatch('/', {
|
||||
minLength: 0,
|
||||
})
|
||||
|
||||
const onSelectOption = useCallback(
|
||||
(
|
||||
selectedOption: TemplateTypeaheadOption,
|
||||
nodeToRemove: TextNode | null,
|
||||
closeMenu: () => void,
|
||||
) => {
|
||||
editor.update(() => {
|
||||
const parsedNodes = selectedOption.template.content.nodes.map((node) =>
|
||||
$parseSerializedNode(node),
|
||||
)
|
||||
if (nodeToRemove) {
|
||||
const parent = nodeToRemove.getParentOrThrow()
|
||||
parent.splice(nodeToRemove.getIndexWithinParent(), 1, parsedNodes)
|
||||
const lastNode = parsedNodes[parsedNodes.length - 1]
|
||||
lastNode.selectEnd()
|
||||
}
|
||||
closeMenu()
|
||||
})
|
||||
},
|
||||
[editor],
|
||||
)
|
||||
|
||||
const handleDelete = useCallback(
|
||||
async (option: TemplateTypeaheadOption) => {
|
||||
await (await getTemplateManager()).deleteTemplate(option.template.id)
|
||||
if (queryString !== null) {
|
||||
const updatedResults = await (
|
||||
await getTemplateManager()
|
||||
).searchTemplates(queryString)
|
||||
setSearchResults(updatedResults)
|
||||
}
|
||||
},
|
||||
[getTemplateManager, queryString],
|
||||
)
|
||||
|
||||
return (
|
||||
<LexicalTypeaheadMenuPlugin<TemplateTypeaheadOption>
|
||||
onQueryChange={setQueryString}
|
||||
onSelectOption={onSelectOption}
|
||||
triggerFn={checkForTriggerMatch}
|
||||
options={options}
|
||||
commandPriority={COMMAND_PRIORITY_NORMAL}
|
||||
menuRenderFn={(
|
||||
anchorElementRef,
|
||||
{ selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
|
||||
) =>
|
||||
anchorElementRef.current && searchResults.length
|
||||
? createPortal(
|
||||
<div
|
||||
className="infio-popover"
|
||||
style={{
|
||||
position: 'fixed',
|
||||
}}
|
||||
>
|
||||
<ul>
|
||||
{options.map((option, i: number) => (
|
||||
<TemplateMenuItem
|
||||
index={i}
|
||||
isSelected={selectedIndex === i}
|
||||
onClick={() => {
|
||||
setHighlightedIndex(i)
|
||||
selectOptionAndCleanUp(option)
|
||||
}}
|
||||
onDelete={() => {
|
||||
handleDelete(option)
|
||||
}}
|
||||
onMouseEnter={() => {
|
||||
setHighlightedIndex(i)
|
||||
}}
|
||||
key={option.key}
|
||||
option={option}
|
||||
/>
|
||||
))}
|
||||
</ul>
|
||||
</div>,
|
||||
anchorElementRef.current,
|
||||
)
|
||||
: null
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
/**
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license.
|
||||
* Original source: https://github.com/facebook/lexical
|
||||
*
|
||||
* Modified from the original code
|
||||
*/
|
||||
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import {
|
||||
$getSelection,
|
||||
$isRangeSelection,
|
||||
$isTextNode,
|
||||
COMMAND_PRIORITY_LOW,
|
||||
CommandListenerPriority,
|
||||
LexicalCommand,
|
||||
LexicalEditor,
|
||||
RangeSelection,
|
||||
TextNode,
|
||||
createCommand,
|
||||
} from 'lexical'
|
||||
import { startTransition, useCallback, useEffect, useState } from 'react'
|
||||
|
||||
import {
|
||||
LexicalMenu,
|
||||
MenuOption,
|
||||
MenuRenderFn,
|
||||
MenuResolution,
|
||||
TriggerFn,
|
||||
useMenuAnchorRef,
|
||||
} from '../shared/LexicalMenu'
|
||||
|
||||
export const PUNCTUATION =
|
||||
'\\.,\\+\\*\\?\\$\\@\\|#{}\\(\\)\\^\\-\\[\\]\\\\/!%\'"~=<>_:;'
|
||||
|
||||
function getTextUpToAnchor(selection: RangeSelection): string | null {
|
||||
const anchor = selection.anchor
|
||||
if (anchor.type !== 'text') {
|
||||
return null
|
||||
}
|
||||
const anchorNode = anchor.getNode()
|
||||
if (!anchorNode.isSimpleText()) {
|
||||
return null
|
||||
}
|
||||
const anchorOffset = anchor.offset
|
||||
return anchorNode.getTextContent().slice(0, anchorOffset)
|
||||
}
|
||||
|
||||
function tryToPositionRange(
|
||||
leadOffset: number,
|
||||
range: Range,
|
||||
editorWindow: Window,
|
||||
): boolean {
|
||||
const domSelection = editorWindow.getSelection()
|
||||
if (domSelection === null || !domSelection.isCollapsed) {
|
||||
return false
|
||||
}
|
||||
const anchorNode = domSelection.anchorNode
|
||||
const startOffset = leadOffset
|
||||
const endOffset = domSelection.anchorOffset
|
||||
|
||||
if (anchorNode == null || endOffset == null) {
|
||||
return false
|
||||
}
|
||||
|
||||
try {
|
||||
range.setStart(anchorNode, startOffset)
|
||||
range.setEnd(anchorNode, endOffset)
|
||||
} catch (error) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
function getQueryTextForSearch(editor: LexicalEditor): string | null {
|
||||
let text = null
|
||||
editor.getEditorState().read(() => {
|
||||
const selection = $getSelection()
|
||||
if (!$isRangeSelection(selection)) {
|
||||
return
|
||||
}
|
||||
text = getTextUpToAnchor(selection)
|
||||
})
|
||||
return text
|
||||
}
|
||||
|
||||
function isSelectionOnEntityBoundary(
|
||||
editor: LexicalEditor,
|
||||
offset: number,
|
||||
): boolean {
|
||||
if (offset !== 0) {
|
||||
return false
|
||||
}
|
||||
return editor.getEditorState().read(() => {
|
||||
const selection = $getSelection()
|
||||
if ($isRangeSelection(selection)) {
|
||||
const anchor = selection.anchor
|
||||
const anchorNode = anchor.getNode()
|
||||
const prevSibling = anchorNode.getPreviousSibling()
|
||||
return $isTextNode(prevSibling) && prevSibling.isTextEntity()
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
// Got from https://stackoverflow.com/a/42543908/2013580
|
||||
export function getScrollParent(
|
||||
element: HTMLElement,
|
||||
includeHidden: boolean,
|
||||
): HTMLElement | HTMLBodyElement {
|
||||
let style = getComputedStyle(element)
|
||||
const excludeStaticParent = style.position === 'absolute'
|
||||
const overflowRegex = includeHidden ? /(auto|scroll|hidden)/ : /(auto|scroll)/
|
||||
if (style.position === 'fixed') {
|
||||
return document.body
|
||||
}
|
||||
for (
|
||||
let parent: HTMLElement | null = element;
|
||||
(parent = parent.parentElement);
|
||||
|
||||
) {
|
||||
style = getComputedStyle(parent)
|
||||
if (excludeStaticParent && style.position === 'static') {
|
||||
continue
|
||||
}
|
||||
if (
|
||||
overflowRegex.test(style.overflow + style.overflowY + style.overflowX)
|
||||
) {
|
||||
return parent
|
||||
}
|
||||
}
|
||||
return document.body
|
||||
}
|
||||
|
||||
export const SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND: LexicalCommand<{
|
||||
index: number
|
||||
option: MenuOption
|
||||
}> = createCommand('SCROLL_TYPEAHEAD_OPTION_INTO_VIEW_COMMAND')
|
||||
|
||||
export function useBasicTypeaheadTriggerMatch(
|
||||
trigger: string,
|
||||
{ minLength = 1, maxLength = 75 }: { minLength?: number; maxLength?: number },
|
||||
): TriggerFn {
|
||||
return useCallback(
|
||||
(text: string) => {
|
||||
const validChars = '[^' + trigger + PUNCTUATION + '\\s]'
|
||||
const TypeaheadTriggerRegex = new RegExp(
|
||||
`(^|\\s|\\()([${trigger}]((?:${validChars}){0,${maxLength}}))$`,
|
||||
)
|
||||
const match = TypeaheadTriggerRegex.exec(text)
|
||||
if (match !== null) {
|
||||
const maybeLeadingWhitespace = match[1]
|
||||
const matchingString = match[3]
|
||||
if (matchingString.length >= minLength) {
|
||||
return {
|
||||
leadOffset: match.index + maybeLeadingWhitespace.length,
|
||||
matchingString,
|
||||
replaceableString: match[2],
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
},
|
||||
[maxLength, minLength, trigger],
|
||||
)
|
||||
}
|
||||
|
||||
export type TypeaheadMenuPluginProps<TOption extends MenuOption> = {
|
||||
onQueryChange: (matchingString: string | null) => void
|
||||
onSelectOption: (
|
||||
option: TOption,
|
||||
textNodeContainingQuery: TextNode | null,
|
||||
closeMenu: () => void,
|
||||
matchingString: string,
|
||||
) => void
|
||||
options: TOption[]
|
||||
menuRenderFn: MenuRenderFn<TOption>
|
||||
triggerFn: TriggerFn
|
||||
onOpen?: (resolution: MenuResolution) => void
|
||||
onClose?: () => void
|
||||
anchorClassName?: string
|
||||
commandPriority?: CommandListenerPriority
|
||||
parent?: HTMLElement
|
||||
}
|
||||
|
||||
export function LexicalTypeaheadMenuPlugin<TOption extends MenuOption>({
|
||||
options,
|
||||
onQueryChange,
|
||||
onSelectOption,
|
||||
onOpen,
|
||||
onClose,
|
||||
menuRenderFn,
|
||||
triggerFn,
|
||||
anchorClassName,
|
||||
commandPriority = COMMAND_PRIORITY_LOW,
|
||||
parent,
|
||||
}: TypeaheadMenuPluginProps<TOption>): JSX.Element | null {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
const [resolution, setResolution] = useState<MenuResolution | null>(null)
|
||||
const anchorElementRef = useMenuAnchorRef(
|
||||
resolution,
|
||||
setResolution,
|
||||
anchorClassName,
|
||||
parent,
|
||||
)
|
||||
|
||||
const closeTypeahead = useCallback(() => {
|
||||
setResolution(null)
|
||||
if (onClose != null && resolution !== null) {
|
||||
onClose()
|
||||
}
|
||||
}, [onClose, resolution])
|
||||
|
||||
const openTypeahead = useCallback(
|
||||
(res: MenuResolution) => {
|
||||
setResolution(res)
|
||||
if (onOpen != null && resolution === null) {
|
||||
onOpen(res)
|
||||
}
|
||||
},
|
||||
[onOpen, resolution],
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
const updateListener = () => {
|
||||
editor.getEditorState().read(() => {
|
||||
const editorWindow = editor._window ?? window
|
||||
const range = editorWindow.document.createRange()
|
||||
const selection = $getSelection()
|
||||
const text = getQueryTextForSearch(editor)
|
||||
|
||||
if (
|
||||
!$isRangeSelection(selection) ||
|
||||
!selection.isCollapsed() ||
|
||||
text === null ||
|
||||
range === null
|
||||
) {
|
||||
closeTypeahead()
|
||||
return
|
||||
}
|
||||
|
||||
const match = triggerFn(text, editor)
|
||||
onQueryChange(match ? match.matchingString : null)
|
||||
|
||||
if (
|
||||
match !== null &&
|
||||
!isSelectionOnEntityBoundary(editor, match.leadOffset)
|
||||
) {
|
||||
const isRangePositioned = tryToPositionRange(
|
||||
match.leadOffset,
|
||||
range,
|
||||
editorWindow,
|
||||
)
|
||||
if (isRangePositioned !== null) {
|
||||
startTransition(() =>
|
||||
openTypeahead({
|
||||
getRect: () => range.getBoundingClientRect(),
|
||||
match,
|
||||
}),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
closeTypeahead()
|
||||
})
|
||||
}
|
||||
|
||||
const removeUpdateListener = editor.registerUpdateListener(updateListener)
|
||||
|
||||
return () => {
|
||||
removeUpdateListener()
|
||||
}
|
||||
}, [
|
||||
editor,
|
||||
triggerFn,
|
||||
onQueryChange,
|
||||
resolution,
|
||||
closeTypeahead,
|
||||
openTypeahead,
|
||||
])
|
||||
|
||||
return resolution === null || editor === null ? null : (
|
||||
<LexicalMenu
|
||||
close={closeTypeahead}
|
||||
resolution={resolution}
|
||||
editor={editor}
|
||||
anchorElementRef={anchorElementRef}
|
||||
options={options}
|
||||
menuRenderFn={menuRenderFn}
|
||||
shouldSplitNodeWithQuery={true}
|
||||
onSelectOption={onSelectOption}
|
||||
commandPriority={commandPriority}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
import {
|
||||
SerializedEditorState,
|
||||
SerializedElementNode,
|
||||
SerializedTextNode,
|
||||
} from 'lexical'
|
||||
|
||||
import { editorStateToPlainText } from './editor-state-to-plain-text'
|
||||
|
||||
describe('editorStateToPlainText', () => {
|
||||
it('should convert editor state to plain text', () => {
|
||||
const editorState: SerializedEditorState = {
|
||||
root: {
|
||||
children: [
|
||||
{
|
||||
children: [
|
||||
{
|
||||
detail: 0,
|
||||
format: 0,
|
||||
mode: 'normal',
|
||||
style: '',
|
||||
text: 'Hello, world!',
|
||||
type: 'text',
|
||||
version: 1,
|
||||
},
|
||||
],
|
||||
direction: 'ltr',
|
||||
format: '',
|
||||
indent: 0,
|
||||
type: 'paragraph',
|
||||
version: 1,
|
||||
textFormat: 0,
|
||||
textStyle: '',
|
||||
} as SerializedElementNode<SerializedTextNode>,
|
||||
],
|
||||
direction: 'ltr',
|
||||
format: '',
|
||||
indent: 0,
|
||||
type: 'root',
|
||||
version: 1,
|
||||
},
|
||||
}
|
||||
const plainText = editorStateToPlainText(editorState)
|
||||
expect(plainText).toBe('Hello, world!')
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,21 @@
|
||||
import { SerializedEditorState, SerializedLexicalNode } from 'lexical'
|
||||
|
||||
export function editorStateToPlainText(
|
||||
editorState: SerializedEditorState,
|
||||
): string {
|
||||
return lexicalNodeToPlainText(editorState.root)
|
||||
}
|
||||
|
||||
function lexicalNodeToPlainText(node: SerializedLexicalNode): string {
|
||||
if ('children' in node) {
|
||||
// Process children recursively and join their results
|
||||
return (node.children as SerializedLexicalNode[])
|
||||
.map(lexicalNodeToPlainText)
|
||||
.join('')
|
||||
} else if (node.type === 'linebreak') {
|
||||
return '\n'
|
||||
} else if ('text' in node && typeof node.text === 'string') {
|
||||
return node.text
|
||||
}
|
||||
return ''
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
import {
|
||||
FileIcon,
|
||||
FolderClosedIcon,
|
||||
FoldersIcon,
|
||||
ImageIcon,
|
||||
LinkIcon,
|
||||
} from 'lucide-react'
|
||||
|
||||
import { Mentionable } from '../../../../types/mentionable'
|
||||
|
||||
export const getMentionableIcon = (mentionable: Mentionable) => {
|
||||
switch (mentionable.type) {
|
||||
case 'file':
|
||||
return FileIcon
|
||||
case 'folder':
|
||||
return FolderClosedIcon
|
||||
case 'vault':
|
||||
return FoldersIcon
|
||||
case 'current-file':
|
||||
return FileIcon
|
||||
case 'block':
|
||||
return FileIcon
|
||||
case 'url':
|
||||
return LinkIcon
|
||||
case 'image':
|
||||
return ImageIcon
|
||||
default:
|
||||
return null
|
||||
}
|
||||
}
|
||||
271
src/components/inline-edit/InlineEdit.tsx
Normal file
271
src/components/inline-edit/InlineEdit.tsx
Normal file
@@ -0,0 +1,271 @@
|
||||
import { MarkdownView, Plugin } from "obsidian";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
|
||||
import { APPLY_VIEW_TYPE } from "../../constants";
|
||||
import LLMManager from "../../core/llm/manager";
|
||||
import { InfioSettings } from "../../types/settings";
|
||||
import { manualApplyChangesToFile } from "../../utils/apply";
|
||||
import { removeAITags } from "../../utils/content-filter";
|
||||
import { PromptGenerator } from "../../utils/prompt-generator";
|
||||
|
||||
interface InlineEditProps {
|
||||
source: string;
|
||||
secStartLine: number;
|
||||
secEndLine: number;
|
||||
plugin: Plugin;
|
||||
settings: InfioSettings;
|
||||
}
|
||||
|
||||
interface InputAreaProps {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
}
|
||||
|
||||
const InputArea: React.FC<InputAreaProps> = ({ value, onChange }) => {
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
// 组件挂载后自动聚焦到 textarea
|
||||
textareaRef.current?.focus();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="infio-ai-block-input-wrapper">
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
className="infio-ai-block-content"
|
||||
placeholder="Enter instruction"
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface ControlAreaProps {
|
||||
settings: InfioSettings;
|
||||
onSubmit: () => void;
|
||||
selectedModel: string;
|
||||
onModelChange: (model: string) => void;
|
||||
isSubmitting: boolean;
|
||||
}
|
||||
|
||||
const ControlArea: React.FC<ControlAreaProps> = ({
|
||||
settings,
|
||||
onSubmit,
|
||||
selectedModel,
|
||||
onModelChange,
|
||||
isSubmitting,
|
||||
}) => (
|
||||
<div className="infio-ai-block-controls">
|
||||
<select
|
||||
className="infio-ai-block-model-select"
|
||||
value={selectedModel}
|
||||
onChange={(e) => onModelChange(e.target.value)}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{settings.activeModels
|
||||
.filter((model) => !model.isEmbeddingModel && model.enabled)
|
||||
.map((model) => (
|
||||
<option key={model.name} value={model.name}>
|
||||
{model.name}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<button
|
||||
className="infio-ai-block-submit-button"
|
||||
onClick={onSubmit}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Submitting..." : "Submit"}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
|
||||
export const InlineEdit: React.FC<InlineEditProps> = ({
|
||||
source,
|
||||
secStartLine,
|
||||
secEndLine,
|
||||
plugin,
|
||||
settings,
|
||||
}) => {
|
||||
const [instruction, setInstruction] = useState("");
|
||||
const [selectedModel, setSelectedModel] = useState(settings.chatModelId);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
const llmManager = new LLMManager({
|
||||
deepseek: settings.deepseekApiKey,
|
||||
openai: settings.openAIApiKey,
|
||||
anthropic: settings.anthropicApiKey,
|
||||
gemini: settings.geminiApiKey,
|
||||
groq: settings.groqApiKey,
|
||||
infio: settings.infioApiKey,
|
||||
});
|
||||
|
||||
const promptGenerator = new PromptGenerator(
|
||||
async () => {
|
||||
throw new Error("RAG not needed for inline edit");
|
||||
},
|
||||
plugin.app,
|
||||
settings
|
||||
);
|
||||
|
||||
const handleClose = () => {
|
||||
const activeView = plugin.app.workspace.getActiveViewOfType(MarkdownView);
|
||||
if (!activeView?.editor) return;
|
||||
activeView.editor.replaceRange(
|
||||
"",
|
||||
{ line: secStartLine, ch: 0 },
|
||||
{ line: secEndLine + 1, ch: 0 }
|
||||
);
|
||||
};
|
||||
|
||||
const getActiveContext = async () => {
|
||||
const activeFile = plugin.app.workspace.getActiveFile();
|
||||
if (!activeFile) {
|
||||
console.error("No active file");
|
||||
return {};
|
||||
}
|
||||
|
||||
const editor = plugin.app.workspace.getActiveViewOfType(MarkdownView)?.editor;
|
||||
if (!editor) {
|
||||
console.error("No active editor");
|
||||
return { activeFile };
|
||||
}
|
||||
|
||||
const selection = editor.getSelection();
|
||||
if (!selection) {
|
||||
console.error("No text selected");
|
||||
return { activeFile, editor };
|
||||
}
|
||||
|
||||
return { activeFile, editor, selection };
|
||||
};
|
||||
|
||||
const parseSmartComposeBlock = (content: string) => {
|
||||
const match = content.match(/<infio_block[^>]*>([\s\S]*?)<\/infio_block>/);
|
||||
if (!match) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const blockContent = match[1].trim();
|
||||
const attributes = match[0].match(/startLine="(\d+)"/);
|
||||
const startLine = attributes ? parseInt(attributes[1]) : undefined;
|
||||
const endLineMatch = match[0].match(/endLine="(\d+)"/);
|
||||
const endLine = endLineMatch ? parseInt(endLineMatch[1]) : undefined;
|
||||
|
||||
return {
|
||||
startLine,
|
||||
endLine,
|
||||
content: blockContent,
|
||||
};
|
||||
};
|
||||
|
||||
const handleSubmit = async () => {
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
const { activeFile, editor, selection } = await getActiveContext();
|
||||
if (!activeFile || !editor || !selection) {
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const chatModel = settings.activeModels.find(
|
||||
(model) => model.name === selectedModel
|
||||
);
|
||||
if (!chatModel) {
|
||||
setIsSubmitting(false);
|
||||
throw new Error("Invalid chat model");
|
||||
}
|
||||
|
||||
const from = editor.getCursor("from");
|
||||
const to = editor.getCursor("to");
|
||||
const defaultStartLine = from.line + 1;
|
||||
const defaultEndLine = to.line + 1;
|
||||
|
||||
const requestMessages = await promptGenerator.generateEditMessages({
|
||||
currentFile: activeFile,
|
||||
selectedContent: selection,
|
||||
instruction: instruction,
|
||||
startLine: defaultStartLine,
|
||||
endLine: defaultEndLine,
|
||||
});
|
||||
|
||||
const response = await llmManager.generateResponse(chatModel, {
|
||||
model: chatModel.name,
|
||||
messages: requestMessages,
|
||||
stream: false,
|
||||
});
|
||||
|
||||
if (!response.choices[0].message.content) {
|
||||
setIsSubmitting(false);
|
||||
throw new Error("Empty response from LLM");
|
||||
}
|
||||
|
||||
const parsedBlock = parseSmartComposeBlock(
|
||||
response.choices[0].message.content
|
||||
);
|
||||
const finalContent = parsedBlock?.content || response.choices[0].message.content;
|
||||
const startLine = parsedBlock?.startLine || defaultStartLine;
|
||||
const endLine = parsedBlock?.endLine || defaultEndLine;
|
||||
|
||||
const updatedContent = await manualApplyChangesToFile(
|
||||
finalContent,
|
||||
activeFile,
|
||||
await plugin.app.vault.read(activeFile),
|
||||
startLine,
|
||||
endLine
|
||||
);
|
||||
|
||||
if (!updatedContent) {
|
||||
console.error("Failed to apply changes");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const originalContent = await plugin.app.vault.read(activeFile);
|
||||
await plugin.app.workspace.getLeaf(true).setViewState({
|
||||
type: APPLY_VIEW_TYPE,
|
||||
active: true,
|
||||
state: {
|
||||
file: activeFile,
|
||||
originalContent: removeAITags(originalContent),
|
||||
newContent: removeAITags(updatedContent),
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error in inline edit:", error);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="infio-ai-block-container"
|
||||
id="infio-ai-block-container"
|
||||
style={{ backgroundColor: 'var(--background-secondary)' }}
|
||||
>
|
||||
<InputArea value={instruction} onChange={setInstruction} />
|
||||
<button className="infio-ai-block-close-button" onClick={handleClose}>
|
||||
<svg
|
||||
width="14"
|
||||
height="14"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
>
|
||||
<line x1="18" y1="6" x2="6" y2="18"></line>
|
||||
<line x1="6" y1="6" x2="18" y2="18"></line>
|
||||
</svg>
|
||||
</button>
|
||||
<ControlArea
|
||||
settings={settings}
|
||||
onSubmit={handleSubmit}
|
||||
selectedModel={selectedModel}
|
||||
onModelChange={setSelectedModel}
|
||||
isSubmitting={isSubmitting}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
157
src/constants.ts
Normal file
157
src/constants.ts
Normal file
@@ -0,0 +1,157 @@
|
||||
import { CustomLLMModel } from './types/llm/model'
|
||||
|
||||
export const CHAT_VIEW_TYPE = 'infio-chat-view'
|
||||
export const APPLY_VIEW_TYPE = 'infio-apply-view'
|
||||
|
||||
export const DEFAULT_MODELS: CustomLLMModel[] = [
|
||||
{
|
||||
name: 'claude-3.5-sonnet',
|
||||
provider: 'anthropic',
|
||||
enabled: true,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'o1-mini',
|
||||
provider: 'openai',
|
||||
enabled: true,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'o1-preview',
|
||||
provider: 'openai',
|
||||
enabled: false,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
enabled: true,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'gpt-4o-mini',
|
||||
provider: 'openai',
|
||||
enabled: false,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'deepseek-chat',
|
||||
provider: 'deepseek',
|
||||
enabled: true,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'gemini-1.5-pro',
|
||||
provider: 'google',
|
||||
enabled: true,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'gemini-2.0-flash-exp',
|
||||
provider: 'google',
|
||||
enabled: true,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'gemini-2.0-flash-thinking-exp-1219',
|
||||
provider: 'google',
|
||||
enabled: false,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'llama-3.1-70b-versatile',
|
||||
provider: 'groq',
|
||||
enabled: true,
|
||||
isEmbeddingModel: false,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'text-embedding-3-small',
|
||||
provider: 'openai',
|
||||
dimension: 1536,
|
||||
enabled: true,
|
||||
isEmbeddingModel: true,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'text-embedding-004',
|
||||
provider: 'google',
|
||||
dimension: 768,
|
||||
enabled: true,
|
||||
isEmbeddingModel: true,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'nomic-embed-text',
|
||||
provider: 'ollama',
|
||||
dimension: 768,
|
||||
enabled: true,
|
||||
isEmbeddingModel: true,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'mxbai-embed-large',
|
||||
provider: 'ollama',
|
||||
dimension: 1024,
|
||||
enabled: true,
|
||||
isEmbeddingModel: true,
|
||||
isBuiltIn: true,
|
||||
},
|
||||
{
|
||||
name: 'bge-m3',
|
||||
provider: 'ollama',
|
||||
dimension: 1024,
|
||||
enabled: true,
|
||||
isEmbeddingModel: true,
|
||||
isBuiltIn: true,
|
||||
}
|
||||
]
|
||||
|
||||
export const SUPPORT_EMBEDDING_SIMENTION: number[] = [
|
||||
384,
|
||||
512,
|
||||
768,
|
||||
1024,
|
||||
1536
|
||||
]
|
||||
|
||||
export const DEEPSEEK_BASE_URL = 'https://api.deepseek.com'
|
||||
|
||||
// Pricing in dollars per million tokens
|
||||
type ModelPricing = {
|
||||
input: number
|
||||
output: number
|
||||
}
|
||||
|
||||
export const OPENAI_PRICES: Record<string, ModelPricing> = {
|
||||
'gpt-4o': { input: 2.5, output: 10 },
|
||||
'gpt-4o-mini': { input: 0.15, output: 0.6 },
|
||||
'deepseek-chat': { input: 0.16, output: 0.32 },
|
||||
}
|
||||
|
||||
export const ANTHROPIC_PRICES: Record<string, ModelPricing> = {
|
||||
'claude-3-5-sonnet-latest': { input: 3, output: 15 },
|
||||
'claude-3-5-haiku-latest': { input: 1, output: 5 },
|
||||
}
|
||||
|
||||
// Gemini is currently free for low rate limits
|
||||
export const GEMINI_PRICES: Record<string, ModelPricing> = {
|
||||
'gemini-1.5-pro': { input: 0, output: 0 },
|
||||
'gemini-1.5-flash': { input: 0, output: 0 },
|
||||
}
|
||||
|
||||
export const GROQ_PRICES: Record<string, ModelPricing> = {
|
||||
'llama-3.1-70b-versatile': { input: 0.59, output: 0.79 },
|
||||
'llama-3.1-8b-instant': { input: 0.05, output: 0.08 },
|
||||
}
|
||||
|
||||
export const PGLITE_DB_PATH = '.infio_vector_db.tar.gz'
|
||||
23
src/contexts/AppContext.tsx
Normal file
23
src/contexts/AppContext.tsx
Normal file
@@ -0,0 +1,23 @@
|
||||
import { App } from 'obsidian'
|
||||
import React from 'react'
|
||||
|
||||
// App context
|
||||
const AppContext = React.createContext<App | undefined>(undefined)
|
||||
|
||||
export const AppProvider = ({
|
||||
children,
|
||||
app,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
app: App
|
||||
}) => {
|
||||
return <AppContext.Provider value={app}>{children}</AppContext.Provider>
|
||||
}
|
||||
|
||||
export const useApp = () => {
|
||||
const app = React.useContext(AppContext)
|
||||
if (!app) {
|
||||
throw new Error('useApp must be used within an AppProvider')
|
||||
}
|
||||
return app
|
||||
}
|
||||
46
src/contexts/DarkModeContext.tsx
Normal file
46
src/contexts/DarkModeContext.tsx
Normal file
@@ -0,0 +1,46 @@
|
||||
import {
|
||||
ReactNode,
|
||||
createContext,
|
||||
useContext,
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react'
|
||||
|
||||
import { useApp } from './AppContext'
|
||||
|
||||
type DarkModeContextType = {
|
||||
isDarkMode: boolean
|
||||
}
|
||||
|
||||
const DarkModeContext = createContext<DarkModeContextType | undefined>(
|
||||
undefined,
|
||||
)
|
||||
|
||||
export function DarkModeProvider({ children }: { children: ReactNode }) {
|
||||
const [isDarkMode, setIsDarkMode] = useState(false)
|
||||
const app = useApp()
|
||||
|
||||
useEffect(() => {
|
||||
const handleDarkMode = () => {
|
||||
setIsDarkMode(document.body.classList.contains('theme-dark'))
|
||||
}
|
||||
handleDarkMode()
|
||||
app.workspace.on('css-change', handleDarkMode)
|
||||
return () => app.workspace.off('css-change', handleDarkMode)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<DarkModeContext.Provider value={{ isDarkMode }}>
|
||||
{children}
|
||||
</DarkModeContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function useDarkModeContext() {
|
||||
const context = useContext(DarkModeContext)
|
||||
if (context === undefined) {
|
||||
throw new Error('useDarkModeContext must be used within a DarkModeProvider')
|
||||
}
|
||||
return context
|
||||
}
|
||||
58
src/contexts/DatabaseContext.tsx
Normal file
58
src/contexts/DatabaseContext.tsx
Normal file
@@ -0,0 +1,58 @@
|
||||
import {
|
||||
createContext,
|
||||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
|
||||
import { DBManager } from '../database/database-manager'
|
||||
import { TemplateManager } from '../database/modules/template/template-manager'
|
||||
import { VectorManager } from '../database/modules/vector/vector-manager'
|
||||
|
||||
type DatabaseContextType = {
|
||||
getDatabaseManager: () => Promise<DBManager>
|
||||
getVectorManager: () => Promise<VectorManager>
|
||||
getTemplateManager: () => Promise<TemplateManager>
|
||||
}
|
||||
|
||||
const DatabaseContext = createContext<DatabaseContextType | null>(null)
|
||||
|
||||
export function DatabaseProvider({
|
||||
children,
|
||||
getDatabaseManager,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
getDatabaseManager: () => Promise<DBManager>
|
||||
}) {
|
||||
const getVectorManager = useCallback(async () => {
|
||||
return (await getDatabaseManager()).getVectorManager()
|
||||
}, [getDatabaseManager])
|
||||
|
||||
const getTemplateManager = useCallback(async () => {
|
||||
return (await getDatabaseManager()).getTemplateManager()
|
||||
}, [getDatabaseManager])
|
||||
|
||||
useEffect(() => {
|
||||
// start initialization of dbManager in the background
|
||||
void getDatabaseManager()
|
||||
}, [getDatabaseManager])
|
||||
|
||||
const value = useMemo(() => {
|
||||
return { getDatabaseManager, getVectorManager, getTemplateManager }
|
||||
}, [getDatabaseManager, getVectorManager, getTemplateManager])
|
||||
|
||||
return (
|
||||
<DatabaseContext.Provider value={value}>
|
||||
{children}
|
||||
</DatabaseContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function useDatabase(): DatabaseContextType {
|
||||
const context = useContext(DatabaseContext)
|
||||
if (!context) {
|
||||
throw new Error('useDatabase must be used within a DatabaseProvider')
|
||||
}
|
||||
return context
|
||||
}
|
||||
27
src/contexts/DialogContext.tsx
Normal file
27
src/contexts/DialogContext.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
import React, { createContext, useContext } from 'react'
|
||||
|
||||
const DialogContext = createContext<HTMLElement | null>(null)
|
||||
|
||||
export function DialogProvider({
|
||||
children,
|
||||
container,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
container: HTMLElement | null
|
||||
}) {
|
||||
return (
|
||||
<DialogContext.Provider value={container}>
|
||||
{children}
|
||||
</DialogContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function useDialogContainer() {
|
||||
const context = useContext(DialogContext)
|
||||
if (!context) {
|
||||
throw new Error(
|
||||
'useDialogContainer must be used within a DialogContainerProvider',
|
||||
)
|
||||
}
|
||||
return context
|
||||
}
|
||||
135
src/contexts/LLMContext.tsx
Normal file
135
src/contexts/LLMContext.tsx
Normal file
@@ -0,0 +1,135 @@
|
||||
import {
|
||||
PropsWithChildren,
|
||||
createContext,
|
||||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useState,
|
||||
} from 'react'
|
||||
|
||||
import LLMManager from '../core/llm/manager'
|
||||
import { CustomLLMModel } from '../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../types/llm/response'
|
||||
|
||||
import { useSettings } from './SettingsContext'
|
||||
|
||||
export type LLMContextType = {
|
||||
generateResponse: (
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
) => Promise<LLMResponseNonStreaming>
|
||||
streamResponse: (
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
) => Promise<AsyncIterable<LLMResponseStreaming>>
|
||||
chatModel: CustomLLMModel
|
||||
applyModel: CustomLLMModel
|
||||
}
|
||||
|
||||
const LLMContext = createContext<LLMContextType | null>(null)
|
||||
|
||||
export function LLMProvider({ children }: PropsWithChildren) {
|
||||
const [llmManager, setLLMManager] = useState<LLMManager | null>(null)
|
||||
const { settings } = useSettings()
|
||||
|
||||
const chatModel = useMemo((): CustomLLMModel => {
|
||||
const model = settings.activeModels.find(
|
||||
(option) => option.name === settings.chatModelId,
|
||||
)
|
||||
if (!model) {
|
||||
throw new Error('Invalid chat model ID')
|
||||
}
|
||||
return model
|
||||
}, [settings])
|
||||
|
||||
const applyModel = useMemo((): CustomLLMModel => {
|
||||
const model = settings.activeModels.find(
|
||||
(option) => option.name === settings.applyModelId,
|
||||
)
|
||||
if (!model) {
|
||||
throw new Error('Invalid apply model ID')
|
||||
}
|
||||
if (model.provider === 'ollama') {
|
||||
return {
|
||||
provider: 'ollama',
|
||||
baseURL: settings.ollamaApplyModel.baseUrl,
|
||||
model: settings.ollamaApplyModel.model,
|
||||
}
|
||||
}
|
||||
return model
|
||||
}, [settings])
|
||||
|
||||
useEffect(() => {
|
||||
const manager = new LLMManager({
|
||||
deepseek: settings.deepseekApiKey,
|
||||
openai: settings.openAIApiKey,
|
||||
anthropic: settings.anthropicApiKey,
|
||||
gemini: settings.geminiApiKey,
|
||||
groq: settings.groqApiKey,
|
||||
infio: settings.infioApiKey,
|
||||
})
|
||||
setLLMManager(manager)
|
||||
}, [
|
||||
settings.deepseekApiKey,
|
||||
settings.openAIApiKey,
|
||||
settings.anthropicApiKey,
|
||||
settings.geminiApiKey,
|
||||
settings.groqApiKey,
|
||||
settings.infioApiKey,
|
||||
])
|
||||
|
||||
const generateResponse = useCallback(
|
||||
async (
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
) => {
|
||||
if (!llmManager) {
|
||||
throw new Error('LLMManager is not initialized')
|
||||
}
|
||||
return await llmManager.generateResponse(model, request, options)
|
||||
},
|
||||
[llmManager],
|
||||
)
|
||||
|
||||
const streamResponse = useCallback(
|
||||
async (
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
) => {
|
||||
if (!llmManager) {
|
||||
throw new Error('LLMManager is not initialized')
|
||||
}
|
||||
return await llmManager.streamResponse(model, request, options)
|
||||
},
|
||||
[llmManager],
|
||||
)
|
||||
|
||||
return (
|
||||
<LLMContext.Provider
|
||||
value={{ generateResponse, streamResponse, chatModel, applyModel }}
|
||||
>
|
||||
{children}
|
||||
</LLMContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function useLLM() {
|
||||
const context = useContext(LLMContext)
|
||||
if (!context) {
|
||||
throw new Error('useLLM must be used within an LLMProvider')
|
||||
}
|
||||
return context
|
||||
}
|
||||
39
src/contexts/RAGContext.tsx
Normal file
39
src/contexts/RAGContext.tsx
Normal file
@@ -0,0 +1,39 @@
|
||||
import {
|
||||
PropsWithChildren,
|
||||
createContext,
|
||||
useContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
|
||||
import { RAGEngine } from '../core/rag/rag-engine'
|
||||
|
||||
export type RAGContextType = {
|
||||
getRAGEngine: () => Promise<RAGEngine>
|
||||
}
|
||||
|
||||
const RAGContext = createContext<RAGContextType | null>(null)
|
||||
|
||||
export function RAGProvider({
|
||||
getRAGEngine,
|
||||
children,
|
||||
}: PropsWithChildren<{ getRAGEngine: () => Promise<RAGEngine> }>) {
|
||||
useEffect(() => {
|
||||
// start initialization of ragEngine in the background
|
||||
void getRAGEngine()
|
||||
}, [getRAGEngine])
|
||||
|
||||
const value = useMemo(() => {
|
||||
return { getRAGEngine }
|
||||
}, [getRAGEngine])
|
||||
|
||||
return <RAGContext.Provider value={value}>{children}</RAGContext.Provider>
|
||||
}
|
||||
|
||||
export function useRAG() {
|
||||
const context = useContext(RAGContext)
|
||||
if (!context) {
|
||||
throw new Error('useRAG must be used within a RAGProvider')
|
||||
}
|
||||
return context
|
||||
}
|
||||
58
src/contexts/SettingsContext.tsx
Normal file
58
src/contexts/SettingsContext.tsx
Normal file
@@ -0,0 +1,58 @@
|
||||
import React, { useEffect, useMemo, useState } from 'react'
|
||||
|
||||
import { InfioSettings } from '../types/settings'
|
||||
|
||||
type SettingsContextType = {
|
||||
settings: InfioSettings
|
||||
setSettings: (newSettings: InfioSettings) => void
|
||||
}
|
||||
|
||||
// Settings context
|
||||
const SettingsContext = React.createContext<SettingsContextType | undefined>(
|
||||
undefined,
|
||||
)
|
||||
|
||||
export const SettingsProvider = ({
|
||||
children,
|
||||
settings: initialSettings,
|
||||
setSettings,
|
||||
addSettingsChangeListener,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
settings: InfioSettings
|
||||
setSettings: (newSettings: InfioSettings) => void
|
||||
addSettingsChangeListener: (
|
||||
listener: (newSettings: InfioSettings) => void,
|
||||
) => () => void
|
||||
}) => {
|
||||
const [settingsCached, setSettingsCached] = useState(initialSettings)
|
||||
|
||||
useEffect(() => {
|
||||
const removeListener = addSettingsChangeListener((newSettings) => {
|
||||
setSettingsCached(newSettings)
|
||||
})
|
||||
|
||||
return () => {
|
||||
removeListener()
|
||||
}
|
||||
}, [addSettingsChangeListener, setSettings])
|
||||
|
||||
const value = useMemo(
|
||||
() => ({ settings: settingsCached, setSettings }),
|
||||
[settingsCached, setSettings],
|
||||
)
|
||||
|
||||
return (
|
||||
<SettingsContext.Provider value={value}>
|
||||
{children}
|
||||
</SettingsContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export const useSettings = () => {
|
||||
const settings = React.useContext(SettingsContext)
|
||||
if (!settings) {
|
||||
throw new Error('useSettings must be used within a SettingsProvider')
|
||||
}
|
||||
return settings
|
||||
}
|
||||
95
src/core/autocomplete/context-detection.ts
Normal file
95
src/core/autocomplete/context-detection.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
import { generateRandomString } from "./utils";
|
||||
|
||||
const UNIQUE_CURSOR = `${generateRandomString(16)}`;
|
||||
const HEADER_REGEX = `^#+\\s.*${UNIQUE_CURSOR}.*$`;
|
||||
const UNORDERED_LIST_REGEX = `^\\s*(-|\\*)\\s.*${UNIQUE_CURSOR}.*$`;
|
||||
const TASK_LIST_REGEX = `^\\s*(-|[0-9]+\\.) +\\[.\\]\\s.*${UNIQUE_CURSOR}.*$`;
|
||||
const BLOCK_QUOTES_REGEX = `^\\s*>.*${UNIQUE_CURSOR}.*$`;
|
||||
const NUMBERED_LIST_REGEX = `^\\s*\\d+\\.\\s.*${UNIQUE_CURSOR}.*$`
|
||||
const MATH_BLOCK_REGEX = /\$\$[\s\S]*?\$\$/g;
|
||||
const INLINE_MATH_BLOCK_REGEX = /\$[\s\S]*?\$/g;
|
||||
const CODE_BLOCK_REGEX = /```[\s\S]*?```/g;
|
||||
const INLINE_CODE_BLOCK_REGEX = /`.*`/g;
|
||||
|
||||
enum Context {
|
||||
Text = "Text",
|
||||
Heading = "Heading",
|
||||
BlockQuotes = "BlockQuotes",
|
||||
UnorderedList = "UnorderedList",
|
||||
NumberedList = "NumberedList",
|
||||
CodeBlock = "CodeBlock",
|
||||
MathBlock = "MathBlock",
|
||||
TaskList = "TaskList",
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-namespace
|
||||
namespace Context {
|
||||
export function values(): Array<Context> {
|
||||
return Object.values(Context).filter(
|
||||
(value) => typeof value === "string"
|
||||
) as Array<Context>;
|
||||
}
|
||||
|
||||
export function getContext(prefix: string, suffix: string): Context {
|
||||
if (new RegExp(HEADER_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.Heading;
|
||||
}
|
||||
if (new RegExp(BLOCK_QUOTES_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.BlockQuotes;
|
||||
}
|
||||
|
||||
if (new RegExp(TASK_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.TaskList;
|
||||
}
|
||||
|
||||
if (
|
||||
isCursorInRegexBlock(prefix, suffix, MATH_BLOCK_REGEX) ||
|
||||
isCursorInRegexBlock(prefix, suffix, INLINE_MATH_BLOCK_REGEX)
|
||||
) {
|
||||
return Context.MathBlock;
|
||||
}
|
||||
|
||||
if (isCursorInRegexBlock(prefix, suffix, CODE_BLOCK_REGEX) || isCursorInRegexBlock(prefix, suffix, INLINE_CODE_BLOCK_REGEX)) {
|
||||
return Context.CodeBlock;
|
||||
}
|
||||
if (new RegExp(NUMBERED_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.NumberedList;
|
||||
}
|
||||
if (new RegExp(UNORDERED_LIST_REGEX, "gm").test(prefix + UNIQUE_CURSOR + suffix)) {
|
||||
return Context.UnorderedList;
|
||||
}
|
||||
|
||||
return Context.Text;
|
||||
}
|
||||
|
||||
export function get(value: string) {
|
||||
for (const context of Context.values()) {
|
||||
if (value === context) {
|
||||
return context;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
function isCursorInRegexBlock(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
regex: RegExp
|
||||
): boolean {
|
||||
const text = prefix + UNIQUE_CURSOR + suffix;
|
||||
const codeBlocks = extractBlocks(text, regex);
|
||||
for (const block of codeBlocks) {
|
||||
if (block.includes(UNIQUE_CURSOR)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function extractBlocks(text: string, regex: RegExp) {
|
||||
const codeBlocks = text.match(regex);
|
||||
return codeBlocks ? codeBlocks.map((block) => block.trim()) : [];
|
||||
}
|
||||
|
||||
export default Context;
|
||||
273
src/core/autocomplete/index.ts
Normal file
273
src/core/autocomplete/index.ts
Normal file
@@ -0,0 +1,273 @@
|
||||
import * as Handlebars from "handlebars";
|
||||
import { err, ok, Result } from "neverthrow";
|
||||
|
||||
import { FewShotExample } from "../../settings/versions";
|
||||
import { CustomLLMModel } from "../../types/llm/model";
|
||||
import { RequestMessage } from '../../types/llm/request';
|
||||
import { InfioSettings } from "../../types/settings";
|
||||
import LLMManager from '../llm/manager';
|
||||
|
||||
import Context from "./context-detection";
|
||||
import RemoveCodeIndicators from "./post-processors/remove-code-indicators";
|
||||
import RemoveMathIndicators from "./post-processors/remove-math-indicators";
|
||||
import RemoveOverlap from "./post-processors/remove-overlap";
|
||||
import RemoveWhitespace from "./post-processors/remove-whitespace";
|
||||
import DataViewRemover from "./pre-processors/data-view-remover";
|
||||
import LengthLimiter from "./pre-processors/length-limiter";
|
||||
import {
|
||||
AutocompleteService,
|
||||
ChatMessage,
|
||||
PostProcessor,
|
||||
PreProcessor,
|
||||
UserMessageFormatter,
|
||||
UserMessageFormattingInputs
|
||||
} from "./types";
|
||||
|
||||
class LLMClient {
|
||||
private llm: LLMManager;
|
||||
private model: CustomLLMModel;
|
||||
|
||||
constructor(llm: LLMManager, model: CustomLLMModel) {
|
||||
this.llm = llm;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
async queryChatModel(messages: RequestMessage[]): Promise<Result<string, Error>> {
|
||||
const data = await this.llm.generateResponse(this.model, {
|
||||
model: this.model.name,
|
||||
messages: messages,
|
||||
stream: false,
|
||||
})
|
||||
return ok(data.choices[0].message.content);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AutoComplete implements AutocompleteService {
|
||||
private readonly client: LLMClient;
|
||||
private readonly systemMessage: string;
|
||||
private readonly userMessageFormatter: UserMessageFormatter;
|
||||
private readonly removePreAnswerGenerationRegex: string;
|
||||
private readonly preProcessors: PreProcessor[];
|
||||
private readonly postProcessors: PostProcessor[];
|
||||
private readonly fewShotExamples: FewShotExample[];
|
||||
private debugMode: boolean;
|
||||
|
||||
private constructor(
|
||||
client: LLMClient,
|
||||
systemMessage: string,
|
||||
userMessageFormatter: UserMessageFormatter,
|
||||
removePreAnswerGenerationRegex: string,
|
||||
preProcessors: PreProcessor[],
|
||||
postProcessors: PostProcessor[],
|
||||
fewShotExamples: FewShotExample[],
|
||||
debugMode: boolean,
|
||||
) {
|
||||
this.client = client;
|
||||
this.systemMessage = systemMessage;
|
||||
this.userMessageFormatter = userMessageFormatter;
|
||||
this.removePreAnswerGenerationRegex = removePreAnswerGenerationRegex;
|
||||
this.preProcessors = preProcessors;
|
||||
this.postProcessors = postProcessors;
|
||||
this.fewShotExamples = fewShotExamples;
|
||||
this.debugMode = debugMode;
|
||||
}
|
||||
|
||||
public static fromSettings(settings: InfioSettings): AutocompleteService {
|
||||
const formatter = Handlebars.compile<UserMessageFormattingInputs>(
|
||||
settings.userMessageTemplate,
|
||||
{ noEscape: true, strict: true }
|
||||
);
|
||||
const preProcessors: PreProcessor[] = [];
|
||||
if (settings.dontIncludeDataviews) {
|
||||
preProcessors.push(new DataViewRemover());
|
||||
}
|
||||
preProcessors.push(
|
||||
new LengthLimiter(
|
||||
settings.maxPrefixCharLimit,
|
||||
settings.maxSuffixCharLimit
|
||||
)
|
||||
);
|
||||
|
||||
const postProcessors: PostProcessor[] = [];
|
||||
if (settings.removeDuplicateMathBlockIndicator) {
|
||||
postProcessors.push(new RemoveMathIndicators());
|
||||
}
|
||||
if (settings.removeDuplicateCodeBlockIndicator) {
|
||||
postProcessors.push(new RemoveCodeIndicators());
|
||||
}
|
||||
|
||||
postProcessors.push(new RemoveOverlap());
|
||||
postProcessors.push(new RemoveWhitespace());
|
||||
|
||||
const llm_manager = new LLMManager({
|
||||
deepseek: settings.deepseekApiKey,
|
||||
openai: settings.openAIApiKey,
|
||||
anthropic: settings.anthropicApiKey,
|
||||
gemini: settings.geminiApiKey,
|
||||
groq: settings.groqApiKey,
|
||||
infio: settings.infioApiKey,
|
||||
})
|
||||
const model: CustomLLMModel = settings.activeModels.find(
|
||||
(option) => option.name === settings.chatModelId,
|
||||
)
|
||||
const llm = new LLMClient(llm_manager, model);
|
||||
|
||||
return new AutoComplete(
|
||||
llm,
|
||||
settings.systemMessage,
|
||||
formatter,
|
||||
settings.chainOfThoughRemovalRegex,
|
||||
preProcessors,
|
||||
postProcessors,
|
||||
settings.fewShotExamples,
|
||||
settings.debugMode,
|
||||
);
|
||||
}
|
||||
|
||||
async fetchPredictions(
|
||||
prefix: string,
|
||||
suffix: string
|
||||
): Promise<Result<string, Error>> {
|
||||
const context: Context = Context.getContext(prefix, suffix);
|
||||
|
||||
for (const preProcessor of this.preProcessors) {
|
||||
if (preProcessor.removesCursor(prefix, suffix)) {
|
||||
return ok("");
|
||||
}
|
||||
|
||||
({ prefix, suffix } = preProcessor.process(
|
||||
prefix,
|
||||
suffix,
|
||||
context
|
||||
));
|
||||
}
|
||||
|
||||
const examples = this.fewShotExamples.filter(
|
||||
(example) => example.context === context
|
||||
);
|
||||
const fewShotExamplesChatMessages =
|
||||
fewShotExamplesToChatMessages(examples);
|
||||
|
||||
const messages: RequestMessage[] = [
|
||||
{
|
||||
content: this.getSystemMessageFor(context),
|
||||
role: "system"
|
||||
},
|
||||
...fewShotExamplesChatMessages,
|
||||
{
|
||||
role: "user",
|
||||
content: this.userMessageFormatter({
|
||||
suffix,
|
||||
prefix,
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
if (this.debugMode) {
|
||||
console.log("Copilot messages send:\n", messages);
|
||||
}
|
||||
|
||||
let result = await this.client.queryChatModel(messages);
|
||||
if (this.debugMode && result.isOk()) {
|
||||
console.log("Copilot response:\n", result.value);
|
||||
}
|
||||
|
||||
result = this.extractAnswerFromChainOfThoughts(result);
|
||||
|
||||
for (const postProcessor of this.postProcessors) {
|
||||
result = result.map((r) => postProcessor.process(prefix, suffix, r, context));
|
||||
}
|
||||
|
||||
result = this.checkAgainstGuardRails(result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private getSystemMessageFor(context: Context): string {
|
||||
if (context === Context.Text) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a paragraph. Your answer must complete this paragraph or sentence in a way that fits the surrounding text without overlapping with it. It must be in the same language as the paragraph.";
|
||||
}
|
||||
if (context === Context.Heading) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in the Markdown heading. Your answer must complete this title in a way that fits the content of this paragraph and be in the same language as the paragraph.";
|
||||
}
|
||||
|
||||
if (context === Context.BlockQuotes) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located within a quote. Your answer must complete this quote in a way that fits the context of the paragraph.";
|
||||
}
|
||||
if (context === Context.UnorderedList) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in an unordered list. Your answer must include one or more list items that fit with the surrounding list without overlapping with it.";
|
||||
}
|
||||
|
||||
if (context === Context.NumberedList) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a numbered list. Your answer must include one or more list items that fit the sequence and context of the surrounding list without overlapping with it.";
|
||||
}
|
||||
|
||||
if (context === Context.CodeBlock) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a code block. Your answer must complete this code block in the same programming language and support the surrounding code and text outside of the code block.";
|
||||
}
|
||||
if (context === Context.MathBlock) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a math block. Your answer must only contain LaTeX code that captures the math discussed in the surrounding text. No text or explaination only LaTex math code.";
|
||||
}
|
||||
if (context === Context.TaskList) {
|
||||
return this.systemMessage + "\n\n" + "The <mask/> is located in a task list. Your answer must include one or more (sub)tasks that are logical given the other tasks and the surrounding text.";
|
||||
}
|
||||
|
||||
|
||||
return this.systemMessage;
|
||||
|
||||
}
|
||||
|
||||
private extractAnswerFromChainOfThoughts(
|
||||
result: Result<string, Error>
|
||||
): Result<string, Error> {
|
||||
if (result.isErr()) {
|
||||
return result;
|
||||
}
|
||||
const chainOfThoughts = result.value;
|
||||
|
||||
const regex = new RegExp(this.removePreAnswerGenerationRegex, "gm");
|
||||
const match = regex.exec(chainOfThoughts);
|
||||
if (match === null) {
|
||||
return err(new Error("No match found"));
|
||||
}
|
||||
return ok(chainOfThoughts.replace(regex, ""));
|
||||
}
|
||||
|
||||
private checkAgainstGuardRails(
|
||||
result: Result<string, Error>
|
||||
): Result<string, Error> {
|
||||
if (result.isErr()) {
|
||||
return result;
|
||||
}
|
||||
if (result.value.length === 0) {
|
||||
return err(new Error("Empty result"));
|
||||
}
|
||||
if (result.value.contains("<mask/>")) {
|
||||
return err(new Error("Mask in result"));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
function fewShotExamplesToChatMessages(
|
||||
examples: FewShotExample[]
|
||||
): ChatMessage[] {
|
||||
return examples
|
||||
.map((example): ChatMessage[] => {
|
||||
return [
|
||||
{
|
||||
role: "user",
|
||||
content: example.input,
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: example.answer,
|
||||
},
|
||||
];
|
||||
})
|
||||
.flat();
|
||||
}
|
||||
|
||||
export default AutoComplete;
|
||||
@@ -0,0 +1,22 @@
|
||||
import Context from "../context-detection";
|
||||
import { PostProcessor } from "../types";
|
||||
|
||||
class RemoveCodeIndicators implements PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string {
|
||||
if (context === Context.CodeBlock) {
|
||||
completion = completion.replace(/```[a-zA-z]+[ \t]*\n?/g, "");
|
||||
completion = completion.replace(/\n?```[ \t]*\n?/g, "");
|
||||
completion = completion.replace(/`/g, "");
|
||||
}
|
||||
|
||||
|
||||
return completion;
|
||||
}
|
||||
}
|
||||
|
||||
export default RemoveCodeIndicators;
|
||||
@@ -0,0 +1,20 @@
|
||||
import Context from "../context-detection";
|
||||
import { PostProcessor } from "../types";
|
||||
|
||||
class RemoveMathIndicators implements PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string {
|
||||
if (context === Context.MathBlock) {
|
||||
completion = completion.replace(/\n?\$\$\n?/g, "");
|
||||
completion = completion.replace(/\$/g, "");
|
||||
}
|
||||
|
||||
return completion;
|
||||
}
|
||||
}
|
||||
|
||||
export default RemoveMathIndicators;
|
||||
96
src/core/autocomplete/post-processors/remove-overlap.ts
Normal file
96
src/core/autocomplete/post-processors/remove-overlap.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import Context from "../context-detection";
|
||||
import { PostProcessor } from "../types";
|
||||
|
||||
|
||||
class RemoveOverlap implements PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string {
|
||||
completion = removeWordOverlapPrefix(prefix, completion);
|
||||
completion = removeWordOverlapSuffix(completion, suffix);
|
||||
completion = removeWhiteSpaceOverlapPrefix(suffix, completion);
|
||||
completion = removeWhiteSpaceOverlapSuffix(completion, suffix);
|
||||
|
||||
return completion;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function removeWhiteSpaceOverlapPrefix(prefix: string, completion: string): string {
|
||||
let prefixIdx = prefix.length - 1;
|
||||
while (completion.length > 0 && completion[0] === prefix[prefixIdx]) {
|
||||
completion = completion.slice(1);
|
||||
prefixIdx--;
|
||||
}
|
||||
return completion;
|
||||
}
|
||||
|
||||
|
||||
function removeWhiteSpaceOverlapSuffix(completion: string, suffix: string): string {
|
||||
let suffixIdx = 0;
|
||||
while (completion.length > 0 && completion[completion.length - 1] === suffix[suffixIdx]) {
|
||||
completion = completion.slice(0, -1);
|
||||
suffixIdx++;
|
||||
}
|
||||
return completion;
|
||||
}
|
||||
|
||||
function removeWordOverlapPrefix(prefix: string, completion: string): string {
|
||||
const rightTrimmed = completion.trimStart();
|
||||
|
||||
const startIdxOfEachWord = startLocationOfEachWord(prefix);
|
||||
|
||||
while (startIdxOfEachWord.length > 0) {
|
||||
const idx = startIdxOfEachWord.pop();
|
||||
const leftSubstring = prefix.slice(idx);
|
||||
if (rightTrimmed.startsWith(leftSubstring)) {
|
||||
return rightTrimmed.replace(leftSubstring, "");
|
||||
}
|
||||
}
|
||||
|
||||
return completion;
|
||||
}
|
||||
|
||||
function removeWordOverlapSuffix(completion: string, suffix: string): string {
|
||||
const suffixTrimmed = removeLeadingWhiteSpace(suffix);
|
||||
|
||||
const startIdxOfEachWord = startLocationOfEachWord(completion);
|
||||
|
||||
while (startIdxOfEachWord.length > 0) {
|
||||
const idx = startIdxOfEachWord.pop();
|
||||
const suffixSubstring = completion.slice(idx);
|
||||
if (suffixTrimmed.startsWith(suffixSubstring)) {
|
||||
return completion.replace(suffixSubstring, "");
|
||||
}
|
||||
}
|
||||
|
||||
return completion;
|
||||
}
|
||||
|
||||
function removeLeadingWhiteSpace(completion: string): string {
|
||||
return completion.replace(/^[ \t\f\r\v]+/, "");
|
||||
}
|
||||
|
||||
function startLocationOfEachWord(text: string): number[] {
|
||||
const locations: number[] = [];
|
||||
if (text.length > 0 && !isWhiteSpaceChar(text[0])) {
|
||||
locations.push(0);
|
||||
}
|
||||
|
||||
for (let i = 1; i < text.length; i++) {
|
||||
if (isWhiteSpaceChar(text[i - 1]) && !isWhiteSpaceChar(text[i])) {
|
||||
locations.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
return locations;
|
||||
}
|
||||
|
||||
function isWhiteSpaceChar(char: string | undefined): boolean {
|
||||
return char !== undefined && char.match(/\s/) !== null;
|
||||
}
|
||||
|
||||
export default RemoveOverlap;
|
||||
24
src/core/autocomplete/post-processors/remove-whitespace.ts
Normal file
24
src/core/autocomplete/post-processors/remove-whitespace.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import Context from "../context-detection";
|
||||
import { PostProcessor } from "../types";
|
||||
|
||||
class RemoveWhitespace implements PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string {
|
||||
if (context === Context.Text || context === Context.Heading || context === Context.MathBlock || context === Context.TaskList || context === Context.NumberedList || context === Context.UnorderedList) {
|
||||
if (prefix.endsWith(" ") || suffix.endsWith("\n")) {
|
||||
completion = completion.trimStart();
|
||||
}
|
||||
if (suffix.startsWith(" ")) {
|
||||
completion = completion.trimEnd();
|
||||
}
|
||||
}
|
||||
|
||||
return completion;
|
||||
}
|
||||
}
|
||||
|
||||
export default RemoveWhitespace;
|
||||
34
src/core/autocomplete/pre-processors/data-view-remover.ts
Normal file
34
src/core/autocomplete/pre-processors/data-view-remover.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { generateRandomString } from "../utils";
|
||||
import Context from "../context-detection";
|
||||
import { PrefixAndSuffix, PreProcessor } from "../types";
|
||||
|
||||
const DATA_VIEW_REGEX = /```dataview(js){0,1}(.|\n)*?```/gm;
|
||||
const UNIQUE_CURSOR = `${generateRandomString(16)}`;
|
||||
|
||||
class DataViewRemover implements PreProcessor {
|
||||
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix {
|
||||
let text = prefix + UNIQUE_CURSOR + suffix;
|
||||
text = text.replace(DATA_VIEW_REGEX, "");
|
||||
const [prefixNew, suffixNew] = text.split(UNIQUE_CURSOR);
|
||||
|
||||
return { prefix: prefixNew, suffix: suffixNew };
|
||||
}
|
||||
|
||||
removesCursor(prefix: string, suffix: string): boolean {
|
||||
const text = prefix + UNIQUE_CURSOR + suffix;
|
||||
const dataviewAreasWithCursor = text
|
||||
.match(DATA_VIEW_REGEX)
|
||||
?.filter((dataviewArea) => dataviewArea.includes(UNIQUE_CURSOR));
|
||||
|
||||
if (
|
||||
dataviewAreasWithCursor !== undefined &&
|
||||
dataviewAreasWithCursor.length > 0
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export default DataViewRemover;
|
||||
24
src/core/autocomplete/pre-processors/length-limiter.ts
Normal file
24
src/core/autocomplete/pre-processors/length-limiter.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import Context from "../context-detection";
|
||||
import { PrefixAndSuffix, PreProcessor } from "../types";
|
||||
|
||||
class LengthLimiter implements PreProcessor {
|
||||
private readonly maxPrefixChars: number;
|
||||
private readonly maxSuffixChars: number;
|
||||
|
||||
constructor(maxPrefixChars: number, maxSuffixChars: number) {
|
||||
this.maxPrefixChars = maxPrefixChars;
|
||||
this.maxSuffixChars = maxSuffixChars;
|
||||
}
|
||||
|
||||
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix {
|
||||
prefix = prefix.slice(-this.maxPrefixChars);
|
||||
suffix = suffix.slice(0, this.maxSuffixChars);
|
||||
return { prefix, suffix };
|
||||
}
|
||||
|
||||
removesCursor(prefix: string, suffix: string): boolean {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export default LengthLimiter;
|
||||
35
src/core/autocomplete/states/disabled-file-specific-state.ts
Normal file
35
src/core/autocomplete/states/disabled-file-specific-state.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { TFile } from "obsidian";
|
||||
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
|
||||
class DisabledFileSpecificState extends State {
|
||||
getStatusBarText(): string {
|
||||
return "Disabled for this file";
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings) {
|
||||
if (!this.context.settings.autocompleteEnabled) {
|
||||
this.context.transitionToDisabledManualState();
|
||||
}
|
||||
if (!this.context.isCurrentFilePathIgnored() || !this.context.currentFileContainsIgnoredTag()) {
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
}
|
||||
|
||||
handleFileChange(file: TFile): void {
|
||||
if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.context.settings.autocompleteEnabled) {
|
||||
this.context.transitionToIdleState();
|
||||
} else {
|
||||
this.context.transitionToDisabledManualState();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default DisabledFileSpecificState;
|
||||
@@ -0,0 +1,25 @@
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
import { checkForErrors } from "../../../utils/auto-complete";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
class DisabledInvalidSettingsState extends State {
|
||||
getStatusBarText(): string {
|
||||
return "Disabled invalid settings";
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings) {
|
||||
const settingErrors = checkForErrors(settings);
|
||||
if (settingErrors.size > 0) {
|
||||
return
|
||||
}
|
||||
if (this.context.settings.autocompleteEnabled) {
|
||||
this.context.transitionToIdleState();
|
||||
} else {
|
||||
this.context.transitionToDisabledManualState();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export default DisabledInvalidSettingsState;
|
||||
21
src/core/autocomplete/states/disabled-manual-state.ts
Normal file
21
src/core/autocomplete/states/disabled-manual-state.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import { TFile } from "obsidian";
|
||||
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
class DisabledManualState extends State {
|
||||
getStatusBarText(): string {
|
||||
return "Disabled";
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings): void {
|
||||
if (this.context.settings.autocompleteEnabled) {
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
}
|
||||
|
||||
handleFileChange(file: TFile): void { }
|
||||
}
|
||||
|
||||
export default DisabledManualState;
|
||||
46
src/core/autocomplete/states/idle-state.ts
Normal file
46
src/core/autocomplete/states/idle-state.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
|
||||
class IdleState extends State {
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
if (
|
||||
!documentChanges.isDocInFocus()
|
||||
|| !documentChanges.hasDocChanged()
|
||||
|| documentChanges.hasUserDeleted()
|
||||
|| documentChanges.hasMultipleCursors()
|
||||
|| documentChanges.hasSelection()
|
||||
|| documentChanges.hasUserUndone()
|
||||
|| documentChanges.hasUserRedone()
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const cachedSuggestion = this.context.getCachedSuggestionFor(documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
const isThereCachedSuggestion = cachedSuggestion !== undefined && cachedSuggestion.trim().length > 0;
|
||||
|
||||
if (this.context.settings.cacheSuggestions && isThereCachedSuggestion) {
|
||||
this.context.transitionToSuggestingState(cachedSuggestion, documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
if (this.context.containsTriggerCharacters(documentChanges)) {
|
||||
this.context.transitionToQueuedState(documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
}
|
||||
}
|
||||
|
||||
handlePredictCommand(prefix: string, suffix: string): void {
|
||||
this.context.transitionToPredictingState(prefix, suffix);
|
||||
}
|
||||
|
||||
getStatusBarText(): string {
|
||||
return "Idle";
|
||||
}
|
||||
}
|
||||
|
||||
export default IdleState;
|
||||
38
src/core/autocomplete/states/init-state.ts
Normal file
38
src/core/autocomplete/states/init-state.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { TFile } from "obsidian";
|
||||
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import { EventHandler } from "./types";
|
||||
|
||||
class InitState implements EventHandler {
|
||||
async handleDocumentChange(documentChanges: DocumentChanges): Promise<void> { }
|
||||
|
||||
handleSettingChanged(settings: InfioSettings): void { }
|
||||
|
||||
handleAcceptKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handlePredictCommand(): void { }
|
||||
|
||||
handleAcceptCommand(): void { }
|
||||
|
||||
getStatusBarText(): string {
|
||||
return "Initializing...";
|
||||
}
|
||||
|
||||
handleFileChange(file: TFile): void {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
export default InitState;
|
||||
94
src/core/autocomplete/states/predicting-state.ts
Normal file
94
src/core/autocomplete/states/predicting-state.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { Notice } from "obsidian";
|
||||
|
||||
import Context from "../context-detection";
|
||||
import EventListener from "../../../event-listener";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
class PredictingState extends State {
|
||||
private predictionPromise: Promise<void> | null = null;
|
||||
private isStillNeeded = true;
|
||||
private readonly prefix: string;
|
||||
private readonly suffix: string;
|
||||
|
||||
constructor(context: EventListener, prefix: string, suffix: string) {
|
||||
super(context);
|
||||
this.prefix = prefix;
|
||||
this.suffix = suffix;
|
||||
}
|
||||
|
||||
static createAndStartPredicting(
|
||||
context: EventListener,
|
||||
prefix: string,
|
||||
suffix: string
|
||||
): PredictingState {
|
||||
const predictingState = new PredictingState(context, prefix, suffix);
|
||||
predictingState.startPredicting();
|
||||
context.setContext(Context.getContext(prefix, suffix));
|
||||
return predictingState;
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
this.cancelPrediction();
|
||||
return true;
|
||||
}
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
if (
|
||||
documentChanges.hasCursorMoved() ||
|
||||
documentChanges.hasUserTyped() ||
|
||||
documentChanges.hasUserDeleted() ||
|
||||
documentChanges.isTextAdded()
|
||||
) {
|
||||
this.cancelPrediction();
|
||||
}
|
||||
}
|
||||
|
||||
private cancelPrediction(): void {
|
||||
this.isStillNeeded = false;
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
|
||||
startPredicting(): void {
|
||||
this.predictionPromise = this.predict();
|
||||
}
|
||||
|
||||
private async predict(): Promise<void> {
|
||||
|
||||
const result =
|
||||
await this.context.autocomplete?.fetchPredictions(
|
||||
this.prefix,
|
||||
this.suffix
|
||||
);
|
||||
|
||||
if (!this.isStillNeeded) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (result.isErr()) {
|
||||
new Notice(
|
||||
`Copilot: Something went wrong cannot make a prediction. Full error is available in the dev console. Please check your settings. `
|
||||
);
|
||||
console.error(result.error);
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
|
||||
const prediction = result.unwrapOr("");
|
||||
|
||||
if (prediction === "") {
|
||||
this.context.transitionToIdleState();
|
||||
return;
|
||||
}
|
||||
this.context.transitionToSuggestingState(prediction, this.prefix, this.suffix);
|
||||
}
|
||||
|
||||
|
||||
getStatusBarText(): string {
|
||||
return `Predicting for ${this.context.context}`;
|
||||
}
|
||||
}
|
||||
|
||||
export default PredictingState;
|
||||
84
src/core/autocomplete/states/queued-state.ts
Normal file
84
src/core/autocomplete/states/queued-state.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import Context from "../context-detection";
|
||||
import EventListener from "../../../event-listener";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
|
||||
class QueuedState extends State {
|
||||
private timer: ReturnType<typeof setTimeout> | null = null;
|
||||
private readonly prefix: string;
|
||||
private readonly suffix: string;
|
||||
|
||||
|
||||
private constructor(
|
||||
context: EventListener,
|
||||
prefix: string,
|
||||
suffix: string
|
||||
) {
|
||||
super(context);
|
||||
this.prefix = prefix;
|
||||
this.suffix = suffix;
|
||||
}
|
||||
|
||||
static createAndStartTimer(
|
||||
context: EventListener,
|
||||
prefix: string,
|
||||
suffix: string
|
||||
): QueuedState {
|
||||
const state = new QueuedState(context, prefix, suffix);
|
||||
state.startTimer();
|
||||
context.setContext(Context.getContext(prefix, suffix));
|
||||
return state;
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
this.cancelTimer();
|
||||
this.context.transitionToIdleState();
|
||||
return true;
|
||||
}
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
if (
|
||||
documentChanges.isDocInFocus() &&
|
||||
documentChanges.isTextAdded() &&
|
||||
this.context.containsTriggerCharacters(documentChanges)
|
||||
) {
|
||||
this.cancelTimer();
|
||||
this.context.transitionToQueuedState(documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
return
|
||||
}
|
||||
if (
|
||||
(documentChanges.hasCursorMoved() ||
|
||||
documentChanges.hasUserTyped() ||
|
||||
documentChanges.hasUserDeleted() ||
|
||||
documentChanges.isTextAdded() ||
|
||||
!documentChanges.isDocInFocus())
|
||||
) {
|
||||
this.cancelTimer();
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
}
|
||||
|
||||
startTimer(): void {
|
||||
this.cancelTimer();
|
||||
this.timer = setTimeout(() => {
|
||||
this.context.transitionToPredictingState(this.prefix, this.suffix);
|
||||
}, this.context.settings.delay);
|
||||
}
|
||||
|
||||
private cancelTimer(): void {
|
||||
if (this.timer !== null) {
|
||||
clearTimeout(this.timer);
|
||||
this.timer = null;
|
||||
}
|
||||
}
|
||||
|
||||
getStatusBarText(): string {
|
||||
return `Queued (${this.context.settings.delay} ms)`;
|
||||
}
|
||||
}
|
||||
|
||||
export default QueuedState;
|
||||
67
src/core/autocomplete/states/state.ts
Normal file
67
src/core/autocomplete/states/state.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import { Notice, TFile } from "obsidian";
|
||||
|
||||
import EventListener from "../../../event-listener";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
// import { Settings } from "../settings/versions";
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
import { checkForErrors } from "../../../utils/auto-complete";
|
||||
|
||||
import { EventHandler } from "./types";
|
||||
|
||||
abstract class State implements EventHandler {
|
||||
protected readonly context: EventListener;
|
||||
|
||||
constructor(context: EventListener) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings): void {
|
||||
const settingErrors = checkForErrors(settings);
|
||||
if (!settings.autocompleteEnabled) {
|
||||
new Notice("Copilot is now disabled.");
|
||||
this.context.transitionToDisabledManualState()
|
||||
} else if (settingErrors.size > 0) {
|
||||
new Notice(
|
||||
`Copilot: There are ${settingErrors.size} errors in your settings. The plugin will be disabled until they are fixed.`
|
||||
);
|
||||
this.context.transitionToDisabledInvalidSettingsState();
|
||||
} else if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
|
||||
this.context.transitionToDisabledFileSpecificState();
|
||||
}
|
||||
}
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
}
|
||||
|
||||
handleAcceptKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
handlePredictCommand(prefix: string, suffix: string): void {
|
||||
}
|
||||
|
||||
handleAcceptCommand(): void {
|
||||
}
|
||||
|
||||
abstract getStatusBarText(): string;
|
||||
|
||||
handleFileChange(file: TFile): void {
|
||||
if (this.context.isCurrentFilePathIgnored() || this.context.currentFileContainsIgnoredTag()) {
|
||||
this.context.transitionToDisabledFileSpecificState();
|
||||
} else if (this.context.isDisabled()) {
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default State;
|
||||
177
src/core/autocomplete/states/suggesting-state.ts
Normal file
177
src/core/autocomplete/states/suggesting-state.ts
Normal file
@@ -0,0 +1,177 @@
|
||||
|
||||
import { Settings } from "../../../settings/versions";
|
||||
import { extractNextWordAndRemaining } from "../utils";
|
||||
import EventListener from "../../../event-listener";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
import State from "./state";
|
||||
|
||||
class SuggestingState extends State {
|
||||
private readonly suggestion: string;
|
||||
private readonly prefix: string;
|
||||
private readonly suffix: string;
|
||||
|
||||
|
||||
constructor(context: EventListener, suggestion: string, prefix: string, suffix: string) {
|
||||
super(context);
|
||||
this.suggestion = suggestion;
|
||||
this.prefix = prefix;
|
||||
this.suffix = suffix;
|
||||
}
|
||||
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
|
||||
if (
|
||||
documentChanges.hasCursorMoved()
|
||||
|| documentChanges.hasUserUndone()
|
||||
|| documentChanges.hasUserDeleted()
|
||||
|| documentChanges.hasUserRedone()
|
||||
|| !documentChanges.isDocInFocus()
|
||||
|| documentChanges.hasSelection()
|
||||
|| documentChanges.hasMultipleCursors()
|
||||
) {
|
||||
this.clearPrediction();
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
documentChanges.noUserEvents()
|
||||
|| !documentChanges.hasDocChanged()
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.hasUserAddedPartOfSuggestion(documentChanges)) {
|
||||
this.acceptPartialAddedText(documentChanges);
|
||||
return
|
||||
}
|
||||
|
||||
const currentPrefix = documentChanges.getPrefix();
|
||||
const currentSuffix = documentChanges.getSuffix();
|
||||
const suggestion = this.context.getCachedSuggestionFor(currentPrefix, currentSuffix);
|
||||
const isThereCachedSuggestion = suggestion !== undefined;
|
||||
const isCachedSuggestionDifferent = suggestion !== this.suggestion;
|
||||
|
||||
if (!isCachedSuggestionDifferent) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isThereCachedSuggestion) {
|
||||
this.context.transitionToSuggestingState(suggestion, currentPrefix, currentSuffix);
|
||||
return;
|
||||
}
|
||||
this.clearPrediction();
|
||||
}
|
||||
|
||||
|
||||
hasUserAddedPartOfSuggestion(documentChanges: DocumentChanges): boolean {
|
||||
const addedPrefixText = documentChanges.getAddedPrefixText();
|
||||
const addedSuffixText = documentChanges.getAddedSuffixText();
|
||||
|
||||
return addedPrefixText !== undefined
|
||||
&& addedSuffixText !== undefined
|
||||
&& this.suggestion.toLowerCase().startsWith(addedPrefixText.toLowerCase())
|
||||
&& this.suggestion.toLowerCase().endsWith(addedSuffixText.toLowerCase());
|
||||
}
|
||||
|
||||
acceptPartialAddedText(documentChanges: DocumentChanges): void {
|
||||
const addedPrefixText = documentChanges.getAddedPrefixText();
|
||||
const addedSuffixText = documentChanges.getAddedSuffixText();
|
||||
if (addedSuffixText === undefined || addedPrefixText === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
const startIdx = addedPrefixText.length;
|
||||
const endIdx = this.suggestion.length - addedSuffixText.length
|
||||
const remainingSuggestion = this.suggestion.substring(startIdx, endIdx);
|
||||
|
||||
if (remainingSuggestion.trim() === "") {
|
||||
this.clearPrediction();
|
||||
} else {
|
||||
this.context.transitionToSuggestingState(remainingSuggestion, documentChanges.getPrefix(), documentChanges.getSuffix());
|
||||
}
|
||||
}
|
||||
|
||||
private clearPrediction(): void {
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
|
||||
handleAcceptKeyPressed(): boolean {
|
||||
this.accept();
|
||||
return true;
|
||||
}
|
||||
|
||||
private accept() {
|
||||
this.addPartialSuggestionCaches(this.suggestion);
|
||||
this.context.insertCurrentSuggestion(this.suggestion);
|
||||
this.context.transitionToIdleState();
|
||||
}
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean {
|
||||
this.acceptNextWord();
|
||||
return true;
|
||||
}
|
||||
|
||||
private acceptNextWord() {
|
||||
const [nextWord, remaining] = extractNextWordAndRemaining(this.suggestion);
|
||||
|
||||
if (nextWord !== undefined && remaining !== undefined) {
|
||||
const updatedPrefix = this.prefix + nextWord;
|
||||
|
||||
this.addPartialSuggestionCaches(nextWord, remaining);
|
||||
this.context.insertCurrentSuggestion(nextWord);
|
||||
this.context.transitionToSuggestingState(remaining, updatedPrefix, this.suffix, false);
|
||||
} else {
|
||||
this.accept();
|
||||
}
|
||||
}
|
||||
|
||||
private addPartialSuggestionCaches(acceptSuggestion: string, remainingSuggestion = "") {
|
||||
// store the sub-suggestions in the cache
|
||||
// so that we can have partial suggestions if the user edits a part
|
||||
for (let i = 0; i < acceptSuggestion.length; i++) {
|
||||
const prefix = this.prefix + acceptSuggestion.substring(0, i);
|
||||
const suggestion = acceptSuggestion.substring(i) + remainingSuggestion;
|
||||
this.context.addSuggestionToCache(prefix, this.suffix, suggestion);
|
||||
}
|
||||
}
|
||||
|
||||
private getNextWordAndRemaining(): [string | undefined, string | undefined] {
|
||||
const words = this.suggestion.split(" ");
|
||||
if (words.length === 0) {
|
||||
return ["", ""];
|
||||
}
|
||||
|
||||
if (words.length === 1) {
|
||||
return [words[0] + " ", ""];
|
||||
}
|
||||
|
||||
return [words[0] + " ", words.slice(1).join(" ")];
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
this.context.clearSuggestionsCache();
|
||||
this.clearPrediction();
|
||||
return true;
|
||||
}
|
||||
|
||||
handleAcceptCommand() {
|
||||
this.accept();
|
||||
}
|
||||
|
||||
getStatusBarText(): string {
|
||||
return `Suggesting for ${this.context.context}`;
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: Settings): void {
|
||||
if (!settings.cacheSuggestions) {
|
||||
this.clearPrediction();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export default SuggestingState;
|
||||
25
src/core/autocomplete/states/types.ts
Normal file
25
src/core/autocomplete/states/types.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
import { TFile } from "obsidian";
|
||||
|
||||
import { InfioSettings } from "../../../types/settings";
|
||||
import { DocumentChanges } from "../../../render-plugin/document-changes-listener";
|
||||
|
||||
export interface EventHandler {
|
||||
handleSettingChanged(settings: InfioSettings): void;
|
||||
|
||||
handleDocumentChange(documentChanges: DocumentChanges): Promise<void>;
|
||||
|
||||
handleAcceptKeyPressed(): boolean;
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean;
|
||||
|
||||
handleCancelKeyPressed(): boolean;
|
||||
|
||||
handlePredictCommand(prefix: string, suffix: string): void;
|
||||
handleAcceptCommand(): void;
|
||||
|
||||
getStatusBarText(): string;
|
||||
|
||||
handleFileChange(file: TFile): void;
|
||||
|
||||
}
|
||||
57
src/core/autocomplete/types.ts
Normal file
57
src/core/autocomplete/types.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { Result } from "neverthrow";
|
||||
|
||||
import Context from "./context-detection";
|
||||
|
||||
export interface AutocompleteService {
|
||||
fetchPredictions(
|
||||
prefix: string,
|
||||
suffix: string
|
||||
): Promise<Result<string, Error>>;
|
||||
}
|
||||
|
||||
export interface PostProcessor {
|
||||
process(
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
completion: string,
|
||||
context: Context
|
||||
): string;
|
||||
}
|
||||
|
||||
export interface PreProcessor {
|
||||
process(prefix: string, suffix: string, context: Context): PrefixAndSuffix;
|
||||
removesCursor(prefix: string, suffix: string): boolean;
|
||||
}
|
||||
|
||||
export interface PrefixAndSuffix {
|
||||
prefix: string;
|
||||
suffix: string;
|
||||
}
|
||||
|
||||
export interface ChatMessage {
|
||||
content: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
}
|
||||
|
||||
|
||||
export interface UserMessageFormattingInputs {
|
||||
prefix: string;
|
||||
suffix: string;
|
||||
}
|
||||
|
||||
export type UserMessageFormatter = (
|
||||
inputs: UserMessageFormattingInputs
|
||||
) => string;
|
||||
|
||||
export interface ApiClient {
|
||||
queryChatModel(messages: ChatMessage[]): Promise<Result<string, Error>>;
|
||||
checkIfConfiguredCorrectly?(): Promise<string[]>;
|
||||
}
|
||||
|
||||
export interface ModelOptions {
|
||||
temperature: number;
|
||||
top_p: number;
|
||||
frequency_penalty: number;
|
||||
presence_penalty: number;
|
||||
max_tokens: number;
|
||||
}
|
||||
68
src/core/autocomplete/utils.ts
Normal file
68
src/core/autocomplete/utils.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import * as mm from "micromatch";
|
||||
|
||||
export function sleep(ms: number) {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
export function enumKeys<O extends object, K extends keyof O = keyof O>(
|
||||
obj: O
|
||||
): K[] {
|
||||
return Object.keys(obj).filter((k) => Number.isNaN(+k)) as K[];
|
||||
}
|
||||
|
||||
export function generateRandomString(n: number): string {
|
||||
let result = '';
|
||||
const characters = '0123456789abcdef';
|
||||
|
||||
for (let i = 0; i < n; i++) {
|
||||
const randomIndex = Math.floor(Math.random() * characters.length);
|
||||
result += characters[randomIndex];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export function isMatchBetweenPathAndPatterns(
|
||||
path: string,
|
||||
patterns: string[],
|
||||
): boolean {
|
||||
patterns = patterns
|
||||
.map(p => p.trim())
|
||||
.filter((p) => p.length > 0);
|
||||
if (patterns.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const exclusionPatterns = patterns.filter((p) => p.startsWith('!')).map(p => p.slice(1));
|
||||
const inclusionPatterns = patterns.filter((p) => !p.startsWith('!'));
|
||||
|
||||
return mm.some(path, inclusionPatterns) && !mm.some(path, exclusionPatterns);
|
||||
}
|
||||
|
||||
export function extractNextWordAndRemaining(suggestion: string): [string | undefined, string | undefined] {
|
||||
const leadingWhitespacesMatch = suggestion.match(/^(\s*)/);
|
||||
const leadingWhitespaces = leadingWhitespacesMatch ? leadingWhitespacesMatch[0] : '';
|
||||
const trimmedSuggestion = suggestion.slice(leadingWhitespaces.length);
|
||||
|
||||
|
||||
let nextWord: string | undefined;
|
||||
let remaining: string | undefined = undefined;
|
||||
|
||||
const whitespaceAfterNextWordMatch = trimmedSuggestion.match(/\s+/);
|
||||
if (!whitespaceAfterNextWordMatch) {
|
||||
nextWord = trimmedSuggestion || undefined;
|
||||
} else {
|
||||
const whitespaceAfterNextWordStartingIndex = whitespaceAfterNextWordMatch.index!;
|
||||
const whitespaceAfterNextWord = whitespaceAfterNextWordMatch[0];
|
||||
const whitespaceLength = whitespaceAfterNextWord.length;
|
||||
const startOfWhitespaceAfterNextWordIndex = whitespaceAfterNextWordStartingIndex + whitespaceLength;
|
||||
|
||||
nextWord = trimmedSuggestion.substring(0, whitespaceAfterNextWordStartingIndex);
|
||||
if (startOfWhitespaceAfterNextWordIndex < trimmedSuggestion.length) {
|
||||
remaining = trimmedSuggestion.slice(startOfWhitespaceAfterNextWordIndex);
|
||||
nextWord += whitespaceAfterNextWord;
|
||||
}
|
||||
}
|
||||
|
||||
return [nextWord ? leadingWhitespaces + nextWord : undefined, remaining];
|
||||
}
|
||||
44
src/core/edit/inline-edit-processor.ts
Normal file
44
src/core/edit/inline-edit-processor.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import { MarkdownPostProcessorContext, Plugin } from "obsidian";
|
||||
import React from "react";
|
||||
import { createRoot } from "react-dom/client";
|
||||
|
||||
import { InlineEdit as InlineEditComponent } from "../../components/inline-edit/InlineEdit";
|
||||
import { InfioSettings } from "../../types/settings";
|
||||
|
||||
export class InlineEdit {
|
||||
plugin: Plugin;
|
||||
settings: InfioSettings;
|
||||
|
||||
constructor(plugin: Plugin, settings: InfioSettings) {
|
||||
this.plugin = plugin;
|
||||
this.settings = settings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Markdown 处理器入口函数
|
||||
*/
|
||||
Processor(source: string, el: HTMLElement, ctx: MarkdownPostProcessorContext) {
|
||||
const sec = ctx.getSectionInfo(el);
|
||||
if (!sec) return;
|
||||
|
||||
const container = createDiv();
|
||||
const root = createRoot(container);
|
||||
|
||||
root.render(
|
||||
React.createElement(InlineEditComponent, {
|
||||
source: source,
|
||||
secStartLine: sec.lineStart,
|
||||
secEndLine: sec.lineEnd,
|
||||
plugin: this.plugin,
|
||||
settings: this.settings
|
||||
})
|
||||
);
|
||||
|
||||
// 移除父元素的代码块样式
|
||||
const parent = el.parentElement;
|
||||
if (parent) {
|
||||
parent.addClass("infio-ai-block");
|
||||
}
|
||||
el.replaceWith(container);
|
||||
}
|
||||
}
|
||||
323
src/core/llm/anthropic.ts
Normal file
323
src/core/llm/anthropic.ts
Normal file
@@ -0,0 +1,323 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
ImageBlockParam,
|
||||
MessageParam,
|
||||
MessageStreamEvent,
|
||||
TextBlockParam,
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
ResponseUsage,
|
||||
} from '../../types/llm/response'
|
||||
import { parseImageDataUrl } from '../../utils/image'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export class AnthropicProvider implements BaseLLMProvider {
|
||||
private client: Anthropic
|
||||
|
||||
private static readonly DEFAULT_MAX_TOKENS = 8192
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new Anthropic({ apiKey, dangerouslyAllowBrowser: true })
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Anthropic API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Anthropic({
|
||||
baseURL: model.baseUrl,
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true
|
||||
})
|
||||
}
|
||||
|
||||
const systemMessage = AnthropicProvider.validateSystemMessages(
|
||||
request.messages,
|
||||
)
|
||||
|
||||
try {
|
||||
const response = await this.client.messages.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages
|
||||
.filter((m) => m.role !== 'system')
|
||||
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
|
||||
.map((m) => AnthropicProvider.parseRequestMessage(m)),
|
||||
system: systemMessage,
|
||||
max_tokens:
|
||||
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
return AnthropicProvider.parseNonStreamingResponse(response)
|
||||
} catch (error) {
|
||||
if (error instanceof Anthropic.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Anthropic API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Anthropic API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Anthropic({
|
||||
baseURL: model.baseUrl,
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true
|
||||
})
|
||||
}
|
||||
|
||||
const systemMessage = AnthropicProvider.validateSystemMessages(
|
||||
request.messages,
|
||||
)
|
||||
|
||||
try {
|
||||
const stream = await this.client.messages.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages
|
||||
.filter((m) => m.role !== 'system')
|
||||
.filter((m) => !AnthropicProvider.isMessageEmpty(m))
|
||||
.map((m) => AnthropicProvider.parseRequestMessage(m)),
|
||||
system: systemMessage,
|
||||
max_tokens:
|
||||
request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
let messageId = ''
|
||||
let model = ''
|
||||
let usage: ResponseUsage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.type === 'message_start') {
|
||||
messageId = chunk.message.id
|
||||
model = chunk.message.model
|
||||
usage = {
|
||||
prompt_tokens: chunk.message.usage.input_tokens,
|
||||
completion_tokens: chunk.message.usage.output_tokens,
|
||||
total_tokens:
|
||||
chunk.message.usage.input_tokens +
|
||||
chunk.message.usage.output_tokens,
|
||||
}
|
||||
} else if (chunk.type === 'content_block_delta') {
|
||||
yield AnthropicProvider.parseStreamingResponseChunk(
|
||||
chunk,
|
||||
messageId,
|
||||
model,
|
||||
)
|
||||
} else if (chunk.type === 'message_delta') {
|
||||
usage = {
|
||||
prompt_tokens: usage.prompt_tokens,
|
||||
completion_tokens:
|
||||
usage.completion_tokens + chunk.usage.output_tokens,
|
||||
total_tokens: usage.total_tokens + chunk.usage.output_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After the stream is complete, yield the final usage
|
||||
yield {
|
||||
id: messageId,
|
||||
choices: [],
|
||||
object: 'chat.completion.chunk',
|
||||
model: model,
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
} catch (error) {
|
||||
if (error instanceof Anthropic.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Anthropic API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(message: RequestMessage): MessageParam {
|
||||
if (message.role !== 'user' && message.role !== 'assistant') {
|
||||
throw new Error(`Anthropic does not support role: ${message.role}`)
|
||||
}
|
||||
|
||||
if (message.role === 'user' && Array.isArray(message.content)) {
|
||||
const content = message.content.map(
|
||||
(part): TextBlockParam | ImageBlockParam => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url': {
|
||||
const { mimeType, base64Data } = parseImageDataUrl(
|
||||
part.image_url.url,
|
||||
)
|
||||
AnthropicProvider.validateImageType(mimeType)
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data,
|
||||
media_type:
|
||||
mimeType as ImageBlockParam['source']['media_type'],
|
||||
type: 'base64',
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
return { role: 'user', content }
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content as string,
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: Anthropic.Message,
|
||||
): LLMResponseNonStreaming {
|
||||
if (response.content[0].type === 'tool_use') {
|
||||
throw new Error('Unsupported content type: tool_use')
|
||||
}
|
||||
return {
|
||||
id: response.id,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: response.stop_reason,
|
||||
message: {
|
||||
content: response.content[0].text,
|
||||
role: response.role,
|
||||
},
|
||||
},
|
||||
],
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
usage: {
|
||||
prompt_tokens: response.usage.input_tokens,
|
||||
completion_tokens: response.usage.output_tokens,
|
||||
total_tokens:
|
||||
response.usage.input_tokens + response.usage.output_tokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: MessageStreamEvent,
|
||||
messageId: string,
|
||||
model: string,
|
||||
): LLMResponseStreaming {
|
||||
if (chunk.type !== 'content_block_delta') {
|
||||
throw new Error('Unsupported chunk type')
|
||||
}
|
||||
if (chunk.delta.type === 'input_json_delta') {
|
||||
throw new Error('Unsupported content type: input_json_delta')
|
||||
}
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: null,
|
||||
delta: {
|
||||
content: chunk.delta.text,
|
||||
},
|
||||
},
|
||||
],
|
||||
object: 'chat.completion.chunk',
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
private static validateSystemMessages(
|
||||
messages: RequestMessage[],
|
||||
): string | undefined {
|
||||
const systemMessages = messages.filter((m) => m.role === 'system')
|
||||
if (systemMessages.length > 1) {
|
||||
throw new Error(`Anthropic does not support more than one system message`)
|
||||
}
|
||||
const systemMessage =
|
||||
systemMessages.length > 0 ? systemMessages[0].content : undefined
|
||||
if (systemMessage && typeof systemMessage !== 'string') {
|
||||
throw new Error(
|
||||
`Anthropic only supports string content for system messages`,
|
||||
)
|
||||
}
|
||||
return systemMessage
|
||||
}
|
||||
|
||||
private static isMessageEmpty(message: RequestMessage) {
|
||||
if (typeof message.content === 'string') {
|
||||
return message.content.trim() === ''
|
||||
}
|
||||
return message.content.length === 0
|
||||
}
|
||||
|
||||
private static validateImageType(mimeType: string) {
|
||||
const SUPPORTED_IMAGE_TYPES = [
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/webp',
|
||||
]
|
||||
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
|
||||
throw new Error(
|
||||
`Anthropic does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
|
||||
', ',
|
||||
)}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
23
src/core/llm/base.ts
Normal file
23
src/core/llm/base.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
export type BaseLLMProvider = {
|
||||
generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming>
|
||||
streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>>
|
||||
}
|
||||
34
src/core/llm/exception.ts
Normal file
34
src/core/llm/exception.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
export class LLMAPIKeyNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMAPIKeyNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMAPIKeyInvalidException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMAPIKeyInvalidException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMBaseUrlNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMBaseUrlNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMModelNotSetException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMModelNotSetException'
|
||||
}
|
||||
}
|
||||
|
||||
export class LLMRateLimitExceededException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'LLMRateLimitExceededException'
|
||||
}
|
||||
}
|
||||
299
src/core/llm/gemini.ts
Normal file
299
src/core/llm/gemini.ts
Normal file
@@ -0,0 +1,299 @@
|
||||
import {
|
||||
Content,
|
||||
EnhancedGenerateContentResponse,
|
||||
GenerateContentResult,
|
||||
GenerateContentStreamResult,
|
||||
GoogleGenerativeAI,
|
||||
} from '@google/generative-ai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
import { parseImageDataUrl } from '../../utils/image'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
/**
|
||||
* Note on OpenAI Compatibility API:
|
||||
* Gemini provides an OpenAI-compatible endpoint (https://ai.google.dev/gemini-api/docs/openai)
|
||||
* which allows using the OpenAI SDK with Gemini models. However, there are currently CORS issues
|
||||
* preventing its use in Obsidian. Consider switching to this endpoint in the future once these
|
||||
* issues are resolved.
|
||||
*/
|
||||
export class GeminiProvider implements BaseLLMProvider {
|
||||
private client: GoogleGenerativeAI
|
||||
private apiKey: string
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.apiKey = apiKey
|
||||
this.client = new GoogleGenerativeAI(apiKey)
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
`Gemini API key is missing. Please set it in settings menu.`,
|
||||
)
|
||||
}
|
||||
this.apiKey = model.apiKey
|
||||
this.client = new GoogleGenerativeAI(model.apiKey)
|
||||
}
|
||||
|
||||
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
||||
const systemInstruction: string | undefined =
|
||||
systemMessages.length > 0
|
||||
? systemMessages.map((m) => m.content).join('\n')
|
||||
: undefined
|
||||
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: request.model,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
topP: request.top_p,
|
||||
presencePenalty: request.presence_penalty,
|
||||
frequencyPenalty: request.frequency_penalty,
|
||||
},
|
||||
systemInstruction: systemInstruction,
|
||||
})
|
||||
|
||||
const result = await model.generateContent(
|
||||
{
|
||||
systemInstruction: systemInstruction,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
||||
return GeminiProvider.parseNonStreamingResponse(
|
||||
result,
|
||||
request.model,
|
||||
messageId,
|
||||
)
|
||||
} catch (error) {
|
||||
const isInvalidApiKey =
|
||||
error.message?.includes('API_KEY_INVALID') ||
|
||||
error.message?.includes('API key not valid')
|
||||
|
||||
if (isInvalidApiKey) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
`Gemini API key is invalid. Please update it in settings menu.`,
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
`Gemini API key is missing. Please set it in settings menu.`,
|
||||
)
|
||||
}
|
||||
this.apiKey = model.apiKey
|
||||
this.client = new GoogleGenerativeAI(model.apiKey)
|
||||
}
|
||||
|
||||
const systemMessages = request.messages.filter((m) => m.role === 'system')
|
||||
const systemInstruction: string | undefined =
|
||||
systemMessages.length > 0
|
||||
? systemMessages.map((m) => m.content).join('\n')
|
||||
: undefined
|
||||
|
||||
try {
|
||||
const model = this.client.getGenerativeModel({
|
||||
model: request.model,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
topP: request.top_p,
|
||||
presencePenalty: request.presence_penalty,
|
||||
frequencyPenalty: request.frequency_penalty,
|
||||
},
|
||||
systemInstruction: systemInstruction,
|
||||
})
|
||||
|
||||
const stream = await model.generateContentStream(
|
||||
{
|
||||
systemInstruction: systemInstruction,
|
||||
contents: request.messages
|
||||
.map((message) => GeminiProvider.parseRequestMessage(message))
|
||||
.filter((m): m is Content => m !== null),
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
const messageId = crypto.randomUUID() // Gemini does not return a message id
|
||||
return this.streamResponseGenerator(stream, request.model, messageId)
|
||||
} catch (error) {
|
||||
const isInvalidApiKey =
|
||||
error.message?.includes('API_KEY_INVALID') ||
|
||||
error.message?.includes('API key not valid')
|
||||
|
||||
if (isInvalidApiKey) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
`Gemini API key is invalid. Please update it in settings menu.`,
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private async *streamResponseGenerator(
|
||||
stream: GenerateContentStreamResult,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream.stream) {
|
||||
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(message: RequestMessage): Content | null {
|
||||
if (message.role === 'system') {
|
||||
return null
|
||||
}
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
return {
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: message.content.map((part) => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { text: part.text }
|
||||
case 'image_url': {
|
||||
const { mimeType, base64Data } = parseImageDataUrl(
|
||||
part.image_url.url,
|
||||
)
|
||||
GeminiProvider.validateImageType(mimeType)
|
||||
|
||||
return {
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: [
|
||||
{
|
||||
text: message.content,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: GenerateContentResult,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason:
|
||||
response.response.candidates?.[0]?.finishReason ?? null,
|
||||
message: {
|
||||
content: response.response.text(),
|
||||
role: 'assistant',
|
||||
},
|
||||
},
|
||||
],
|
||||
created: Date.now(),
|
||||
model: model,
|
||||
object: 'chat.completion',
|
||||
usage: response.response.usageMetadata
|
||||
? {
|
||||
prompt_tokens: response.response.usageMetadata.promptTokenCount,
|
||||
completion_tokens:
|
||||
response.response.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: response.response.usageMetadata.totalTokenCount,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: EnhancedGenerateContentResponse,
|
||||
model: string,
|
||||
messageId: string,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: messageId,
|
||||
choices: [
|
||||
{
|
||||
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
|
||||
delta: {
|
||||
content: chunk.text(),
|
||||
},
|
||||
},
|
||||
],
|
||||
created: Date.now(),
|
||||
model: model,
|
||||
object: 'chat.completion.chunk',
|
||||
usage: chunk.usageMetadata
|
||||
? {
|
||||
prompt_tokens: chunk.usageMetadata.promptTokenCount,
|
||||
completion_tokens: chunk.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: chunk.usageMetadata.totalTokenCount,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
private static validateImageType(mimeType: string) {
|
||||
const SUPPORTED_IMAGE_TYPES = [
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/webp',
|
||||
'image/heic',
|
||||
'image/heif',
|
||||
]
|
||||
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
|
||||
throw new Error(
|
||||
`Gemini does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
|
||||
', ',
|
||||
)}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
200
src/core/llm/groq.ts
Normal file
200
src/core/llm/groq.ts
Normal file
@@ -0,0 +1,200 @@
|
||||
import Groq from 'groq-sdk'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionMessageParam,
|
||||
} from 'groq-sdk/resources/chat/completions'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export class GroqProvider implements BaseLLMProvider {
|
||||
private client: Groq
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new Groq({
|
||||
apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Groq API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Groq({
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await this.client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
GroqProvider.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
return GroqProvider.parseNonStreamingResponse(response)
|
||||
} catch (error) {
|
||||
if (error instanceof Groq.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Groq API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'Groq API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new Groq({
|
||||
apiKey: model.apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
const stream = await this.client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
GroqProvider.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream) {
|
||||
yield GroqProvider.parseStreamingResponseChunk(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
} catch (error) {
|
||||
if (error instanceof Groq.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'Groq API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseRequestMessage(
|
||||
message: RequestMessage,
|
||||
): ChatCompletionMessageParam {
|
||||
switch (message.role) {
|
||||
case 'user': {
|
||||
const content = Array.isArray(message.content)
|
||||
? message.content.map((part): ChatCompletionContentPart => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url':
|
||||
return { type: 'image_url', image_url: part.image_url }
|
||||
}
|
||||
})
|
||||
: message.content
|
||||
return { role: 'user', content }
|
||||
}
|
||||
case 'assistant': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('Assistant message should be a string')
|
||||
}
|
||||
return { role: 'assistant', content: message.content }
|
||||
}
|
||||
case 'system': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('System message should be a string')
|
||||
}
|
||||
return { role: 'system', content: message.content }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
}
|
||||
}
|
||||
}
|
||||
252
src/core/llm/infio.ts
Normal file
252
src/core/llm/infio.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import OpenAI from 'openai'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
} from 'openai/resources/chat/completions'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
|
||||
export interface RangeFilter {
|
||||
gte?: number;
|
||||
lte?: number;
|
||||
}
|
||||
|
||||
export interface ChunkFilter {
|
||||
field: string;
|
||||
match_all?: string[];
|
||||
range?: RangeFilter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for making requests to the Infio API
|
||||
*/
|
||||
export interface InfioRequest {
|
||||
/** Required: The content of the user message to attach to the topic and then generate an assistant message in response to */
|
||||
messages: RequestMessage[];
|
||||
// /** Required: The ID of the topic to attach the message to */
|
||||
// topic_id: string;
|
||||
/** Optional: URLs to include */
|
||||
links?: string[];
|
||||
/** Optional: Files to include */
|
||||
files?: string[];
|
||||
/** Optional: Whether to highlight results in chunk_html. Default is true */
|
||||
highlight_results?: boolean;
|
||||
/** Optional: Delimiters for highlighting citations. Default is [".", "!", "?", "\n", "\t", ","] */
|
||||
highlight_delimiters?: string[];
|
||||
/** Optional: Search type - "semantic", "fulltext", or "hybrid". Default is "hybrid" */
|
||||
search_type?: string;
|
||||
/** Optional: Filters for chunk filtering */
|
||||
filters?: ChunkFilter;
|
||||
/** Optional: Whether to use web search API. Default is false */
|
||||
use_web_search?: boolean;
|
||||
/** Optional: LLM model to use */
|
||||
llm_model?: string;
|
||||
/** Optional: Force source */
|
||||
force_source?: string;
|
||||
/** Optional: Whether completion should come before chunks in stream. Default is false */
|
||||
completion_first?: boolean;
|
||||
/** Optional: Whether to stream the response. Default is true */
|
||||
stream_response?: boolean;
|
||||
/** Optional: Sampling temperature between 0 and 2. Default is 0.5 */
|
||||
temperature?: number;
|
||||
/** Optional: Frequency penalty between -2.0 and 2.0. Default is 0.7 */
|
||||
frequency_penalty?: number;
|
||||
/** Optional: Presence penalty between -2.0 and 2.0. Default is 0.7 */
|
||||
presence_penalty?: number;
|
||||
/** Optional: Maximum tokens to generate */
|
||||
max_tokens?: number;
|
||||
/** Optional: Stop tokens (up to 4 sequences) */
|
||||
stop_tokens?: string[];
|
||||
}
|
||||
|
||||
export class InfioProvider implements BaseLLMProvider {
|
||||
// private adapter: OpenAIMessageAdapter
|
||||
// private client: OpenAI
|
||||
private apiKey: string
|
||||
private baseUrl: string
|
||||
|
||||
constructor(apiKey: string) {
|
||||
// this.client = new OpenAI({ apiKey, dangerouslyAllowBrowser: true })
|
||||
// this.adapter = new OpenAIMessageAdapter()
|
||||
this.apiKey = apiKey
|
||||
this.baseUrl = 'https://api.infio.com/api/raw_message'
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
try {
|
||||
const req: InfioRequest = {
|
||||
messages: request.messages,
|
||||
stream_response: false,
|
||||
temperature: request.temperature,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
max_tokens: request.max_tokens,
|
||||
}
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: this.apiKey,
|
||||
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(req)
|
||||
};
|
||||
|
||||
const response = await fetch(this.baseUrl, options);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data = await response.json() as ChatCompletion;
|
||||
return InfioProvider.parseNonStreamingResponse(data);
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
try {
|
||||
const req: InfioRequest = {
|
||||
llm_model: request.model,
|
||||
messages: request.messages,
|
||||
stream_response: true,
|
||||
temperature: request.temperature,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
max_tokens: request.max_tokens,
|
||||
}
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: this.apiKey,
|
||||
"TR-Dataset": "74aaec22-0cf0-4cba-80a5-ae5c0518344e",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify(req)
|
||||
};
|
||||
|
||||
const response = await fetch(this.baseUrl, options);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
return {
|
||||
[Symbol.asyncIterator]: async function* () {
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split('\n').filter(line => line.trim());
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const jsonData = JSON.parse(line.slice(6)) as ChatCompletionChunk;
|
||||
if (!jsonData || typeof jsonData !== 'object' || !('choices' in jsonData)) {
|
||||
throw new Error('Invalid chunk format received');
|
||||
}
|
||||
yield InfioProvider.parseStreamingResponseChunk(jsonData);
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
system_fingerprint: chunk.system_fingerprint,
|
||||
usage: chunk.usage ?? undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
142
src/core/llm/manager.ts
Normal file
142
src/core/llm/manager.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import { DEEPSEEK_BASE_URL } from '../../constants'
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { AnthropicProvider } from './anthropic'
|
||||
import { GeminiProvider } from './gemini'
|
||||
import { GroqProvider } from './groq'
|
||||
import { InfioProvider } from './infio'
|
||||
import { OllamaProvider } from './ollama'
|
||||
import { OpenAIAuthenticatedProvider } from './openai'
|
||||
import { OpenAICompatibleProvider } from './openai-compatible-provider'
|
||||
|
||||
|
||||
export type LLMManagerInterface = {
|
||||
generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming>
|
||||
streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>>
|
||||
}
|
||||
|
||||
class LLMManager implements LLMManagerInterface {
|
||||
private openaiProvider: OpenAIAuthenticatedProvider
|
||||
private deepseekProvider: OpenAICompatibleProvider
|
||||
private anthropicProvider: AnthropicProvider
|
||||
private googleProvider: GeminiProvider
|
||||
private groqProvider: GroqProvider
|
||||
private infioProvider: InfioProvider
|
||||
private ollamaProvider: OllamaProvider
|
||||
private isInfioEnabled: boolean
|
||||
|
||||
constructor(apiKeys: {
|
||||
deepseek?: string
|
||||
openai?: string
|
||||
anthropic?: string
|
||||
gemini?: string
|
||||
groq?: string
|
||||
infio?: string
|
||||
}) {
|
||||
this.deepseekProvider = new OpenAICompatibleProvider(apiKeys.deepseek ?? '', DEEPSEEK_BASE_URL)
|
||||
this.openaiProvider = new OpenAIAuthenticatedProvider(apiKeys.openai ?? '')
|
||||
this.anthropicProvider = new AnthropicProvider(apiKeys.anthropic ?? '')
|
||||
this.googleProvider = new GeminiProvider(apiKeys.gemini ?? '')
|
||||
this.groqProvider = new GroqProvider(apiKeys.groq ?? '')
|
||||
this.infioProvider = new InfioProvider(apiKeys.infio ?? '')
|
||||
this.ollamaProvider = new OllamaProvider()
|
||||
this.isInfioEnabled = !!apiKeys.infio
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (this.isInfioEnabled) {
|
||||
return await this.infioProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
}
|
||||
// use custom provider
|
||||
switch (model.provider) {
|
||||
case 'deepseek':
|
||||
return await this.deepseekProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'openai':
|
||||
return await this.openaiProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'anthropic':
|
||||
return await this.anthropicProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'google':
|
||||
return await this.googleProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'groq':
|
||||
return await this.groqProvider.generateResponse(model, request, options)
|
||||
case 'ollama':
|
||||
return await this.ollamaProvider.generateResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (this.isInfioEnabled) {
|
||||
return await this.infioProvider.streamResponse(model, request, options)
|
||||
}
|
||||
// use custom provider
|
||||
switch (model.provider) {
|
||||
case 'deepseek':
|
||||
return await this.deepseekProvider.streamResponse(model, request, options)
|
||||
case 'openai':
|
||||
return await this.openaiProvider.streamResponse(model, request, options)
|
||||
case 'anthropic':
|
||||
return await this.anthropicProvider.streamResponse(
|
||||
model,
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case 'google':
|
||||
return await this.googleProvider.streamResponse(model, request, options)
|
||||
case 'groq':
|
||||
return await this.groqProvider.streamResponse(model, request, options)
|
||||
case 'ollama':
|
||||
return await this.ollamaProvider.streamResponse(model, request, options)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default LLMManager
|
||||
104
src/core/llm/ollama.ts
Normal file
104
src/core/llm/ollama.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
/**
|
||||
* This provider is nearly identical to OpenAICompatibleProvider, but uses a custom OpenAI client
|
||||
* (NoStainlessOpenAI) to work around CORS issues specific to Ollama.
|
||||
*/
|
||||
|
||||
import OpenAI from 'openai'
|
||||
import { FinalRequestOptions } from 'openai/core'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import { LLMBaseUrlNotSetException, LLMModelNotSetException } from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class NoStainlessOpenAI extends OpenAI {
|
||||
defaultHeaders() {
|
||||
return {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
}
|
||||
|
||||
buildRequest<Req>(
|
||||
options: FinalRequestOptions<Req>,
|
||||
{ retryCount = 0 }: { retryCount?: number } = {},
|
||||
): { req: RequestInit; url: string; timeout: number } {
|
||||
const req = super.buildRequest(options, { retryCount })
|
||||
const headers = req.req.headers as Record<string, string>
|
||||
Object.keys(headers).forEach((k) => {
|
||||
if (k.startsWith('x-stainless')) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-dynamic-delete
|
||||
delete headers[k]
|
||||
}
|
||||
})
|
||||
return req
|
||||
}
|
||||
}
|
||||
|
||||
export class OllamaProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
|
||||
constructor() {
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama base URL is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
if (!model.name) {
|
||||
throw new LLMModelNotSetException(
|
||||
'Ollama model is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
const client = new NoStainlessOpenAI({
|
||||
baseURL: `${model.baseUrl}/v1`,
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return this.adapter.generateResponse(client, request, options)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama base URL is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
if (!model.name) {
|
||||
throw new LLMModelNotSetException(
|
||||
'Ollama model is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
const client = new NoStainlessOpenAI({
|
||||
baseURL: `${model.baseUrl}/v1`,
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return this.adapter.streamResponse(client, request, options)
|
||||
}
|
||||
}
|
||||
62
src/core/llm/openai-compatible-provider.ts
Normal file
62
src/core/llm/openai-compatible-provider.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import { LLMBaseUrlNotSetException } from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class OpenAICompatibleProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
private client: OpenAI
|
||||
private apiKey: string
|
||||
private baseURL: string
|
||||
|
||||
constructor(apiKey: string, baseURL: string) {
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
this.client = new OpenAI({
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
this.apiKey = apiKey
|
||||
this.baseURL = baseURL
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.baseURL || !this.apiKey) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
return this.adapter.generateResponse(this.client, request, options)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.baseURL || !this.apiKey) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
|
||||
return this.adapter.streamResponse(this.client, request, options)
|
||||
}
|
||||
}
|
||||
155
src/core/llm/openai-message-adapter.ts
Normal file
155
src/core/llm/openai-message-adapter.ts
Normal file
@@ -0,0 +1,155 @@
|
||||
import OpenAI from 'openai'
|
||||
import {
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionMessageParam,
|
||||
} from 'openai/resources/chat/completions'
|
||||
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
RequestMessage,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
export class OpenAIMessageAdapter {
|
||||
async generateResponse(
|
||||
client: OpenAI,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
const response = await client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
OpenAIMessageAdapter.parseRequestMessage(m),
|
||||
),
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
logit_bias: request.logit_bias,
|
||||
prediction: request.prediction,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
return OpenAIMessageAdapter.parseNonStreamingResponse(response)
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
client: OpenAI,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
const stream = await client.chat.completions.create(
|
||||
{
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) =>
|
||||
OpenAIMessageAdapter.parseRequestMessage(m),
|
||||
),
|
||||
max_completion_tokens: request.max_tokens,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
frequency_penalty: request.frequency_penalty,
|
||||
presence_penalty: request.presence_penalty,
|
||||
logit_bias: request.logit_bias,
|
||||
stream: true,
|
||||
stream_options: {
|
||||
include_usage: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
)
|
||||
|
||||
// eslint-disable-next-line no-inner-declarations
|
||||
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
|
||||
for await (const chunk of stream) {
|
||||
yield OpenAIMessageAdapter.parseStreamingResponseChunk(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse()
|
||||
}
|
||||
|
||||
static parseRequestMessage(
|
||||
message: RequestMessage,
|
||||
): ChatCompletionMessageParam {
|
||||
switch (message.role) {
|
||||
case 'user': {
|
||||
const content = Array.isArray(message.content)
|
||||
? message.content.map((part): ChatCompletionContentPart => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return { type: 'text', text: part.text }
|
||||
case 'image_url':
|
||||
return { type: 'image_url', image_url: part.image_url }
|
||||
}
|
||||
})
|
||||
: message.content
|
||||
return { role: 'user', content }
|
||||
}
|
||||
case 'assistant': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('Assistant message should be a string')
|
||||
}
|
||||
return { role: 'assistant', content: message.content }
|
||||
}
|
||||
case 'system': {
|
||||
if (Array.isArray(message.content)) {
|
||||
throw new Error('System message should be a string')
|
||||
}
|
||||
return { role: 'system', content: message.content }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static parseNonStreamingResponse(
|
||||
response: ChatCompletion,
|
||||
): LLMResponseNonStreaming {
|
||||
return {
|
||||
id: response.id,
|
||||
choices: response.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason,
|
||||
message: {
|
||||
content: choice.message.content,
|
||||
role: choice.message.role,
|
||||
},
|
||||
})),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
object: 'chat.completion',
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
static parseStreamingResponseChunk(
|
||||
chunk: ChatCompletionChunk,
|
||||
): LLMResponseStreaming {
|
||||
return {
|
||||
id: chunk.id,
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
finish_reason: choice.finish_reason ?? null,
|
||||
delta: {
|
||||
content: choice.delta.content ?? null,
|
||||
role: choice.delta.role,
|
||||
},
|
||||
})),
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
object: 'chat.completion.chunk',
|
||||
system_fingerprint: chunk.system_fingerprint,
|
||||
usage: chunk.usage ?? undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
91
src/core/llm/openai.ts
Normal file
91
src/core/llm/openai.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CustomLLMModel } from '../../types/llm/model'
|
||||
import {
|
||||
LLMOptions,
|
||||
LLMRequestNonStreaming,
|
||||
LLMRequestStreaming,
|
||||
} from '../../types/llm/request'
|
||||
import {
|
||||
LLMResponseNonStreaming,
|
||||
LLMResponseStreaming,
|
||||
} from '../../types/llm/response'
|
||||
|
||||
import { BaseLLMProvider } from './base'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
} from './exception'
|
||||
import { OpenAIMessageAdapter } from './openai-message-adapter'
|
||||
|
||||
export class OpenAIAuthenticatedProvider implements BaseLLMProvider {
|
||||
private adapter: OpenAIMessageAdapter
|
||||
private client: OpenAI
|
||||
|
||||
constructor(apiKey: string) {
|
||||
this.client = new OpenAI({
|
||||
apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
this.adapter = new OpenAIMessageAdapter()
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestNonStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<LLMResponseNonStreaming> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new OpenAI({
|
||||
apiKey: model.apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
try {
|
||||
return this.adapter.generateResponse(this.client, request, options)
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async streamResponse(
|
||||
model: CustomLLMModel,
|
||||
request: LLMRequestStreaming,
|
||||
options?: LLMOptions,
|
||||
): Promise<AsyncIterable<LLMResponseStreaming>> {
|
||||
if (!this.client.apiKey) {
|
||||
if (!model.baseUrl) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
this.client = new OpenAI({
|
||||
apiKey: model.apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
return this.adapter.streamResponse(this.client, request, options)
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) {
|
||||
throw new LLMAPIKeyInvalidException(
|
||||
'OpenAI API key is invalid. Please update it in settings menu.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
151
src/core/rag/embedding.ts
Normal file
151
src/core/rag/embedding.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
import { EmbeddingModel } from '../../types/embedding'
|
||||
import {
|
||||
LLMAPIKeyNotSetException,
|
||||
LLMBaseUrlNotSetException,
|
||||
LLMRateLimitExceededException,
|
||||
} from '../llm/exception'
|
||||
import { NoStainlessOpenAI } from '../llm/ollama'
|
||||
|
||||
export const getEmbeddingModel = (
|
||||
embeddingModelId: string,
|
||||
apiKeys: {
|
||||
openAIApiKey: string
|
||||
geminiApiKey: string
|
||||
},
|
||||
ollamaBaseUrl: string,
|
||||
): EmbeddingModel => {
|
||||
switch (embeddingModelId) {
|
||||
case 'text-embedding-3-small': {
|
||||
const openai = new OpenAI({
|
||||
apiKey: apiKeys.openAIApiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return {
|
||||
id: 'text-embedding-3-small',
|
||||
dimension: 1536,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
if (!openai.apiKey) {
|
||||
throw new LLMAPIKeyNotSetException(
|
||||
'OpenAI API key is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'text-embedding-3-small',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
} catch (error) {
|
||||
if (
|
||||
error.status === 429 &&
|
||||
error.message.toLowerCase().includes('rate limit')
|
||||
) {
|
||||
throw new LLMRateLimitExceededException(
|
||||
'OpenAI API rate limit exceeded. Please try again later.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'text-embedding-004': {
|
||||
const client = new GoogleGenerativeAI(apiKeys.geminiApiKey)
|
||||
const model = client.getGenerativeModel({ model: 'text-embedding-004' })
|
||||
return {
|
||||
id: 'text-embedding-004',
|
||||
dimension: 768,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
const response = await model.embedContent(text)
|
||||
return response.embedding.values
|
||||
} catch (error) {
|
||||
if (
|
||||
error.status === 429 &&
|
||||
error.message.includes('RATE_LIMIT_EXCEEDED')
|
||||
) {
|
||||
throw new LLMRateLimitExceededException(
|
||||
'Gemini API rate limit exceeded. Please try again later.',
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'nomic-embed-text': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'nomic-embed-text',
|
||||
dimension: 768,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'nomic-embed-text',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'mxbai-embed-large': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'mxbai-embed-large',
|
||||
dimension: 1024,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'mxbai-embed-large',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
case 'bge-m3': {
|
||||
const openai = new NoStainlessOpenAI({
|
||||
apiKey: '',
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${ollamaBaseUrl}/v1`,
|
||||
})
|
||||
return {
|
||||
id: 'bge-m3',
|
||||
dimension: 1024,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!ollamaBaseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
'Ollama Address is missing. Please set it in settings menu.',
|
||||
)
|
||||
}
|
||||
const embedding = await openai.embeddings.create({
|
||||
model: 'bge-m3',
|
||||
input: text,
|
||||
})
|
||||
return embedding.data[0].embedding
|
||||
},
|
||||
}
|
||||
}
|
||||
default:
|
||||
throw new Error('Invalid embedding model')
|
||||
}
|
||||
}
|
||||
124
src/core/rag/rag-engine.ts
Normal file
124
src/core/rag/rag-engine.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
import { App } from 'obsidian'
|
||||
|
||||
import { QueryProgressState } from '../../components/chat-view/QueryProgress'
|
||||
import { DBManager } from '../../database/database-manager'
|
||||
import { VectorManager } from '../../database/modules/vector/vector-manager'
|
||||
import { SelectVector } from '../../database/schema'
|
||||
import { EmbeddingModel } from '../../types/embedding'
|
||||
import { InfioSettings } from '../../types/settings'
|
||||
|
||||
import { getEmbeddingModel } from './embedding'
|
||||
|
||||
export class RAGEngine {
|
||||
private app: App
|
||||
private settings: InfioSettings
|
||||
private vectorManager: VectorManager
|
||||
private embeddingModel: EmbeddingModel | null = null
|
||||
|
||||
constructor(
|
||||
app: App,
|
||||
settings: InfioSettings,
|
||||
dbManager: DBManager,
|
||||
) {
|
||||
this.app = app
|
||||
this.settings = settings
|
||||
this.vectorManager = dbManager.getVectorManager()
|
||||
this.embeddingModel = getEmbeddingModel(
|
||||
settings.embeddingModelId,
|
||||
{
|
||||
openAIApiKey: settings.openAIApiKey,
|
||||
geminiApiKey: settings.geminiApiKey,
|
||||
},
|
||||
settings.ollamaEmbeddingModel.baseUrl,
|
||||
)
|
||||
}
|
||||
|
||||
setSettings(settings: InfioSettings) {
|
||||
this.settings = settings
|
||||
this.embeddingModel = getEmbeddingModel(
|
||||
settings.embeddingModelId,
|
||||
{
|
||||
openAIApiKey: settings.openAIApiKey,
|
||||
geminiApiKey: settings.geminiApiKey,
|
||||
},
|
||||
settings.ollamaEmbeddingModel.baseUrl,
|
||||
)
|
||||
}
|
||||
|
||||
// TODO: Implement automatic vault re-indexing when settings are changed.
|
||||
// Currently, users must manually re-index the vault.
|
||||
async updateVaultIndex(
|
||||
options: { reindexAll: boolean } = {
|
||||
reindexAll: false,
|
||||
},
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void,
|
||||
): Promise<void> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
await this.vectorManager.updateVaultIndex(
|
||||
this.embeddingModel,
|
||||
{
|
||||
chunkSize: this.settings.ragOptions.chunkSize,
|
||||
excludePatterns: this.settings.ragOptions.excludePatterns,
|
||||
includePatterns: this.settings.ragOptions.includePatterns,
|
||||
reindexAll: options.reindexAll,
|
||||
},
|
||||
(indexProgress) => {
|
||||
onQueryProgressChange?.({
|
||||
type: 'indexing',
|
||||
indexProgress,
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
async processQuery({
|
||||
query,
|
||||
scope,
|
||||
onQueryProgressChange,
|
||||
}: {
|
||||
query: string
|
||||
scope?: {
|
||||
files: string[]
|
||||
folders: string[]
|
||||
}
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
||||
}): Promise<
|
||||
(Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
})[]
|
||||
> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
// TODO: Decide the vault index update strategy.
|
||||
// Current approach: Update on every query.
|
||||
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
||||
const queryEmbedding = await this.getQueryEmbedding(query)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying',
|
||||
})
|
||||
const queryResult = await this.vectorManager.performSimilaritySearch(
|
||||
queryEmbedding,
|
||||
this.embeddingModel,
|
||||
{
|
||||
minSimilarity: this.settings.ragOptions.minSimilarity,
|
||||
limit: this.settings.ragOptions.limit,
|
||||
scope,
|
||||
},
|
||||
)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying-done',
|
||||
queryResult,
|
||||
})
|
||||
return queryResult
|
||||
}
|
||||
|
||||
private async getQueryEmbedding(query: string): Promise<number[]> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
return this.embeddingModel.getEmbedding(query)
|
||||
}
|
||||
}
|
||||
180
src/database/database-manager.ts
Normal file
180
src/database/database-manager.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
import { PGlite } from '@electric-sql/pglite'
|
||||
import { type PGliteWithLive, live } from '@electric-sql/pglite/live'
|
||||
// import { PgliteDatabase, drizzle } from 'drizzle-orm/pglite'
|
||||
import { App, normalizePath } from 'obsidian'
|
||||
|
||||
import { PGLITE_DB_PATH } from '../constants'
|
||||
|
||||
import { ConversationManager } from './modules/conversation/conversation-manager'
|
||||
import { TemplateManager } from './modules/template/template-manager'
|
||||
import { VectorManager } from './modules/vector/vector-manager'
|
||||
import { pgliteResources } from './pglite-resources'
|
||||
import { migrations } from './sql'
|
||||
|
||||
export class DBManager {
|
||||
private app: App
|
||||
private dbPath: string
|
||||
private db: PGliteWithLive | null = null
|
||||
// private db: PgliteDatabase | null = null
|
||||
private vectorManager: VectorManager
|
||||
private templateManager: TemplateManager
|
||||
private conversationManager: ConversationManager
|
||||
|
||||
constructor(app: App, dbPath: string) {
|
||||
this.app = app
|
||||
this.dbPath = dbPath
|
||||
}
|
||||
|
||||
static async create(app: App): Promise<DBManager> {
|
||||
const dbManager = new DBManager(app, normalizePath(PGLITE_DB_PATH))
|
||||
await dbManager.loadExistingDatabase()
|
||||
if (!dbManager.db) {
|
||||
await dbManager.createNewDatabase()
|
||||
}
|
||||
await dbManager.migrateDatabase()
|
||||
await dbManager.save()
|
||||
|
||||
dbManager.vectorManager = new VectorManager(app, dbManager)
|
||||
dbManager.templateManager = new TemplateManager(app, dbManager)
|
||||
dbManager.conversationManager = new ConversationManager(app, dbManager)
|
||||
|
||||
console.log('Smart composer database initialized.')
|
||||
return dbManager
|
||||
}
|
||||
|
||||
// getDb() {
|
||||
// return this.db
|
||||
// }
|
||||
|
||||
getPgClient() {
|
||||
return this.db
|
||||
}
|
||||
|
||||
getVectorManager() {
|
||||
return this.vectorManager
|
||||
}
|
||||
|
||||
getTemplateManager() {
|
||||
return this.templateManager
|
||||
}
|
||||
|
||||
getConversationManager() {
|
||||
return this.conversationManager
|
||||
}
|
||||
|
||||
private async createNewDatabase() {
|
||||
const { fsBundle, wasmModule, vectorExtensionBundlePath } =
|
||||
await this.loadPGliteResources()
|
||||
this.db = await PGlite.create({
|
||||
fsBundle: fsBundle,
|
||||
wasmModule: wasmModule,
|
||||
extensions: {
|
||||
vector: vectorExtensionBundlePath,
|
||||
live,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
private async loadExistingDatabase() {
|
||||
try {
|
||||
const databaseFileExists = await this.app.vault.adapter.exists(
|
||||
this.dbPath,
|
||||
)
|
||||
if (!databaseFileExists) {
|
||||
return null
|
||||
}
|
||||
const fileBuffer = await this.app.vault.adapter.readBinary(this.dbPath)
|
||||
const fileBlob = new Blob([fileBuffer], { type: 'application/x-gzip' })
|
||||
const { fsBundle, wasmModule, vectorExtensionBundlePath } =
|
||||
await this.loadPGliteResources()
|
||||
this.db = await PGlite.create({
|
||||
loadDataDir: fileBlob,
|
||||
fsBundle: fsBundle,
|
||||
wasmModule: wasmModule,
|
||||
extensions: {
|
||||
vector: vectorExtensionBundlePath,
|
||||
live
|
||||
},
|
||||
})
|
||||
// return drizzle(this.pgClient)
|
||||
} catch (error) {
|
||||
console.error('Error loading database:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
private async migrateDatabase(): Promise<void> {
|
||||
if (!this.db) {
|
||||
throw new Error('Database client not initialized');
|
||||
}
|
||||
|
||||
try {
|
||||
// Execute SQL migrations
|
||||
for (const [_key, migration] of Object.entries(migrations)) {
|
||||
// Split SQL into individual commands and execute them one by one
|
||||
const commands = migration.sql.split('\n\n').filter(cmd => cmd.trim());
|
||||
for (const command of commands) {
|
||||
console.log('Executing SQL migration:', command);
|
||||
await this.db.query(command);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error executing SQL migrations:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async save(): Promise<void> {
|
||||
if (!this.db) {
|
||||
return
|
||||
}
|
||||
try {
|
||||
const blob: Blob = await this.db.dumpDataDir('gzip')
|
||||
await this.app.vault.adapter.writeBinary(
|
||||
this.dbPath,
|
||||
Buffer.from(await blob.arrayBuffer()),
|
||||
)
|
||||
} catch (error) {
|
||||
console.error('Error saving database:', error)
|
||||
}
|
||||
}
|
||||
|
||||
async cleanup() {
|
||||
this.db?.close()
|
||||
this.db = null
|
||||
}
|
||||
|
||||
private async loadPGliteResources(): Promise<{
|
||||
fsBundle: Blob
|
||||
wasmModule: WebAssembly.Module
|
||||
vectorExtensionBundlePath: URL
|
||||
}> {
|
||||
try {
|
||||
// Convert base64 to binary data
|
||||
const wasmBinary = Buffer.from(pgliteResources.wasmBase64, 'base64')
|
||||
const dataBinary = Buffer.from(pgliteResources.dataBase64, 'base64')
|
||||
const vectorBinary = Buffer.from(pgliteResources.vectorBase64, 'base64')
|
||||
|
||||
// Create blobs from binary data
|
||||
const fsBundle = new Blob([dataBinary], {
|
||||
type: 'application/octet-stream',
|
||||
})
|
||||
const wasmModule = await WebAssembly.compile(wasmBinary)
|
||||
|
||||
// Create a blob URL for the vector extension
|
||||
const vectorBlob = new Blob([vectorBinary], {
|
||||
type: 'application/gzip',
|
||||
})
|
||||
const vectorExtensionBundlePath = URL.createObjectURL(vectorBlob)
|
||||
|
||||
return {
|
||||
fsBundle,
|
||||
wasmModule,
|
||||
vectorExtensionBundlePath: new URL(vectorExtensionBundlePath),
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading PGlite resources:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
20
src/database/exception.ts
Normal file
20
src/database/exception.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
export class DatabaseException extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'DatabaseException'
|
||||
}
|
||||
}
|
||||
|
||||
export class DatabaseNotInitializedException extends DatabaseException {
|
||||
constructor(message = 'Database not initialized') {
|
||||
super(message)
|
||||
this.name = 'DatabaseNotInitializedException'
|
||||
}
|
||||
}
|
||||
|
||||
export class DuplicateTemplateException extends DatabaseException {
|
||||
constructor(templateName: string) {
|
||||
super(`Template with name "${templateName}" already exists`)
|
||||
this.name = 'DuplicateTemplateException'
|
||||
}
|
||||
}
|
||||
162
src/database/modules/conversation/conversation-manager.ts
Normal file
162
src/database/modules/conversation/conversation-manager.ts
Normal file
@@ -0,0 +1,162 @@
|
||||
import { SerializedEditorState } from 'lexical'
|
||||
import { App } from 'obsidian'
|
||||
|
||||
import { editorStateToPlainText } from '../../../components/chat-view/chat-input/utils/editor-state-to-plain-text'
|
||||
import { ChatAssistantMessage, ChatConversationMeta, ChatMessage, ChatUserMessage } from '../../../types/chat'
|
||||
import { ContentPart } from '../../../types/llm/request'
|
||||
import { Mentionable, SerializedMentionable } from '../../../types/mentionable'
|
||||
import { deserializeMentionable, serializeMentionable } from '../../../utils/mentionable'
|
||||
import { DBManager } from '../../database-manager'
|
||||
import { InsertMessage } from '../../schema'
|
||||
|
||||
import { ConversationRepository } from './conversation-repository'
|
||||
|
||||
export class ConversationManager {
|
||||
private app: App
|
||||
private repository: ConversationRepository
|
||||
private dbManager: DBManager
|
||||
|
||||
constructor(app: App, dbManager: DBManager) {
|
||||
this.app = app
|
||||
this.dbManager = dbManager
|
||||
const db = dbManager.getPgClient()
|
||||
if (!db) throw new Error('Database not initialized')
|
||||
this.repository = new ConversationRepository(app, db)
|
||||
}
|
||||
|
||||
async createConversation(id: string, title = 'New chat'): Promise<void> {
|
||||
const conversation = {
|
||||
id,
|
||||
title,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
await this.repository.create(conversation)
|
||||
await this.dbManager.save()
|
||||
}
|
||||
|
||||
async saveConversation(id: string, messages: ChatMessage[]): Promise<void> {
|
||||
const conversation = await this.repository.findById(id)
|
||||
if (!conversation) {
|
||||
let title = 'New chat'
|
||||
if (messages.length > 0 && messages[0].role === 'user') {
|
||||
const query = editorStateToPlainText(messages[0].content)
|
||||
if (query.length > 20) {
|
||||
title = `${query.slice(0, 20)}...`
|
||||
} else {
|
||||
title = query
|
||||
}
|
||||
}
|
||||
await this.createConversation(id, title)
|
||||
}
|
||||
|
||||
// Delete existing messages
|
||||
await this.repository.deleteAllMessagesFromConversation(id)
|
||||
|
||||
// Insert new messages
|
||||
for (const message of messages) {
|
||||
const insertMessage = this.serializeMessage(message, id)
|
||||
await this.repository.createMessage(insertMessage)
|
||||
}
|
||||
|
||||
// Update conversation timestamp
|
||||
await this.repository.update(id, { updatedAt: new Date() })
|
||||
await this.dbManager.save()
|
||||
}
|
||||
|
||||
async findConversation(id: string): Promise<ChatMessage[] | null> {
|
||||
const conversation = await this.repository.findById(id)
|
||||
if (!conversation) {
|
||||
return null
|
||||
}
|
||||
|
||||
const messages = await this.repository.findMessagesByConversationId(id)
|
||||
return messages.map(msg => this.deserializeMessage(msg))
|
||||
}
|
||||
|
||||
async deleteConversation(id: string): Promise<void> {
|
||||
await this.repository.delete(id)
|
||||
await this.dbManager.save()
|
||||
}
|
||||
|
||||
getAllConversations(callback: (conversations: ChatConversationMeta[]) => void): void {
|
||||
const db = this.dbManager.getPgClient()
|
||||
db?.live.query('SELECT * FROM conversations ORDER BY updated_at', [], (results) => {
|
||||
callback(results.rows.map(conv => ({
|
||||
id: conv.id,
|
||||
title: conv.title,
|
||||
schemaVersion: 2,
|
||||
createdAt: conv.createdAt instanceof Date ? conv.createdAt.getTime() : conv.createdAt,
|
||||
updatedAt: conv.updatedAt instanceof Date ? conv.updatedAt.getTime() : conv.updatedAt,
|
||||
})))
|
||||
})
|
||||
}
|
||||
|
||||
async updateConversationTitle(id: string, title: string): Promise<void> {
|
||||
await this.repository.update(id, { title })
|
||||
await this.dbManager.save()
|
||||
}
|
||||
|
||||
private serializeMessage(message: ChatMessage, conversationId: string): InsertMessage {
|
||||
const base = {
|
||||
id: message.id,
|
||||
conversationId,
|
||||
role: message.role,
|
||||
createdAt: new Date(),
|
||||
}
|
||||
|
||||
if (message.role === 'user') {
|
||||
const userMessage: ChatUserMessage = message
|
||||
return {
|
||||
...base,
|
||||
content: userMessage.content ? JSON.stringify(userMessage.content) : null,
|
||||
promptContent: userMessage.promptContent
|
||||
? typeof userMessage.promptContent === 'string'
|
||||
? userMessage.promptContent
|
||||
: JSON.stringify(userMessage.promptContent)
|
||||
: null,
|
||||
mentionables: JSON.stringify(userMessage.mentionables.map(serializeMentionable)),
|
||||
similaritySearchResults: userMessage.similaritySearchResults
|
||||
? JSON.stringify(userMessage.similaritySearchResults)
|
||||
: null,
|
||||
}
|
||||
} else {
|
||||
const assistantMessage: ChatAssistantMessage = message
|
||||
return {
|
||||
...base,
|
||||
content: assistantMessage.content,
|
||||
metadata: assistantMessage.metadata ? JSON.stringify(assistantMessage.metadata) : null,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private deserializeMessage(message: InsertMessage): ChatMessage {
|
||||
if (message.role === 'user') {
|
||||
return {
|
||||
id: message.id,
|
||||
role: 'user',
|
||||
content: message.content ? JSON.parse(message.content) as SerializedEditorState : null,
|
||||
promptContent: message.promptContent
|
||||
? message.promptContent.startsWith('{')
|
||||
? JSON.parse(message.promptContent) as ContentPart[]
|
||||
: message.promptContent
|
||||
: null,
|
||||
mentionables: message.mentionables
|
||||
? (JSON.parse(message.mentionables) as SerializedMentionable[])
|
||||
.map(m => deserializeMentionable(m, this.app))
|
||||
.filter((m: Mentionable | null): m is Mentionable => m !== null)
|
||||
: [],
|
||||
similaritySearchResults: message.similaritySearchResults
|
||||
? JSON.parse(message.similaritySearchResults)
|
||||
: undefined,
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
id: message.id,
|
||||
role: 'assistant',
|
||||
content: message.content || '',
|
||||
metadata: message.metadata ? JSON.parse(message.metadata) : undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
131
src/database/modules/conversation/conversation-repository.ts
Normal file
131
src/database/modules/conversation/conversation-repository.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
import { PGliteInterface } from '@electric-sql/pglite'
|
||||
import { App } from 'obsidian'
|
||||
|
||||
import {
|
||||
InsertConversation,
|
||||
InsertMessage,
|
||||
SelectConversation,
|
||||
SelectMessage,
|
||||
} from '../../schema'
|
||||
|
||||
type QueryResult<T> = {
|
||||
rows: T[]
|
||||
}
|
||||
|
||||
export class ConversationRepository {
|
||||
private app: App
|
||||
private db: PGliteInterface
|
||||
|
||||
constructor(app: App, db: PGliteInterface) {
|
||||
this.app = app
|
||||
this.db = db
|
||||
}
|
||||
|
||||
async create(conversation: InsertConversation): Promise<SelectConversation> {
|
||||
const result = await this.db.query<SelectConversation>(
|
||||
`INSERT INTO conversations (id, title, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING *`,
|
||||
[
|
||||
conversation.id,
|
||||
conversation.title,
|
||||
conversation.createdAt || new Date(),
|
||||
conversation.updatedAt || new Date()
|
||||
]
|
||||
) as QueryResult<SelectConversation>
|
||||
return result.rows[0]
|
||||
}
|
||||
|
||||
async createMessage(message: InsertMessage): Promise<SelectMessage> {
|
||||
const result = await this.db.query<SelectMessage>(
|
||||
`INSERT INTO messages (
|
||||
id, conversation_id, role, content,
|
||||
prompt_content, metadata, mentionables,
|
||||
similarity_search_results, created_at
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING *`,
|
||||
[
|
||||
message.id,
|
||||
message.conversationId,
|
||||
message.role,
|
||||
message.content,
|
||||
message.promptContent,
|
||||
message.metadata,
|
||||
message.mentionables,
|
||||
message.similaritySearchResults,
|
||||
message.createdAt || new Date()
|
||||
]
|
||||
) as QueryResult<SelectMessage>
|
||||
return result.rows[0]
|
||||
}
|
||||
|
||||
async findById(id: string): Promise<SelectConversation | undefined> {
|
||||
const result = await this.db.query<SelectConversation>(
|
||||
`SELECT * FROM conversations WHERE id = $1 LIMIT 1`,
|
||||
[id]
|
||||
) as QueryResult<SelectConversation>
|
||||
return result.rows[0]
|
||||
}
|
||||
|
||||
async findMessagesByConversationId(conversationId: string): Promise<SelectMessage[]> {
|
||||
const result = await this.db.query<SelectMessage>(
|
||||
`SELECT * FROM messages
|
||||
WHERE conversation_id = $1
|
||||
ORDER BY created_at`,
|
||||
[conversationId]
|
||||
) as QueryResult<SelectMessage>
|
||||
return result.rows
|
||||
}
|
||||
|
||||
async findAll(): Promise<SelectConversation[]> {
|
||||
const result = await this.db.query<SelectConversation>(
|
||||
`SELECT * FROM conversations ORDER BY updated_at DESC`
|
||||
) as QueryResult<SelectConversation>
|
||||
return result.rows
|
||||
}
|
||||
|
||||
async update(id: string, data: Partial<InsertConversation>): Promise<SelectConversation> {
|
||||
const setClauses: string[] = []
|
||||
const values: any[] = []
|
||||
let paramIndex = 1
|
||||
|
||||
if (data.title !== undefined) {
|
||||
setClauses.push(`title = $${paramIndex}`)
|
||||
values.push(data.title)
|
||||
paramIndex++
|
||||
}
|
||||
|
||||
// Always update updated_at
|
||||
setClauses.push(`updated_at = $${paramIndex}`)
|
||||
values.push(new Date())
|
||||
paramIndex++
|
||||
|
||||
// Add id as the last parameter
|
||||
values.push(id)
|
||||
|
||||
const result = await this.db.query<SelectConversation>(
|
||||
`UPDATE conversations
|
||||
SET ${setClauses.join(', ')}
|
||||
WHERE id = $${paramIndex}
|
||||
RETURNING *`,
|
||||
values
|
||||
) as QueryResult<SelectConversation>
|
||||
return result.rows[0]
|
||||
}
|
||||
|
||||
async delete(id: string): Promise<boolean> {
|
||||
const result = await this.db.query<SelectConversation>(
|
||||
`DELETE FROM conversations WHERE id = $1 RETURNING *`,
|
||||
[id]
|
||||
) as QueryResult<SelectConversation>
|
||||
return result.rows.length > 0
|
||||
}
|
||||
|
||||
async deleteAllMessagesFromConversation(conversationId: string): Promise<void> {
|
||||
await this.db.query(
|
||||
`DELETE FROM messages WHERE conversation_id = $1`,
|
||||
[conversationId]
|
||||
)
|
||||
}
|
||||
}
|
||||
51
src/database/modules/template/template-manager.ts
Normal file
51
src/database/modules/template/template-manager.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import fuzzysort from 'fuzzysort'
|
||||
import { App } from 'obsidian'
|
||||
|
||||
import { DBManager } from '../../database-manager'
|
||||
import { DuplicateTemplateException } from '../../exception'
|
||||
import { InsertTemplate, SelectTemplate } from '../../schema'
|
||||
|
||||
import { TemplateRepository } from './template-repository'
|
||||
|
||||
export class TemplateManager {
|
||||
private app: App
|
||||
private repository: TemplateRepository
|
||||
private dbManager: DBManager
|
||||
|
||||
constructor(app: App, dbManager: DBManager) {
|
||||
this.app = app
|
||||
this.dbManager = dbManager
|
||||
this.repository = new TemplateRepository(app, dbManager.getPgClient())
|
||||
}
|
||||
|
||||
async createTemplate(template: InsertTemplate): Promise<SelectTemplate> {
|
||||
const existingTemplate = await this.repository.findByName(template.name)
|
||||
if (existingTemplate) {
|
||||
throw new DuplicateTemplateException(template.name)
|
||||
}
|
||||
const created = await this.repository.create(template)
|
||||
await this.dbManager.save()
|
||||
return created
|
||||
}
|
||||
|
||||
async findAllTemplates(): Promise<SelectTemplate[]> {
|
||||
return await this.repository.findAll()
|
||||
}
|
||||
|
||||
async searchTemplates(query: string): Promise<SelectTemplate[]> {
|
||||
const templates = await this.findAllTemplates()
|
||||
const results = fuzzysort.go(query, templates, {
|
||||
keys: ['name'],
|
||||
threshold: 0.2,
|
||||
limit: 20,
|
||||
all: true,
|
||||
})
|
||||
return results.map((result) => result.obj)
|
||||
}
|
||||
|
||||
async deleteTemplate(id: string): Promise<boolean> {
|
||||
const deleted = await this.repository.delete(id)
|
||||
await this.dbManager.save()
|
||||
return deleted
|
||||
}
|
||||
}
|
||||
98
src/database/modules/template/template-repository.ts
Normal file
98
src/database/modules/template/template-repository.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
import { PGliteInterface } from '@electric-sql/pglite'
|
||||
import { App } from 'obsidian'
|
||||
|
||||
import { DatabaseNotInitializedException } from '../../exception'
|
||||
import { type InsertTemplate, type SelectTemplate } from '../../schema'
|
||||
|
||||
export class TemplateRepository {
|
||||
private app: App
|
||||
private db: PGliteInterface | null
|
||||
|
||||
constructor(app: App, pgClient: PGliteInterface | null) {
|
||||
this.app = app
|
||||
this.db = pgClient
|
||||
}
|
||||
|
||||
async create(template: InsertTemplate): Promise<SelectTemplate> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
|
||||
const result = await this.db.query<SelectTemplate>(
|
||||
`INSERT INTO "template" (name, content)
|
||||
VALUES ($1, $2)
|
||||
RETURNING *`,
|
||||
[template.name, template.content]
|
||||
)
|
||||
return result.rows[0]
|
||||
}
|
||||
|
||||
async findAll(): Promise<SelectTemplate[]> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const result = await this.db.query<SelectTemplate>(
|
||||
`SELECT * FROM "template"`
|
||||
)
|
||||
return result.rows
|
||||
}
|
||||
|
||||
async findByName(name: string): Promise<SelectTemplate | null> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const result = await this.db.query<SelectTemplate>(
|
||||
`SELECT * FROM "template" WHERE name = $1`,
|
||||
[name]
|
||||
)
|
||||
return result.rows[0] ?? null
|
||||
}
|
||||
|
||||
async update(
|
||||
id: string,
|
||||
template: Partial<InsertTemplate>,
|
||||
): Promise<SelectTemplate | null> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
|
||||
const setClauses: string[] = []
|
||||
const params: any[] = []
|
||||
let paramIndex = 1
|
||||
|
||||
if (template.name !== undefined) {
|
||||
setClauses.push(`name = $${paramIndex}`)
|
||||
params.push(template.name)
|
||||
paramIndex++
|
||||
}
|
||||
|
||||
if (template.content !== undefined) {
|
||||
setClauses.push(`content = $${paramIndex}`)
|
||||
params.push(template.content)
|
||||
paramIndex++
|
||||
}
|
||||
|
||||
setClauses.push(`updated_at = now()`)
|
||||
params.push(id)
|
||||
|
||||
const result = await this.db.query<SelectTemplate>(
|
||||
`UPDATE "template"
|
||||
SET ${setClauses.join(', ')}
|
||||
WHERE id = $${paramIndex}
|
||||
RETURNING *`,
|
||||
params
|
||||
)
|
||||
return result.rows[0] ?? null
|
||||
}
|
||||
|
||||
async delete(id: string): Promise<boolean> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const result = await this.db.query<SelectTemplate>(
|
||||
`DELETE FROM "template" WHERE id = $1 RETURNING *`,
|
||||
[id]
|
||||
)
|
||||
return result.rows.length > 0
|
||||
}
|
||||
}
|
||||
277
src/database/modules/vector/vector-manager.ts
Normal file
277
src/database/modules/vector/vector-manager.ts
Normal file
@@ -0,0 +1,277 @@
|
||||
import { backOff } from 'exponential-backoff'
|
||||
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
|
||||
import { minimatch } from 'minimatch'
|
||||
import { App, Notice, TFile } from 'obsidian'
|
||||
import pLimit from 'p-limit'
|
||||
|
||||
import { IndexProgress } from '../../../components/chat-view/QueryProgress'
|
||||
import {
|
||||
LLMAPIKeyInvalidException,
|
||||
LLMAPIKeyNotSetException,
|
||||
LLMBaseUrlNotSetException,
|
||||
LLMRateLimitExceededException,
|
||||
} from '../../../core/llm/exception'
|
||||
import { InsertVector, SelectVector } from '../../../database/schema'
|
||||
import { EmbeddingModel } from '../../../types/embedding'
|
||||
import { openSettingsModalWithError } from '../../../utils/open-settings-modal'
|
||||
import { DBManager } from '../../database-manager'
|
||||
|
||||
import { VectorRepository } from './vector-repository'
|
||||
|
||||
export class VectorManager {
|
||||
private app: App
|
||||
private repository: VectorRepository
|
||||
private dbManager: DBManager
|
||||
|
||||
constructor(app: App, dbManager: DBManager) {
|
||||
this.app = app
|
||||
this.dbManager = dbManager
|
||||
this.repository = new VectorRepository(app, dbManager.getPgClient())
|
||||
}
|
||||
|
||||
async performSimilaritySearch(
|
||||
queryVector: number[],
|
||||
embeddingModel: EmbeddingModel,
|
||||
options: {
|
||||
minSimilarity: number
|
||||
limit: number
|
||||
scope?: {
|
||||
files: string[]
|
||||
folders: string[]
|
||||
}
|
||||
},
|
||||
): Promise<
|
||||
(Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
})[]
|
||||
> {
|
||||
return await this.repository.performSimilaritySearch(
|
||||
queryVector,
|
||||
embeddingModel,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
async updateVaultIndex(
|
||||
embeddingModel: EmbeddingModel,
|
||||
options: {
|
||||
chunkSize: number
|
||||
excludePatterns: string[]
|
||||
includePatterns: string[]
|
||||
reindexAll?: boolean
|
||||
},
|
||||
updateProgress?: (indexProgress: IndexProgress) => void,
|
||||
): Promise<void> {
|
||||
let filesToIndex: TFile[]
|
||||
if (options.reindexAll) {
|
||||
filesToIndex = await this.getFilesToIndex({
|
||||
embeddingModel: embeddingModel,
|
||||
excludePatterns: options.excludePatterns,
|
||||
includePatterns: options.includePatterns,
|
||||
reindexAll: true,
|
||||
})
|
||||
await this.repository.clearAllVectors(embeddingModel)
|
||||
} else {
|
||||
await this.deleteVectorsForDeletedFiles(embeddingModel)
|
||||
filesToIndex = await this.getFilesToIndex({
|
||||
embeddingModel: embeddingModel,
|
||||
excludePatterns: options.excludePatterns,
|
||||
includePatterns: options.includePatterns,
|
||||
})
|
||||
await this.repository.deleteVectorsForMultipleFiles(
|
||||
filesToIndex.map((file) => file.path),
|
||||
embeddingModel,
|
||||
)
|
||||
}
|
||||
|
||||
if (filesToIndex.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const textSplitter = RecursiveCharacterTextSplitter.fromLanguage(
|
||||
'markdown',
|
||||
{
|
||||
chunkSize: options.chunkSize,
|
||||
// TODO: Use token-based chunking after migrating to WebAssembly-based tiktoken
|
||||
// Current token counting method is too slow for practical use
|
||||
// lengthFunction: async (text) => {
|
||||
// return await tokenCount(text)
|
||||
// },
|
||||
},
|
||||
)
|
||||
|
||||
const contentChunks: InsertVector[] = (
|
||||
await Promise.all(
|
||||
filesToIndex.map(async (file) => {
|
||||
const fileContent = await this.app.vault.cachedRead(file)
|
||||
const fileDocuments = await textSplitter.createDocuments([
|
||||
fileContent,
|
||||
])
|
||||
return fileDocuments.map((chunk): InsertVector => {
|
||||
return {
|
||||
path: file.path,
|
||||
mtime: file.stat.mtime,
|
||||
content: chunk.pageContent,
|
||||
metadata: {
|
||||
startLine: chunk.metadata.loc.lines.from as number,
|
||||
endLine: chunk.metadata.loc.lines.to as number,
|
||||
},
|
||||
}
|
||||
})
|
||||
}),
|
||||
)
|
||||
).flat()
|
||||
|
||||
updateProgress?.({
|
||||
completedChunks: 0,
|
||||
totalChunks: contentChunks.length,
|
||||
totalFiles: filesToIndex.length,
|
||||
})
|
||||
|
||||
const embeddingProgress = { completed: 0, inserted: 0 }
|
||||
const embeddingChunks: InsertVector[] = []
|
||||
const batchSize = 100
|
||||
const limit = pLimit(50)
|
||||
const abortController = new AbortController()
|
||||
const tasks = contentChunks.map((chunk) =>
|
||||
limit(async () => {
|
||||
if (abortController.signal.aborted) {
|
||||
throw new Error('Operation was aborted')
|
||||
}
|
||||
try {
|
||||
await backOff(
|
||||
async () => {
|
||||
const embedding = await embeddingModel.getEmbedding(chunk.content)
|
||||
const embeddedChunk = {
|
||||
path: chunk.path,
|
||||
mtime: chunk.mtime,
|
||||
content: chunk.content,
|
||||
embedding,
|
||||
metadata: chunk.metadata,
|
||||
}
|
||||
embeddingChunks.push(embeddedChunk)
|
||||
embeddingProgress.completed++
|
||||
updateProgress?.({
|
||||
completedChunks: embeddingProgress.completed,
|
||||
totalChunks: contentChunks.length,
|
||||
totalFiles: filesToIndex.length,
|
||||
})
|
||||
|
||||
// Insert vectors in batches
|
||||
if (
|
||||
embeddingChunks.length >=
|
||||
embeddingProgress.inserted + batchSize ||
|
||||
embeddingChunks.length === contentChunks.length
|
||||
) {
|
||||
await this.repository.insertVectors(
|
||||
embeddingChunks.slice(
|
||||
embeddingProgress.inserted,
|
||||
embeddingProgress.inserted + batchSize,
|
||||
),
|
||||
embeddingModel,
|
||||
)
|
||||
embeddingProgress.inserted += batchSize
|
||||
}
|
||||
},
|
||||
{
|
||||
numOfAttempts: 5,
|
||||
startingDelay: 1000,
|
||||
timeMultiple: 1.5,
|
||||
jitter: 'full',
|
||||
},
|
||||
)
|
||||
} catch (error) {
|
||||
abortController.abort()
|
||||
throw error
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
try {
|
||||
await Promise.all(tasks)
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof LLMAPIKeyNotSetException ||
|
||||
error instanceof LLMAPIKeyInvalidException ||
|
||||
error instanceof LLMBaseUrlNotSetException
|
||||
) {
|
||||
openSettingsModalWithError(this.app, (error as Error).message)
|
||||
} else if (error instanceof LLMRateLimitExceededException) {
|
||||
new Notice(error.message)
|
||||
} else {
|
||||
console.error('Error embedding chunks:', error)
|
||||
throw error
|
||||
}
|
||||
} finally {
|
||||
await this.dbManager.save()
|
||||
}
|
||||
}
|
||||
|
||||
private async deleteVectorsForDeletedFiles(embeddingModel: EmbeddingModel) {
|
||||
const indexedFilePaths =
|
||||
await this.repository.getIndexedFilePaths(embeddingModel)
|
||||
for (const filePath of indexedFilePaths) {
|
||||
if (!this.app.vault.getAbstractFileByPath(filePath)) {
|
||||
await this.repository.deleteVectorsForMultipleFiles(
|
||||
[filePath],
|
||||
embeddingModel,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async getFilesToIndex({
|
||||
embeddingModel,
|
||||
excludePatterns,
|
||||
includePatterns,
|
||||
reindexAll,
|
||||
}: {
|
||||
embeddingModel: EmbeddingModel
|
||||
excludePatterns: string[]
|
||||
includePatterns: string[]
|
||||
reindexAll?: boolean
|
||||
}): Promise<TFile[]> {
|
||||
let filesToIndex = this.app.vault.getMarkdownFiles()
|
||||
|
||||
filesToIndex = filesToIndex.filter((file) => {
|
||||
return !excludePatterns.some((pattern) => minimatch(file.path, pattern))
|
||||
})
|
||||
|
||||
if (includePatterns.length > 0) {
|
||||
filesToIndex = filesToIndex.filter((file) => {
|
||||
return includePatterns.some((pattern) => minimatch(file.path, pattern))
|
||||
})
|
||||
}
|
||||
|
||||
if (reindexAll) {
|
||||
return filesToIndex
|
||||
}
|
||||
|
||||
// Check for updated or new files
|
||||
filesToIndex = await Promise.all(
|
||||
filesToIndex.map(async (file) => {
|
||||
const fileChunks = await this.repository.getVectorsByFilePath(
|
||||
file.path,
|
||||
embeddingModel,
|
||||
)
|
||||
if (fileChunks.length === 0) {
|
||||
// File is not indexed, so we need to index it
|
||||
const fileContent = await this.app.vault.cachedRead(file)
|
||||
if (fileContent.length === 0) {
|
||||
// Ignore empty files
|
||||
return null
|
||||
}
|
||||
return file
|
||||
}
|
||||
const outOfDate = file.stat.mtime > fileChunks[0].mtime
|
||||
if (outOfDate) {
|
||||
// File has changed, so we need to re-index it
|
||||
return file
|
||||
}
|
||||
return null
|
||||
}),
|
||||
).then((files) => files.filter(Boolean) as TFile[])
|
||||
|
||||
return filesToIndex
|
||||
}
|
||||
}
|
||||
180
src/database/modules/vector/vector-repository.ts
Normal file
180
src/database/modules/vector/vector-repository.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
import { PGliteInterface } from '@electric-sql/pglite'
|
||||
import { App } from 'obsidian'
|
||||
|
||||
import { EmbeddingModel } from '../../../types/embedding'
|
||||
import { DatabaseNotInitializedException } from '../../exception'
|
||||
import { InsertVector, SelectVector, vectorTables } from '../../schema'
|
||||
|
||||
export class VectorRepository {
|
||||
private app: App
|
||||
private db: PGliteInterface | null
|
||||
|
||||
constructor(app: App, pgClient: PGliteInterface | null) {
|
||||
this.app = app
|
||||
this.db = pgClient
|
||||
}
|
||||
|
||||
private getTableName(embeddingModel: EmbeddingModel): string {
|
||||
const tableDefinition = vectorTables[embeddingModel.dimension]
|
||||
if (!tableDefinition) {
|
||||
throw new Error(`No table definition found for model: ${embeddingModel.id}`)
|
||||
}
|
||||
return tableDefinition.name
|
||||
}
|
||||
|
||||
async getIndexedFilePaths(embeddingModel: EmbeddingModel): Promise<string[]> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const tableName = this.getTableName(embeddingModel)
|
||||
const result = await this.db.query<{ path: string }>(
|
||||
`SELECT DISTINCT path FROM "${tableName}"`
|
||||
)
|
||||
return result.rows.map((row: { path: string }) => row.path)
|
||||
}
|
||||
|
||||
async getVectorsByFilePath(
|
||||
filePath: string,
|
||||
embeddingModel: EmbeddingModel,
|
||||
): Promise<SelectVector[]> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const tableName = this.getTableName(embeddingModel)
|
||||
const result = await this.db.query<SelectVector>(
|
||||
`SELECT * FROM "${tableName}" WHERE path = $1`,
|
||||
[filePath]
|
||||
)
|
||||
return result.rows
|
||||
}
|
||||
|
||||
async deleteVectorsForSingleFile(
|
||||
filePath: string,
|
||||
embeddingModel: EmbeddingModel,
|
||||
): Promise<void> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const tableName = this.getTableName(embeddingModel)
|
||||
await this.db.query(
|
||||
`DELETE FROM "${tableName}" WHERE path = $1`,
|
||||
[filePath]
|
||||
)
|
||||
}
|
||||
|
||||
async deleteVectorsForMultipleFiles(
|
||||
filePaths: string[],
|
||||
embeddingModel: EmbeddingModel,
|
||||
): Promise<void> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const tableName = this.getTableName(embeddingModel)
|
||||
await this.db.query(
|
||||
`DELETE FROM "${tableName}" WHERE path = ANY($1)`,
|
||||
[filePaths]
|
||||
)
|
||||
}
|
||||
|
||||
async clearAllVectors(embeddingModel: EmbeddingModel): Promise<void> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const tableName = this.getTableName(embeddingModel)
|
||||
await this.db.query(`DELETE FROM "${tableName}"`)
|
||||
}
|
||||
|
||||
async insertVectors(
|
||||
data: InsertVector[],
|
||||
embeddingModel: EmbeddingModel,
|
||||
): Promise<void> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const tableName = this.getTableName(embeddingModel)
|
||||
|
||||
// 构建批量插入的 SQL
|
||||
const values = data.map((vector, index) => {
|
||||
const offset = index * 5
|
||||
return `($${offset + 1}, $${offset + 2}, $${offset + 3}, $${offset + 4}, $${offset + 5})`
|
||||
}).join(',')
|
||||
|
||||
const params = data.flatMap(vector => [
|
||||
vector.path,
|
||||
vector.mtime,
|
||||
vector.content,
|
||||
`[${vector.embedding.join(',')}]`, // 转换为PostgreSQL vector格式
|
||||
vector.metadata
|
||||
])
|
||||
|
||||
await this.db.query(
|
||||
`INSERT INTO "${tableName}" (path, mtime, content, embedding, metadata)
|
||||
VALUES ${values}`,
|
||||
params
|
||||
)
|
||||
}
|
||||
|
||||
async performSimilaritySearch(
|
||||
queryVector: number[],
|
||||
embeddingModel: EmbeddingModel,
|
||||
options: {
|
||||
minSimilarity: number
|
||||
limit: number
|
||||
scope?: {
|
||||
files: string[]
|
||||
folders: string[]
|
||||
}
|
||||
},
|
||||
): Promise<
|
||||
(Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
})[]
|
||||
> {
|
||||
if (!this.db) {
|
||||
throw new DatabaseNotInitializedException()
|
||||
}
|
||||
const tableName = this.getTableName(embeddingModel)
|
||||
|
||||
let scopeCondition = ''
|
||||
const params: any[] = [`[${queryVector.join(',')}]`, options.minSimilarity, options.limit]
|
||||
let paramIndex = 4
|
||||
|
||||
if (options.scope) {
|
||||
const conditions: string[] = []
|
||||
|
||||
if (options.scope.files.length > 0) {
|
||||
conditions.push(`path = ANY($${paramIndex})`)
|
||||
params.push(options.scope.files)
|
||||
paramIndex++
|
||||
}
|
||||
|
||||
if (options.scope.folders.length > 0) {
|
||||
const folderConditions = options.scope.folders.map((folder, idx) => {
|
||||
params.push(`${folder}/%`)
|
||||
return `path LIKE $${paramIndex + idx}`
|
||||
})
|
||||
conditions.push(`(${folderConditions.join(' OR ')})`)
|
||||
paramIndex += options.scope.folders.length
|
||||
}
|
||||
|
||||
if (conditions.length > 0) {
|
||||
scopeCondition = `AND (${conditions.join(' OR ')})`
|
||||
}
|
||||
}
|
||||
|
||||
const query = `
|
||||
SELECT
|
||||
id, path, mtime, content, metadata,
|
||||
1 - (embedding <=> $1::vector) as similarity
|
||||
FROM "${tableName}"
|
||||
WHERE 1 - (embedding <=> $1::vector) > $2
|
||||
${scopeCondition}
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $3
|
||||
`
|
||||
|
||||
type SearchResult = Omit<SelectVector, 'embedding'> & { similarity: number }
|
||||
const result = await this.db.query<SearchResult>(query, params)
|
||||
return result.rows
|
||||
}
|
||||
}
|
||||
7
src/database/pglite-resources.d.ts
vendored
Normal file
7
src/database/pglite-resources.d.ts
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
export interface PgliteResources {
|
||||
wasmBase64: string;
|
||||
dataBase64: string;
|
||||
vectorBase64: string;
|
||||
}
|
||||
|
||||
export const pgliteResources: PgliteResources;
|
||||
7
src/database/pglite-resources.ts
Normal file
7
src/database/pglite-resources.ts
Normal file
File diff suppressed because one or more lines are too long
156
src/database/schema.ts
Normal file
156
src/database/schema.ts
Normal file
@@ -0,0 +1,156 @@
|
||||
import { SerializedLexicalNode } from 'lexical'
|
||||
|
||||
import { SUPPORT_EMBEDDING_SIMENTION } from '../constants'
|
||||
import { EmbeddingModelId } from '../types/embedding'
|
||||
|
||||
// PostgreSQL column types
|
||||
interface ColumnDefinition {
|
||||
type: string
|
||||
notNull?: boolean
|
||||
primaryKey?: boolean
|
||||
defaultRandom?: boolean
|
||||
unique?: boolean
|
||||
defaultNow?: boolean
|
||||
dimensions?: number
|
||||
}
|
||||
|
||||
interface TableDefinition {
|
||||
name: string
|
||||
columns: Record<string, ColumnDefinition>
|
||||
indices?: Record<string, {
|
||||
type: string
|
||||
columns: string[]
|
||||
options?: string
|
||||
}>
|
||||
}
|
||||
|
||||
/* Vector Table */
|
||||
const createVectorTable = (dimension: number): TableDefinition => {
|
||||
const tableName = `embeddings_${dimension}`
|
||||
|
||||
const table: TableDefinition = {
|
||||
name: tableName,
|
||||
columns: {
|
||||
id: { type: 'SERIAL', primaryKey: true },
|
||||
path: { type: 'TEXT', notNull: true },
|
||||
mtime: { type: 'BIGINT', notNull: true },
|
||||
content: { type: 'TEXT', notNull: true },
|
||||
embedding: { type: 'VECTOR', dimensions: dimension },
|
||||
metadata: { type: 'JSONB', notNull: true },
|
||||
}
|
||||
}
|
||||
|
||||
if (dimension <= 2000) {
|
||||
table.indices = {
|
||||
[`embeddingIndex_${dimension}`]: {
|
||||
type: 'HNSW',
|
||||
columns: ['embedding'],
|
||||
options: 'vector_cosine_ops'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return table
|
||||
}
|
||||
|
||||
export const vectorTables = SUPPORT_EMBEDDING_SIMENTION.reduce<
|
||||
Record<number, TableDefinition>
|
||||
>((acc, dimension) => {
|
||||
acc[dimension] = createVectorTable(dimension)
|
||||
return acc
|
||||
}, {})
|
||||
|
||||
// Type definitions for vector table
|
||||
export interface VectorRecord {
|
||||
id: number
|
||||
path: string
|
||||
mtime: number
|
||||
content: string
|
||||
embedding: number[]
|
||||
metadata: VectorMetaData
|
||||
}
|
||||
|
||||
export type SelectVector = VectorRecord
|
||||
export type InsertVector = Omit<VectorRecord, 'id'>
|
||||
|
||||
export type VectorMetaData = {
|
||||
startLine: number
|
||||
endLine: number
|
||||
}
|
||||
|
||||
// // Export individual vector tables for reference
|
||||
// export const vectorTable0 = vectorTables[EMBEDDING_MODEL_OPTIONS[0].id]
|
||||
// export const vectorTable1 = vectorTables[EMBEDDING_MODEL_OPTIONS[1].id]
|
||||
// export const vectorTable2 = vectorTables[EMBEDDING_MODEL_OPTIONS[2].id]
|
||||
// export const vectorTable3 = vectorTables[EMBEDDING_MODEL_OPTIONS[3].id]
|
||||
// export const vectorTable4 = vectorTables[EMBEDDING_MODEL_OPTIONS[4].id]
|
||||
// export const vectorTable5 = vectorTables[EMBEDDING_MODEL_OPTIONS[5].id]
|
||||
|
||||
/* Template Table */
|
||||
export type TemplateContent = {
|
||||
nodes: SerializedLexicalNode[]
|
||||
}
|
||||
|
||||
export interface TemplateRecord {
|
||||
id: string
|
||||
name: string
|
||||
content: TemplateContent
|
||||
createdAt: Date
|
||||
updatedAt: Date
|
||||
}
|
||||
|
||||
export type SelectTemplate = TemplateRecord
|
||||
export type InsertTemplate = Omit<TemplateRecord, 'id' | 'createdAt' | 'updatedAt'>
|
||||
|
||||
export const templateTable: TableDefinition = {
|
||||
name: 'template',
|
||||
columns: {
|
||||
id: { type: 'UUID', primaryKey: true, defaultRandom: true },
|
||||
name: { type: 'TEXT', notNull: true, unique: true },
|
||||
content: { type: 'JSONB', notNull: true },
|
||||
createdAt: { type: 'TIMESTAMP', notNull: true, defaultNow: true },
|
||||
updatedAt: { type: 'TIMESTAMP', notNull: true, defaultNow: true }
|
||||
}
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
id: string // uuid
|
||||
title: string
|
||||
createdAt: Date
|
||||
updatedAt: Date
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string // uuid
|
||||
conversationId: string // uuid
|
||||
role: 'user' | 'assistant'
|
||||
content: string | null
|
||||
promptContent?: string | null
|
||||
metadata?: string | null
|
||||
mentionables?: string | null
|
||||
similaritySearchResults?: string | null
|
||||
createdAt: Date
|
||||
}
|
||||
|
||||
export type InsertConversation = {
|
||||
id: string
|
||||
title: string
|
||||
createdAt?: Date
|
||||
updatedAt?: Date
|
||||
}
|
||||
|
||||
export type SelectConversation = Conversation
|
||||
|
||||
export type InsertMessage = {
|
||||
id: string
|
||||
conversationId: string
|
||||
role: 'user' | 'assistant'
|
||||
content: string | null
|
||||
promptContent?: string | null
|
||||
metadata?: string | null
|
||||
mentionables?: string | null
|
||||
similaritySearchResults?: string | null
|
||||
createdAt?: Date
|
||||
}
|
||||
|
||||
export type SelectMessage = Message
|
||||
118
src/database/sql.ts
Normal file
118
src/database/sql.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
export interface SqlMigration {
|
||||
description: string;
|
||||
sql: string;
|
||||
}
|
||||
|
||||
export const migrations: Record<string, SqlMigration> = {
|
||||
vector: {
|
||||
description: "Creates vector tables and indexes for different models",
|
||||
sql: `
|
||||
-- Enable required extensions
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- Create vector tables for different models
|
||||
CREATE TABLE IF NOT EXISTS "embeddings_1536" (
|
||||
"id" serial PRIMARY KEY NOT NULL,
|
||||
"path" text NOT NULL,
|
||||
"mtime" bigint NOT NULL,
|
||||
"content" text NOT NULL,
|
||||
"embedding" vector(1536),
|
||||
"metadata" jsonb NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "embeddings_1024" (
|
||||
"id" serial PRIMARY KEY NOT NULL,
|
||||
"path" text NOT NULL,
|
||||
"mtime" bigint NOT NULL,
|
||||
"content" text NOT NULL,
|
||||
"embedding" vector(1024),
|
||||
"metadata" jsonb NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "embeddings_768" (
|
||||
"id" serial PRIMARY KEY NOT NULL,
|
||||
"path" text NOT NULL,
|
||||
"mtime" bigint NOT NULL,
|
||||
"content" text NOT NULL,
|
||||
"embedding" vector(768),
|
||||
"metadata" jsonb NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "embeddings_512" (
|
||||
"id" serial PRIMARY KEY NOT NULL,
|
||||
"path" text NOT NULL,
|
||||
"mtime" bigint NOT NULL,
|
||||
"content" text NOT NULL,
|
||||
"embedding" vector(512),
|
||||
"metadata" jsonb NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "embeddings_384" (
|
||||
"id" serial PRIMARY KEY NOT NULL,
|
||||
"path" text NOT NULL,
|
||||
"mtime" bigint NOT NULL,
|
||||
"content" text NOT NULL,
|
||||
"embedding" vector(384),
|
||||
"metadata" jsonb NOT NULL
|
||||
);
|
||||
|
||||
-- Create HNSW indexes for vector similarity search
|
||||
CREATE INDEX IF NOT EXISTS "embeddingIndex_1536"
|
||||
ON "embeddings_1536"
|
||||
USING hnsw ("embedding" vector_cosine_ops);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "embeddingIndex_1024"
|
||||
ON "embeddings_1024"
|
||||
USING hnsw ("embedding" vector_cosine_ops);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "embeddingIndex_768"
|
||||
ON "embeddings_768"
|
||||
USING hnsw ("embedding" vector_cosine_ops);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "embeddingIndex_512"
|
||||
ON "embeddings_512"
|
||||
USING hnsw ("embedding" vector_cosine_ops);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "embeddingIndex_384"
|
||||
ON "embeddings_384"
|
||||
USING hnsw ("embedding" vector_cosine_ops);
|
||||
`
|
||||
},
|
||||
template: {
|
||||
description: "Creates template table with UUID support",
|
||||
sql: `
|
||||
-- Create template table
|
||||
CREATE TABLE IF NOT EXISTS "template" (
|
||||
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
|
||||
"name" text NOT NULL,
|
||||
"content" jsonb NOT NULL,
|
||||
"created_at" timestamp DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp DEFAULT now() NOT NULL,
|
||||
CONSTRAINT "template_name_unique" UNIQUE("name")
|
||||
);
|
||||
`
|
||||
},
|
||||
conversation: {
|
||||
description: "Creates conversations and messages tables",
|
||||
sql: `
|
||||
CREATE TABLE IF NOT EXISTS "conversations" (
|
||||
"id" uuid PRIMARY KEY NOT NULL,
|
||||
"title" text NOT NULL,
|
||||
"created_at" timestamp DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp DEFAULT now() NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "messages" (
|
||||
"id" uuid PRIMARY KEY NOT NULL,
|
||||
"conversation_id" uuid NOT NULL REFERENCES "conversations"("id") ON DELETE CASCADE,
|
||||
"role" text NOT NULL,
|
||||
"content" text,
|
||||
"prompt_content" text,
|
||||
"metadata" text,
|
||||
"mentionables" text,
|
||||
"similarity_search_results" text,
|
||||
"created_at" timestamp DEFAULT now() NOT NULL
|
||||
);
|
||||
`
|
||||
}
|
||||
};
|
||||
302
src/event-listener.ts
Normal file
302
src/event-listener.ts
Normal file
@@ -0,0 +1,302 @@
|
||||
import { EditorView } from "@codemirror/view";
|
||||
import { LRUCache } from "lru-cache";
|
||||
import { App, TFile } from "obsidian";
|
||||
|
||||
import AutoComplete from "./core/autocomplete";
|
||||
import Context from "./core/autocomplete/context-detection";
|
||||
import DisabledFileSpecificState from "./core/autocomplete/states/disabled-file-specific-state";
|
||||
import DisabledInvalidSettingsState from "./core/autocomplete/states/disabled-invalid-settings-state";
|
||||
import DisabledManualState from "./core/autocomplete/states/disabled-manual-state";
|
||||
import IdleState from "./core/autocomplete/states/idle-state";
|
||||
import InitState from "./core/autocomplete/states/init-state";
|
||||
import PredictingState from "./core/autocomplete/states/predicting-state";
|
||||
import QueuedState from "./core/autocomplete/states/queued-state";
|
||||
import State from "./core/autocomplete/states/state";
|
||||
import SuggestingState from "./core/autocomplete/states/suggesting-state";
|
||||
import { EventHandler } from "./core/autocomplete/states/types";
|
||||
import { AutocompleteService } from "./core/autocomplete/types";
|
||||
import { isMatchBetweenPathAndPatterns } from "./core/autocomplete/utils";
|
||||
import { DocumentChanges } from "./render-plugin/document-changes-listener";
|
||||
import { cancelSuggestion, insertSuggestion, updateSuggestion } from "./render-plugin/states";
|
||||
import StatusBar from "./status-bar";
|
||||
import { InfioSettings } from './types/settings';
|
||||
import { checkForErrors } from "./utils/auto-complete";
|
||||
|
||||
|
||||
const FIVE_MINUTES_IN_MS = 1000 * 60 * 5;
|
||||
const MAX_N_ITEMS_IN_CACHE = 5000;
|
||||
|
||||
class EventListener implements EventHandler {
|
||||
private view: EditorView | null = null;
|
||||
|
||||
private state: EventHandler = new InitState();
|
||||
private statusBar: StatusBar;
|
||||
private app: App;
|
||||
context: Context = Context.Text;
|
||||
autocomplete: AutocompleteService;
|
||||
settings: InfioSettings;
|
||||
private currentFile: TFile | null = null;
|
||||
private suggestionCache = new LRUCache<string, string>({ max: MAX_N_ITEMS_IN_CACHE, ttl: FIVE_MINUTES_IN_MS });
|
||||
|
||||
public static fromSettings(
|
||||
settings: InfioSettings,
|
||||
statusBar: StatusBar,
|
||||
app: App
|
||||
): EventListener {
|
||||
const autocomplete = createPredictionService(settings);
|
||||
|
||||
const eventListener = new EventListener(
|
||||
settings,
|
||||
statusBar,
|
||||
app,
|
||||
autocomplete
|
||||
);
|
||||
|
||||
const settingErrors = checkForErrors(settings);
|
||||
if (settings.autocompleteEnabled) {
|
||||
eventListener.transitionToIdleState()
|
||||
} else if (settingErrors.size > 0) {
|
||||
eventListener.transitionToDisabledInvalidSettingsState();
|
||||
} else if (!settings.autocompleteEnabled) {
|
||||
eventListener.transitionToDisabledManualState();
|
||||
}
|
||||
|
||||
return eventListener;
|
||||
}
|
||||
|
||||
private constructor(
|
||||
settings: InfioSettings,
|
||||
statusBar: StatusBar,
|
||||
app: App,
|
||||
autocomplete: AutocompleteService
|
||||
) {
|
||||
this.settings = settings;
|
||||
this.statusBar = statusBar;
|
||||
this.app = app;
|
||||
this.autocomplete = autocomplete;
|
||||
}
|
||||
|
||||
public setContext(context: Context): void {
|
||||
if (context === this.context) {
|
||||
return;
|
||||
}
|
||||
this.context = context;
|
||||
this.updateStatusBarText();
|
||||
}
|
||||
|
||||
public isSuggesting(): boolean {
|
||||
return this.state instanceof SuggestingState;
|
||||
}
|
||||
|
||||
public onViewUpdate(view: EditorView): void {
|
||||
this.view = view;
|
||||
}
|
||||
|
||||
public handleFileChange(file: TFile): void {
|
||||
this.currentFile = file;
|
||||
this.state.handleFileChange(file);
|
||||
|
||||
}
|
||||
|
||||
public isCurrentFilePathIgnored(): boolean {
|
||||
if (this.currentFile === null) {
|
||||
return false;
|
||||
}
|
||||
const patterns = this.settings.ignoredFilePatterns.split("\n");
|
||||
return isMatchBetweenPathAndPatterns(this.currentFile.path, patterns);
|
||||
}
|
||||
|
||||
public currentFileContainsIgnoredTag(): boolean {
|
||||
if (this.currentFile === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ignoredTags = this.settings.ignoredTags.toLowerCase().split('\n');
|
||||
|
||||
const metadata = this.app.metadataCache.getFileCache(this.currentFile);
|
||||
if (!metadata || !metadata.tags) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const tags = metadata.tags.map(tag => tag.tag.replace(/#/g, '').toLowerCase());
|
||||
return tags.some(tag => ignoredTags.includes(tag));
|
||||
}
|
||||
|
||||
|
||||
insertCurrentSuggestion(suggestion: string): void {
|
||||
if (this.view === null) {
|
||||
return;
|
||||
}
|
||||
insertSuggestion(this.view, suggestion);
|
||||
}
|
||||
|
||||
cancelSuggestion(): void {
|
||||
if (this.view === null) {
|
||||
return;
|
||||
}
|
||||
cancelSuggestion(this.view);
|
||||
}
|
||||
|
||||
private transitionTo(state: State): void {
|
||||
this.state = state;
|
||||
this.updateStatusBarText();
|
||||
}
|
||||
|
||||
transitionToDisabledFileSpecificState(): void {
|
||||
this.transitionTo(new DisabledFileSpecificState(this));
|
||||
}
|
||||
|
||||
transitionToDisabledManualState(): void {
|
||||
this.cancelSuggestion();
|
||||
this.transitionTo(new DisabledManualState(this));
|
||||
}
|
||||
|
||||
transitionToDisabledInvalidSettingsState(): void {
|
||||
this.cancelSuggestion();
|
||||
this.transitionTo(new DisabledInvalidSettingsState(this));
|
||||
}
|
||||
|
||||
transitionToQueuedState(prefix: string, suffix: string): void {
|
||||
this.transitionTo(
|
||||
QueuedState.createAndStartTimer(
|
||||
this,
|
||||
prefix,
|
||||
suffix
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
transitionToPredictingState(prefix: string, suffix: string): void {
|
||||
this.transitionTo(PredictingState.createAndStartPredicting(
|
||||
this,
|
||||
prefix,
|
||||
suffix
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
transitionToSuggestingState(
|
||||
suggestion: string,
|
||||
prefix: string,
|
||||
suffix: string,
|
||||
addToCache = true
|
||||
): void {
|
||||
if (this.view === null) {
|
||||
return;
|
||||
}
|
||||
if (suggestion.trim().length === 0) {
|
||||
this.transitionToIdleState();
|
||||
return;
|
||||
}
|
||||
if (addToCache) {
|
||||
this.addSuggestionToCache(prefix, suffix, suggestion);
|
||||
}
|
||||
this.transitionTo(new SuggestingState(this, suggestion, prefix, suffix));
|
||||
updateSuggestion(this.view, suggestion);
|
||||
}
|
||||
|
||||
public transitionToIdleState() {
|
||||
const previousState = this.state;
|
||||
|
||||
this.transitionTo(new IdleState(this));
|
||||
|
||||
if (previousState instanceof SuggestingState) {
|
||||
this.cancelSuggestion();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private updateStatusBarText(): void {
|
||||
this.statusBar.updateText(this.getStatusBarText());
|
||||
}
|
||||
|
||||
getStatusBarText(): string {
|
||||
return `Copilot: ${this.state.getStatusBarText()}`;
|
||||
}
|
||||
|
||||
handleSettingChanged(settings: InfioSettings): void {
|
||||
this.settings = settings;
|
||||
this.autocomplete = createPredictionService(settings);
|
||||
if (!this.settings.cacheSuggestions) {
|
||||
this.clearSuggestionsCache();
|
||||
}
|
||||
|
||||
this.state.handleSettingChanged(settings);
|
||||
}
|
||||
|
||||
async handleDocumentChange(
|
||||
documentChanges: DocumentChanges
|
||||
): Promise<void> {
|
||||
await this.state.handleDocumentChange(documentChanges);
|
||||
}
|
||||
|
||||
handleAcceptKeyPressed(): boolean {
|
||||
return this.state.handleAcceptKeyPressed();
|
||||
}
|
||||
|
||||
handlePartialAcceptKeyPressed(): boolean {
|
||||
return this.state.handlePartialAcceptKeyPressed();
|
||||
}
|
||||
|
||||
handleCancelKeyPressed(): boolean {
|
||||
return this.state.handleCancelKeyPressed();
|
||||
}
|
||||
|
||||
handlePredictCommand(prefix: string, suffix: string): void {
|
||||
this.state.handlePredictCommand(prefix, suffix);
|
||||
}
|
||||
|
||||
handleAcceptCommand(): void {
|
||||
this.state.handleAcceptCommand();
|
||||
}
|
||||
|
||||
containsTriggerCharacters(
|
||||
documentChanges: DocumentChanges
|
||||
): boolean {
|
||||
for (const trigger of this.settings.triggers) {
|
||||
if (trigger.type === "string" && documentChanges.getPrefix().endsWith(trigger.value)) {
|
||||
return true;
|
||||
}
|
||||
if (trigger.type === "regex" && documentChanges.getPrefix().match(trigger.value)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public isDisabled(): boolean {
|
||||
return this.state instanceof DisabledManualState || this.state instanceof DisabledInvalidSettingsState || this.state instanceof DisabledFileSpecificState;
|
||||
}
|
||||
|
||||
public isIdle(): boolean {
|
||||
return this.state instanceof IdleState;
|
||||
}
|
||||
|
||||
public getCachedSuggestionFor(prefix: string, suffix: string): string | undefined {
|
||||
return this.suggestionCache.get(this.getCacheKey(prefix, suffix));
|
||||
}
|
||||
|
||||
private getCacheKey(prefix: string, suffix: string): string {
|
||||
const nCharsToKeepPrefix = prefix.length;
|
||||
const nCharsToKeepSuffix = suffix.length;
|
||||
|
||||
return `${prefix.substring(prefix.length - nCharsToKeepPrefix)}<mask/>${suffix.substring(0, nCharsToKeepSuffix)}`
|
||||
}
|
||||
|
||||
public clearSuggestionsCache(): void {
|
||||
this.suggestionCache.clear();
|
||||
}
|
||||
|
||||
public addSuggestionToCache(prefix: string, suffix: string, suggestion: string): void {
|
||||
if (!this.settings.cacheSuggestions) {
|
||||
return;
|
||||
}
|
||||
this.suggestionCache.set(this.getCacheKey(prefix, suffix), suggestion);
|
||||
}
|
||||
}
|
||||
|
||||
function createPredictionService(settings: InfioSettings) {
|
||||
return AutoComplete.fromSettings(settings);
|
||||
}
|
||||
|
||||
export default EventListener;
|
||||
83
src/hooks/use-chat-history.ts
Normal file
83
src/hooks/use-chat-history.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
|
||||
import { useApp } from '../contexts/AppContext'
|
||||
import { useDatabase } from '../contexts/DatabaseContext'
|
||||
import { DBManager } from '../database/database-manager'
|
||||
import { ChatConversationMeta, ChatMessage } from '../types/chat'
|
||||
|
||||
type UseChatHistory = {
|
||||
createOrUpdateConversation: (
|
||||
id: string,
|
||||
messages: ChatMessage[],
|
||||
) => Promise<void>
|
||||
deleteConversation: (id: string) => Promise<void>
|
||||
getChatMessagesById: (id: string) => Promise<ChatMessage[] | null>
|
||||
updateConversationTitle: (id: string, title: string) => Promise<void>
|
||||
chatList: ChatConversationMeta[]
|
||||
}
|
||||
|
||||
export function useChatHistory(): UseChatHistory {
|
||||
const app = useApp()
|
||||
const { getDatabaseManager } = useDatabase()
|
||||
|
||||
// 这里更新有点繁琐, 但是能保持 chatList 实时更新
|
||||
const [chatList, setChatList] = useState<ChatConversationMeta[]>([])
|
||||
|
||||
const getManager = useCallback(async (): Promise<DBManager> => {
|
||||
return await getDatabaseManager()
|
||||
}, [getDatabaseManager])
|
||||
|
||||
const fetchChatList = useCallback(async () => {
|
||||
const dbManager = await getManager()
|
||||
dbManager.getConversationManager().getAllConversations(setChatList)
|
||||
}, [getManager])
|
||||
|
||||
useEffect(() => {
|
||||
void fetchChatList()
|
||||
}, [fetchChatList])
|
||||
|
||||
// 只新增消息
|
||||
const createConversation = useCallback(
|
||||
async (id: string, messages: ChatMessage[]): Promise<void> => {
|
||||
const dbManager = await getManager()
|
||||
const conversationManager = dbManager.getConversationManager()
|
||||
await conversationManager.saveConversation(id, messages)
|
||||
},
|
||||
[getManager],
|
||||
)
|
||||
|
||||
const deleteConversation = useCallback(
|
||||
async (id: string): Promise<void> => {
|
||||
const dbManager = await getManager()
|
||||
const conversationManager = dbManager.getConversationManager()
|
||||
await conversationManager.deleteConversation(id)
|
||||
},
|
||||
[getManager],
|
||||
)
|
||||
|
||||
const getChatMessagesById = useCallback(
|
||||
async (id: string): Promise<ChatMessage[] | null> => {
|
||||
const dbManager = await getManager()
|
||||
const conversationManager = dbManager.getConversationManager()
|
||||
return await conversationManager.findConversation(id)
|
||||
},
|
||||
[getManager],
|
||||
)
|
||||
|
||||
const updateConversationTitle = useCallback(
|
||||
async (id: string, title: string): Promise<void> => {
|
||||
const dbManager = await getManager()
|
||||
const conversationManager = dbManager.getConversationManager()
|
||||
await conversationManager.updateConversationTitle(id, title)
|
||||
},
|
||||
[getManager],
|
||||
)
|
||||
|
||||
return {
|
||||
createOrUpdateConversation: createConversation,
|
||||
deleteConversation,
|
||||
getChatMessagesById,
|
||||
updateConversationTitle,
|
||||
chatList,
|
||||
}
|
||||
}
|
||||
429
src/main.ts
Normal file
429
src/main.ts
Normal file
@@ -0,0 +1,429 @@
|
||||
import { EditorView } from '@codemirror/view'
|
||||
import { Editor, MarkdownView, Notice, Plugin, TFile } from 'obsidian'
|
||||
|
||||
import { ApplyView } from './ApplyView'
|
||||
import { ChatView } from './ChatView'
|
||||
import { ChatProps } from './components/chat-view/Chat'
|
||||
import { APPLY_VIEW_TYPE, CHAT_VIEW_TYPE } from './constants'
|
||||
import { InlineEdit } from './core/edit/inline-edit-processor'
|
||||
import { RAGEngine } from './core/rag/rag-engine'
|
||||
import { DBManager } from './database/database-manager'
|
||||
import EventListener from "./event-listener"
|
||||
import CompletionKeyWatcher from "./render-plugin/completion-key-watcher"
|
||||
import DocumentChangesListener, {
|
||||
DocumentChanges,
|
||||
getPrefix, getSuffix,
|
||||
hasMultipleCursors,
|
||||
hasSelection
|
||||
} from "./render-plugin/document-changes-listener"
|
||||
import RenderSuggestionPlugin from "./render-plugin/render-surgestion-plugin"
|
||||
import { InlineSuggestionState } from "./render-plugin/states"
|
||||
import { InfioSettingTab } from './settings/SettingTab'
|
||||
import StatusBar from "./status-bar"
|
||||
import {
|
||||
InfioSettings,
|
||||
parseInfioSettings,
|
||||
} from './types/settings'
|
||||
import { getMentionableBlockData } from './utils/obsidian'
|
||||
|
||||
// Remember to rename these classes and interfaces!
|
||||
export default class InfioPlugin extends Plugin {
|
||||
settings: InfioSettings
|
||||
settingsListeners: ((newSettings: InfioSettings) => void)[] = []
|
||||
initChatProps?: ChatProps
|
||||
dbManager: DBManager | null = null
|
||||
ragEngine: RAGEngine | null = null
|
||||
inlineEdit: InlineEdit | null = null
|
||||
private dbManagerInitPromise: Promise<DBManager> | null = null
|
||||
private ragEngineInitPromise: Promise<RAGEngine> | null = null
|
||||
|
||||
async onload() {
|
||||
await this.loadSettings()
|
||||
|
||||
// This creates an icon in the left ribbon.
|
||||
this.addRibbonIcon('wand-sparkles', 'Open smart composer', () =>
|
||||
this.openChatView(),
|
||||
)
|
||||
|
||||
this.registerView(CHAT_VIEW_TYPE, (leaf) => new ChatView(leaf, this))
|
||||
this.registerView(APPLY_VIEW_TYPE, (leaf) => new ApplyView(leaf))
|
||||
|
||||
// This adds a settings tab so the user can configure various aspects of the plugin
|
||||
this.addSettingTab(new InfioSettingTab(this.app, this))
|
||||
|
||||
// Register markdown processor for ai blocks
|
||||
this.inlineEdit = new InlineEdit(this, this.settings);
|
||||
this.registerMarkdownCodeBlockProcessor("infioedit", (source, el, ctx) => {
|
||||
this.inlineEdit?.Processor(source, el, ctx);
|
||||
});
|
||||
// Update inlineEdit when settings change
|
||||
this.addSettingsListener((newSettings) => {
|
||||
this.inlineEdit = new InlineEdit(this, newSettings);
|
||||
});
|
||||
|
||||
// Setup event listener
|
||||
const statusBar = StatusBar.fromApp(this);
|
||||
const eventListener = EventListener.fromSettings(
|
||||
this.settings,
|
||||
statusBar,
|
||||
this.app
|
||||
);
|
||||
this.addSettingsListener((newSettings) => {
|
||||
eventListener.handleSettingChanged(newSettings)
|
||||
});
|
||||
|
||||
// Setup render plugin
|
||||
this.registerEditorExtension([
|
||||
InlineSuggestionState,
|
||||
CompletionKeyWatcher(
|
||||
eventListener.handleAcceptKeyPressed.bind(eventListener) as () => boolean,
|
||||
eventListener.handlePartialAcceptKeyPressed.bind(eventListener) as () => boolean,
|
||||
eventListener.handleCancelKeyPressed.bind(eventListener) as () => boolean,
|
||||
),
|
||||
DocumentChangesListener(
|
||||
eventListener.handleDocumentChange.bind(eventListener) as (documentChange: DocumentChanges) => Promise<void>
|
||||
),
|
||||
RenderSuggestionPlugin(),
|
||||
]);
|
||||
|
||||
this.app.workspace.onLayoutReady(() => {
|
||||
const view = this.app.workspace.getActiveViewOfType(MarkdownView);
|
||||
|
||||
if (view) {
|
||||
// @ts-expect-error, not typed
|
||||
const editorView = view.editor.cm as EditorView;
|
||||
eventListener.onViewUpdate(editorView);
|
||||
}
|
||||
});
|
||||
|
||||
this.app.workspace.on("active-leaf-change", (leaf) => {
|
||||
if (leaf?.view instanceof MarkdownView) {
|
||||
// @ts-expect-error, not typed
|
||||
const editorView = leaf.view.editor.cm as EditorView;
|
||||
eventListener.onViewUpdate(editorView);
|
||||
if (leaf.view.file) {
|
||||
eventListener.handleFileChange(leaf.view.file);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
this.app.metadataCache.on("changed", (file: TFile) => {
|
||||
if (file) {
|
||||
eventListener.handleFileChange(file);
|
||||
}
|
||||
});
|
||||
|
||||
// This adds a simple command that can be triggered anywhere
|
||||
this.addCommand({
|
||||
id: 'infio-open-new-chat',
|
||||
name: 'Infio open new chat',
|
||||
callback: () => this.openChatView(true),
|
||||
})
|
||||
|
||||
this.addCommand({
|
||||
id: 'infio-add-selection-to-chat',
|
||||
name: 'Infio add selection to chat',
|
||||
editorCallback: (editor: Editor, view: MarkdownView) => {
|
||||
this.addSelectionToChat(editor, view)
|
||||
},
|
||||
hotkeys: [
|
||||
{
|
||||
modifiers: ['Mod', 'Shift'],
|
||||
key: 'l',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
this.addCommand({
|
||||
id: 'infio-rebuild-vault-index',
|
||||
name: 'Infio rebuild entire vault index',
|
||||
callback: async () => {
|
||||
const notice = new Notice('Rebuilding vault index...', 0)
|
||||
try {
|
||||
const ragEngine = await this.getRAGEngine()
|
||||
await ragEngine.updateVaultIndex(
|
||||
{ reindexAll: true },
|
||||
(queryProgress) => {
|
||||
if (queryProgress.type === 'indexing') {
|
||||
const { completedChunks, totalChunks } =
|
||||
queryProgress.indexProgress
|
||||
notice.setMessage(
|
||||
`Indexing chunks: ${completedChunks} / ${totalChunks}`,
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
notice.setMessage('Rebuilding vault index complete')
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
notice.setMessage('Rebuilding vault index failed')
|
||||
} finally {
|
||||
setTimeout(() => {
|
||||
notice.hide()
|
||||
}, 1000)
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
this.addCommand({
|
||||
id: 'infio-update-vault-index',
|
||||
name: 'Infio update index for modified files',
|
||||
callback: async () => {
|
||||
const notice = new Notice('Updating vault index...', 0)
|
||||
try {
|
||||
const ragEngine = await this.getRAGEngine()
|
||||
await ragEngine.updateVaultIndex(
|
||||
{ reindexAll: false },
|
||||
(queryProgress) => {
|
||||
if (queryProgress.type === 'indexing') {
|
||||
const { completedChunks, totalChunks } =
|
||||
queryProgress.indexProgress
|
||||
notice.setMessage(
|
||||
`Indexing chunks: ${completedChunks} / ${totalChunks}`,
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
notice.setMessage('Vault index updated')
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
notice.setMessage('Vault index update failed')
|
||||
} finally {
|
||||
setTimeout(() => {
|
||||
notice.hide()
|
||||
}, 1000)
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
this.addCommand({
|
||||
id: 'infio-autocomplete-accept',
|
||||
name: 'Infio Autocomplete Accept',
|
||||
editorCheckCallback: (
|
||||
checking: boolean,
|
||||
editor: Editor,
|
||||
view: MarkdownView
|
||||
) => {
|
||||
if (checking) {
|
||||
return (
|
||||
eventListener.isSuggesting()
|
||||
);
|
||||
}
|
||||
|
||||
eventListener.handleAcceptCommand();
|
||||
|
||||
return true;
|
||||
},
|
||||
})
|
||||
|
||||
this.addCommand({
|
||||
id: 'infio-autocomplete-predict',
|
||||
name: 'Infio Autocomplete Predict',
|
||||
editorCheckCallback: (
|
||||
checking: boolean,
|
||||
editor: Editor,
|
||||
view: MarkdownView
|
||||
) => {
|
||||
// @ts-expect-error, not typed
|
||||
const editorView = editor.cm as EditorView;
|
||||
const state = editorView.state;
|
||||
if (checking) {
|
||||
return eventListener.isIdle() && !hasMultipleCursors(state) && !hasSelection(state);
|
||||
}
|
||||
|
||||
const prefix = getPrefix(state)
|
||||
const suffix = getSuffix(state)
|
||||
|
||||
eventListener.handlePredictCommand(prefix, suffix);
|
||||
return true;
|
||||
},
|
||||
});
|
||||
|
||||
this.addCommand({
|
||||
id: "infio-autocomplete-toggle",
|
||||
name: "Infio Autocomplete Toggle",
|
||||
callback: () => {
|
||||
const newValue = !this.settings.autocompleteEnabled;
|
||||
this.setSettings({
|
||||
...this.settings,
|
||||
autocompleteEnabled: newValue,
|
||||
})
|
||||
},
|
||||
});
|
||||
|
||||
this.addCommand({
|
||||
id: "infio-autocomplete-enable",
|
||||
name: "Infio Autocomplete Enable",
|
||||
checkCallback: (checking) => {
|
||||
if (checking) {
|
||||
return !this.settings.autocompleteEnabled;
|
||||
}
|
||||
|
||||
this.setSettings({
|
||||
...this.settings,
|
||||
autocompleteEnabled: true,
|
||||
})
|
||||
return true;
|
||||
},
|
||||
});
|
||||
|
||||
this.addCommand({
|
||||
id: "infio-autocomplete-disable",
|
||||
name: "Infio Autocomplete Disable",
|
||||
checkCallback: (checking) => {
|
||||
if (checking) {
|
||||
return this.settings.autocompleteEnabled;
|
||||
}
|
||||
|
||||
this.setSettings({
|
||||
...this.settings,
|
||||
autocompleteEnabled: false,
|
||||
})
|
||||
return true;
|
||||
},
|
||||
});
|
||||
|
||||
this.addCommand({
|
||||
id: "infio-ai-inline-edit",
|
||||
name: "infio Inline Edit",
|
||||
hotkeys: [
|
||||
{
|
||||
modifiers: ['Mod', 'Shift'],
|
||||
key: "k",
|
||||
},
|
||||
],
|
||||
editorCallback: (editor: Editor) => {
|
||||
const selection = editor.getSelection();
|
||||
if (!selection) {
|
||||
new Notice("Please select some text first");
|
||||
return;
|
||||
}
|
||||
// Get the selection start position
|
||||
const from = editor.getCursor("from");
|
||||
// Create the position for inserting the block
|
||||
const insertPos = { line: from.line, ch: 0 };
|
||||
// Create the AI block with the selected text
|
||||
const customBlock = "```infioedit\n```\n";
|
||||
// Insert the block above the selection
|
||||
editor.replaceRange(customBlock, insertPos);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
onunload() {
|
||||
this.dbManager?.cleanup()
|
||||
this.dbManager = null
|
||||
}
|
||||
|
||||
async loadSettings() {
|
||||
this.settings = parseInfioSettings(await this.loadData())
|
||||
await this.saveData(this.settings) // Save updated settings
|
||||
}
|
||||
|
||||
async setSettings(newSettings: InfioSettings) {
|
||||
this.settings = newSettings
|
||||
await this.saveData(newSettings)
|
||||
this.ragEngine?.setSettings(newSettings)
|
||||
this.settingsListeners.forEach((listener) => listener(newSettings))
|
||||
}
|
||||
|
||||
addSettingsListener(
|
||||
listener: (newSettings: InfioSettings) => void,
|
||||
) {
|
||||
this.settingsListeners.push(listener)
|
||||
return () => {
|
||||
this.settingsListeners = this.settingsListeners.filter(
|
||||
(l) => l !== listener,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async openChatView(openNewChat = false) {
|
||||
const view = this.app.workspace.getActiveViewOfType(MarkdownView)
|
||||
const editor = view?.editor
|
||||
if (!view || !editor) {
|
||||
this.activateChatView(undefined, openNewChat)
|
||||
return
|
||||
}
|
||||
const selectedBlockData = await getMentionableBlockData(editor, view)
|
||||
this.activateChatView(
|
||||
{
|
||||
selectedBlock: selectedBlockData ?? undefined,
|
||||
},
|
||||
openNewChat,
|
||||
)
|
||||
}
|
||||
|
||||
async activateChatView(chatProps?: ChatProps, openNewChat = false) {
|
||||
// chatProps is consumed in ChatView.tsx
|
||||
this.initChatProps = chatProps
|
||||
|
||||
const leaf = this.app.workspace.getLeavesOfType(CHAT_VIEW_TYPE)[0]
|
||||
|
||||
await (leaf ?? this.app.workspace.getRightLeaf(false))?.setViewState({
|
||||
type: CHAT_VIEW_TYPE,
|
||||
active: true,
|
||||
})
|
||||
|
||||
if (openNewChat && leaf && leaf.view instanceof ChatView) {
|
||||
leaf.view.openNewChat(chatProps?.selectedBlock)
|
||||
}
|
||||
|
||||
this.app.workspace.revealLeaf(
|
||||
this.app.workspace.getLeavesOfType(CHAT_VIEW_TYPE)[0],
|
||||
)
|
||||
}
|
||||
|
||||
async addSelectionToChat(editor: Editor, view: MarkdownView) {
|
||||
const data = await getMentionableBlockData(editor, view)
|
||||
if (!data) return
|
||||
|
||||
const leaves = this.app.workspace.getLeavesOfType(CHAT_VIEW_TYPE)
|
||||
if (leaves.length === 0 || !(leaves[0].view instanceof ChatView)) {
|
||||
await this.activateChatView({
|
||||
selectedBlock: data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// bring leaf to foreground (uncollapse sidebar if it's collapsed)
|
||||
await this.app.workspace.revealLeaf(leaves[0])
|
||||
|
||||
const chatView = leaves[0].view
|
||||
chatView.addSelectionToChat(data)
|
||||
chatView.focusMessage()
|
||||
}
|
||||
|
||||
async getDbManager(): Promise<DBManager> {
|
||||
if (this.dbManager) {
|
||||
return this.dbManager
|
||||
}
|
||||
|
||||
if (!this.dbManagerInitPromise) {
|
||||
this.dbManagerInitPromise = (async () => {
|
||||
this.dbManager = await DBManager.create(this.app)
|
||||
return this.dbManager
|
||||
})()
|
||||
}
|
||||
|
||||
// if initialization is running, wait for it to complete instead of creating a new initialization promise
|
||||
return this.dbManagerInitPromise
|
||||
}
|
||||
|
||||
async getRAGEngine(): Promise<RAGEngine> {
|
||||
if (this.ragEngine) {
|
||||
return this.ragEngine
|
||||
}
|
||||
|
||||
if (!this.ragEngineInitPromise) {
|
||||
this.ragEngineInitPromise = (async () => {
|
||||
const dbManager = await this.getDbManager()
|
||||
this.ragEngine = new RAGEngine(this.app, this.settings, dbManager)
|
||||
return this.ragEngine
|
||||
})()
|
||||
}
|
||||
|
||||
// if initialization is running, wait for it to complete instead of creating a new initialization promise
|
||||
return this.ragEngineInitPromise
|
||||
}
|
||||
}
|
||||
17
src/open-settings-modal.ts
Normal file
17
src/open-settings-modal.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import { App, Modal, Setting } from 'obsidian'
|
||||
|
||||
export class OpenSettingsModal extends Modal {
|
||||
constructor(app: App, title: string, onSubmit: () => void) {
|
||||
super(app)
|
||||
|
||||
this.setTitle(title)
|
||||
|
||||
new Setting(this.contentEl).addButton((button) => {
|
||||
button.setButtonText('Open settings')
|
||||
button.onClick(() => {
|
||||
this.close()
|
||||
onSubmit()
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
27
src/render-plugin/completion-key-watcher.ts
Normal file
27
src/render-plugin/completion-key-watcher.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { Prec } from "@codemirror/state";
|
||||
import { keymap } from "@codemirror/view";
|
||||
|
||||
function CompletionKeyWatcher(
|
||||
handleAcceptKey: () => boolean,
|
||||
handlePartialAcceptKey: () => boolean,
|
||||
handleCancelKey: () => boolean
|
||||
) {
|
||||
return Prec.highest(
|
||||
keymap.of([
|
||||
{
|
||||
key: "Tab",
|
||||
run: handleAcceptKey,
|
||||
},
|
||||
{
|
||||
key: "ArrowRight",
|
||||
run: handlePartialAcceptKey,
|
||||
},
|
||||
{
|
||||
key: "Escape",
|
||||
run: handleCancelKey,
|
||||
},
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
export default CompletionKeyWatcher;
|
||||
201
src/render-plugin/document-changes-listener.ts
Normal file
201
src/render-plugin/document-changes-listener.ts
Normal file
@@ -0,0 +1,201 @@
|
||||
import { EditorState } from "@codemirror/state";
|
||||
import { ViewPlugin, ViewUpdate } from "@codemirror/view";
|
||||
|
||||
import UserEvent from "./user-event";
|
||||
|
||||
export class DocumentChanges {
|
||||
private update: ViewUpdate;
|
||||
|
||||
constructor(update: ViewUpdate) {
|
||||
this.update = update;
|
||||
|
||||
}
|
||||
|
||||
public isDocInFocus(): boolean {
|
||||
return this.update.view.hasFocus;
|
||||
}
|
||||
|
||||
public noUserEvents(): boolean {
|
||||
return this.getUserEvents().length === 0;
|
||||
}
|
||||
|
||||
public hasUserTyped(): boolean {
|
||||
const userEvents = this.getUserEvents();
|
||||
return userEvents.contains(UserEvent.INPUT_TYPE);
|
||||
}
|
||||
|
||||
public hasUserUndone(): boolean {
|
||||
const userEvents = this.getUserEvents();
|
||||
return userEvents.contains(UserEvent.UNDO);
|
||||
}
|
||||
|
||||
public hasUserRedone(): boolean {
|
||||
const userEvents = this.getUserEvents();
|
||||
return userEvents.contains(UserEvent.REDO);
|
||||
}
|
||||
|
||||
public hasUserDeleted(): boolean {
|
||||
const userEvents = this.getUserEvents();
|
||||
return (
|
||||
userEvents.filter((event) => UserEvent.isDelete(event)).length > 0
|
||||
);
|
||||
}
|
||||
|
||||
public hasDocChanged(): boolean {
|
||||
return (
|
||||
this.update.docChanged ||
|
||||
this.hasUserTyped() ||
|
||||
this.hasUserDeleted()
|
||||
);
|
||||
}
|
||||
|
||||
public hasCursorMoved(): boolean {
|
||||
return this.getUserEvents().contains(UserEvent.CURSOR_MOVED);
|
||||
}
|
||||
|
||||
public getUserEvents(): UserEvent[] {
|
||||
const userEvents: UserEvent[] = [];
|
||||
for (const transaction of this.update.transactions) {
|
||||
const event = UserEvent.fromTransaction(transaction);
|
||||
if (event) {
|
||||
userEvents.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
return userEvents;
|
||||
}
|
||||
|
||||
isTextAdded(): boolean {
|
||||
return this.getAddedText().length > 0;
|
||||
}
|
||||
|
||||
getAddedText(): string {
|
||||
let addedText = "";
|
||||
this.update.changes.iterChanges((fromA, toA, fromB, toB, inserted) => {
|
||||
addedText += inserted;
|
||||
});
|
||||
|
||||
return addedText;
|
||||
}
|
||||
|
||||
getPrefix(): string {
|
||||
return getPrefix(this.update.state);
|
||||
}
|
||||
|
||||
getSuffix(): string {
|
||||
return getSuffix(this.update.state);
|
||||
}
|
||||
|
||||
getAddedPrefixText(): string | undefined {
|
||||
if (!this.isDocInFocus() || this.hasCursorMoved()) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const previousPrefix = this.getPreviousPrefix();
|
||||
const updatedPrefix = this.getPrefix();
|
||||
|
||||
if (updatedPrefix.length > previousPrefix.length) {
|
||||
return updatedPrefix.substring(previousPrefix.length);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
getPreviousPrefix(): string {
|
||||
return getPrefix(this.update.startState);
|
||||
}
|
||||
|
||||
getAddedSuffixText(): string | undefined {
|
||||
if (!this.isDocInFocus() || this.hasCursorMoved()) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const previousSuffix = this.getPreviousSuffix();
|
||||
const updatedSuffix = this.getSuffix();
|
||||
|
||||
if (updatedSuffix.length > previousSuffix.length) {
|
||||
return updatedSuffix.substring(0, updatedSuffix.length - previousSuffix.length);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
getPreviousSuffix(): string {
|
||||
return getSuffix(this.update.startState);
|
||||
}
|
||||
|
||||
getRemovedPrefixText(): string | undefined {
|
||||
if (!this.isDocInFocus() || this.hasCursorMoved()) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const previousPrefix = this.getPreviousPrefix();
|
||||
const updatedPrefix = this.getPrefix();
|
||||
|
||||
if (updatedPrefix.length < previousPrefix.length) {
|
||||
return previousPrefix.substring(updatedPrefix.length);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
getRemovedSuffixText(): string | undefined {
|
||||
if (!this.isDocInFocus() || this.hasCursorMoved()) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const previousSuffix = this.getPreviousSuffix();
|
||||
const updatedSuffix = this.getSuffix();
|
||||
|
||||
if (updatedSuffix.length < previousSuffix.length) {
|
||||
return previousSuffix.substring(0, previousSuffix.length - updatedSuffix.length);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
hasSelection(): boolean {
|
||||
return hasSelection(this.update.state);
|
||||
}
|
||||
|
||||
hasMultipleCursors(): boolean {
|
||||
return hasMultipleCursors(this.update.state);
|
||||
}
|
||||
}
|
||||
|
||||
const DocumentChangesListener = (
|
||||
onDocumentChange: (documentChange: DocumentChanges) => Promise<void>
|
||||
) =>
|
||||
ViewPlugin.fromClass(
|
||||
class FetchPlugin {
|
||||
|
||||
async update(update: ViewUpdate) {
|
||||
|
||||
await onDocumentChange(new DocumentChanges(update));
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
export function getPrefix(state: EditorState): string {
|
||||
return state.doc.sliceString(0, getCursorLocation(state));
|
||||
}
|
||||
|
||||
export function getCursorLocation(state: EditorState): number {
|
||||
return state.selection.main.head;
|
||||
}
|
||||
|
||||
export function getSuffix(state: EditorState): string {
|
||||
return state.doc.sliceString(getCursorLocation(state));
|
||||
}
|
||||
|
||||
export function hasMultipleCursors(state: EditorState): boolean {
|
||||
return state.selection.ranges.length > 1;
|
||||
}
|
||||
|
||||
export function hasSelection(state: EditorState): boolean {
|
||||
for (const range of state.selection.ranges) {
|
||||
const { from, to } = range;
|
||||
if (from !== to) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export default DocumentChangesListener;
|
||||
105
src/render-plugin/render-surgestion-plugin.ts
Normal file
105
src/render-plugin/render-surgestion-plugin.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import { Prec } from "@codemirror/state";
|
||||
import {
|
||||
Decoration,
|
||||
DecorationSet,
|
||||
EditorView,
|
||||
ViewPlugin,
|
||||
ViewUpdate,
|
||||
WidgetType,
|
||||
} from "@codemirror/view";
|
||||
|
||||
import { cancelSuggestion, InlineSuggestionState } from "./states";
|
||||
import { OptionalSuggestion, Suggestion } from "./types";
|
||||
|
||||
const RenderSuggestionPlugin = () =>
|
||||
Prec.lowest(
|
||||
// must be lowest else you get infinite loop with state changes by our plugin
|
||||
ViewPlugin.fromClass(
|
||||
class RenderPlugin {
|
||||
decorations: DecorationSet;
|
||||
suggestion: Suggestion;
|
||||
|
||||
constructor(view: EditorView) {
|
||||
this.decorations = Decoration.none;
|
||||
this.suggestion = {
|
||||
value: "",
|
||||
render: false,
|
||||
}
|
||||
}
|
||||
|
||||
async update(update: ViewUpdate) {
|
||||
const suggestion: OptionalSuggestion = update.state.field(
|
||||
InlineSuggestionState
|
||||
);
|
||||
|
||||
if (suggestion !== null && suggestion !== undefined) {
|
||||
this.suggestion = suggestion;
|
||||
}
|
||||
|
||||
this.decorations = inlineSuggestionDecoration(
|
||||
update.view,
|
||||
this.suggestion
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
decorations: (v) => v.decorations,
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
function inlineSuggestionDecoration(
|
||||
view: EditorView,
|
||||
display_suggestion: Suggestion
|
||||
) {
|
||||
const post = view.state.selection.main.head;
|
||||
|
||||
if (!display_suggestion.render) {
|
||||
return Decoration.none;
|
||||
}
|
||||
try {
|
||||
const widget = new InlineSuggestionWidget(display_suggestion.value, view);
|
||||
const decoration = Decoration.widget({
|
||||
widget,
|
||||
side: 1,
|
||||
});
|
||||
|
||||
return Decoration.set([decoration.range(post)]);
|
||||
} catch (e) {
|
||||
return Decoration.none;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class InlineSuggestionWidget extends WidgetType {
|
||||
constructor(readonly display_suggestion: string, readonly view: EditorView) {
|
||||
super();
|
||||
this.display_suggestion = display_suggestion;
|
||||
this.view = view;
|
||||
}
|
||||
|
||||
eq(other: InlineSuggestionWidget) {
|
||||
return other.display_suggestion == this.display_suggestion;
|
||||
}
|
||||
|
||||
toDOM() {
|
||||
const span = document.createElement("span");
|
||||
span.textContent = this.display_suggestion;
|
||||
span.style.opacity = "0.4"; // TODO replace with css
|
||||
span.onclick = () => {
|
||||
cancelSuggestion(this.view);
|
||||
}
|
||||
span.onselect = () => {
|
||||
cancelSuggestion(this.view);
|
||||
}
|
||||
|
||||
return span;
|
||||
}
|
||||
|
||||
destroy(dom: HTMLElement) {
|
||||
super.destroy(dom);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export default RenderSuggestionPlugin;
|
||||
151
src/render-plugin/states.ts
Normal file
151
src/render-plugin/states.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
import {
|
||||
EditorSelection,
|
||||
EditorState,
|
||||
SelectionRange,
|
||||
StateEffect,
|
||||
StateField,
|
||||
Transaction,
|
||||
TransactionSpec,
|
||||
} from "@codemirror/state";
|
||||
import { EditorView } from "@codemirror/view";
|
||||
|
||||
import { InlineSuggestion, OptionalSuggestion } from "./types";
|
||||
|
||||
const InlineSuggestionEffect = StateEffect.define<InlineSuggestion>();
|
||||
|
||||
export const InlineSuggestionState = StateField.define<OptionalSuggestion>({
|
||||
create(): OptionalSuggestion {
|
||||
return null;
|
||||
},
|
||||
update(
|
||||
value: OptionalSuggestion,
|
||||
transaction: Transaction
|
||||
): OptionalSuggestion {
|
||||
const inlineSuggestion = transaction.effects.find((effect) =>
|
||||
effect.is(InlineSuggestionEffect)
|
||||
);
|
||||
|
||||
if (
|
||||
inlineSuggestion?.value?.doc !== undefined
|
||||
) {
|
||||
return inlineSuggestion.value.suggestion;
|
||||
}
|
||||
return null;
|
||||
},
|
||||
});
|
||||
|
||||
export const updateSuggestion = (
|
||||
view: EditorView,
|
||||
suggestion: string
|
||||
) => {
|
||||
const doc = view.state.doc;
|
||||
sleep(1).then(() => {
|
||||
view.dispatch({
|
||||
effects: InlineSuggestionEffect.of({
|
||||
suggestion: {
|
||||
value: suggestion,
|
||||
render: true,
|
||||
},
|
||||
doc: doc,
|
||||
}),
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
export const cancelSuggestion = (view: EditorView) => {
|
||||
const doc = view.state.doc;
|
||||
sleep(1).then(() => {
|
||||
view.dispatch({
|
||||
effects: InlineSuggestionEffect.of({
|
||||
suggestion: {
|
||||
value: "",
|
||||
render: false,
|
||||
},
|
||||
doc: doc,
|
||||
}),
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
export const insertSuggestion = (view: EditorView, suggestion: string) => {
|
||||
view.dispatch({
|
||||
...createInsertSuggestionTransaction(
|
||||
view.state,
|
||||
suggestion,
|
||||
view.state.selection.main.from,
|
||||
view.state.selection.main.to
|
||||
),
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
function createInsertSuggestionTransaction(
|
||||
state: EditorState,
|
||||
text: string,
|
||||
from: number,
|
||||
to: number
|
||||
): TransactionSpec {
|
||||
const docLength = state.doc.length;
|
||||
if (from < 0 || to > docLength || from > to) {
|
||||
// If the range is not valid, return an empty transaction spec.
|
||||
return { changes: [] };
|
||||
}
|
||||
|
||||
const createInsertSuggestionTransactionFromSelectionRange = (
|
||||
range: SelectionRange
|
||||
) => {
|
||||
|
||||
if (range === state.selection.main) {
|
||||
return {
|
||||
changes: { from, to, insert: text },
|
||||
range: EditorSelection.cursor(to + text.length),
|
||||
};
|
||||
}
|
||||
const length = to - from;
|
||||
if (hasTextChanged(from, to, state, range)) {
|
||||
return { range };
|
||||
}
|
||||
return {
|
||||
changes: {
|
||||
from: range.from - length,
|
||||
to: range.from,
|
||||
insert: text,
|
||||
},
|
||||
range: EditorSelection.cursor(range.from - length + text.length),
|
||||
};
|
||||
};
|
||||
|
||||
return {
|
||||
...state.changeByRange(
|
||||
createInsertSuggestionTransactionFromSelectionRange
|
||||
),
|
||||
userEvent: "input.complete",
|
||||
};
|
||||
}
|
||||
|
||||
function hasTextChanged(
|
||||
from: number,
|
||||
to: number,
|
||||
state: EditorState,
|
||||
changeRange: SelectionRange
|
||||
) {
|
||||
if (changeRange.empty) {
|
||||
return false;
|
||||
}
|
||||
const length = to - from;
|
||||
if (length <= 0) {
|
||||
return false;
|
||||
}
|
||||
if (changeRange.to <= from || changeRange.from >= to) {
|
||||
return false;
|
||||
}
|
||||
// check out of bound
|
||||
if (changeRange.from < 0 || changeRange.to > state.doc.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (
|
||||
state.sliceDoc(changeRange.from - length, changeRange.from) !=
|
||||
state.sliceDoc(from, to)
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user