add tool use, update system prompt

This commit is contained in:
duanfuxiang
2025-03-12 21:39:29 +08:00
parent cabf2d5fa4
commit b0fbbb22d3
36 changed files with 7149 additions and 430 deletions

View File

@@ -0,0 +1,22 @@
import type { DiffStrategy } from "./types"
import { UnifiedDiffStrategy } from "./strategies/unified"
import { SearchReplaceDiffStrategy } from "./strategies/search-replace"
import { NewUnifiedDiffStrategy } from "./strategies/new-unified"
/**
* Get the appropriate diff strategy for the given model
* @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus')
* @returns The appropriate diff strategy for the model
*/
export function getDiffStrategy(
model: string,
fuzzyMatchThreshold?: number,
experimentalDiffStrategy: boolean = false,
): DiffStrategy {
if (experimentalDiffStrategy) {
return new NewUnifiedDiffStrategy(fuzzyMatchThreshold)
}
return new SearchReplaceDiffStrategy(fuzzyMatchThreshold)
}
export type { DiffStrategy }
export { UnifiedDiffStrategy, SearchReplaceDiffStrategy }

View File

@@ -0,0 +1,31 @@
/**
* Inserts multiple groups of elements at specified indices in an array
* @param original Array to insert into, split by lines
* @param insertGroups Array of groups to insert, each with an index and elements to insert
* @returns New array with all insertions applied
*/
export interface InsertGroup {
index: number
elements: string[]
}
export function insertGroups(original: string[], insertGroups: InsertGroup[]): string[] {
// Sort groups by index to maintain order
insertGroups.sort((a, b) => a.index - b.index)
let result: string[] = []
let lastIndex = 0
insertGroups.forEach(({ index, elements }) => {
// Add elements from original array up to insertion point
result.push(...original.slice(lastIndex, index))
// Add the group of elements
result.push(...elements)
lastIndex = index
})
// Add remaining elements from original array
result.push(...original.slice(lastIndex))
return result
}

View File

@@ -0,0 +1,738 @@
import { NewUnifiedDiffStrategy } from "../new-unified"
describe("main", () => {
let strategy: NewUnifiedDiffStrategy
beforeEach(() => {
strategy = new NewUnifiedDiffStrategy(0.97)
})
describe("constructor", () => {
it("should use default confidence threshold when not provided", () => {
const defaultStrategy = new NewUnifiedDiffStrategy()
expect(defaultStrategy["confidenceThreshold"]).toBe(1)
})
it("should use provided confidence threshold", () => {
const customStrategy = new NewUnifiedDiffStrategy(0.85)
expect(customStrategy["confidenceThreshold"]).toBe(0.85)
})
it("should enforce minimum confidence threshold", () => {
const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8
expect(lowStrategy["confidenceThreshold"]).toBe(0.8)
})
})
describe("getToolDescription", () => {
it("should return tool description with correct cwd", () => {
const cwd = "/test/path"
const description = strategy.getToolDescription({ cwd })
expect(description).toContain("apply_diff Tool - Generate Precise Code Changes")
expect(description).toContain(cwd)
expect(description).toContain("Step-by-Step Instructions")
expect(description).toContain("Requirements")
expect(description).toContain("Examples")
expect(description).toContain("Parameters:")
})
})
it("should apply simple diff correctly", async () => {
const original = `line1
line2
line3`
const diff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
+new line
line2
-line3
+modified line3`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`line1
new line
line2
modified line3`)
}
})
it("should handle multiple hunks", async () => {
const original = `line1
line2
line3
line4
line5`
const diff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
+new line
line2
-line3
+modified line3
@@ ... @@
line4
-line5
+modified line5
+new line at end`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`line1
new line
line2
modified line3
line4
modified line5
new line at end`)
}
})
it("should handle complex large", async () => {
const original = `line1
line2
line3
line4
line5
line6
line7
line8
line9
line10`
const diff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
+header line
+another header
line2
-line3
-line4
+modified line3
+modified line4
+extra line
@@ ... @@
line6
+middle section
line7
-line8
+changed line8
+bonus line
@@ ... @@
line9
-line10
+final line
+very last line`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`line1
header line
another header
line2
modified line3
modified line4
extra line
line5
line6
middle section
line7
changed line8
bonus line
line9
final line
very last line`)
}
})
it("should handle indentation changes", async () => {
const original = `first line
indented line
double indented line
back to single indent
no indent
indented again
double indent again
triple indent
back to single
last line`
const diff = `--- original
+++ modified
@@ ... @@
first line
indented line
+ tab indented line
+ new indented line
double indented line
back to single indent
no indent
indented again
double indent again
- triple indent
+ hi there mate
back to single
last line`
const expected = `first line
indented line
tab indented line
new indented line
double indented line
back to single indent
no indent
indented again
double indent again
hi there mate
back to single
last line`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should handle high level edits", async () => {
const original = `def factorial(n):
if n == 0:
return 1
else:
return n * factorial(n-1)`
const diff = `@@ ... @@
-def factorial(n):
- if n == 0:
- return 1
- else:
- return n * factorial(n-1)
+def factorial(number):
+ if number == 0:
+ return 1
+ else:
+ return number * factorial(number-1)`
const expected = `def factorial(number):
if number == 0:
return 1
else:
return number * factorial(number-1)`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("it should handle very complex edits", async () => {
const original = `//Initialize the array that will hold the primes
var primeArray = [];
/*Write a function that checks for primeness and
pushes those values to t*he array*/
function PrimeCheck(candidate){
isPrime = true;
for(var i = 2; i < candidate && isPrime; i++){
if(candidate%i === 0){
isPrime = false;
} else {
isPrime = true;
}
}
if(isPrime){
primeArray.push(candidate);
}
return primeArray;
}
/*Write the code that runs the above until the
l ength of the array equa*ls the number of primes
desired*/
var numPrimes = prompt("How many primes?");
//Display the finished array of primes
//for loop starting at 2 as that is the lowest prime number keep going until the array is as long as we requested
for (var i = 2; primeArray.length < numPrimes; i++) {
PrimeCheck(i); //
}
console.log(primeArray);
`
const diff = `--- test_diff.js
+++ test_diff.js
@@ ... @@
-//Initialize the array that will hold the primes
var primeArray = [];
-/*Write a function that checks for primeness and
- pushes those values to t*he array*/
function PrimeCheck(candidate){
isPrime = true;
for(var i = 2; i < candidate && isPrime; i++){
@@ ... @@
return primeArray;
}
-/*Write the code that runs the above until the
- l ength of the array equa*ls the number of primes
- desired*/
var numPrimes = prompt("How many primes?");
-//Display the finished array of primes
-
-//for loop starting at 2 as that is the lowest prime number keep going until the array is as long as we requested
for (var i = 2; primeArray.length < numPrimes; i++) {
- PrimeCheck(i); //
+ PrimeCheck(i);
}
console.log(primeArray);`
const expected = `var primeArray = [];
function PrimeCheck(candidate){
isPrime = true;
for(var i = 2; i < candidate && isPrime; i++){
if(candidate%i === 0){
isPrime = false;
} else {
isPrime = true;
}
}
if(isPrime){
primeArray.push(candidate);
}
return primeArray;
}
var numPrimes = prompt("How many primes?");
for (var i = 2; primeArray.length < numPrimes; i++) {
PrimeCheck(i);
}
console.log(primeArray);
`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
describe("error handling and edge cases", () => {
it("should reject completely invalid diff format", async () => {
const original = "line1\nline2\nline3"
const invalidDiff = "this is not a diff at all"
const result = await strategy.applyDiff(original, invalidDiff)
expect(result.success).toBe(false)
})
it("should reject diff with invalid hunk format", async () => {
const original = "line1\nline2\nline3"
const invalidHunkDiff = `--- a/file.txt
+++ b/file.txt
invalid hunk header
line1
-line2
+new line`
const result = await strategy.applyDiff(original, invalidHunkDiff)
expect(result.success).toBe(false)
})
it("should fail when diff tries to modify non-existent content", async () => {
const original = "line1\nline2\nline3"
const nonMatchingDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
-nonexistent line
+new line
line3`
const result = await strategy.applyDiff(original, nonMatchingDiff)
expect(result.success).toBe(false)
})
it("should handle overlapping hunks", async () => {
const original = `line1
line2
line3
line4
line5`
const overlappingDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
line2
-line3
+modified3
line4
@@ ... @@
line2
-line3
-line4
+modified3and4
line5`
const result = await strategy.applyDiff(original, overlappingDiff)
expect(result.success).toBe(false)
})
it("should handle empty lines modifications", async () => {
const original = `line1
line3
line5`
const emptyLinesDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
-line3
+line3modified
line5`
const result = await strategy.applyDiff(original, emptyLinesDiff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`line1
line3modified
line5`)
}
})
it("should handle mixed line endings in diff", async () => {
const original = "line1\r\nline2\nline3\r\n"
const mixedEndingsDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1\r
-line2
+modified2\r
line3`
const result = await strategy.applyDiff(original, mixedEndingsDiff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe("line1\r\nmodified2\r\nline3\r\n")
}
})
it("should handle partial line modifications", async () => {
const original = "const value = oldValue + 123;"
const partialDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
-const value = oldValue + 123;
+const value = newValue + 123;`
const result = await strategy.applyDiff(original, partialDiff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe("const value = newValue + 123;")
}
})
it("should handle slightly malformed but recoverable diff", async () => {
const original = "line1\nline2\nline3"
// Missing space after --- and +++
const slightlyBadDiff = `---a/file.txt
+++b/file.txt
@@ ... @@
line1
-line2
+new line
line3`
const result = await strategy.applyDiff(original, slightlyBadDiff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe("line1\nnew line\nline3")
}
})
})
describe("similar code sections", () => {
it("should correctly modify the right section when similar code exists", async () => {
const original = `function add(a, b) {
return a + b;
}
function subtract(a, b) {
return a - b;
}
function multiply(a, b) {
return a + b; // Bug here
}`
const diff = `--- a/math.js
+++ b/math.js
@@ ... @@
function multiply(a, b) {
- return a + b; // Bug here
+ return a * b;
}`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`function add(a, b) {
return a + b;
}
function subtract(a, b) {
return a - b;
}
function multiply(a, b) {
return a * b;
}`)
}
})
it("should handle multiple similar sections with correct context", async () => {
const original = `if (condition) {
doSomething();
doSomething();
doSomething();
}
if (otherCondition) {
doSomething();
doSomething();
doSomething();
}`
const diff = `--- a/file.js
+++ b/file.js
@@ ... @@
if (otherCondition) {
doSomething();
- doSomething();
+ doSomethingElse();
doSomething();
}`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`if (condition) {
doSomething();
doSomething();
doSomething();
}
if (otherCondition) {
doSomething();
doSomethingElse();
doSomething();
}`)
}
})
})
describe("hunk splitting", () => {
it("should handle large diffs with multiple non-contiguous changes", async () => {
const original = `import { readFile } from 'fs';
import { join } from 'path';
import { Logger } from './logger';
const logger = new Logger();
async function processFile(filePath: string) {
try {
const data = await readFile(filePath, 'utf8');
logger.info('File read successfully');
return data;
} catch (error) {
logger.error('Failed to read file:', error);
throw error;
}
}
function validateInput(input: string): boolean {
if (!input) {
logger.warn('Empty input received');
return false;
}
return input.length > 0;
}
async function writeOutput(data: string) {
logger.info('Processing output');
// TODO: Implement output writing
return Promise.resolve();
}
function parseConfig(configPath: string) {
logger.debug('Reading config from:', configPath);
// Basic config parsing
return {
enabled: true,
maxRetries: 3
};
}
export {
processFile,
validateInput,
writeOutput,
parseConfig
};`
const diff = `--- a/file.ts
+++ b/file.ts
@@ ... @@
-import { readFile } from 'fs';
+import { readFile, writeFile } from 'fs';
import { join } from 'path';
-import { Logger } from './logger';
+import { Logger } from './utils/logger';
+import { Config } from './types';
-const logger = new Logger();
+const logger = new Logger('FileProcessor');
async function processFile(filePath: string) {
try {
const data = await readFile(filePath, 'utf8');
- logger.info('File read successfully');
+ logger.info(\`File \${filePath} read successfully\`);
return data;
} catch (error) {
- logger.error('Failed to read file:', error);
+ logger.error(\`Failed to read file \${filePath}:\`, error);
throw error;
}
}
function validateInput(input: string): boolean {
if (!input) {
- logger.warn('Empty input received');
+ logger.warn('Validation failed: Empty input received');
return false;
}
- return input.length > 0;
+ return input.trim().length > 0;
}
-async function writeOutput(data: string) {
- logger.info('Processing output');
- // TODO: Implement output writing
- return Promise.resolve();
+async function writeOutput(data: string, outputPath: string) {
+ try {
+ await writeFile(outputPath, data, 'utf8');
+ logger.info(\`Output written to \${outputPath}\`);
+ } catch (error) {
+ logger.error(\`Failed to write output to \${outputPath}:\`, error);
+ throw error;
+ }
}
-function parseConfig(configPath: string) {
- logger.debug('Reading config from:', configPath);
- // Basic config parsing
- return {
- enabled: true,
- maxRetries: 3
- };
+async function parseConfig(configPath: string): Promise<Config> {
+ try {
+ const configData = await readFile(configPath, 'utf8');
+ logger.debug(\`Reading config from \${configPath}\`);
+ return JSON.parse(configData);
+ } catch (error) {
+ logger.error(\`Failed to parse config from \${configPath}:\`, error);
+ throw error;
+ }
}
export {
processFile,
validateInput,
writeOutput,
- parseConfig
+ parseConfig,
+ type Config
};`
const expected = `import { readFile, writeFile } from 'fs';
import { join } from 'path';
import { Logger } from './utils/logger';
import { Config } from './types';
const logger = new Logger('FileProcessor');
async function processFile(filePath: string) {
try {
const data = await readFile(filePath, 'utf8');
logger.info(\`File \${filePath} read successfully\`);
return data;
} catch (error) {
logger.error(\`Failed to read file \${filePath}:\`, error);
throw error;
}
}
function validateInput(input: string): boolean {
if (!input) {
logger.warn('Validation failed: Empty input received');
return false;
}
return input.trim().length > 0;
}
async function writeOutput(data: string, outputPath: string) {
try {
await writeFile(outputPath, data, 'utf8');
logger.info(\`Output written to \${outputPath}\`);
} catch (error) {
logger.error(\`Failed to write output to \${outputPath}:\`, error);
throw error;
}
}
async function parseConfig(configPath: string): Promise<Config> {
try {
const configData = await readFile(configPath, 'utf8');
logger.debug(\`Reading config from \${configPath}\`);
return JSON.parse(configData);
} catch (error) {
logger.error(\`Failed to parse config from \${configPath}:\`, error);
throw error;
}
}
export {
processFile,
validateInput,
writeOutput,
parseConfig,
type Config
};`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,228 @@
import { UnifiedDiffStrategy } from "../unified"
describe("UnifiedDiffStrategy", () => {
let strategy: UnifiedDiffStrategy
beforeEach(() => {
strategy = new UnifiedDiffStrategy()
})
describe("getToolDescription", () => {
it("should return tool description with correct cwd", () => {
const cwd = "/test/path"
const description = strategy.getToolDescription({ cwd })
expect(description).toContain("apply_diff")
expect(description).toContain(cwd)
expect(description).toContain("Parameters:")
expect(description).toContain("Format Requirements:")
})
})
describe("applyDiff", () => {
it("should successfully apply a function modification diff", async () => {
const originalContent = `import { Logger } from '../logger';
function calculateTotal(items: number[]): number {
return items.reduce((sum, item) => {
return sum + item;
}, 0);
}
export { calculateTotal };`
const diffContent = `--- src/utils/helper.ts
+++ src/utils/helper.ts
@@ -1,9 +1,10 @@
import { Logger } from '../logger';
function calculateTotal(items: number[]): number {
- return items.reduce((sum, item) => {
- return sum + item;
+ const total = items.reduce((sum, item) => {
+ return sum + item * 1.1; // Add 10% markup
}, 0);
+ return Math.round(total * 100) / 100; // Round to 2 decimal places
}
export { calculateTotal };`
const expected = `import { Logger } from '../logger';
function calculateTotal(items: number[]): number {
const total = items.reduce((sum, item) => {
return sum + item * 1.1; // Add 10% markup
}, 0);
return Math.round(total * 100) / 100; // Round to 2 decimal places
}
export { calculateTotal };`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should successfully apply a diff adding a new method", async () => {
const originalContent = `class Calculator {
add(a: number, b: number): number {
return a + b;
}
}`
const diffContent = `--- src/Calculator.ts
+++ src/Calculator.ts
@@ -1,5 +1,9 @@
class Calculator {
add(a: number, b: number): number {
return a + b;
}
+
+ multiply(a: number, b: number): number {
+ return a * b;
+ }
}`
const expected = `class Calculator {
add(a: number, b: number): number {
return a + b;
}
multiply(a: number, b: number): number {
return a * b;
}
}`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should successfully apply a diff modifying imports", async () => {
const originalContent = `import { useState } from 'react';
import { Button } from './components';
function App() {
const [count, setCount] = useState(0);
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
}`
const diffContent = `--- src/App.tsx
+++ src/App.tsx
@@ -1,7 +1,8 @@
-import { useState } from 'react';
+import { useState, useEffect } from 'react';
import { Button } from './components';
function App() {
const [count, setCount] = useState(0);
+ useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
}`
const expected = `import { useState, useEffect } from 'react';
import { Button } from './components';
function App() {
const [count, setCount] = useState(0);
useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
}`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should successfully apply a diff with multiple hunks", async () => {
const originalContent = `import { readFile, writeFile } from 'fs';
function processFile(path: string) {
readFile(path, 'utf8', (err, data) => {
if (err) throw err;
const processed = data.toUpperCase();
writeFile(path, processed, (err) => {
if (err) throw err;
});
});
}
export { processFile };`
const diffContent = `--- src/file-processor.ts
+++ src/file-processor.ts
@@ -1,12 +1,14 @@
-import { readFile, writeFile } from 'fs';
+import { promises as fs } from 'fs';
+import { join } from 'path';
-function processFile(path: string) {
- readFile(path, 'utf8', (err, data) => {
- if (err) throw err;
+async function processFile(path: string) {
+ try {
+ const data = await fs.readFile(join(__dirname, path), 'utf8');
const processed = data.toUpperCase();
- writeFile(path, processed, (err) => {
- if (err) throw err;
- });
- });
+ await fs.writeFile(join(__dirname, path), processed);
+ } catch (error) {
+ console.error('Failed to process file:', error);
+ throw error;
+ }
}
export { processFile };`
const expected = `import { promises as fs } from 'fs';
import { join } from 'path';
async function processFile(path: string) {
try {
const data = await fs.readFile(join(__dirname, path), 'utf8');
const processed = data.toUpperCase();
await fs.writeFile(join(__dirname, path), processed);
} catch (error) {
console.error('Failed to process file:', error);
throw error;
}
}
export { processFile };`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should handle empty original content", async () => {
const originalContent = ""
const diffContent = `--- empty.ts
+++ empty.ts
@@ -0,0 +1,3 @@
+export function greet(name: string): string {
+ return \`Hello, \${name}!\`;
+}`
const expected = `export function greet(name: string): string {
return \`Hello, \${name}!\`;
}\n`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
})
})

View File

@@ -0,0 +1,295 @@
import { applyContextMatching, applyDMP, applyGitFallback } from "../edit-strategies"
import { Hunk } from "../types"
const testCases = [
{
name: "should return original content if no match is found",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "add", content: "line2" },
],
} as Hunk,
content: ["line1", "line3"],
matchPosition: -1,
expected: {
confidence: 0,
result: ["line1", "line3"],
},
expectedResult: "line1\nline3",
strategies: ["context", "dmp"],
},
{
name: "should apply a simple add change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "add", content: "line2" },
],
} as Hunk,
content: ["line1", "line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line2", "line3"],
},
expectedResult: "line1\nline2\nline3",
strategies: ["context", "dmp"],
},
{
name: "should apply a simple remove change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "remove", content: "line2" },
],
} as Hunk,
content: ["line1", "line2", "line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line3"],
},
expectedResult: "line1\nline3",
strategies: ["context", "dmp"],
},
{
name: "should apply a simple context change",
hunk: {
changes: [{ type: "context", content: "line1" }],
} as Hunk,
content: ["line1", "line2", "line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line2", "line3"],
},
expectedResult: "line1\nline2\nline3",
strategies: ["context", "dmp"],
},
{
name: "should apply a multi-line add change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "add", content: "line2\nline3" },
],
} as Hunk,
content: ["line1", "line4"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line2\nline3", "line4"],
},
expectedResult: "line1\nline2\nline3\nline4",
strategies: ["context", "dmp"],
},
{
name: "should apply a multi-line remove change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "remove", content: "line2\nline3" },
],
} as Hunk,
content: ["line1", "line2", "line3", "line4"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line4"],
},
expectedResult: "line1\nline4",
strategies: ["context", "dmp"],
},
{
name: "should apply a multi-line context change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "context", content: "line2\nline3" },
],
} as Hunk,
content: ["line1", "line2", "line3", "line4"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line2\nline3", "line4"],
},
expectedResult: "line1\nline2\nline3\nline4",
strategies: ["context", "dmp"],
},
{
name: "should apply a change with indentation",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "add", content: " line2" },
],
} as Hunk,
content: [" line1", " line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: [" line1", " line2", " line3"],
},
expectedResult: " line1\n line2\n line3",
strategies: ["context", "dmp"],
},
{
name: "should apply a change with mixed indentation",
hunk: {
changes: [
{ type: "context", content: "\tline1" },
{ type: "add", content: " line2" },
],
} as Hunk,
content: ["\tline1", " line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["\tline1", " line2", " line3"],
},
expectedResult: "\tline1\n line2\n line3",
strategies: ["context", "dmp"],
},
{
name: "should apply a change with mixed indentation and multi-line",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "add", content: "\tline2\n line3" },
],
} as Hunk,
content: [" line1", " line4"],
matchPosition: 0,
expected: {
confidence: 1,
result: [" line1", "\tline2\n line3", " line4"],
},
expectedResult: " line1\n\tline2\n line3\n line4",
strategies: ["context", "dmp"],
},
{
name: "should apply a complex change with mixed indentation and multi-line",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "remove", content: " line2" },
{ type: "add", content: "\tline3\n line4" },
{ type: "context", content: " line5" },
],
} as Hunk,
content: [" line1", " line2", " line5", " line6"],
matchPosition: 0,
expected: {
confidence: 1,
result: [" line1", "\tline3\n line4", " line5", " line6"],
},
expectedResult: " line1\n\tline3\n line4\n line5\n line6",
strategies: ["context", "dmp"],
},
{
name: "should apply a complex change with mixed indentation and multi-line and context",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "remove", content: " line2" },
{ type: "add", content: "\tline3\n line4" },
{ type: "context", content: " line5" },
{ type: "context", content: " line6" },
],
} as Hunk,
content: [" line1", " line2", " line5", " line6", " line7"],
matchPosition: 0,
expected: {
confidence: 1,
result: [" line1", "\tline3\n line4", " line5", " line6", " line7"],
},
expectedResult: " line1\n\tline3\n line4\n line5\n line6\n line7",
strategies: ["context", "dmp"],
},
{
name: "should apply a complex change with mixed indentation and multi-line and context and a different match position",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "remove", content: " line2" },
{ type: "add", content: "\tline3\n line4" },
{ type: "context", content: " line5" },
{ type: "context", content: " line6" },
],
} as Hunk,
content: [" line0", " line1", " line2", " line5", " line6", " line7"],
matchPosition: 1,
expected: {
confidence: 1,
result: [" line0", " line1", "\tline3\n line4", " line5", " line6", " line7"],
},
expectedResult: " line0\n line1\n\tline3\n line4\n line5\n line6\n line7",
strategies: ["context", "dmp"],
},
]
describe("applyContextMatching", () => {
testCases.forEach(({ name, hunk, content, matchPosition, expected, strategies, expectedResult }) => {
if (!strategies?.includes("context")) {
return
}
it(name, () => {
const result = applyContextMatching(hunk, content, matchPosition)
expect(result.result.join("\n")).toEqual(expectedResult)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("context")
})
})
})
describe("applyDMP", () => {
testCases.forEach(({ name, hunk, content, matchPosition, expected, strategies, expectedResult }) => {
if (!strategies?.includes("dmp")) {
return
}
it(name, () => {
const result = applyDMP(hunk, content, matchPosition)
expect(result.result.join("\n")).toEqual(expectedResult)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("dmp")
})
})
})
describe("applyGitFallback", () => {
it("should successfully apply changes using git operations", async () => {
const hunk = {
changes: [
{ type: "context", content: "line1", indent: "" },
{ type: "remove", content: "line2", indent: "" },
{ type: "add", content: "new line2", indent: "" },
{ type: "context", content: "line3", indent: "" },
],
} as Hunk
const content = ["line1", "line2", "line3"]
const result = await applyGitFallback(hunk, content)
expect(result.result.join("\n")).toEqual("line1\nnew line2\nline3")
expect(result.confidence).toBe(1)
expect(result.strategy).toBe("git-fallback")
})
it("should return original content with 0 confidence when changes cannot be applied", async () => {
const hunk = {
changes: [
{ type: "context", content: "nonexistent", indent: "" },
{ type: "add", content: "new line", indent: "" },
],
} as Hunk
const content = ["line1", "line2", "line3"]
const result = await applyGitFallback(hunk, content)
expect(result.result).toEqual(content)
expect(result.confidence).toBe(0)
expect(result.strategy).toBe("git-fallback")
})
})

View File

@@ -0,0 +1,262 @@
import { findAnchorMatch, findExactMatch, findSimilarityMatch, findLevenshteinMatch } from "../search-strategies"
type SearchStrategy = (
searchStr: string,
content: string[],
startIndex?: number,
) => {
index: number
confidence: number
strategy: string
}
const testCases = [
{
name: "should return no match if the search string is not found",
searchStr: "not found",
content: ["line1", "line2", "line3"],
expected: { index: -1, confidence: 0 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match if the search string is found",
searchStr: "line2",
content: ["line1", "line2", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with correct index when startIndex is provided",
searchStr: "line3",
content: ["line1", "line2", "line3", "line4", "line3"],
startIndex: 3,
expected: { index: 4, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match even if there are more lines in content",
searchStr: "line2",
content: ["line1", "line2", "line3", "line4", "line5"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match even if the search string is at the beginning of the content",
searchStr: "line1",
content: ["line1", "line2", "line3"],
expected: { index: 0, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match even if the search string is at the end of the content",
searchStr: "line3",
content: ["line1", "line2", "line3"],
expected: { index: 2, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match for a multi-line search string",
searchStr: "line2\nline3",
content: ["line1", "line2", "line3", "line4"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return no match if a multi-line search string is not found",
searchStr: "line2\nline4",
content: ["line1", "line2", "line3", "line4"],
expected: { index: -1, confidence: 0 },
strategies: ["exact", "similarity"],
},
{
name: "should return a match with indentation",
searchStr: " line2",
content: ["line1", " line2", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with more complex indentation",
searchStr: " line3",
content: [" line1", " line2", " line3", " line4"],
expected: { index: 2, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with mixed indentation",
searchStr: "\tline2",
content: [" line1", "\tline2", " line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with mixed indentation and multi-line",
searchStr: " line2\n\tline3",
content: ["line1", " line2", "\tline3", " line4"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return no match if mixed indentation and multi-line is not found",
searchStr: " line2\n line4",
content: ["line1", " line2", "\tline3", " line4"],
expected: { index: -1, confidence: 0 },
strategies: ["exact", "similarity"],
},
{
name: "should return a match with leading and trailing spaces",
searchStr: " line2 ",
content: ["line1", " line2 ", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with leading and trailing tabs",
searchStr: "\tline2\t",
content: ["line1", "\tline2\t", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with mixed leading and trailing spaces and tabs",
searchStr: " \tline2\t ",
content: ["line1", " \tline2\t ", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with mixed leading and trailing spaces and tabs and multi-line",
searchStr: " \tline2\t \n line3 ",
content: ["line1", " \tline2\t ", " line3 ", "line4"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return no match if mixed leading and trailing spaces and tabs and multi-line is not found",
searchStr: " \tline2\t \n line4 ",
content: ["line1", " \tline2\t ", " line3 ", "line4"],
expected: { index: -1, confidence: 0 },
strategies: ["exact", "similarity"],
},
]
describe("findExactMatch", () => {
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
if (!strategies?.includes("exact")) {
return
}
it(name, () => {
const result = findExactMatch(searchStr, content, startIndex)
expect(result.index).toBe(expected.index)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toMatch(/exact(-overlapping)?/)
})
})
})
describe("findAnchorMatch", () => {
const anchorTestCases = [
{
name: "should return no match if no anchors are found",
searchStr: " \n \n ",
content: ["line1", "line2", "line3"],
expected: { index: -1, confidence: 0 },
},
{
name: "should return no match if anchor positions cannot be validated",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: [
"different line 1",
"different line 2",
"different line 3",
"another unique line",
"context line 1",
"context line 2",
],
expected: { index: -1, confidence: 0 },
},
{
name: "should return a match if anchor positions can be validated",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: ["line1", "line2", "unique line", "context line 1", "context line 2", "line 6"],
expected: { index: 2, confidence: 1 },
},
{
name: "should return a match with correct index when startIndex is provided",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: ["line1", "line2", "line3", "unique line", "context line 1", "context line 2", "line 7"],
startIndex: 3,
expected: { index: 3, confidence: 1 },
},
{
name: "should return a match even if there are more lines in content",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: [
"line1",
"line2",
"unique line",
"context line 1",
"context line 2",
"line 6",
"extra line 1",
"extra line 2",
],
expected: { index: 2, confidence: 1 },
},
{
name: "should return a match even if the anchor is at the beginning of the content",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: ["unique line", "context line 1", "context line 2", "line 6"],
expected: { index: 0, confidence: 1 },
},
{
name: "should return a match even if the anchor is at the end of the content",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: ["line1", "line2", "unique line", "context line 1", "context line 2"],
expected: { index: 2, confidence: 1 },
},
{
name: "should return no match if no valid anchor is found",
searchStr: "non-unique line\ncontext line 1\ncontext line 2",
content: ["line1", "line2", "non-unique line", "context line 1", "context line 2", "non-unique line"],
expected: { index: -1, confidence: 0 },
},
]
anchorTestCases.forEach(({ name, searchStr, content, startIndex, expected }) => {
it(name, () => {
const result = findAnchorMatch(searchStr, content, startIndex)
expect(result.index).toBe(expected.index)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("anchor")
})
})
})
describe("findSimilarityMatch", () => {
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
if (!strategies?.includes("similarity")) {
return
}
it(name, () => {
const result = findSimilarityMatch(searchStr, content, startIndex)
expect(result.index).toBe(expected.index)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("similarity")
})
})
})
describe("findLevenshteinMatch", () => {
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
if (!strategies?.includes("levenshtein")) {
return
}
it(name, () => {
const result = findLevenshteinMatch(searchStr, content, startIndex)
expect(result.index).toBe(expected.index)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("levenshtein")
})
})
})

View File

@@ -0,0 +1,297 @@
import { diff_match_patch } from "diff-match-patch"
import { EditResult, Hunk } from "./types"
import { getDMPSimilarity, validateEditResult } from "./search-strategies"
import * as path from "path"
import simpleGit, { SimpleGit } from "simple-git"
import * as tmp from "tmp"
import * as fs from "fs"
// Helper function to infer indentation - simplified version
function inferIndentation(line: string, contextLines: string[], previousIndent: string = ""): string {
// If the line has explicit indentation in the change, use it exactly
const lineMatch = line.match(/^(\s+)/)
if (lineMatch) {
return lineMatch[1]
}
// If we have context lines, use the indentation from the first context line
const contextLine = contextLines[0]
if (contextLine) {
const contextMatch = contextLine.match(/^(\s+)/)
if (contextMatch) {
return contextMatch[1]
}
}
// Fallback to previous indent
return previousIndent
}
// Context matching edit strategy
export function applyContextMatching(hunk: Hunk, content: string[], matchPosition: number): EditResult {
if (matchPosition === -1) {
return { confidence: 0, result: content, strategy: "context" }
}
const newResult = [...content.slice(0, matchPosition)]
let sourceIndex = matchPosition
for (const change of hunk.changes) {
if (change.type === "context") {
// Use the original line from content if available
if (sourceIndex < content.length) {
newResult.push(content[sourceIndex])
} else {
const line = change.indent ? change.indent + change.content : change.content
newResult.push(line)
}
sourceIndex++
} else if (change.type === "add") {
// Use exactly the indentation from the change
const baseIndent = change.indent || ""
// Handle multi-line additions
const lines = change.content.split("\n").map((line) => {
// If the line already has indentation, preserve it relative to the base indent
const lineIndentMatch = line.match(/^(\s*)(.*)/)
if (lineIndentMatch) {
const [, lineIndent, content] = lineIndentMatch
// Only add base indent if the line doesn't already have it
return lineIndent ? line : baseIndent + content
}
return baseIndent + line
})
newResult.push(...lines)
} else if (change.type === "remove") {
// Handle multi-line removes by incrementing sourceIndex for each line
const removedLines = change.content.split("\n").length
sourceIndex += removedLines
}
}
// Append remaining content
newResult.push(...content.slice(sourceIndex))
// Calculate confidence based on the actual changes
const afterText = newResult.slice(matchPosition, newResult.length - (content.length - sourceIndex)).join("\n")
const confidence = validateEditResult(hunk, afterText)
return {
confidence,
result: newResult,
strategy: "context",
}
}
// DMP edit strategy
export function applyDMP(hunk: Hunk, content: string[], matchPosition: number): EditResult {
if (matchPosition === -1) {
return { confidence: 0, result: content, strategy: "dmp" }
}
const dmp = new diff_match_patch()
// Calculate total lines in before block accounting for multi-line content
const beforeLineCount = hunk.changes
.filter((change) => change.type === "context" || change.type === "remove")
.reduce((count, change) => count + change.content.split("\n").length, 0)
// Build BEFORE block (context + removals)
const beforeLines = hunk.changes
.filter((change) => change.type === "context" || change.type === "remove")
.map((change) => {
if (change.originalLine) {
return change.originalLine
}
return change.indent ? change.indent + change.content : change.content
})
// Build AFTER block (context + additions)
const afterLines = hunk.changes
.filter((change) => change.type === "context" || change.type === "add")
.map((change) => {
if (change.originalLine) {
return change.originalLine
}
return change.indent ? change.indent + change.content : change.content
})
// Convert to text with proper line endings
const beforeText = beforeLines.join("\n")
const afterText = afterLines.join("\n")
// Create and apply patch
const patch = dmp.patch_make(beforeText, afterText)
const targetText = content.slice(matchPosition, matchPosition + beforeLineCount).join("\n")
const [patchedText] = dmp.patch_apply(patch, targetText)
// Split result and preserve line endings
const patchedLines = patchedText.split("\n")
// Construct final result
const newResult = [
...content.slice(0, matchPosition),
...patchedLines,
...content.slice(matchPosition + beforeLineCount),
]
const confidence = validateEditResult(hunk, patchedText)
return {
confidence,
result: newResult,
strategy: "dmp",
}
}
// Git fallback strategy that works with full content
export async function applyGitFallback(hunk: Hunk, content: string[]): Promise<EditResult> {
let tmpDir: tmp.DirResult | undefined
try {
tmpDir = tmp.dirSync({ unsafeCleanup: true })
const git: SimpleGit = simpleGit(tmpDir.name)
await git.init()
await git.addConfig("user.name", "Temp")
await git.addConfig("user.email", "temp@example.com")
const filePath = path.join(tmpDir.name, "file.txt")
const searchLines = hunk.changes
.filter((change) => change.type === "context" || change.type === "remove")
.map((change) => change.originalLine || change.indent + change.content)
const replaceLines = hunk.changes
.filter((change) => change.type === "context" || change.type === "add")
.map((change) => change.originalLine || change.indent + change.content)
const searchText = searchLines.join("\n")
const replaceText = replaceLines.join("\n")
const originalText = content.join("\n")
try {
fs.writeFileSync(filePath, originalText)
await git.add("file.txt")
const originalCommit = await git.commit("original")
console.log("Strategy 1 - Original commit:", originalCommit.commit)
fs.writeFileSync(filePath, searchText)
await git.add("file.txt")
const searchCommit1 = await git.commit("search")
console.log("Strategy 1 - Search commit:", searchCommit1.commit)
fs.writeFileSync(filePath, replaceText)
await git.add("file.txt")
const replaceCommit = await git.commit("replace")
console.log("Strategy 1 - Replace commit:", replaceCommit.commit)
console.log("Strategy 1 - Attempting checkout of:", originalCommit.commit)
await git.raw(["checkout", originalCommit.commit])
try {
console.log("Strategy 1 - Attempting cherry-pick of:", replaceCommit.commit)
await git.raw(["cherry-pick", "--minimal", replaceCommit.commit])
const newText = fs.readFileSync(filePath, "utf-8")
const newLines = newText.split("\n")
return {
confidence: 1,
result: newLines,
strategy: "git-fallback",
}
} catch (cherryPickError) {
console.error("Strategy 1 failed with merge conflict")
}
} catch (error) {
console.error("Strategy 1 failed:", error)
}
try {
await git.init()
await git.addConfig("user.name", "Temp")
await git.addConfig("user.email", "temp@example.com")
fs.writeFileSync(filePath, searchText)
await git.add("file.txt")
const searchCommit = await git.commit("search")
const searchHash = searchCommit.commit.replace(/^HEAD /, "")
console.log("Strategy 2 - Search commit:", searchHash)
fs.writeFileSync(filePath, replaceText)
await git.add("file.txt")
const replaceCommit = await git.commit("replace")
const replaceHash = replaceCommit.commit.replace(/^HEAD /, "")
console.log("Strategy 2 - Replace commit:", replaceHash)
console.log("Strategy 2 - Attempting checkout of:", searchHash)
await git.raw(["checkout", searchHash])
fs.writeFileSync(filePath, originalText)
await git.add("file.txt")
const originalCommit2 = await git.commit("original")
console.log("Strategy 2 - Original commit:", originalCommit2.commit)
try {
console.log("Strategy 2 - Attempting cherry-pick of:", replaceHash)
await git.raw(["cherry-pick", "--minimal", replaceHash])
const newText = fs.readFileSync(filePath, "utf-8")
const newLines = newText.split("\n")
return {
confidence: 1,
result: newLines,
strategy: "git-fallback",
}
} catch (cherryPickError) {
console.error("Strategy 2 failed with merge conflict")
}
} catch (error) {
console.error("Strategy 2 failed:", error)
}
console.error("Git fallback failed")
return { confidence: 0, result: content, strategy: "git-fallback" }
} catch (error) {
console.error("Git fallback strategy failed:", error)
return { confidence: 0, result: content, strategy: "git-fallback" }
} finally {
if (tmpDir) {
tmpDir.removeCallback()
}
}
}
// Main edit function that tries strategies sequentially
export async function applyEdit(
hunk: Hunk,
content: string[],
matchPosition: number,
confidence: number,
confidenceThreshold: number = 0.97,
): Promise<EditResult> {
// Don't attempt regular edits if confidence is too low
if (confidence < confidenceThreshold) {
console.log(
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`,
)
return applyGitFallback(hunk, content)
}
// Try each strategy in sequence until one succeeds
const strategies = [
{ name: "dmp", apply: () => applyDMP(hunk, content, matchPosition) },
{ name: "context", apply: () => applyContextMatching(hunk, content, matchPosition) },
{ name: "git-fallback", apply: () => applyGitFallback(hunk, content) },
]
// Try strategies sequentially until one succeeds
for (const strategy of strategies) {
const result = await strategy.apply()
if (result.confidence >= confidenceThreshold) {
return result
}
}
return { confidence: 0, result: content, strategy: "none" }
}

View File

@@ -0,0 +1,350 @@
import { Diff, Hunk, Change } from "./types"
import { findBestMatch, prepareSearchString } from "./search-strategies"
import { applyEdit } from "./edit-strategies"
import { DiffResult, DiffStrategy } from "../../types"
export class NewUnifiedDiffStrategy implements DiffStrategy {
private readonly confidenceThreshold: number
constructor(confidenceThreshold: number = 1) {
this.confidenceThreshold = Math.max(confidenceThreshold, 0.8)
}
private parseUnifiedDiff(diff: string): Diff {
const MAX_CONTEXT_LINES = 6 // Number of context lines to keep before/after changes
const lines = diff.split("\n")
const hunks: Hunk[] = []
let currentHunk: Hunk | null = null
let i = 0
while (i < lines.length && !lines[i].startsWith("@@")) {
i++
}
for (; i < lines.length; i++) {
const line = lines[i]
if (line.startsWith("@@")) {
if (
currentHunk &&
currentHunk.changes.length > 0 &&
currentHunk.changes.some((change) => change.type === "add" || change.type === "remove")
) {
const changes = currentHunk.changes
let startIdx = 0
let endIdx = changes.length - 1
for (let j = 0; j < changes.length; j++) {
if (changes[j].type !== "context") {
startIdx = Math.max(0, j - MAX_CONTEXT_LINES)
break
}
}
for (let j = changes.length - 1; j >= 0; j--) {
if (changes[j].type !== "context") {
endIdx = Math.min(changes.length - 1, j + MAX_CONTEXT_LINES)
break
}
}
currentHunk.changes = changes.slice(startIdx, endIdx + 1)
hunks.push(currentHunk)
}
currentHunk = { changes: [] }
continue
}
if (!currentHunk) {
continue
}
const content = line.slice(1)
const indentMatch = content.match(/^(\s*)/)
const indent = indentMatch ? indentMatch[0] : ""
const trimmedContent = content.slice(indent.length)
if (line.startsWith(" ")) {
currentHunk.changes.push({
type: "context",
content: trimmedContent,
indent,
originalLine: content,
})
} else if (line.startsWith("+")) {
currentHunk.changes.push({
type: "add",
content: trimmedContent,
indent,
originalLine: content,
})
} else if (line.startsWith("-")) {
currentHunk.changes.push({
type: "remove",
content: trimmedContent,
indent,
originalLine: content,
})
} else {
const finalContent = trimmedContent ? " " + trimmedContent : " "
currentHunk.changes.push({
type: "context",
content: finalContent,
indent,
originalLine: content,
})
}
}
if (
currentHunk &&
currentHunk.changes.length > 0 &&
currentHunk.changes.some((change) => change.type === "add" || change.type === "remove")
) {
hunks.push(currentHunk)
}
return { hunks }
}
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
return `# apply_diff Tool - Generate Precise Code Changes
Generate a unified diff that can be cleanly applied to modify code files.
## Step-by-Step Instructions:
1. Start with file headers:
- First line: "--- {original_file_path}"
- Second line: "+++ {new_file_path}"
2. For each change section:
- Begin with "@@ ... @@" separator line without line numbers
- Include 2-3 lines of context before and after changes
- Mark removed lines with "-"
- Mark added lines with "+"
- Preserve exact indentation
3. Group related changes:
- Keep related modifications in the same hunk
- Start new hunks for logically separate changes
- When modifying functions/methods, include the entire block
## Requirements:
1. MUST include exact indentation
2. MUST include sufficient context for unique matching
3. MUST group related changes together
4. MUST use proper unified diff format
5. MUST NOT include timestamps in file headers
6. MUST NOT include line numbers in the @@ header
## Examples:
✅ Good diff (follows all requirements):
\`\`\`diff
--- src/utils.ts
+++ src/utils.ts
@@ ... @@
def calculate_total(items):
- total = 0
- for item in items:
- total += item.price
+ return sum(item.price for item in items)
\`\`\`
❌ Bad diff (violates requirements #1 and #2):
\`\`\`diff
--- src/utils.ts
+++ src/utils.ts
@@ ... @@
-total = 0
-for item in items:
+return sum(item.price for item in items)
\`\`\`
Parameters:
- path: (required) File path relative to ${args.cwd}
- diff: (required) Unified diff content in unified format to apply to the file.
Usage:
<apply_diff>
<path>path/to/file.ext</path>
<diff>
Your diff here
</diff>
</apply_diff>`
}
// Helper function to split a hunk into smaller hunks based on contiguous changes
private splitHunk(hunk: Hunk): Hunk[] {
const result: Hunk[] = []
let currentHunk: Hunk | null = null
let contextBefore: Change[] = []
let contextAfter: Change[] = []
const MAX_CONTEXT_LINES = 3 // Keep 3 lines of context before/after changes
for (let i = 0; i < hunk.changes.length; i++) {
const change = hunk.changes[i]
if (change.type === "context") {
if (!currentHunk) {
contextBefore.push(change)
if (contextBefore.length > MAX_CONTEXT_LINES) {
contextBefore.shift()
}
} else {
contextAfter.push(change)
if (contextAfter.length > MAX_CONTEXT_LINES) {
// We've collected enough context after changes, create a new hunk
currentHunk.changes.push(...contextAfter)
result.push(currentHunk)
currentHunk = null
// Keep the last few context lines for the next hunk
contextBefore = contextAfter
contextAfter = []
}
}
} else {
if (!currentHunk) {
currentHunk = { changes: [...contextBefore] }
contextAfter = []
} else if (contextAfter.length > 0) {
// Add accumulated context to current hunk
currentHunk.changes.push(...contextAfter)
contextAfter = []
}
currentHunk.changes.push(change)
}
}
// Add any remaining changes
if (currentHunk) {
if (contextAfter.length > 0) {
currentHunk.changes.push(...contextAfter)
}
result.push(currentHunk)
}
return result
}
async applyDiff(
originalContent: string,
diffContent: string,
startLine?: number,
endLine?: number,
): Promise<DiffResult> {
const parsedDiff = this.parseUnifiedDiff(diffContent)
const originalLines = originalContent.split("\n")
let result = [...originalLines]
if (!parsedDiff.hunks.length) {
return {
success: false,
error: "No hunks found in diff. Please ensure your diff includes actual changes and follows the unified diff format.",
}
}
for (const hunk of parsedDiff.hunks) {
const contextStr = prepareSearchString(hunk.changes)
const {
index: matchPosition,
confidence,
strategy,
} = findBestMatch(contextStr, result, 0, this.confidenceThreshold)
if (confidence < this.confidenceThreshold) {
console.log("Full hunk application failed, trying sub-hunks strategy")
// Try splitting the hunk into smaller hunks
const subHunks = this.splitHunk(hunk)
let subHunkSuccess = true
let subHunkResult = [...result]
for (const subHunk of subHunks) {
const subContextStr = prepareSearchString(subHunk.changes)
const subSearchResult = findBestMatch(subContextStr, subHunkResult, 0, this.confidenceThreshold)
if (subSearchResult.confidence >= this.confidenceThreshold) {
const subEditResult = await applyEdit(
subHunk,
subHunkResult,
subSearchResult.index,
subSearchResult.confidence,
this.confidenceThreshold,
)
if (subEditResult.confidence >= this.confidenceThreshold) {
subHunkResult = subEditResult.result
continue
}
}
subHunkSuccess = false
break
}
if (subHunkSuccess) {
result = subHunkResult
continue
}
// If sub-hunks also failed, return the original error
const contextLines = hunk.changes.filter((c) => c.type === "context").length
const totalLines = hunk.changes.length
const contextRatio = contextLines / totalLines
let errorMsg = `Failed to find a matching location in the file (${Math.floor(
confidence * 100,
)}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n`
errorMsg += "Debug Info:\n"
errorMsg += `- Search Strategy Used: ${strategy}\n`
errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor(
contextRatio * 100,
)}%)\n`
errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n`
if (contextRatio < 0.2) {
errorMsg += "\nPossible Issues:\n"
errorMsg += "- Not enough context lines to uniquely identify the location\n"
errorMsg += "- Add a few more lines of unchanged code around your changes\n"
} else if (contextRatio > 0.5) {
errorMsg += "\nPossible Issues:\n"
errorMsg += "- Too many context lines may reduce search accuracy\n"
errorMsg += "- Try to keep only 2-3 lines of context before and after changes\n"
} else {
errorMsg += "\nPossible Issues:\n"
errorMsg += "- The diff may be targeting a different version of the file\n"
errorMsg +=
"- There may be too many changes in a single hunk, try splitting the changes into multiple hunks\n"
}
if (startLine && endLine) {
errorMsg += `\nSearch Range: lines ${startLine}-${endLine}\n`
}
return { success: false, error: errorMsg }
}
const editResult = await applyEdit(hunk, result, matchPosition, confidence, this.confidenceThreshold)
if (editResult.confidence >= this.confidenceThreshold) {
result = editResult.result
} else {
// Edit failure - likely due to content mismatch
let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor(
editResult.confidence * 100,
)}% confidence)\n\n`
errorMsg += "Debug Info:\n"
errorMsg += "- The location was found but the content didn't match exactly\n"
errorMsg += "- This usually means the file has been modified since the diff was created\n"
errorMsg += "- Or the diff may be targeting a different version of the file\n"
errorMsg += "\nPossible Solutions:\n"
errorMsg += "1. Refresh your view of the file and create a new diff\n"
errorMsg += "2. Double-check that the removed lines (-) match the current file content\n"
errorMsg += "3. Ensure your diff targets the correct version of the file"
return { success: false, error: errorMsg }
}
}
return { success: true, content: result.join("\n") }
}
}

View File

@@ -0,0 +1,408 @@
import { compareTwoStrings } from "string-similarity"
import { closest } from "fastest-levenshtein"
import { diff_match_patch } from "diff-match-patch"
import { Change, Hunk } from "./types"
export type SearchResult = {
index: number
confidence: number
strategy: string
}
const LARGE_FILE_THRESHOLD = 1000 // lines
const UNIQUE_CONTENT_BOOST = 0.05
const DEFAULT_OVERLAP_SIZE = 3 // lines of overlap between windows
const MAX_WINDOW_SIZE = 500 // maximum lines in a window
// Helper function to calculate adaptive confidence threshold based on file size
function getAdaptiveThreshold(contentLength: number, baseThreshold: number): number {
if (contentLength <= LARGE_FILE_THRESHOLD) {
return baseThreshold
}
return Math.max(baseThreshold - 0.07, 0.8) // Reduce threshold for large files but keep minimum at 80%
}
// Helper function to evaluate content uniqueness
function evaluateContentUniqueness(searchStr: string, content: string[]): number {
const searchLines = searchStr.split("\n")
const uniqueLines = new Set(searchLines)
const contentStr = content.join("\n")
// Calculate how many search lines are relatively unique in the content
let uniqueCount = 0
for (const line of uniqueLines) {
const regex = new RegExp(line.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "g")
const matches = contentStr.match(regex)
if (matches && matches.length <= 2) {
// Line appears at most twice
uniqueCount++
}
}
return uniqueCount / uniqueLines.size
}
// Helper function to prepare search string from context
export function prepareSearchString(changes: Change[]): string {
const lines = changes.filter((c) => c.type === "context" || c.type === "remove").map((c) => c.originalLine)
return lines.join("\n")
}
// Helper function to evaluate similarity between two texts
export function evaluateSimilarity(original: string, modified: string): number {
return compareTwoStrings(original, modified)
}
// Helper function to validate using diff-match-patch
export function getDMPSimilarity(original: string, modified: string): number {
const dmp = new diff_match_patch()
const diffs = dmp.diff_main(original, modified)
dmp.diff_cleanupSemantic(diffs)
const patches = dmp.patch_make(original, diffs)
const [expectedText] = dmp.patch_apply(patches, original)
const similarity = evaluateSimilarity(expectedText, modified)
return similarity
}
// Helper function to validate edit results using hunk information
export function validateEditResult(hunk: Hunk, result: string): number {
// Build the expected text from the hunk
const expectedText = hunk.changes
.filter((change) => change.type === "context" || change.type === "add")
.map((change) => (change.indent ? change.indent + change.content : change.content))
.join("\n")
// Calculate similarity between the result and expected text
const similarity = getDMPSimilarity(expectedText, result)
// If the result is unchanged from original, return low confidence
const originalText = hunk.changes
.filter((change) => change.type === "context" || change.type === "remove")
.map((change) => (change.indent ? change.indent + change.content : change.content))
.join("\n")
const originalSimilarity = getDMPSimilarity(originalText, result)
if (originalSimilarity > 0.97 && similarity !== 1) {
return 0.8 * similarity // Some confidence since we found the right location
}
// For partial matches, scale the confidence but keep it high if we're close
return similarity
}
// Helper function to validate context lines against original content
function validateContextLines(searchStr: string, content: string, confidenceThreshold: number): number {
// Extract just the context lines from the search string
const contextLines = searchStr.split("\n").filter((line) => !line.startsWith("-")) // Exclude removed lines
// Compare context lines with content
const similarity = evaluateSimilarity(contextLines.join("\n"), content)
// Get adaptive threshold based on content size
const threshold = getAdaptiveThreshold(content.split("\n").length, confidenceThreshold)
// Calculate uniqueness boost
const uniquenessScore = evaluateContentUniqueness(searchStr, content.split("\n"))
const uniquenessBoost = uniquenessScore * UNIQUE_CONTENT_BOOST
// Adjust confidence based on threshold and uniqueness
return similarity < threshold ? similarity * 0.3 + uniquenessBoost : similarity + uniquenessBoost
}
// Helper function to create overlapping windows
function createOverlappingWindows(
content: string[],
searchSize: number,
overlapSize: number = DEFAULT_OVERLAP_SIZE,
): { window: string[]; startIndex: number }[] {
const windows: { window: string[]; startIndex: number }[] = []
// Ensure minimum window size is at least searchSize
const effectiveWindowSize = Math.max(searchSize, Math.min(searchSize * 2, MAX_WINDOW_SIZE))
// Ensure overlap size doesn't exceed window size
const effectiveOverlapSize = Math.min(overlapSize, effectiveWindowSize - 1)
// Calculate step size, ensure it's at least 1
const stepSize = Math.max(1, effectiveWindowSize - effectiveOverlapSize)
for (let i = 0; i < content.length; i += stepSize) {
const windowContent = content.slice(i, i + effectiveWindowSize)
if (windowContent.length >= searchSize) {
windows.push({ window: windowContent, startIndex: i })
}
}
return windows
}
// Helper function to combine overlapping matches
function combineOverlappingMatches(
matches: (SearchResult & { windowIndex: number })[],
overlapSize: number = DEFAULT_OVERLAP_SIZE,
): SearchResult[] {
if (matches.length === 0) {
return []
}
// Sort matches by confidence
matches.sort((a, b) => b.confidence - a.confidence)
const combinedMatches: SearchResult[] = []
const usedIndices = new Set<number>()
for (const match of matches) {
if (usedIndices.has(match.windowIndex)) {
continue
}
// Find overlapping matches
const overlapping = matches.filter(
(m) =>
Math.abs(m.windowIndex - match.windowIndex) === 1 &&
Math.abs(m.index - match.index) <= overlapSize &&
!usedIndices.has(m.windowIndex),
)
if (overlapping.length > 0) {
// Boost confidence if we find same match in overlapping windows
const avgConfidence =
(match.confidence + overlapping.reduce((sum, m) => sum + m.confidence, 0)) / (overlapping.length + 1)
const boost = Math.min(0.05 * overlapping.length, 0.1) // Max 10% boost
combinedMatches.push({
index: match.index,
confidence: Math.min(1, avgConfidence + boost),
strategy: `${match.strategy}-overlapping`,
})
usedIndices.add(match.windowIndex)
overlapping.forEach((m) => usedIndices.add(m.windowIndex))
} else {
combinedMatches.push({
index: match.index,
confidence: match.confidence,
strategy: match.strategy,
})
usedIndices.add(match.windowIndex)
}
}
return combinedMatches
}
export function findExactMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const searchLines = searchStr.split("\n")
const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length)
const matches: (SearchResult & { windowIndex: number })[] = []
windows.forEach((windowData, windowIndex) => {
const windowStr = windowData.window.join("\n")
const exactMatch = windowStr.indexOf(searchStr)
if (exactMatch !== -1) {
const matchedContent = windowData.window
.slice(
windowStr.slice(0, exactMatch).split("\n").length - 1,
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length,
)
.join("\n")
const similarity = getDMPSimilarity(searchStr, matchedContent)
const contextSimilarity = validateContextLines(searchStr, matchedContent, confidenceThreshold)
const confidence = Math.min(similarity, contextSimilarity)
matches.push({
index: startIndex + windowData.startIndex + windowStr.slice(0, exactMatch).split("\n").length - 1,
confidence,
strategy: "exact",
windowIndex,
})
}
})
const combinedMatches = combineOverlappingMatches(matches)
return combinedMatches.length > 0 ? combinedMatches[0] : { index: -1, confidence: 0, strategy: "exact" }
}
// String similarity strategy
export function findSimilarityMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const searchLines = searchStr.split("\n")
let bestScore = 0
let bestIndex = -1
for (let i = startIndex; i < content.length - searchLines.length + 1; i++) {
const windowStr = content.slice(i, i + searchLines.length).join("\n")
const score = compareTwoStrings(searchStr, windowStr)
if (score > bestScore && score >= confidenceThreshold) {
const similarity = getDMPSimilarity(searchStr, windowStr)
const contextSimilarity = validateContextLines(searchStr, windowStr, confidenceThreshold)
const adjustedScore = Math.min(similarity, contextSimilarity) * score
if (adjustedScore > bestScore) {
bestScore = adjustedScore
bestIndex = i
}
}
}
return {
index: bestIndex,
confidence: bestIndex !== -1 ? bestScore : 0,
strategy: "similarity",
}
}
// Levenshtein strategy
export function findLevenshteinMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const searchLines = searchStr.split("\n")
const candidates = []
for (let i = startIndex; i < content.length - searchLines.length + 1; i++) {
candidates.push(content.slice(i, i + searchLines.length).join("\n"))
}
if (candidates.length > 0) {
const closestMatch = closest(searchStr, candidates)
const index = startIndex + candidates.indexOf(closestMatch)
const similarity = getDMPSimilarity(searchStr, closestMatch)
const contextSimilarity = validateContextLines(searchStr, closestMatch, confidenceThreshold)
const confidence = Math.min(similarity, contextSimilarity)
return {
index: confidence === 0 ? -1 : index,
confidence: index !== -1 ? confidence : 0,
strategy: "levenshtein",
}
}
return { index: -1, confidence: 0, strategy: "levenshtein" }
}
// Helper function to identify anchor lines
function identifyAnchors(searchStr: string): { first: string | null; last: string | null } {
const searchLines = searchStr.split("\n")
let first: string | null = null
let last: string | null = null
// Find the first non-empty line
for (const line of searchLines) {
if (line.trim()) {
first = line
break
}
}
// Find the last non-empty line
for (let i = searchLines.length - 1; i >= 0; i--) {
if (searchLines[i].trim()) {
last = searchLines[i]
break
}
}
return { first, last }
}
// Anchor-based search strategy
export function findAnchorMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const searchLines = searchStr.split("\n")
const { first, last } = identifyAnchors(searchStr)
if (!first || !last) {
return { index: -1, confidence: 0, strategy: "anchor" }
}
let firstIndex = -1
let lastIndex = -1
// Check if the first anchor is unique
let firstOccurrences = 0
for (const contentLine of content) {
if (contentLine === first) {
firstOccurrences++
}
}
if (firstOccurrences !== 1) {
return { index: -1, confidence: 0, strategy: "anchor" }
}
// Find the first anchor
for (let i = startIndex; i < content.length; i++) {
if (content[i] === first) {
firstIndex = i
break
}
}
// Find the last anchor
for (let i = content.length - 1; i >= startIndex; i--) {
if (content[i] === last) {
lastIndex = i
break
}
}
if (firstIndex === -1 || lastIndex === -1 || lastIndex <= firstIndex) {
return { index: -1, confidence: 0, strategy: "anchor" }
}
// Validate the context
const expectedContext = searchLines.slice(searchLines.indexOf(first) + 1, searchLines.indexOf(last)).join("\n")
const actualContext = content.slice(firstIndex + 1, lastIndex).join("\n")
const contextSimilarity = evaluateSimilarity(expectedContext, actualContext)
if (contextSimilarity < getAdaptiveThreshold(content.length, confidenceThreshold)) {
return { index: -1, confidence: 0, strategy: "anchor" }
}
const confidence = 1
return {
index: firstIndex,
confidence: confidence,
strategy: "anchor",
}
}
// Main search function that tries all strategies
export function findBestMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch]
let bestResult: SearchResult = { index: -1, confidence: 0, strategy: "none" }
for (const strategy of strategies) {
const result = strategy(searchStr, content, startIndex, confidenceThreshold)
if (result.confidence > bestResult.confidence) {
bestResult = result
}
}
return bestResult
}

View File

@@ -0,0 +1,20 @@
export type Change = {
type: "context" | "add" | "remove"
content: string
indent: string
originalLine?: string
}
export type Hunk = {
changes: Change[]
}
export type Diff = {
hunks: Hunk[]
}
export type EditResult = {
confidence: number
result: string[]
strategy: string
}

View File

@@ -0,0 +1,302 @@
import { DiffStrategy, DiffResult } from "../types"
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
import { distance } from "fastest-levenshtein"
const BUFFER_LINES = 20 // Number of extra context lines to show before and after matches
function getSimilarity(original: string, search: string): number {
if (search === "") {
return 1
}
// Normalize strings by removing extra whitespace but preserve case
const normalizeStr = (str: string) => str.replace(/\s+/g, " ").trim()
const normalizedOriginal = normalizeStr(original)
const normalizedSearch = normalizeStr(search)
if (normalizedOriginal === normalizedSearch) {
return 1
}
// Calculate Levenshtein distance using fastest-levenshtein's distance function
const dist = distance(normalizedOriginal, normalizedSearch)
// Calculate similarity ratio (0 to 1, where 1 is an exact match)
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length)
return 1 - dist / maxLength
}
export class SearchReplaceDiffStrategy implements DiffStrategy {
private fuzzyThreshold: number
private bufferLines: number
constructor(fuzzyThreshold?: number, bufferLines?: number) {
// Use provided threshold or default to exact matching (1.0)
// Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9)
// so we use it directly here
this.fuzzyThreshold = fuzzyThreshold ?? 1.0
this.bufferLines = bufferLines ?? BUFFER_LINES
}
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
return `## apply_diff
Description: Request to replace existing code using a search and replace block.
This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with.
The tool will maintain proper indentation and formatting while making changes.
Only a single operation is allowed per tool use.
The SEARCH section must exactly match existing content including whitespace and indentation.
If you're not confident in the exact content to search for, use the read_file tool first to get the exact content.
When applying the diffs, be extra careful to remember to change any closing brackets or other syntax that may be affected by the diff farther down in the file.
Parameters:
- path: (required) The path of the file to modify (relative to the current working directory ${args.cwd})
- diff: (required) The search/replace block defining the changes.
- start_line: (required) The line number where the search block starts.
- end_line: (required) The line number where the search block ends.
Diff format:
\`\`\`
<<<<<<< SEARCH
[exact content to find including whitespace]
=======
[new content to replace with]
>>>>>>> REPLACE
\`\`\`
Example:
Original file:
\`\`\`
1 | def calculate_total(items):
2 | total = 0
3 | for item in items:
4 | total += item
5 | return total
\`\`\`
Search/Replace content:
\`\`\`
<<<<<<< SEARCH
def calculate_total(items):
total = 0
for item in items:
total += item
return total
=======
def calculate_total(items):
"""Calculate total with 10% markup"""
return sum(item * 1.1 for item in items)
>>>>>>> REPLACE
\`\`\`
Usage:
<apply_diff>
<path>File path here</path>
<diff>
Your search/replace content here
</diff>
<start_line>1</start_line>
<end_line>5</end_line>
</apply_diff>`
}
async applyDiff(
originalContent: string,
diffContent: string,
startLine?: number,
endLine?: number,
): Promise<DiffResult> {
// Extract the search and replace blocks
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/)
if (!match) {
return {
success: false,
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`,
}
}
let [_, searchContent, replaceContent] = match
// Detect line ending from original content
const lineEnding = originalContent.includes("\r\n") ? "\r\n" : "\n"
// Strip line numbers from search and replace content if every line starts with a line number
if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) {
searchContent = stripLineNumbers(searchContent)
replaceContent = stripLineNumbers(replaceContent)
}
// Split content into lines, handling both \n and \r\n
const searchLines = searchContent === "" ? [] : searchContent.split(/\r?\n/)
const replaceLines = replaceContent === "" ? [] : replaceContent.split(/\r?\n/)
const originalLines = originalContent.split(/\r?\n/)
// Validate that empty search requires start line
if (searchLines.length === 0 && !startLine) {
return {
success: false,
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`,
}
}
// Validate that empty search requires same start and end line
if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) {
return {
success: false,
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`,
}
}
// Initialize search variables
let matchIndex = -1
let bestMatchScore = 0
let bestMatchContent = ""
const searchChunk = searchLines.join("\n")
// Determine search bounds
let searchStartIndex = 0
let searchEndIndex = originalLines.length
// Validate and handle line range if provided
if (startLine && endLine) {
// Convert to 0-based index
const exactStartIndex = startLine - 1
const exactEndIndex = endLine - 1
if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) {
return {
success: false,
error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`,
}
}
// Try exact match first
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity >= this.fuzzyThreshold) {
matchIndex = exactStartIndex
bestMatchScore = similarity
bestMatchContent = originalChunk
} else {
// Set bounds for buffered search
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1))
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines)
}
}
// If no match found yet, try middle-out search within bounds
if (matchIndex === -1) {
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2)
let leftIndex = midPoint
let rightIndex = midPoint + 1
// Search outward from the middle within bounds
while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) {
// Check left side if still in range
if (leftIndex >= searchStartIndex) {
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity > bestMatchScore) {
bestMatchScore = similarity
matchIndex = leftIndex
bestMatchContent = originalChunk
}
leftIndex--
}
// Check right side if still in range
if (rightIndex <= searchEndIndex - searchLines.length) {
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity > bestMatchScore) {
bestMatchScore = similarity
matchIndex = rightIndex
bestMatchContent = originalChunk
}
rightIndex++
}
}
}
// Require similarity to meet threshold
if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) {
const searchChunk = searchLines.join("\n")
const originalContentSection =
startLine !== undefined && endLine !== undefined
? `\n\nOriginal Content:\n${addLineNumbers(
originalLines
.slice(
Math.max(0, startLine - 1 - this.bufferLines),
Math.min(originalLines.length, endLine + this.bufferLines),
)
.join("\n"),
Math.max(1, startLine - this.bufferLines),
)}`
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join("\n"))}`
const bestMatchSection = bestMatchContent
? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}`
: `\n\nBest Match Found:\n(no match)`
const lineRange =
startLine || endLine
? ` at ${startLine ? `start: ${startLine}` : "start"} to ${endLine ? `end: ${endLine}` : "end"}`
: ""
return {
success: false,
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : "start to end"}\n- Tip: Use read_file to get the latest content of the file before attempting the diff again, as the file content may have changed\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`,
}
}
// Get the matched lines from the original content
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length)
// Get the exact indentation (preserving tabs/spaces) of each line
const originalIndents = matchedLines.map((line) => {
const match = line.match(/^[\t ]*/)
return match ? match[0] : ""
})
// Get the exact indentation of each line in the search block
const searchIndents = searchLines.map((line) => {
const match = line.match(/^[\t ]*/)
return match ? match[0] : ""
})
// Apply the replacement while preserving exact indentation
const indentedReplaceLines = replaceLines.map((line, i) => {
// Get the matched line's exact indentation
const matchedIndent = originalIndents[0] || ""
// Get the current line's indentation relative to the search content
const currentIndentMatch = line.match(/^[\t ]*/)
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ""
const searchBaseIndent = searchIndents[0] || ""
// Calculate the relative indentation level
const searchBaseLevel = searchBaseIndent.length
const currentLevel = currentIndent.length
const relativeLevel = currentLevel - searchBaseLevel
// If relative level is negative, remove indentation from matched indent
// If positive, add to matched indent
const finalIndent =
relativeLevel < 0
? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel))
: matchedIndent + currentIndent.slice(searchBaseLevel)
return finalIndent + line.trim()
})
// Construct the final content
const beforeMatch = originalLines.slice(0, matchIndex)
const afterMatch = originalLines.slice(matchIndex + searchLines.length)
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding)
return {
success: true,
content: finalContent,
}
}
}

View File

@@ -0,0 +1,137 @@
import { applyPatch } from "diff"
import { DiffStrategy, DiffResult } from "../types"
export class UnifiedDiffStrategy implements DiffStrategy {
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
return `## apply_diff
Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3).
Parameters:
- path: (required) The path of the file to apply the diff to (relative to the current working directory ${args.cwd})
- diff: (required) The diff content in unified format to apply to the file.
Format Requirements:
1. Header (REQUIRED):
\`\`\`
--- path/to/original/file
+++ path/to/modified/file
\`\`\`
- Must include both lines exactly as shown
- Use actual file paths
- NO timestamps after paths
2. Hunks:
\`\`\`
@@ -lineStart,lineCount +lineStart,lineCount @@
-removed line
+added line
\`\`\`
- Each hunk starts with @@ showing line numbers for changes
- Format: @@ -originalStart,originalCount +newStart,newCount @@
- Use - for removed/changed lines
- Use + for new/modified lines
- Indentation must match exactly
Complete Example:
Original file (with line numbers):
\`\`\`
1 | import { Logger } from '../logger';
2 |
3 | function calculateTotal(items: number[]): number {
4 | return items.reduce((sum, item) => {
5 | return sum + item;
6 | }, 0);
7 | }
8 |
9 | export { calculateTotal };
\`\`\`
After applying the diff, the file would look like:
\`\`\`
1 | import { Logger } from '../logger';
2 |
3 | function calculateTotal(items: number[]): number {
4 | const total = items.reduce((sum, item) => {
5 | return sum + item * 1.1; // Add 10% markup
6 | }, 0);
7 | return Math.round(total * 100) / 100; // Round to 2 decimal places
8 | }
9 |
10 | export { calculateTotal };
\`\`\`
Diff to modify the file:
\`\`\`
--- src/utils/helper.ts
+++ src/utils/helper.ts
@@ -1,9 +1,10 @@
import { Logger } from '../logger';
function calculateTotal(items: number[]): number {
- return items.reduce((sum, item) => {
- return sum + item;
+ const total = items.reduce((sum, item) => {
+ return sum + item * 1.1; // Add 10% markup
}, 0);
+ return Math.round(total * 100) / 100; // Round to 2 decimal places
}
export { calculateTotal };
\`\`\`
Common Pitfalls:
1. Missing or incorrect header lines
2. Incorrect line numbers in @@ lines
3. Wrong indentation in changed lines
4. Incomplete context (missing lines that need changing)
5. Not marking all modified lines with - and +
Best Practices:
1. Replace entire code blocks:
- Remove complete old version with - lines
- Add complete new version with + lines
- Include correct line numbers
2. Moving code requires two hunks:
- First hunk: Remove from old location
- Second hunk: Add to new location
3. One hunk per logical change
4. Verify line numbers match the line numbers you have in the file
Usage:
<apply_diff>
<path>File path here</path>
<diff>
Your diff here
</diff>
</apply_diff>`
}
async applyDiff(originalContent: string, diffContent: string): Promise<DiffResult> {
try {
const result = applyPatch(originalContent, diffContent)
if (result === false) {
return {
success: false,
error: "Failed to apply unified diff - patch rejected",
details: {
searchContent: diffContent,
},
}
}
return {
success: true,
content: result,
}
} catch (error) {
return {
success: false,
error: `Error applying unified diff: ${error.message}`,
details: {
searchContent: diffContent,
},
}
}
}
}

36
src/core/diff/types.ts Normal file
View File

@@ -0,0 +1,36 @@
/**
* Interface for implementing different diff strategies
*/
export type DiffResult =
| { success: true; content: string }
| {
success: false
error: string
details?: {
similarity?: number
threshold?: number
matchedRange?: { start: number; end: number }
searchContent?: string
bestMatch?: string
}
}
export interface DiffStrategy {
/**
* Get the tool description for this diff strategy
* @param args The tool arguments including cwd and toolOptions
* @returns The complete tool description including format requirements and examples
*/
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string
/**
* Apply a diff to the original content
* @param originalContent The original file content
* @param diffContent The diff content in the strategy's format
* @param startLine Optional line number where the search block starts. If not provided, searches the entire file.
* @param endLine Optional line number where the search block ends. If not provided, searches the entire file.
* @returns A DiffResult object containing either the successful result or error details
*/
applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult>
}

746
src/core/mcp/McpHub.ts Normal file
View File

@@ -0,0 +1,746 @@
import { Client } from "@modelcontextprotocol/sdk/client/index.js"
import { StdioClientTransport, StdioServerParameters } from "@modelcontextprotocol/sdk/client/stdio.js"
import {
CallToolResultSchema,
ListResourcesResultSchema,
ListResourceTemplatesResultSchema,
ListToolsResultSchema,
ReadResourceResultSchema,
} from "@modelcontextprotocol/sdk/types.js"
import chokidar, { FSWatcher } from "chokidar"
import delay from "delay"
import deepEqual from "fast-deep-equal"
import * as fs from "fs/promises"
import * as path from "path"
import * as vscode from "vscode"
import { z } from "zod"
import { ClineProvider } from "../../core/webview/ClineProvider"
import { GlobalFileNames } from "../../shared/globalFileNames"
import {
McpResource,
McpResourceResponse,
McpResourceTemplate,
McpServer,
McpTool,
McpToolCallResponse,
} from "../../shared/mcp"
import { fileExistsAtPath } from "../../utils/fs"
import { arePathsEqual } from "../../utils/path"
export type McpConnection = {
server: McpServer
client: Client
transport: StdioClientTransport
}
// StdioServerParameters
const AlwaysAllowSchema = z.array(z.string()).default([])
export const StdioConfigSchema = z.object({
command: z.string(),
args: z.array(z.string()).optional(),
env: z.record(z.string()).optional(),
alwaysAllow: AlwaysAllowSchema.optional(),
disabled: z.boolean().optional(),
timeout: z.number().min(1).max(3600).optional().default(60),
})
const McpSettingsSchema = z.object({
mcpServers: z.record(StdioConfigSchema),
})
export class McpHub {
private providerRef: WeakRef<ClineProvider>
private disposables: vscode.Disposable[] = []
private settingsWatcher?: vscode.FileSystemWatcher
private fileWatchers: Map<string, FSWatcher> = new Map()
connections: McpConnection[] = []
isConnecting: boolean = false
constructor(provider: ClineProvider) {
this.providerRef = new WeakRef(provider)
this.watchMcpSettingsFile()
this.initializeMcpServers()
}
getServers(): McpServer[] {
// Only return enabled servers
return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server)
}
getAllServers(): McpServer[] {
// Return all servers regardless of state
return this.connections.map((conn) => conn.server)
}
async getMcpServersPath(): Promise<string> {
const provider = this.providerRef.deref()
if (!provider) {
throw new Error("Provider not available")
}
const mcpServersPath = await provider.ensureMcpServersDirectoryExists()
return mcpServersPath
}
async getMcpSettingsFilePath(): Promise<string> {
const provider = this.providerRef.deref()
if (!provider) {
throw new Error("Provider not available")
}
const mcpSettingsFilePath = path.join(
await provider.ensureSettingsDirectoryExists(),
GlobalFileNames.mcpSettings,
)
const fileExists = await fileExistsAtPath(mcpSettingsFilePath)
if (!fileExists) {
await fs.writeFile(
mcpSettingsFilePath,
`{
"mcpServers": {
}
}`,
)
}
return mcpSettingsFilePath
}
private async watchMcpSettingsFile(): Promise<void> {
const settingsPath = await this.getMcpSettingsFilePath()
this.disposables.push(
vscode.workspace.onDidSaveTextDocument(async (document) => {
if (arePathsEqual(document.uri.fsPath, settingsPath)) {
const content = await fs.readFile(settingsPath, "utf-8")
const errorMessage =
"Invalid MCP settings format. Please ensure your settings follow the correct JSON format."
let config: any
try {
config = JSON.parse(content)
} catch (error) {
vscode.window.showErrorMessage(errorMessage)
return
}
const result = McpSettingsSchema.safeParse(config)
if (!result.success) {
vscode.window.showErrorMessage(errorMessage)
return
}
try {
await this.updateServerConnections(result.data.mcpServers || {})
} catch (error) {
console.error("Failed to process MCP settings change:", error)
}
}
}),
)
}
private async initializeMcpServers(): Promise<void> {
try {
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
await this.updateServerConnections(config.mcpServers || {})
} catch (error) {
console.error("Failed to initialize MCP servers:", error)
}
}
private async connectToServer(name: string, config: StdioServerParameters): Promise<void> {
// Remove existing connection if it exists (should never happen, the connection should be deleted beforehand)
this.connections = this.connections.filter((conn) => conn.server.name !== name)
try {
// Each MCP server requires its own transport connection and has unique capabilities, configurations, and error handling. Having separate clients also allows proper scoping of resources/tools and independent server management like reconnection.
const client = new Client(
{
name: "Roo Code",
version: this.providerRef.deref()?.context.extension?.packageJSON?.version ?? "1.0.0",
},
{
capabilities: {},
},
)
const transport = new StdioClientTransport({
command: config.command,
args: config.args,
env: {
...config.env,
...(process.env.PATH ? { PATH: process.env.PATH } : {}),
// ...(process.env.NODE_PATH ? { NODE_PATH: process.env.NODE_PATH } : {}),
},
stderr: "pipe", // necessary for stderr to be available
})
transport.onerror = async (error) => {
console.error(`Transport error for "${name}":`, error)
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
connection.server.status = "disconnected"
this.appendErrorMessage(connection, error.message)
}
await this.notifyWebviewOfServerChanges()
}
transport.onclose = async () => {
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
connection.server.status = "disconnected"
}
await this.notifyWebviewOfServerChanges()
}
// If the config is invalid, show an error
if (!StdioConfigSchema.safeParse(config).success) {
console.error(`Invalid config for "${name}": missing or invalid parameters`)
const connection: McpConnection = {
server: {
name,
config: JSON.stringify(config),
status: "disconnected",
error: "Invalid config: missing or invalid parameters",
},
client,
transport,
}
this.connections.push(connection)
return
}
// valid schema
const parsedConfig = StdioConfigSchema.parse(config)
const connection: McpConnection = {
server: {
name,
config: JSON.stringify(config),
status: "connecting",
disabled: parsedConfig.disabled,
},
client,
transport,
}
this.connections.push(connection)
// transport.stderr is only available after the process has been started. However we can't start it separately from the .connect() call because it also starts the transport. And we can't place this after the connect call since we need to capture the stderr stream before the connection is established, in order to capture errors during the connection process.
// As a workaround, we start the transport ourselves, and then monkey-patch the start method to no-op so that .connect() doesn't try to start it again.
await transport.start()
const stderrStream = transport.stderr
if (stderrStream) {
stderrStream.on("data", async (data: Buffer) => {
const errorOutput = data.toString()
console.error(`Server "${name}" stderr:`, errorOutput)
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
// NOTE: we do not set server status to "disconnected" because stderr logs do not necessarily mean the server crashed or disconnected, it could just be informational. In fact when the server first starts up, it immediately logs "<name> server running on stdio" to stderr.
this.appendErrorMessage(connection, errorOutput)
// Only need to update webview right away if it's already disconnected
if (connection.server.status === "disconnected") {
await this.notifyWebviewOfServerChanges()
}
}
})
} else {
console.error(`No stderr stream for ${name}`)
}
transport.start = async () => {} // No-op now, .connect() won't fail
// Connect
await client.connect(transport)
connection.server.status = "connected"
connection.server.error = ""
// Initial fetch of tools and resources
connection.server.tools = await this.fetchToolsList(name)
connection.server.resources = await this.fetchResourcesList(name)
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(name)
} catch (error) {
// Update status with error
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
connection.server.status = "disconnected"
this.appendErrorMessage(connection, error instanceof Error ? error.message : String(error))
}
throw error
}
}
private appendErrorMessage(connection: McpConnection, error: string) {
const newError = connection.server.error ? `${connection.server.error}\n${error}` : error
connection.server.error = newError //.slice(0, 800)
}
private async fetchToolsList(serverName: string): Promise<McpTool[]> {
try {
const response = await this.connections
.find((conn) => conn.server.name === serverName)
?.client.request({ method: "tools/list" }, ListToolsResultSchema)
// Get always allow settings
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
const alwaysAllowConfig = config.mcpServers[serverName]?.alwaysAllow || []
// Mark tools as always allowed based on settings
const tools = (response?.tools || []).map((tool) => ({
...tool,
alwaysAllow: alwaysAllowConfig.includes(tool.name),
}))
console.log(`[MCP] Fetched tools for ${serverName}:`, tools)
return tools
} catch (error) {
// console.error(`Failed to fetch tools for ${serverName}:`, error)
return []
}
}
private async fetchResourcesList(serverName: string): Promise<McpResource[]> {
try {
const response = await this.connections
.find((conn) => conn.server.name === serverName)
?.client.request({ method: "resources/list" }, ListResourcesResultSchema)
return response?.resources || []
} catch (error) {
// console.error(`Failed to fetch resources for ${serverName}:`, error)
return []
}
}
private async fetchResourceTemplatesList(serverName: string): Promise<McpResourceTemplate[]> {
try {
const response = await this.connections
.find((conn) => conn.server.name === serverName)
?.client.request({ method: "resources/templates/list" }, ListResourceTemplatesResultSchema)
return response?.resourceTemplates || []
} catch (error) {
// console.error(`Failed to fetch resource templates for ${serverName}:`, error)
return []
}
}
async deleteConnection(name: string): Promise<void> {
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
try {
await connection.transport.close()
await connection.client.close()
} catch (error) {
console.error(`Failed to close transport for ${name}:`, error)
}
this.connections = this.connections.filter((conn) => conn.server.name !== name)
}
}
async updateServerConnections(newServers: Record<string, any>): Promise<void> {
this.isConnecting = true
this.removeAllFileWatchers()
const currentNames = new Set(this.connections.map((conn) => conn.server.name))
const newNames = new Set(Object.keys(newServers))
// Delete removed servers
for (const name of currentNames) {
if (!newNames.has(name)) {
await this.deleteConnection(name)
console.log(`Deleted MCP server: ${name}`)
}
}
// Update or add servers
for (const [name, config] of Object.entries(newServers)) {
const currentConnection = this.connections.find((conn) => conn.server.name === name)
if (!currentConnection) {
// New server
try {
this.setupFileWatcher(name, config)
await this.connectToServer(name, config)
} catch (error) {
console.error(`Failed to connect to new MCP server ${name}:`, error)
}
} else if (!deepEqual(JSON.parse(currentConnection.server.config), config)) {
// Existing server with changed config
try {
this.setupFileWatcher(name, config)
await this.deleteConnection(name)
await this.connectToServer(name, config)
console.log(`Reconnected MCP server with updated config: ${name}`)
} catch (error) {
console.error(`Failed to reconnect MCP server ${name}:`, error)
}
}
// If server exists with same config, do nothing
}
await this.notifyWebviewOfServerChanges()
this.isConnecting = false
}
private setupFileWatcher(name: string, config: any) {
const filePath = config.args?.find((arg: string) => arg.includes("build/index.js"))
if (filePath) {
// we use chokidar instead of onDidSaveTextDocument because it doesn't require the file to be open in the editor. The settings config is better suited for onDidSave since that will be manually updated by the user or Cline (and we want to detect save events, not every file change)
const watcher = chokidar.watch(filePath, {
// persistent: true,
// ignoreInitial: true,
// awaitWriteFinish: true, // This helps with atomic writes
})
watcher.on("change", () => {
console.log(`Detected change in ${filePath}. Restarting server ${name}...`)
this.restartConnection(name)
})
this.fileWatchers.set(name, watcher)
}
}
private removeAllFileWatchers() {
this.fileWatchers.forEach((watcher) => watcher.close())
this.fileWatchers.clear()
}
async restartConnection(serverName: string): Promise<void> {
this.isConnecting = true
const provider = this.providerRef.deref()
if (!provider) {
return
}
// Get existing connection and update its status
const connection = this.connections.find((conn) => conn.server.name === serverName)
const config = connection?.server.config
if (config) {
vscode.window.showInformationMessage(`Restarting ${serverName} MCP server...`)
connection.server.status = "connecting"
connection.server.error = ""
await this.notifyWebviewOfServerChanges()
await delay(500) // artificial delay to show user that server is restarting
try {
await this.deleteConnection(serverName)
// Try to connect again using existing config
await this.connectToServer(serverName, JSON.parse(config))
vscode.window.showInformationMessage(`${serverName} MCP server connected`)
} catch (error) {
console.error(`Failed to restart connection for ${serverName}:`, error)
vscode.window.showErrorMessage(`Failed to connect to ${serverName} MCP server`)
}
}
await this.notifyWebviewOfServerChanges()
this.isConnecting = false
}
private async notifyWebviewOfServerChanges(): Promise<void> {
// servers should always be sorted in the order they are defined in the settings file
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
const serverOrder = Object.keys(config.mcpServers || {})
await this.providerRef.deref()?.postMessageToWebview({
type: "mcpServers",
mcpServers: [...this.connections]
.sort((a, b) => {
const indexA = serverOrder.indexOf(a.server.name)
const indexB = serverOrder.indexOf(b.server.name)
return indexA - indexB
})
.map((connection) => connection.server),
})
}
public async toggleServerDisabled(serverName: string, disabled: boolean): Promise<void> {
let settingsPath: string
try {
settingsPath = await this.getMcpSettingsFilePath()
// Ensure the settings file exists and is accessible
try {
await fs.access(settingsPath)
} catch (error) {
console.error("Settings file not accessible:", error)
throw new Error("Settings file not accessible")
}
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Validate the config structure
if (!config || typeof config !== "object") {
throw new Error("Invalid config structure")
}
if (!config.mcpServers || typeof config.mcpServers !== "object") {
config.mcpServers = {}
}
if (config.mcpServers[serverName]) {
// Create a new server config object to ensure clean structure
const serverConfig = {
...config.mcpServers[serverName],
disabled,
}
// Ensure required fields exist
if (!serverConfig.alwaysAllow) {
serverConfig.alwaysAllow = []
}
config.mcpServers[serverName] = serverConfig
// Write the entire config back
const updatedConfig = {
mcpServers: config.mcpServers,
}
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
const connection = this.connections.find((conn) => conn.server.name === serverName)
if (connection) {
try {
connection.server.disabled = disabled
// Only refresh capabilities if connected
if (connection.server.status === "connected") {
connection.server.tools = await this.fetchToolsList(serverName)
connection.server.resources = await this.fetchResourcesList(serverName)
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(serverName)
}
} catch (error) {
console.error(`Failed to refresh capabilities for ${serverName}:`, error)
}
}
await this.notifyWebviewOfServerChanges()
}
} catch (error) {
console.error("Failed to update server disabled state:", error)
if (error instanceof Error) {
console.error("Error details:", error.message, error.stack)
}
vscode.window.showErrorMessage(
`Failed to update server state: ${error instanceof Error ? error.message : String(error)}`,
)
throw error
}
}
public async updateServerTimeout(serverName: string, timeout: number): Promise<void> {
let settingsPath: string
try {
settingsPath = await this.getMcpSettingsFilePath()
// Ensure the settings file exists and is accessible
try {
await fs.access(settingsPath)
} catch (error) {
console.error("Settings file not accessible:", error)
throw new Error("Settings file not accessible")
}
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Validate the config structure
if (!config || typeof config !== "object") {
throw new Error("Invalid config structure")
}
if (!config.mcpServers || typeof config.mcpServers !== "object") {
config.mcpServers = {}
}
if (config.mcpServers[serverName]) {
// Create a new server config object to ensure clean structure
const serverConfig = {
...config.mcpServers[serverName],
timeout,
}
config.mcpServers[serverName] = serverConfig
// Write the entire config back
const updatedConfig = {
mcpServers: config.mcpServers,
}
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
await this.notifyWebviewOfServerChanges()
}
} catch (error) {
console.error("Failed to update server timeout:", error)
if (error instanceof Error) {
console.error("Error details:", error.message, error.stack)
}
vscode.window.showErrorMessage(
`Failed to update server timeout: ${error instanceof Error ? error.message : String(error)}`,
)
throw error
}
}
public async deleteServer(serverName: string): Promise<void> {
try {
const settingsPath = await this.getMcpSettingsFilePath()
// Ensure the settings file exists and is accessible
try {
await fs.access(settingsPath)
} catch (error) {
throw new Error("Settings file not accessible")
}
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Validate the config structure
if (!config || typeof config !== "object") {
throw new Error("Invalid config structure")
}
if (!config.mcpServers || typeof config.mcpServers !== "object") {
config.mcpServers = {}
}
// Remove the server from the settings
if (config.mcpServers[serverName]) {
delete config.mcpServers[serverName]
// Write the entire config back
const updatedConfig = {
mcpServers: config.mcpServers,
}
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
// Update server connections
await this.updateServerConnections(config.mcpServers)
vscode.window.showInformationMessage(`Deleted MCP server: ${serverName}`)
} else {
vscode.window.showWarningMessage(`Server "${serverName}" not found in configuration`)
}
} catch (error) {
console.error("Failed to delete MCP server:", error)
if (error instanceof Error) {
console.error("Error details:", error.message, error.stack)
}
vscode.window.showErrorMessage(
`Failed to delete MCP server: ${error instanceof Error ? error.message : String(error)}`,
)
throw error
}
}
async readResource(serverName: string, uri: string): Promise<McpResourceResponse> {
const connection = this.connections.find((conn) => conn.server.name === serverName)
if (!connection) {
throw new Error(`No connection found for server: ${serverName}`)
}
if (connection.server.disabled) {
throw new Error(`Server "${serverName}" is disabled`)
}
return await connection.client.request(
{
method: "resources/read",
params: {
uri,
},
},
ReadResourceResultSchema,
)
}
async callTool(
serverName: string,
toolName: string,
toolArguments?: Record<string, unknown>,
): Promise<McpToolCallResponse> {
const connection = this.connections.find((conn) => conn.server.name === serverName)
if (!connection) {
throw new Error(
`No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`,
)
}
if (connection.server.disabled) {
throw new Error(`Server "${serverName}" is disabled and cannot be used`)
}
let timeout: number
try {
const parsedConfig = StdioConfigSchema.parse(JSON.parse(connection.server.config))
timeout = (parsedConfig.timeout ?? 60) * 1000
} catch (error) {
console.error("Failed to parse server config for timeout:", error)
// Default to 60 seconds if parsing fails
timeout = 60 * 1000
}
return await connection.client.request(
{
method: "tools/call",
params: {
name: toolName,
arguments: toolArguments,
},
},
CallToolResultSchema,
{
timeout,
},
)
}
async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise<void> {
try {
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Initialize alwaysAllow if it doesn't exist
if (!config.mcpServers[serverName].alwaysAllow) {
config.mcpServers[serverName].alwaysAllow = []
}
const alwaysAllow = config.mcpServers[serverName].alwaysAllow
const toolIndex = alwaysAllow.indexOf(toolName)
if (shouldAllow && toolIndex === -1) {
// Add tool to always allow list
alwaysAllow.push(toolName)
} else if (!shouldAllow && toolIndex !== -1) {
// Remove tool from always allow list
alwaysAllow.splice(toolIndex, 1)
}
// Write updated config back to file
await fs.writeFile(settingsPath, JSON.stringify(config, null, 2))
// Update the tools list to reflect the change
const connection = this.connections.find((conn) => conn.server.name === serverName)
if (connection) {
connection.server.tools = await this.fetchToolsList(serverName)
await this.notifyWebviewOfServerChanges()
}
} catch (error) {
console.error("Failed to update always allow settings:", error)
vscode.window.showErrorMessage("Failed to update always allow settings")
throw error // Re-throw to ensure the error is properly handled
}
}
async dispose(): Promise<void> {
this.removeAllFileWatchers()
for (const connection of this.connections) {
try {
await this.deleteConnection(connection.server.name)
} catch (error) {
console.error(`Failed to close connection for ${connection.server.name}:`, error)
}
}
this.connections = []
if (this.settingsWatcher) {
this.settingsWatcher.dispose()
}
this.disposables.forEach((d) => d.dispose())
}
}

View File

@@ -0,0 +1,83 @@
import * as vscode from "vscode"
import { ClineProvider } from "../../core/webview/ClineProvider"
import { McpHub } from "./McpHub"
/**
* Singleton manager for MCP server instances.
* Ensures only one set of MCP servers runs across all webviews.
*/
export class McpServerManager {
private static instance: McpHub | null = null
private static readonly GLOBAL_STATE_KEY = "mcpHubInstanceId"
private static providers: Set<ClineProvider> = new Set()
private static initializationPromise: Promise<McpHub> | null = null
/**
* Get the singleton McpHub instance.
* Creates a new instance if one doesn't exist.
* Thread-safe implementation using a promise-based lock.
*/
static async getInstance(context: vscode.ExtensionContext, provider: ClineProvider): Promise<McpHub> {
// Register the provider
this.providers.add(provider)
// If we already have an instance, return it
if (this.instance) {
return this.instance
}
// If initialization is in progress, wait for it
if (this.initializationPromise) {
return this.initializationPromise
}
// Create a new initialization promise
this.initializationPromise = (async () => {
try {
// Double-check instance in case it was created while we were waiting
if (!this.instance) {
this.instance = new McpHub(provider)
// Store a unique identifier in global state to track the primary instance
await context.globalState.update(this.GLOBAL_STATE_KEY, Date.now().toString())
}
return this.instance
} finally {
// Clear the initialization promise after completion or error
this.initializationPromise = null
}
})()
return this.initializationPromise
}
/**
* Remove a provider from the tracked set.
* This is called when a webview is disposed.
*/
static unregisterProvider(provider: ClineProvider): void {
this.providers.delete(provider)
}
/**
* Notify all registered providers of server state changes.
*/
static notifyProviders(message: any): void {
this.providers.forEach((provider) => {
provider.postMessageToWebview(message).catch((error) => {
console.error("Failed to notify provider:", error)
})
})
}
/**
* Clean up the singleton instance and all its resources.
*/
static async cleanup(context: vscode.ExtensionContext): Promise<void> {
if (this.instance) {
await this.instance.dispose()
this.instance = null
await context.globalState.update(this.GLOBAL_STATE_KEY, undefined)
}
this.providers.clear()
}
}

View File

@@ -0,0 +1,206 @@
// import * as vscode from "vscode"
import * as childProcess from "child_process"
import * as fs from "fs"
import * as path from "path"
import * as readline from "readline"
const isWindows = /^win/.test(process.platform)
const binName = isWindows ? "rg.exe" : "rg"
interface SearchResult {
file: string
line: number
column: number
match: string
beforeContext: string[]
afterContext: string[]
}
// Constants
const MAX_RESULTS = 300
const MAX_LINE_LENGTH = 500
/**
* Truncates a line if it exceeds the maximum length
* @param line The line to truncate
* @param maxLength The maximum allowed length (defaults to MAX_LINE_LENGTH)
* @returns The truncated line, or the original line if it's shorter than maxLength
*/
export function truncateLine(line: string, maxLength: number = MAX_LINE_LENGTH): string {
return line.length > maxLength ? line.substring(0, maxLength) + " [truncated...]" : line
}
async function getBinPath(): Promise<string | undefined> {
const binPath = path.join("/opt/homebrew/bin/", binName)
return (await pathExists(binPath)) ? binPath : undefined
}
async function pathExists(path: string): Promise<boolean> {
return new Promise((resolve) => {
fs.access(path, (err) => {
resolve(err === null)
})
})
}
async function execRipgrep(bin: string, args: string[]): Promise<string> {
return new Promise((resolve, reject) => {
const rgProcess = childProcess.spawn(bin, args)
// cross-platform alternative to head, which is ripgrep author's recommendation for limiting output.
const rl = readline.createInterface({
input: rgProcess.stdout,
crlfDelay: Infinity, // treat \r\n as a single line break even if it's split across chunks. This ensures consistent behavior across different operating systems.
})
let output = ""
let lineCount = 0
const maxLines = MAX_RESULTS * 5 // limiting ripgrep output with max lines since there's no other way to limit results. it's okay that we're outputting as json, since we're parsing it line by line and ignore anything that's not part of a match. This assumes each result is at most 5 lines.
rl.on("line", (line) => {
if (lineCount < maxLines) {
output += line + "\n"
lineCount++
} else {
rl.close()
rgProcess.kill()
}
})
let errorOutput = ""
rgProcess.stderr.on("data", (data) => {
errorOutput += data.toString()
})
rl.on("close", () => {
if (errorOutput) {
reject(new Error(`ripgrep process error: ${errorOutput}`))
} else {
resolve(output)
}
})
rgProcess.on("error", (error) => {
reject(new Error(`ripgrep process error: ${error.message}`))
})
})
}
export async function regexSearchFiles(
directoryPath: string,
regex: string,
): Promise<string> {
const rgPath = await getBinPath()
if (!rgPath) {
throw new Error("Could not find ripgrep binary")
}
// 使用--glob参数排除.obsidian目录
const args = [
"--json",
"-e",
regex,
"--glob",
"!.obsidian/**", // 排除.obsidian目录及其所有子目录
"--glob",
"!.git/**",
"--context",
"1",
directoryPath
]
let output: string
try {
output = await execRipgrep(rgPath, args)
console.log("output", output)
} catch (error) {
console.error("Error executing ripgrep:", error)
return "No results found"
}
const results: SearchResult[] = []
let currentResult: Partial<SearchResult> | null = null
output.split("\n").forEach((line) => {
if (line) {
try {
const parsed = JSON.parse(line)
if (parsed.type === "match") {
if (currentResult) {
results.push(currentResult as SearchResult)
}
// Safety check: truncate extremely long lines to prevent excessive output
const matchText = parsed.data.lines.text
const truncatedMatch = truncateLine(matchText)
currentResult = {
file: parsed.data.path.text,
line: parsed.data.line_number,
column: parsed.data.submatches[0].start,
match: truncatedMatch,
beforeContext: [],
afterContext: [],
}
} else if (parsed.type === "context" && currentResult) {
// Apply the same truncation logic to context lines
const contextText = parsed.data.lines.text
const truncatedContext = truncateLine(contextText)
if (parsed.data.line_number < currentResult.line!) {
currentResult.beforeContext!.push(truncatedContext)
} else {
currentResult.afterContext!.push(truncatedContext)
}
}
} catch (error) {
console.error("Error parsing ripgrep output:", error)
}
}
})
if (currentResult) {
results.push(currentResult as SearchResult)
}
console.log("results", results)
console.log("currentResult", currentResult)
return formatResults(results, directoryPath)
}
function formatResults(results: SearchResult[], cwd: string): string {
const groupedResults: { [key: string]: SearchResult[] } = {}
let output = ""
if (results.length >= MAX_RESULTS) {
output += `Showing first ${MAX_RESULTS} of ${MAX_RESULTS}+ results. Use a more specific search if necessary.\n\n`
} else {
output += `Found ${results.length === 1 ? "1 result" : `${results.length.toLocaleString()} results`}.\n\n`
}
// Group results by file name
results.slice(0, MAX_RESULTS).forEach((result) => {
const relativeFilePath = path.relative(cwd, result.file)
if (!groupedResults[relativeFilePath]) {
groupedResults[relativeFilePath] = []
}
groupedResults[relativeFilePath].push(result)
})
for (const [filePath, fileResults] of Object.entries(groupedResults)) {
output += `${filePath.toPosix()}\n│----\n`
fileResults.forEach((result, index) => {
const allLines = [...result.beforeContext, result.match, ...result.afterContext]
allLines.forEach((line) => {
output += `${line?.trimEnd() ?? ""}\n`
})
if (index < fileResults.length - 1) {
output += "│----\n"
}
})
output += "│----\n\n"
}
return output.trim()
}

View File