add tool use, update system prompt
This commit is contained in:
22
src/core/diff/DiffStrategy.ts
Normal file
22
src/core/diff/DiffStrategy.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import type { DiffStrategy } from "./types"
|
||||
import { UnifiedDiffStrategy } from "./strategies/unified"
|
||||
import { SearchReplaceDiffStrategy } from "./strategies/search-replace"
|
||||
import { NewUnifiedDiffStrategy } from "./strategies/new-unified"
|
||||
/**
|
||||
* Get the appropriate diff strategy for the given model
|
||||
* @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus')
|
||||
* @returns The appropriate diff strategy for the model
|
||||
*/
|
||||
export function getDiffStrategy(
|
||||
model: string,
|
||||
fuzzyMatchThreshold?: number,
|
||||
experimentalDiffStrategy: boolean = false,
|
||||
): DiffStrategy {
|
||||
if (experimentalDiffStrategy) {
|
||||
return new NewUnifiedDiffStrategy(fuzzyMatchThreshold)
|
||||
}
|
||||
return new SearchReplaceDiffStrategy(fuzzyMatchThreshold)
|
||||
}
|
||||
|
||||
export type { DiffStrategy }
|
||||
export { UnifiedDiffStrategy, SearchReplaceDiffStrategy }
|
||||
31
src/core/diff/insert-groups.ts
Normal file
31
src/core/diff/insert-groups.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Inserts multiple groups of elements at specified indices in an array
|
||||
* @param original Array to insert into, split by lines
|
||||
* @param insertGroups Array of groups to insert, each with an index and elements to insert
|
||||
* @returns New array with all insertions applied
|
||||
*/
|
||||
export interface InsertGroup {
|
||||
index: number
|
||||
elements: string[]
|
||||
}
|
||||
|
||||
export function insertGroups(original: string[], insertGroups: InsertGroup[]): string[] {
|
||||
// Sort groups by index to maintain order
|
||||
insertGroups.sort((a, b) => a.index - b.index)
|
||||
|
||||
let result: string[] = []
|
||||
let lastIndex = 0
|
||||
|
||||
insertGroups.forEach(({ index, elements }) => {
|
||||
// Add elements from original array up to insertion point
|
||||
result.push(...original.slice(lastIndex, index))
|
||||
// Add the group of elements
|
||||
result.push(...elements)
|
||||
lastIndex = index
|
||||
})
|
||||
|
||||
// Add remaining elements from original array
|
||||
result.push(...original.slice(lastIndex))
|
||||
|
||||
return result
|
||||
}
|
||||
738
src/core/diff/strategies/__tests__/new-unified.test.ts
Normal file
738
src/core/diff/strategies/__tests__/new-unified.test.ts
Normal file
@@ -0,0 +1,738 @@
|
||||
import { NewUnifiedDiffStrategy } from "../new-unified"
|
||||
|
||||
describe("main", () => {
|
||||
let strategy: NewUnifiedDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new NewUnifiedDiffStrategy(0.97)
|
||||
})
|
||||
|
||||
describe("constructor", () => {
|
||||
it("should use default confidence threshold when not provided", () => {
|
||||
const defaultStrategy = new NewUnifiedDiffStrategy()
|
||||
expect(defaultStrategy["confidenceThreshold"]).toBe(1)
|
||||
})
|
||||
|
||||
it("should use provided confidence threshold", () => {
|
||||
const customStrategy = new NewUnifiedDiffStrategy(0.85)
|
||||
expect(customStrategy["confidenceThreshold"]).toBe(0.85)
|
||||
})
|
||||
|
||||
it("should enforce minimum confidence threshold", () => {
|
||||
const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8
|
||||
expect(lowStrategy["confidenceThreshold"]).toBe(0.8)
|
||||
})
|
||||
})
|
||||
|
||||
describe("getToolDescription", () => {
|
||||
it("should return tool description with correct cwd", () => {
|
||||
const cwd = "/test/path"
|
||||
const description = strategy.getToolDescription({ cwd })
|
||||
|
||||
expect(description).toContain("apply_diff Tool - Generate Precise Code Changes")
|
||||
expect(description).toContain(cwd)
|
||||
expect(description).toContain("Step-by-Step Instructions")
|
||||
expect(description).toContain("Requirements")
|
||||
expect(description).toContain("Examples")
|
||||
expect(description).toContain("Parameters:")
|
||||
})
|
||||
})
|
||||
|
||||
it("should apply simple diff correctly", async () => {
|
||||
const original = `line1
|
||||
line2
|
||||
line3`
|
||||
|
||||
const diff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1
|
||||
+new line
|
||||
line2
|
||||
-line3
|
||||
+modified line3`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`line1
|
||||
new line
|
||||
line2
|
||||
modified line3`)
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle multiple hunks", async () => {
|
||||
const original = `line1
|
||||
line2
|
||||
line3
|
||||
line4
|
||||
line5`
|
||||
|
||||
const diff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1
|
||||
+new line
|
||||
line2
|
||||
-line3
|
||||
+modified line3
|
||||
@@ ... @@
|
||||
line4
|
||||
-line5
|
||||
+modified line5
|
||||
+new line at end`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`line1
|
||||
new line
|
||||
line2
|
||||
modified line3
|
||||
line4
|
||||
modified line5
|
||||
new line at end`)
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle complex large", async () => {
|
||||
const original = `line1
|
||||
line2
|
||||
line3
|
||||
line4
|
||||
line5
|
||||
line6
|
||||
line7
|
||||
line8
|
||||
line9
|
||||
line10`
|
||||
|
||||
const diff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1
|
||||
+header line
|
||||
+another header
|
||||
line2
|
||||
-line3
|
||||
-line4
|
||||
+modified line3
|
||||
+modified line4
|
||||
+extra line
|
||||
@@ ... @@
|
||||
line6
|
||||
+middle section
|
||||
line7
|
||||
-line8
|
||||
+changed line8
|
||||
+bonus line
|
||||
@@ ... @@
|
||||
line9
|
||||
-line10
|
||||
+final line
|
||||
+very last line`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`line1
|
||||
header line
|
||||
another header
|
||||
line2
|
||||
modified line3
|
||||
modified line4
|
||||
extra line
|
||||
line5
|
||||
line6
|
||||
middle section
|
||||
line7
|
||||
changed line8
|
||||
bonus line
|
||||
line9
|
||||
final line
|
||||
very last line`)
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle indentation changes", async () => {
|
||||
const original = `first line
|
||||
indented line
|
||||
double indented line
|
||||
back to single indent
|
||||
no indent
|
||||
indented again
|
||||
double indent again
|
||||
triple indent
|
||||
back to single
|
||||
last line`
|
||||
|
||||
const diff = `--- original
|
||||
+++ modified
|
||||
@@ ... @@
|
||||
first line
|
||||
indented line
|
||||
+ tab indented line
|
||||
+ new indented line
|
||||
double indented line
|
||||
back to single indent
|
||||
no indent
|
||||
indented again
|
||||
double indent again
|
||||
- triple indent
|
||||
+ hi there mate
|
||||
back to single
|
||||
last line`
|
||||
|
||||
const expected = `first line
|
||||
indented line
|
||||
tab indented line
|
||||
new indented line
|
||||
double indented line
|
||||
back to single indent
|
||||
no indent
|
||||
indented again
|
||||
double indent again
|
||||
hi there mate
|
||||
back to single
|
||||
last line`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle high level edits", async () => {
|
||||
const original = `def factorial(n):
|
||||
if n == 0:
|
||||
return 1
|
||||
else:
|
||||
return n * factorial(n-1)`
|
||||
const diff = `@@ ... @@
|
||||
-def factorial(n):
|
||||
- if n == 0:
|
||||
- return 1
|
||||
- else:
|
||||
- return n * factorial(n-1)
|
||||
+def factorial(number):
|
||||
+ if number == 0:
|
||||
+ return 1
|
||||
+ else:
|
||||
+ return number * factorial(number-1)`
|
||||
|
||||
const expected = `def factorial(number):
|
||||
if number == 0:
|
||||
return 1
|
||||
else:
|
||||
return number * factorial(number-1)`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
|
||||
it("it should handle very complex edits", async () => {
|
||||
const original = `//Initialize the array that will hold the primes
|
||||
var primeArray = [];
|
||||
/*Write a function that checks for primeness and
|
||||
pushes those values to t*he array*/
|
||||
function PrimeCheck(candidate){
|
||||
isPrime = true;
|
||||
for(var i = 2; i < candidate && isPrime; i++){
|
||||
if(candidate%i === 0){
|
||||
isPrime = false;
|
||||
} else {
|
||||
isPrime = true;
|
||||
}
|
||||
}
|
||||
if(isPrime){
|
||||
primeArray.push(candidate);
|
||||
}
|
||||
return primeArray;
|
||||
}
|
||||
/*Write the code that runs the above until the
|
||||
l ength of the array equa*ls the number of primes
|
||||
desired*/
|
||||
|
||||
var numPrimes = prompt("How many primes?");
|
||||
|
||||
//Display the finished array of primes
|
||||
|
||||
//for loop starting at 2 as that is the lowest prime number keep going until the array is as long as we requested
|
||||
for (var i = 2; primeArray.length < numPrimes; i++) {
|
||||
PrimeCheck(i); //
|
||||
}
|
||||
console.log(primeArray);
|
||||
`
|
||||
|
||||
const diff = `--- test_diff.js
|
||||
+++ test_diff.js
|
||||
@@ ... @@
|
||||
-//Initialize the array that will hold the primes
|
||||
var primeArray = [];
|
||||
-/*Write a function that checks for primeness and
|
||||
- pushes those values to t*he array*/
|
||||
function PrimeCheck(candidate){
|
||||
isPrime = true;
|
||||
for(var i = 2; i < candidate && isPrime; i++){
|
||||
@@ ... @@
|
||||
return primeArray;
|
||||
}
|
||||
-/*Write the code that runs the above until the
|
||||
- l ength of the array equa*ls the number of primes
|
||||
- desired*/
|
||||
|
||||
var numPrimes = prompt("How many primes?");
|
||||
|
||||
-//Display the finished array of primes
|
||||
-
|
||||
-//for loop starting at 2 as that is the lowest prime number keep going until the array is as long as we requested
|
||||
for (var i = 2; primeArray.length < numPrimes; i++) {
|
||||
- PrimeCheck(i); //
|
||||
+ PrimeCheck(i);
|
||||
}
|
||||
console.log(primeArray);`
|
||||
|
||||
const expected = `var primeArray = [];
|
||||
function PrimeCheck(candidate){
|
||||
isPrime = true;
|
||||
for(var i = 2; i < candidate && isPrime; i++){
|
||||
if(candidate%i === 0){
|
||||
isPrime = false;
|
||||
} else {
|
||||
isPrime = true;
|
||||
}
|
||||
}
|
||||
if(isPrime){
|
||||
primeArray.push(candidate);
|
||||
}
|
||||
return primeArray;
|
||||
}
|
||||
|
||||
var numPrimes = prompt("How many primes?");
|
||||
|
||||
for (var i = 2; primeArray.length < numPrimes; i++) {
|
||||
PrimeCheck(i);
|
||||
}
|
||||
console.log(primeArray);
|
||||
`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
|
||||
describe("error handling and edge cases", () => {
|
||||
it("should reject completely invalid diff format", async () => {
|
||||
const original = "line1\nline2\nline3"
|
||||
const invalidDiff = "this is not a diff at all"
|
||||
|
||||
const result = await strategy.applyDiff(original, invalidDiff)
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it("should reject diff with invalid hunk format", async () => {
|
||||
const original = "line1\nline2\nline3"
|
||||
const invalidHunkDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
invalid hunk header
|
||||
line1
|
||||
-line2
|
||||
+new line`
|
||||
|
||||
const result = await strategy.applyDiff(original, invalidHunkDiff)
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it("should fail when diff tries to modify non-existent content", async () => {
|
||||
const original = "line1\nline2\nline3"
|
||||
const nonMatchingDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1
|
||||
-nonexistent line
|
||||
+new line
|
||||
line3`
|
||||
|
||||
const result = await strategy.applyDiff(original, nonMatchingDiff)
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it("should handle overlapping hunks", async () => {
|
||||
const original = `line1
|
||||
line2
|
||||
line3
|
||||
line4
|
||||
line5`
|
||||
const overlappingDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1
|
||||
line2
|
||||
-line3
|
||||
+modified3
|
||||
line4
|
||||
@@ ... @@
|
||||
line2
|
||||
-line3
|
||||
-line4
|
||||
+modified3and4
|
||||
line5`
|
||||
|
||||
const result = await strategy.applyDiff(original, overlappingDiff)
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
it("should handle empty lines modifications", async () => {
|
||||
const original = `line1
|
||||
|
||||
line3
|
||||
|
||||
line5`
|
||||
const emptyLinesDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1
|
||||
|
||||
-line3
|
||||
+line3modified
|
||||
|
||||
line5`
|
||||
|
||||
const result = await strategy.applyDiff(original, emptyLinesDiff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`line1
|
||||
|
||||
line3modified
|
||||
|
||||
line5`)
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle mixed line endings in diff", async () => {
|
||||
const original = "line1\r\nline2\nline3\r\n"
|
||||
const mixedEndingsDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
line1\r
|
||||
-line2
|
||||
+modified2\r
|
||||
line3`
|
||||
|
||||
const result = await strategy.applyDiff(original, mixedEndingsDiff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe("line1\r\nmodified2\r\nline3\r\n")
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle partial line modifications", async () => {
|
||||
const original = "const value = oldValue + 123;"
|
||||
const partialDiff = `--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ ... @@
|
||||
-const value = oldValue + 123;
|
||||
+const value = newValue + 123;`
|
||||
|
||||
const result = await strategy.applyDiff(original, partialDiff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe("const value = newValue + 123;")
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle slightly malformed but recoverable diff", async () => {
|
||||
const original = "line1\nline2\nline3"
|
||||
// Missing space after --- and +++
|
||||
const slightlyBadDiff = `---a/file.txt
|
||||
+++b/file.txt
|
||||
@@ ... @@
|
||||
line1
|
||||
-line2
|
||||
+new line
|
||||
line3`
|
||||
|
||||
const result = await strategy.applyDiff(original, slightlyBadDiff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe("line1\nnew line\nline3")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("similar code sections", () => {
|
||||
it("should correctly modify the right section when similar code exists", async () => {
|
||||
const original = `function add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
function subtract(a, b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
function multiply(a, b) {
|
||||
return a + b; // Bug here
|
||||
}`
|
||||
|
||||
const diff = `--- a/math.js
|
||||
+++ b/math.js
|
||||
@@ ... @@
|
||||
function multiply(a, b) {
|
||||
- return a + b; // Bug here
|
||||
+ return a * b;
|
||||
}`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`function add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
function subtract(a, b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
function multiply(a, b) {
|
||||
return a * b;
|
||||
}`)
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle multiple similar sections with correct context", async () => {
|
||||
const original = `if (condition) {
|
||||
doSomething();
|
||||
doSomething();
|
||||
doSomething();
|
||||
}
|
||||
|
||||
if (otherCondition) {
|
||||
doSomething();
|
||||
doSomething();
|
||||
doSomething();
|
||||
}`
|
||||
|
||||
const diff = `--- a/file.js
|
||||
+++ b/file.js
|
||||
@@ ... @@
|
||||
if (otherCondition) {
|
||||
doSomething();
|
||||
- doSomething();
|
||||
+ doSomethingElse();
|
||||
doSomething();
|
||||
}`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(`if (condition) {
|
||||
doSomething();
|
||||
doSomething();
|
||||
doSomething();
|
||||
}
|
||||
|
||||
if (otherCondition) {
|
||||
doSomething();
|
||||
doSomethingElse();
|
||||
doSomething();
|
||||
}`)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("hunk splitting", () => {
|
||||
it("should handle large diffs with multiple non-contiguous changes", async () => {
|
||||
const original = `import { readFile } from 'fs';
|
||||
import { join } from 'path';
|
||||
import { Logger } from './logger';
|
||||
|
||||
const logger = new Logger();
|
||||
|
||||
async function processFile(filePath: string) {
|
||||
try {
|
||||
const data = await readFile(filePath, 'utf8');
|
||||
logger.info('File read successfully');
|
||||
return data;
|
||||
} catch (error) {
|
||||
logger.error('Failed to read file:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
function validateInput(input: string): boolean {
|
||||
if (!input) {
|
||||
logger.warn('Empty input received');
|
||||
return false;
|
||||
}
|
||||
return input.length > 0;
|
||||
}
|
||||
|
||||
async function writeOutput(data: string) {
|
||||
logger.info('Processing output');
|
||||
// TODO: Implement output writing
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
function parseConfig(configPath: string) {
|
||||
logger.debug('Reading config from:', configPath);
|
||||
// Basic config parsing
|
||||
return {
|
||||
enabled: true,
|
||||
maxRetries: 3
|
||||
};
|
||||
}
|
||||
|
||||
export {
|
||||
processFile,
|
||||
validateInput,
|
||||
writeOutput,
|
||||
parseConfig
|
||||
};`
|
||||
|
||||
const diff = `--- a/file.ts
|
||||
+++ b/file.ts
|
||||
@@ ... @@
|
||||
-import { readFile } from 'fs';
|
||||
+import { readFile, writeFile } from 'fs';
|
||||
import { join } from 'path';
|
||||
-import { Logger } from './logger';
|
||||
+import { Logger } from './utils/logger';
|
||||
+import { Config } from './types';
|
||||
|
||||
-const logger = new Logger();
|
||||
+const logger = new Logger('FileProcessor');
|
||||
|
||||
async function processFile(filePath: string) {
|
||||
try {
|
||||
const data = await readFile(filePath, 'utf8');
|
||||
- logger.info('File read successfully');
|
||||
+ logger.info(\`File \${filePath} read successfully\`);
|
||||
return data;
|
||||
} catch (error) {
|
||||
- logger.error('Failed to read file:', error);
|
||||
+ logger.error(\`Failed to read file \${filePath}:\`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
function validateInput(input: string): boolean {
|
||||
if (!input) {
|
||||
- logger.warn('Empty input received');
|
||||
+ logger.warn('Validation failed: Empty input received');
|
||||
return false;
|
||||
}
|
||||
- return input.length > 0;
|
||||
+ return input.trim().length > 0;
|
||||
}
|
||||
|
||||
-async function writeOutput(data: string) {
|
||||
- logger.info('Processing output');
|
||||
- // TODO: Implement output writing
|
||||
- return Promise.resolve();
|
||||
+async function writeOutput(data: string, outputPath: string) {
|
||||
+ try {
|
||||
+ await writeFile(outputPath, data, 'utf8');
|
||||
+ logger.info(\`Output written to \${outputPath}\`);
|
||||
+ } catch (error) {
|
||||
+ logger.error(\`Failed to write output to \${outputPath}:\`, error);
|
||||
+ throw error;
|
||||
+ }
|
||||
}
|
||||
|
||||
-function parseConfig(configPath: string) {
|
||||
- logger.debug('Reading config from:', configPath);
|
||||
- // Basic config parsing
|
||||
- return {
|
||||
- enabled: true,
|
||||
- maxRetries: 3
|
||||
- };
|
||||
+async function parseConfig(configPath: string): Promise<Config> {
|
||||
+ try {
|
||||
+ const configData = await readFile(configPath, 'utf8');
|
||||
+ logger.debug(\`Reading config from \${configPath}\`);
|
||||
+ return JSON.parse(configData);
|
||||
+ } catch (error) {
|
||||
+ logger.error(\`Failed to parse config from \${configPath}:\`, error);
|
||||
+ throw error;
|
||||
+ }
|
||||
}
|
||||
|
||||
export {
|
||||
processFile,
|
||||
validateInput,
|
||||
writeOutput,
|
||||
- parseConfig
|
||||
+ parseConfig,
|
||||
+ type Config
|
||||
};`
|
||||
|
||||
const expected = `import { readFile, writeFile } from 'fs';
|
||||
import { join } from 'path';
|
||||
import { Logger } from './utils/logger';
|
||||
import { Config } from './types';
|
||||
|
||||
const logger = new Logger('FileProcessor');
|
||||
|
||||
async function processFile(filePath: string) {
|
||||
try {
|
||||
const data = await readFile(filePath, 'utf8');
|
||||
logger.info(\`File \${filePath} read successfully\`);
|
||||
return data;
|
||||
} catch (error) {
|
||||
logger.error(\`Failed to read file \${filePath}:\`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
function validateInput(input: string): boolean {
|
||||
if (!input) {
|
||||
logger.warn('Validation failed: Empty input received');
|
||||
return false;
|
||||
}
|
||||
return input.trim().length > 0;
|
||||
}
|
||||
|
||||
async function writeOutput(data: string, outputPath: string) {
|
||||
try {
|
||||
await writeFile(outputPath, data, 'utf8');
|
||||
logger.info(\`Output written to \${outputPath}\`);
|
||||
} catch (error) {
|
||||
logger.error(\`Failed to write output to \${outputPath}:\`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async function parseConfig(configPath: string): Promise<Config> {
|
||||
try {
|
||||
const configData = await readFile(configPath, 'utf8');
|
||||
logger.debug(\`Reading config from \${configPath}\`);
|
||||
return JSON.parse(configData);
|
||||
} catch (error) {
|
||||
logger.error(\`Failed to parse config from \${configPath}:\`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export {
|
||||
processFile,
|
||||
validateInput,
|
||||
writeOutput,
|
||||
parseConfig,
|
||||
type Config
|
||||
};`
|
||||
|
||||
const result = await strategy.applyDiff(original, diff)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
1557
src/core/diff/strategies/__tests__/search-replace.test.ts
Normal file
1557
src/core/diff/strategies/__tests__/search-replace.test.ts
Normal file
File diff suppressed because it is too large
Load Diff
228
src/core/diff/strategies/__tests__/unified.test.ts
Normal file
228
src/core/diff/strategies/__tests__/unified.test.ts
Normal file
@@ -0,0 +1,228 @@
|
||||
import { UnifiedDiffStrategy } from "../unified"
|
||||
|
||||
describe("UnifiedDiffStrategy", () => {
|
||||
let strategy: UnifiedDiffStrategy
|
||||
|
||||
beforeEach(() => {
|
||||
strategy = new UnifiedDiffStrategy()
|
||||
})
|
||||
|
||||
describe("getToolDescription", () => {
|
||||
it("should return tool description with correct cwd", () => {
|
||||
const cwd = "/test/path"
|
||||
const description = strategy.getToolDescription({ cwd })
|
||||
|
||||
expect(description).toContain("apply_diff")
|
||||
expect(description).toContain(cwd)
|
||||
expect(description).toContain("Parameters:")
|
||||
expect(description).toContain("Format Requirements:")
|
||||
})
|
||||
})
|
||||
|
||||
describe("applyDiff", () => {
|
||||
it("should successfully apply a function modification diff", async () => {
|
||||
const originalContent = `import { Logger } from '../logger';
|
||||
|
||||
function calculateTotal(items: number[]): number {
|
||||
return items.reduce((sum, item) => {
|
||||
return sum + item;
|
||||
}, 0);
|
||||
}
|
||||
|
||||
export { calculateTotal };`
|
||||
|
||||
const diffContent = `--- src/utils/helper.ts
|
||||
+++ src/utils/helper.ts
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Logger } from '../logger';
|
||||
|
||||
function calculateTotal(items: number[]): number {
|
||||
- return items.reduce((sum, item) => {
|
||||
- return sum + item;
|
||||
+ const total = items.reduce((sum, item) => {
|
||||
+ return sum + item * 1.1; // Add 10% markup
|
||||
}, 0);
|
||||
+ return Math.round(total * 100) / 100; // Round to 2 decimal places
|
||||
}
|
||||
|
||||
export { calculateTotal };`
|
||||
|
||||
const expected = `import { Logger } from '../logger';
|
||||
|
||||
function calculateTotal(items: number[]): number {
|
||||
const total = items.reduce((sum, item) => {
|
||||
return sum + item * 1.1; // Add 10% markup
|
||||
}, 0);
|
||||
return Math.round(total * 100) / 100; // Round to 2 decimal places
|
||||
}
|
||||
|
||||
export { calculateTotal };`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
|
||||
it("should successfully apply a diff adding a new method", async () => {
|
||||
const originalContent = `class Calculator {
|
||||
add(a: number, b: number): number {
|
||||
return a + b;
|
||||
}
|
||||
}`
|
||||
|
||||
const diffContent = `--- src/Calculator.ts
|
||||
+++ src/Calculator.ts
|
||||
@@ -1,5 +1,9 @@
|
||||
class Calculator {
|
||||
add(a: number, b: number): number {
|
||||
return a + b;
|
||||
}
|
||||
+
|
||||
+ multiply(a: number, b: number): number {
|
||||
+ return a * b;
|
||||
+ }
|
||||
}`
|
||||
|
||||
const expected = `class Calculator {
|
||||
add(a: number, b: number): number {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
multiply(a: number, b: number): number {
|
||||
return a * b;
|
||||
}
|
||||
}`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
|
||||
it("should successfully apply a diff modifying imports", async () => {
|
||||
const originalContent = `import { useState } from 'react';
|
||||
import { Button } from './components';
|
||||
|
||||
function App() {
|
||||
const [count, setCount] = useState(0);
|
||||
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||
}`
|
||||
|
||||
const diffContent = `--- src/App.tsx
|
||||
+++ src/App.tsx
|
||||
@@ -1,7 +1,8 @@
|
||||
-import { useState } from 'react';
|
||||
+import { useState, useEffect } from 'react';
|
||||
import { Button } from './components';
|
||||
|
||||
function App() {
|
||||
const [count, setCount] = useState(0);
|
||||
+ useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
|
||||
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||
}`
|
||||
|
||||
const expected = `import { useState, useEffect } from 'react';
|
||||
import { Button } from './components';
|
||||
|
||||
function App() {
|
||||
const [count, setCount] = useState(0);
|
||||
useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
|
||||
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||
}`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
|
||||
it("should successfully apply a diff with multiple hunks", async () => {
|
||||
const originalContent = `import { readFile, writeFile } from 'fs';
|
||||
|
||||
function processFile(path: string) {
|
||||
readFile(path, 'utf8', (err, data) => {
|
||||
if (err) throw err;
|
||||
const processed = data.toUpperCase();
|
||||
writeFile(path, processed, (err) => {
|
||||
if (err) throw err;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export { processFile };`
|
||||
|
||||
const diffContent = `--- src/file-processor.ts
|
||||
+++ src/file-processor.ts
|
||||
@@ -1,12 +1,14 @@
|
||||
-import { readFile, writeFile } from 'fs';
|
||||
+import { promises as fs } from 'fs';
|
||||
+import { join } from 'path';
|
||||
|
||||
-function processFile(path: string) {
|
||||
- readFile(path, 'utf8', (err, data) => {
|
||||
- if (err) throw err;
|
||||
+async function processFile(path: string) {
|
||||
+ try {
|
||||
+ const data = await fs.readFile(join(__dirname, path), 'utf8');
|
||||
const processed = data.toUpperCase();
|
||||
- writeFile(path, processed, (err) => {
|
||||
- if (err) throw err;
|
||||
- });
|
||||
- });
|
||||
+ await fs.writeFile(join(__dirname, path), processed);
|
||||
+ } catch (error) {
|
||||
+ console.error('Failed to process file:', error);
|
||||
+ throw error;
|
||||
+ }
|
||||
}
|
||||
|
||||
export { processFile };`
|
||||
|
||||
const expected = `import { promises as fs } from 'fs';
|
||||
import { join } from 'path';
|
||||
|
||||
async function processFile(path: string) {
|
||||
try {
|
||||
const data = await fs.readFile(join(__dirname, path), 'utf8');
|
||||
const processed = data.toUpperCase();
|
||||
await fs.writeFile(join(__dirname, path), processed);
|
||||
} catch (error) {
|
||||
console.error('Failed to process file:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export { processFile };`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
|
||||
it("should handle empty original content", async () => {
|
||||
const originalContent = ""
|
||||
const diffContent = `--- empty.ts
|
||||
+++ empty.ts
|
||||
@@ -0,0 +1,3 @@
|
||||
+export function greet(name: string): string {
|
||||
+ return \`Hello, \${name}!\`;
|
||||
+}`
|
||||
|
||||
const expected = `export function greet(name: string): string {
|
||||
return \`Hello, \${name}!\`;
|
||||
}\n`
|
||||
|
||||
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.content).toBe(expected)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,295 @@
|
||||
import { applyContextMatching, applyDMP, applyGitFallback } from "../edit-strategies"
|
||||
import { Hunk } from "../types"
|
||||
|
||||
const testCases = [
|
||||
{
|
||||
name: "should return original content if no match is found",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: "line1" },
|
||||
{ type: "add", content: "line2" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: ["line1", "line3"],
|
||||
matchPosition: -1,
|
||||
expected: {
|
||||
confidence: 0,
|
||||
result: ["line1", "line3"],
|
||||
},
|
||||
expectedResult: "line1\nline3",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a simple add change",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: "line1" },
|
||||
{ type: "add", content: "line2" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: ["line1", "line3"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: ["line1", "line2", "line3"],
|
||||
},
|
||||
expectedResult: "line1\nline2\nline3",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a simple remove change",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: "line1" },
|
||||
{ type: "remove", content: "line2" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: ["line1", "line2", "line3"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: ["line1", "line3"],
|
||||
},
|
||||
expectedResult: "line1\nline3",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a simple context change",
|
||||
hunk: {
|
||||
changes: [{ type: "context", content: "line1" }],
|
||||
} as Hunk,
|
||||
content: ["line1", "line2", "line3"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: ["line1", "line2", "line3"],
|
||||
},
|
||||
expectedResult: "line1\nline2\nline3",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a multi-line add change",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: "line1" },
|
||||
{ type: "add", content: "line2\nline3" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: ["line1", "line4"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: ["line1", "line2\nline3", "line4"],
|
||||
},
|
||||
expectedResult: "line1\nline2\nline3\nline4",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a multi-line remove change",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: "line1" },
|
||||
{ type: "remove", content: "line2\nline3" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: ["line1", "line2", "line3", "line4"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: ["line1", "line4"],
|
||||
},
|
||||
expectedResult: "line1\nline4",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a multi-line context change",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: "line1" },
|
||||
{ type: "context", content: "line2\nline3" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: ["line1", "line2", "line3", "line4"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: ["line1", "line2\nline3", "line4"],
|
||||
},
|
||||
expectedResult: "line1\nline2\nline3\nline4",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a change with indentation",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: " line1" },
|
||||
{ type: "add", content: " line2" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: [" line1", " line3"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: [" line1", " line2", " line3"],
|
||||
},
|
||||
expectedResult: " line1\n line2\n line3",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a change with mixed indentation",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: "\tline1" },
|
||||
{ type: "add", content: " line2" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: ["\tline1", " line3"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: ["\tline1", " line2", " line3"],
|
||||
},
|
||||
expectedResult: "\tline1\n line2\n line3",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a change with mixed indentation and multi-line",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: " line1" },
|
||||
{ type: "add", content: "\tline2\n line3" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: [" line1", " line4"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: [" line1", "\tline2\n line3", " line4"],
|
||||
},
|
||||
expectedResult: " line1\n\tline2\n line3\n line4",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a complex change with mixed indentation and multi-line",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: " line1" },
|
||||
{ type: "remove", content: " line2" },
|
||||
{ type: "add", content: "\tline3\n line4" },
|
||||
{ type: "context", content: " line5" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: [" line1", " line2", " line5", " line6"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: [" line1", "\tline3\n line4", " line5", " line6"],
|
||||
},
|
||||
expectedResult: " line1\n\tline3\n line4\n line5\n line6",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a complex change with mixed indentation and multi-line and context",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: " line1" },
|
||||
{ type: "remove", content: " line2" },
|
||||
{ type: "add", content: "\tline3\n line4" },
|
||||
{ type: "context", content: " line5" },
|
||||
{ type: "context", content: " line6" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: [" line1", " line2", " line5", " line6", " line7"],
|
||||
matchPosition: 0,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: [" line1", "\tline3\n line4", " line5", " line6", " line7"],
|
||||
},
|
||||
expectedResult: " line1\n\tline3\n line4\n line5\n line6\n line7",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
{
|
||||
name: "should apply a complex change with mixed indentation and multi-line and context and a different match position",
|
||||
hunk: {
|
||||
changes: [
|
||||
{ type: "context", content: " line1" },
|
||||
{ type: "remove", content: " line2" },
|
||||
{ type: "add", content: "\tline3\n line4" },
|
||||
{ type: "context", content: " line5" },
|
||||
{ type: "context", content: " line6" },
|
||||
],
|
||||
} as Hunk,
|
||||
content: [" line0", " line1", " line2", " line5", " line6", " line7"],
|
||||
matchPosition: 1,
|
||||
expected: {
|
||||
confidence: 1,
|
||||
result: [" line0", " line1", "\tline3\n line4", " line5", " line6", " line7"],
|
||||
},
|
||||
expectedResult: " line0\n line1\n\tline3\n line4\n line5\n line6\n line7",
|
||||
strategies: ["context", "dmp"],
|
||||
},
|
||||
]
|
||||
|
||||
describe("applyContextMatching", () => {
|
||||
testCases.forEach(({ name, hunk, content, matchPosition, expected, strategies, expectedResult }) => {
|
||||
if (!strategies?.includes("context")) {
|
||||
return
|
||||
}
|
||||
it(name, () => {
|
||||
const result = applyContextMatching(hunk, content, matchPosition)
|
||||
expect(result.result.join("\n")).toEqual(expectedResult)
|
||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||
expect(result.strategy).toBe("context")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("applyDMP", () => {
|
||||
testCases.forEach(({ name, hunk, content, matchPosition, expected, strategies, expectedResult }) => {
|
||||
if (!strategies?.includes("dmp")) {
|
||||
return
|
||||
}
|
||||
it(name, () => {
|
||||
const result = applyDMP(hunk, content, matchPosition)
|
||||
expect(result.result.join("\n")).toEqual(expectedResult)
|
||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||
expect(result.strategy).toBe("dmp")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("applyGitFallback", () => {
|
||||
it("should successfully apply changes using git operations", async () => {
|
||||
const hunk = {
|
||||
changes: [
|
||||
{ type: "context", content: "line1", indent: "" },
|
||||
{ type: "remove", content: "line2", indent: "" },
|
||||
{ type: "add", content: "new line2", indent: "" },
|
||||
{ type: "context", content: "line3", indent: "" },
|
||||
],
|
||||
} as Hunk
|
||||
|
||||
const content = ["line1", "line2", "line3"]
|
||||
const result = await applyGitFallback(hunk, content)
|
||||
|
||||
expect(result.result.join("\n")).toEqual("line1\nnew line2\nline3")
|
||||
expect(result.confidence).toBe(1)
|
||||
expect(result.strategy).toBe("git-fallback")
|
||||
})
|
||||
|
||||
it("should return original content with 0 confidence when changes cannot be applied", async () => {
|
||||
const hunk = {
|
||||
changes: [
|
||||
{ type: "context", content: "nonexistent", indent: "" },
|
||||
{ type: "add", content: "new line", indent: "" },
|
||||
],
|
||||
} as Hunk
|
||||
|
||||
const content = ["line1", "line2", "line3"]
|
||||
const result = await applyGitFallback(hunk, content)
|
||||
|
||||
expect(result.result).toEqual(content)
|
||||
expect(result.confidence).toBe(0)
|
||||
expect(result.strategy).toBe("git-fallback")
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,262 @@
|
||||
import { findAnchorMatch, findExactMatch, findSimilarityMatch, findLevenshteinMatch } from "../search-strategies"
|
||||
|
||||
type SearchStrategy = (
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex?: number,
|
||||
) => {
|
||||
index: number
|
||||
confidence: number
|
||||
strategy: string
|
||||
}
|
||||
|
||||
const testCases = [
|
||||
{
|
||||
name: "should return no match if the search string is not found",
|
||||
searchStr: "not found",
|
||||
content: ["line1", "line2", "line3"],
|
||||
expected: { index: -1, confidence: 0 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match if the search string is found",
|
||||
searchStr: "line2",
|
||||
content: ["line1", "line2", "line3"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with correct index when startIndex is provided",
|
||||
searchStr: "line3",
|
||||
content: ["line1", "line2", "line3", "line4", "line3"],
|
||||
startIndex: 3,
|
||||
expected: { index: 4, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match even if there are more lines in content",
|
||||
searchStr: "line2",
|
||||
content: ["line1", "line2", "line3", "line4", "line5"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match even if the search string is at the beginning of the content",
|
||||
searchStr: "line1",
|
||||
content: ["line1", "line2", "line3"],
|
||||
expected: { index: 0, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match even if the search string is at the end of the content",
|
||||
searchStr: "line3",
|
||||
content: ["line1", "line2", "line3"],
|
||||
expected: { index: 2, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match for a multi-line search string",
|
||||
searchStr: "line2\nline3",
|
||||
content: ["line1", "line2", "line3", "line4"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return no match if a multi-line search string is not found",
|
||||
searchStr: "line2\nline4",
|
||||
content: ["line1", "line2", "line3", "line4"],
|
||||
expected: { index: -1, confidence: 0 },
|
||||
strategies: ["exact", "similarity"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with indentation",
|
||||
searchStr: " line2",
|
||||
content: ["line1", " line2", "line3"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with more complex indentation",
|
||||
searchStr: " line3",
|
||||
content: [" line1", " line2", " line3", " line4"],
|
||||
expected: { index: 2, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with mixed indentation",
|
||||
searchStr: "\tline2",
|
||||
content: [" line1", "\tline2", " line3"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with mixed indentation and multi-line",
|
||||
searchStr: " line2\n\tline3",
|
||||
content: ["line1", " line2", "\tline3", " line4"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return no match if mixed indentation and multi-line is not found",
|
||||
searchStr: " line2\n line4",
|
||||
content: ["line1", " line2", "\tline3", " line4"],
|
||||
expected: { index: -1, confidence: 0 },
|
||||
strategies: ["exact", "similarity"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with leading and trailing spaces",
|
||||
searchStr: " line2 ",
|
||||
content: ["line1", " line2 ", "line3"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with leading and trailing tabs",
|
||||
searchStr: "\tline2\t",
|
||||
content: ["line1", "\tline2\t", "line3"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with mixed leading and trailing spaces and tabs",
|
||||
searchStr: " \tline2\t ",
|
||||
content: ["line1", " \tline2\t ", "line3"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return a match with mixed leading and trailing spaces and tabs and multi-line",
|
||||
searchStr: " \tline2\t \n line3 ",
|
||||
content: ["line1", " \tline2\t ", " line3 ", "line4"],
|
||||
expected: { index: 1, confidence: 1 },
|
||||
strategies: ["exact", "similarity", "levenshtein"],
|
||||
},
|
||||
{
|
||||
name: "should return no match if mixed leading and trailing spaces and tabs and multi-line is not found",
|
||||
searchStr: " \tline2\t \n line4 ",
|
||||
content: ["line1", " \tline2\t ", " line3 ", "line4"],
|
||||
expected: { index: -1, confidence: 0 },
|
||||
strategies: ["exact", "similarity"],
|
||||
},
|
||||
]
|
||||
|
||||
describe("findExactMatch", () => {
|
||||
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||
if (!strategies?.includes("exact")) {
|
||||
return
|
||||
}
|
||||
it(name, () => {
|
||||
const result = findExactMatch(searchStr, content, startIndex)
|
||||
expect(result.index).toBe(expected.index)
|
||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||
expect(result.strategy).toMatch(/exact(-overlapping)?/)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("findAnchorMatch", () => {
|
||||
const anchorTestCases = [
|
||||
{
|
||||
name: "should return no match if no anchors are found",
|
||||
searchStr: " \n \n ",
|
||||
content: ["line1", "line2", "line3"],
|
||||
expected: { index: -1, confidence: 0 },
|
||||
},
|
||||
{
|
||||
name: "should return no match if anchor positions cannot be validated",
|
||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||
content: [
|
||||
"different line 1",
|
||||
"different line 2",
|
||||
"different line 3",
|
||||
"another unique line",
|
||||
"context line 1",
|
||||
"context line 2",
|
||||
],
|
||||
expected: { index: -1, confidence: 0 },
|
||||
},
|
||||
{
|
||||
name: "should return a match if anchor positions can be validated",
|
||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||
content: ["line1", "line2", "unique line", "context line 1", "context line 2", "line 6"],
|
||||
expected: { index: 2, confidence: 1 },
|
||||
},
|
||||
{
|
||||
name: "should return a match with correct index when startIndex is provided",
|
||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||
content: ["line1", "line2", "line3", "unique line", "context line 1", "context line 2", "line 7"],
|
||||
startIndex: 3,
|
||||
expected: { index: 3, confidence: 1 },
|
||||
},
|
||||
{
|
||||
name: "should return a match even if there are more lines in content",
|
||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||
content: [
|
||||
"line1",
|
||||
"line2",
|
||||
"unique line",
|
||||
"context line 1",
|
||||
"context line 2",
|
||||
"line 6",
|
||||
"extra line 1",
|
||||
"extra line 2",
|
||||
],
|
||||
expected: { index: 2, confidence: 1 },
|
||||
},
|
||||
{
|
||||
name: "should return a match even if the anchor is at the beginning of the content",
|
||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||
content: ["unique line", "context line 1", "context line 2", "line 6"],
|
||||
expected: { index: 0, confidence: 1 },
|
||||
},
|
||||
{
|
||||
name: "should return a match even if the anchor is at the end of the content",
|
||||
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||
content: ["line1", "line2", "unique line", "context line 1", "context line 2"],
|
||||
expected: { index: 2, confidence: 1 },
|
||||
},
|
||||
{
|
||||
name: "should return no match if no valid anchor is found",
|
||||
searchStr: "non-unique line\ncontext line 1\ncontext line 2",
|
||||
content: ["line1", "line2", "non-unique line", "context line 1", "context line 2", "non-unique line"],
|
||||
expected: { index: -1, confidence: 0 },
|
||||
},
|
||||
]
|
||||
|
||||
anchorTestCases.forEach(({ name, searchStr, content, startIndex, expected }) => {
|
||||
it(name, () => {
|
||||
const result = findAnchorMatch(searchStr, content, startIndex)
|
||||
expect(result.index).toBe(expected.index)
|
||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||
expect(result.strategy).toBe("anchor")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("findSimilarityMatch", () => {
|
||||
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||
if (!strategies?.includes("similarity")) {
|
||||
return
|
||||
}
|
||||
it(name, () => {
|
||||
const result = findSimilarityMatch(searchStr, content, startIndex)
|
||||
expect(result.index).toBe(expected.index)
|
||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||
expect(result.strategy).toBe("similarity")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("findLevenshteinMatch", () => {
|
||||
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||
if (!strategies?.includes("levenshtein")) {
|
||||
return
|
||||
}
|
||||
it(name, () => {
|
||||
const result = findLevenshteinMatch(searchStr, content, startIndex)
|
||||
expect(result.index).toBe(expected.index)
|
||||
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||
expect(result.strategy).toBe("levenshtein")
|
||||
})
|
||||
})
|
||||
})
|
||||
297
src/core/diff/strategies/new-unified/edit-strategies.ts
Normal file
297
src/core/diff/strategies/new-unified/edit-strategies.ts
Normal file
@@ -0,0 +1,297 @@
|
||||
import { diff_match_patch } from "diff-match-patch"
|
||||
import { EditResult, Hunk } from "./types"
|
||||
import { getDMPSimilarity, validateEditResult } from "./search-strategies"
|
||||
import * as path from "path"
|
||||
import simpleGit, { SimpleGit } from "simple-git"
|
||||
import * as tmp from "tmp"
|
||||
import * as fs from "fs"
|
||||
|
||||
// Helper function to infer indentation - simplified version
|
||||
function inferIndentation(line: string, contextLines: string[], previousIndent: string = ""): string {
|
||||
// If the line has explicit indentation in the change, use it exactly
|
||||
const lineMatch = line.match(/^(\s+)/)
|
||||
if (lineMatch) {
|
||||
return lineMatch[1]
|
||||
}
|
||||
|
||||
// If we have context lines, use the indentation from the first context line
|
||||
const contextLine = contextLines[0]
|
||||
if (contextLine) {
|
||||
const contextMatch = contextLine.match(/^(\s+)/)
|
||||
if (contextMatch) {
|
||||
return contextMatch[1]
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to previous indent
|
||||
return previousIndent
|
||||
}
|
||||
|
||||
// Context matching edit strategy
|
||||
export function applyContextMatching(hunk: Hunk, content: string[], matchPosition: number): EditResult {
|
||||
if (matchPosition === -1) {
|
||||
return { confidence: 0, result: content, strategy: "context" }
|
||||
}
|
||||
|
||||
const newResult = [...content.slice(0, matchPosition)]
|
||||
let sourceIndex = matchPosition
|
||||
|
||||
for (const change of hunk.changes) {
|
||||
if (change.type === "context") {
|
||||
// Use the original line from content if available
|
||||
if (sourceIndex < content.length) {
|
||||
newResult.push(content[sourceIndex])
|
||||
} else {
|
||||
const line = change.indent ? change.indent + change.content : change.content
|
||||
newResult.push(line)
|
||||
}
|
||||
sourceIndex++
|
||||
} else if (change.type === "add") {
|
||||
// Use exactly the indentation from the change
|
||||
const baseIndent = change.indent || ""
|
||||
|
||||
// Handle multi-line additions
|
||||
const lines = change.content.split("\n").map((line) => {
|
||||
// If the line already has indentation, preserve it relative to the base indent
|
||||
const lineIndentMatch = line.match(/^(\s*)(.*)/)
|
||||
if (lineIndentMatch) {
|
||||
const [, lineIndent, content] = lineIndentMatch
|
||||
// Only add base indent if the line doesn't already have it
|
||||
return lineIndent ? line : baseIndent + content
|
||||
}
|
||||
return baseIndent + line
|
||||
})
|
||||
|
||||
newResult.push(...lines)
|
||||
} else if (change.type === "remove") {
|
||||
// Handle multi-line removes by incrementing sourceIndex for each line
|
||||
const removedLines = change.content.split("\n").length
|
||||
sourceIndex += removedLines
|
||||
}
|
||||
}
|
||||
|
||||
// Append remaining content
|
||||
newResult.push(...content.slice(sourceIndex))
|
||||
|
||||
// Calculate confidence based on the actual changes
|
||||
const afterText = newResult.slice(matchPosition, newResult.length - (content.length - sourceIndex)).join("\n")
|
||||
|
||||
const confidence = validateEditResult(hunk, afterText)
|
||||
|
||||
return {
|
||||
confidence,
|
||||
result: newResult,
|
||||
strategy: "context",
|
||||
}
|
||||
}
|
||||
|
||||
// DMP edit strategy
|
||||
export function applyDMP(hunk: Hunk, content: string[], matchPosition: number): EditResult {
|
||||
if (matchPosition === -1) {
|
||||
return { confidence: 0, result: content, strategy: "dmp" }
|
||||
}
|
||||
|
||||
const dmp = new diff_match_patch()
|
||||
|
||||
// Calculate total lines in before block accounting for multi-line content
|
||||
const beforeLineCount = hunk.changes
|
||||
.filter((change) => change.type === "context" || change.type === "remove")
|
||||
.reduce((count, change) => count + change.content.split("\n").length, 0)
|
||||
|
||||
// Build BEFORE block (context + removals)
|
||||
const beforeLines = hunk.changes
|
||||
.filter((change) => change.type === "context" || change.type === "remove")
|
||||
.map((change) => {
|
||||
if (change.originalLine) {
|
||||
return change.originalLine
|
||||
}
|
||||
return change.indent ? change.indent + change.content : change.content
|
||||
})
|
||||
|
||||
// Build AFTER block (context + additions)
|
||||
const afterLines = hunk.changes
|
||||
.filter((change) => change.type === "context" || change.type === "add")
|
||||
.map((change) => {
|
||||
if (change.originalLine) {
|
||||
return change.originalLine
|
||||
}
|
||||
return change.indent ? change.indent + change.content : change.content
|
||||
})
|
||||
|
||||
// Convert to text with proper line endings
|
||||
const beforeText = beforeLines.join("\n")
|
||||
const afterText = afterLines.join("\n")
|
||||
|
||||
// Create and apply patch
|
||||
const patch = dmp.patch_make(beforeText, afterText)
|
||||
const targetText = content.slice(matchPosition, matchPosition + beforeLineCount).join("\n")
|
||||
const [patchedText] = dmp.patch_apply(patch, targetText)
|
||||
|
||||
// Split result and preserve line endings
|
||||
const patchedLines = patchedText.split("\n")
|
||||
|
||||
// Construct final result
|
||||
const newResult = [
|
||||
...content.slice(0, matchPosition),
|
||||
...patchedLines,
|
||||
...content.slice(matchPosition + beforeLineCount),
|
||||
]
|
||||
|
||||
const confidence = validateEditResult(hunk, patchedText)
|
||||
|
||||
return {
|
||||
confidence,
|
||||
result: newResult,
|
||||
strategy: "dmp",
|
||||
}
|
||||
}
|
||||
|
||||
// Git fallback strategy that works with full content
|
||||
export async function applyGitFallback(hunk: Hunk, content: string[]): Promise<EditResult> {
|
||||
let tmpDir: tmp.DirResult | undefined
|
||||
|
||||
try {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true })
|
||||
const git: SimpleGit = simpleGit(tmpDir.name)
|
||||
|
||||
await git.init()
|
||||
await git.addConfig("user.name", "Temp")
|
||||
await git.addConfig("user.email", "temp@example.com")
|
||||
|
||||
const filePath = path.join(tmpDir.name, "file.txt")
|
||||
|
||||
const searchLines = hunk.changes
|
||||
.filter((change) => change.type === "context" || change.type === "remove")
|
||||
.map((change) => change.originalLine || change.indent + change.content)
|
||||
|
||||
const replaceLines = hunk.changes
|
||||
.filter((change) => change.type === "context" || change.type === "add")
|
||||
.map((change) => change.originalLine || change.indent + change.content)
|
||||
|
||||
const searchText = searchLines.join("\n")
|
||||
const replaceText = replaceLines.join("\n")
|
||||
const originalText = content.join("\n")
|
||||
|
||||
try {
|
||||
fs.writeFileSync(filePath, originalText)
|
||||
await git.add("file.txt")
|
||||
const originalCommit = await git.commit("original")
|
||||
console.log("Strategy 1 - Original commit:", originalCommit.commit)
|
||||
|
||||
fs.writeFileSync(filePath, searchText)
|
||||
await git.add("file.txt")
|
||||
const searchCommit1 = await git.commit("search")
|
||||
console.log("Strategy 1 - Search commit:", searchCommit1.commit)
|
||||
|
||||
fs.writeFileSync(filePath, replaceText)
|
||||
await git.add("file.txt")
|
||||
const replaceCommit = await git.commit("replace")
|
||||
console.log("Strategy 1 - Replace commit:", replaceCommit.commit)
|
||||
|
||||
console.log("Strategy 1 - Attempting checkout of:", originalCommit.commit)
|
||||
await git.raw(["checkout", originalCommit.commit])
|
||||
try {
|
||||
console.log("Strategy 1 - Attempting cherry-pick of:", replaceCommit.commit)
|
||||
await git.raw(["cherry-pick", "--minimal", replaceCommit.commit])
|
||||
|
||||
const newText = fs.readFileSync(filePath, "utf-8")
|
||||
const newLines = newText.split("\n")
|
||||
return {
|
||||
confidence: 1,
|
||||
result: newLines,
|
||||
strategy: "git-fallback",
|
||||
}
|
||||
} catch (cherryPickError) {
|
||||
console.error("Strategy 1 failed with merge conflict")
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Strategy 1 failed:", error)
|
||||
}
|
||||
|
||||
try {
|
||||
await git.init()
|
||||
await git.addConfig("user.name", "Temp")
|
||||
await git.addConfig("user.email", "temp@example.com")
|
||||
|
||||
fs.writeFileSync(filePath, searchText)
|
||||
await git.add("file.txt")
|
||||
const searchCommit = await git.commit("search")
|
||||
const searchHash = searchCommit.commit.replace(/^HEAD /, "")
|
||||
console.log("Strategy 2 - Search commit:", searchHash)
|
||||
|
||||
fs.writeFileSync(filePath, replaceText)
|
||||
await git.add("file.txt")
|
||||
const replaceCommit = await git.commit("replace")
|
||||
const replaceHash = replaceCommit.commit.replace(/^HEAD /, "")
|
||||
console.log("Strategy 2 - Replace commit:", replaceHash)
|
||||
|
||||
console.log("Strategy 2 - Attempting checkout of:", searchHash)
|
||||
await git.raw(["checkout", searchHash])
|
||||
fs.writeFileSync(filePath, originalText)
|
||||
await git.add("file.txt")
|
||||
const originalCommit2 = await git.commit("original")
|
||||
console.log("Strategy 2 - Original commit:", originalCommit2.commit)
|
||||
|
||||
try {
|
||||
console.log("Strategy 2 - Attempting cherry-pick of:", replaceHash)
|
||||
await git.raw(["cherry-pick", "--minimal", replaceHash])
|
||||
|
||||
const newText = fs.readFileSync(filePath, "utf-8")
|
||||
const newLines = newText.split("\n")
|
||||
return {
|
||||
confidence: 1,
|
||||
result: newLines,
|
||||
strategy: "git-fallback",
|
||||
}
|
||||
} catch (cherryPickError) {
|
||||
console.error("Strategy 2 failed with merge conflict")
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Strategy 2 failed:", error)
|
||||
}
|
||||
|
||||
console.error("Git fallback failed")
|
||||
return { confidence: 0, result: content, strategy: "git-fallback" }
|
||||
} catch (error) {
|
||||
console.error("Git fallback strategy failed:", error)
|
||||
return { confidence: 0, result: content, strategy: "git-fallback" }
|
||||
} finally {
|
||||
if (tmpDir) {
|
||||
tmpDir.removeCallback()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Main edit function that tries strategies sequentially
|
||||
export async function applyEdit(
|
||||
hunk: Hunk,
|
||||
content: string[],
|
||||
matchPosition: number,
|
||||
confidence: number,
|
||||
confidenceThreshold: number = 0.97,
|
||||
): Promise<EditResult> {
|
||||
// Don't attempt regular edits if confidence is too low
|
||||
if (confidence < confidenceThreshold) {
|
||||
console.log(
|
||||
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`,
|
||||
)
|
||||
return applyGitFallback(hunk, content)
|
||||
}
|
||||
|
||||
// Try each strategy in sequence until one succeeds
|
||||
const strategies = [
|
||||
{ name: "dmp", apply: () => applyDMP(hunk, content, matchPosition) },
|
||||
{ name: "context", apply: () => applyContextMatching(hunk, content, matchPosition) },
|
||||
{ name: "git-fallback", apply: () => applyGitFallback(hunk, content) },
|
||||
]
|
||||
|
||||
// Try strategies sequentially until one succeeds
|
||||
for (const strategy of strategies) {
|
||||
const result = await strategy.apply()
|
||||
if (result.confidence >= confidenceThreshold) {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return { confidence: 0, result: content, strategy: "none" }
|
||||
}
|
||||
350
src/core/diff/strategies/new-unified/index.ts
Normal file
350
src/core/diff/strategies/new-unified/index.ts
Normal file
@@ -0,0 +1,350 @@
|
||||
import { Diff, Hunk, Change } from "./types"
|
||||
import { findBestMatch, prepareSearchString } from "./search-strategies"
|
||||
import { applyEdit } from "./edit-strategies"
|
||||
import { DiffResult, DiffStrategy } from "../../types"
|
||||
|
||||
export class NewUnifiedDiffStrategy implements DiffStrategy {
|
||||
private readonly confidenceThreshold: number
|
||||
|
||||
constructor(confidenceThreshold: number = 1) {
|
||||
this.confidenceThreshold = Math.max(confidenceThreshold, 0.8)
|
||||
}
|
||||
|
||||
private parseUnifiedDiff(diff: string): Diff {
|
||||
const MAX_CONTEXT_LINES = 6 // Number of context lines to keep before/after changes
|
||||
const lines = diff.split("\n")
|
||||
const hunks: Hunk[] = []
|
||||
let currentHunk: Hunk | null = null
|
||||
|
||||
let i = 0
|
||||
while (i < lines.length && !lines[i].startsWith("@@")) {
|
||||
i++
|
||||
}
|
||||
|
||||
for (; i < lines.length; i++) {
|
||||
const line = lines[i]
|
||||
|
||||
if (line.startsWith("@@")) {
|
||||
if (
|
||||
currentHunk &&
|
||||
currentHunk.changes.length > 0 &&
|
||||
currentHunk.changes.some((change) => change.type === "add" || change.type === "remove")
|
||||
) {
|
||||
const changes = currentHunk.changes
|
||||
let startIdx = 0
|
||||
let endIdx = changes.length - 1
|
||||
|
||||
for (let j = 0; j < changes.length; j++) {
|
||||
if (changes[j].type !== "context") {
|
||||
startIdx = Math.max(0, j - MAX_CONTEXT_LINES)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for (let j = changes.length - 1; j >= 0; j--) {
|
||||
if (changes[j].type !== "context") {
|
||||
endIdx = Math.min(changes.length - 1, j + MAX_CONTEXT_LINES)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
currentHunk.changes = changes.slice(startIdx, endIdx + 1)
|
||||
hunks.push(currentHunk)
|
||||
}
|
||||
currentHunk = { changes: [] }
|
||||
continue
|
||||
}
|
||||
|
||||
if (!currentHunk) {
|
||||
continue
|
||||
}
|
||||
|
||||
const content = line.slice(1)
|
||||
const indentMatch = content.match(/^(\s*)/)
|
||||
const indent = indentMatch ? indentMatch[0] : ""
|
||||
const trimmedContent = content.slice(indent.length)
|
||||
|
||||
if (line.startsWith(" ")) {
|
||||
currentHunk.changes.push({
|
||||
type: "context",
|
||||
content: trimmedContent,
|
||||
indent,
|
||||
originalLine: content,
|
||||
})
|
||||
} else if (line.startsWith("+")) {
|
||||
currentHunk.changes.push({
|
||||
type: "add",
|
||||
content: trimmedContent,
|
||||
indent,
|
||||
originalLine: content,
|
||||
})
|
||||
} else if (line.startsWith("-")) {
|
||||
currentHunk.changes.push({
|
||||
type: "remove",
|
||||
content: trimmedContent,
|
||||
indent,
|
||||
originalLine: content,
|
||||
})
|
||||
} else {
|
||||
const finalContent = trimmedContent ? " " + trimmedContent : " "
|
||||
currentHunk.changes.push({
|
||||
type: "context",
|
||||
content: finalContent,
|
||||
indent,
|
||||
originalLine: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
currentHunk &&
|
||||
currentHunk.changes.length > 0 &&
|
||||
currentHunk.changes.some((change) => change.type === "add" || change.type === "remove")
|
||||
) {
|
||||
hunks.push(currentHunk)
|
||||
}
|
||||
|
||||
return { hunks }
|
||||
}
|
||||
|
||||
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||
return `# apply_diff Tool - Generate Precise Code Changes
|
||||
|
||||
Generate a unified diff that can be cleanly applied to modify code files.
|
||||
|
||||
## Step-by-Step Instructions:
|
||||
|
||||
1. Start with file headers:
|
||||
- First line: "--- {original_file_path}"
|
||||
- Second line: "+++ {new_file_path}"
|
||||
|
||||
2. For each change section:
|
||||
- Begin with "@@ ... @@" separator line without line numbers
|
||||
- Include 2-3 lines of context before and after changes
|
||||
- Mark removed lines with "-"
|
||||
- Mark added lines with "+"
|
||||
- Preserve exact indentation
|
||||
|
||||
3. Group related changes:
|
||||
- Keep related modifications in the same hunk
|
||||
- Start new hunks for logically separate changes
|
||||
- When modifying functions/methods, include the entire block
|
||||
|
||||
## Requirements:
|
||||
|
||||
1. MUST include exact indentation
|
||||
2. MUST include sufficient context for unique matching
|
||||
3. MUST group related changes together
|
||||
4. MUST use proper unified diff format
|
||||
5. MUST NOT include timestamps in file headers
|
||||
6. MUST NOT include line numbers in the @@ header
|
||||
|
||||
## Examples:
|
||||
|
||||
✅ Good diff (follows all requirements):
|
||||
\`\`\`diff
|
||||
--- src/utils.ts
|
||||
+++ src/utils.ts
|
||||
@@ ... @@
|
||||
def calculate_total(items):
|
||||
- total = 0
|
||||
- for item in items:
|
||||
- total += item.price
|
||||
+ return sum(item.price for item in items)
|
||||
\`\`\`
|
||||
|
||||
❌ Bad diff (violates requirements #1 and #2):
|
||||
\`\`\`diff
|
||||
--- src/utils.ts
|
||||
+++ src/utils.ts
|
||||
@@ ... @@
|
||||
-total = 0
|
||||
-for item in items:
|
||||
+return sum(item.price for item in items)
|
||||
\`\`\`
|
||||
|
||||
Parameters:
|
||||
- path: (required) File path relative to ${args.cwd}
|
||||
- diff: (required) Unified diff content in unified format to apply to the file.
|
||||
|
||||
Usage:
|
||||
<apply_diff>
|
||||
<path>path/to/file.ext</path>
|
||||
<diff>
|
||||
Your diff here
|
||||
</diff>
|
||||
</apply_diff>`
|
||||
}
|
||||
|
||||
// Helper function to split a hunk into smaller hunks based on contiguous changes
|
||||
private splitHunk(hunk: Hunk): Hunk[] {
|
||||
const result: Hunk[] = []
|
||||
let currentHunk: Hunk | null = null
|
||||
let contextBefore: Change[] = []
|
||||
let contextAfter: Change[] = []
|
||||
const MAX_CONTEXT_LINES = 3 // Keep 3 lines of context before/after changes
|
||||
|
||||
for (let i = 0; i < hunk.changes.length; i++) {
|
||||
const change = hunk.changes[i]
|
||||
|
||||
if (change.type === "context") {
|
||||
if (!currentHunk) {
|
||||
contextBefore.push(change)
|
||||
if (contextBefore.length > MAX_CONTEXT_LINES) {
|
||||
contextBefore.shift()
|
||||
}
|
||||
} else {
|
||||
contextAfter.push(change)
|
||||
if (contextAfter.length > MAX_CONTEXT_LINES) {
|
||||
// We've collected enough context after changes, create a new hunk
|
||||
currentHunk.changes.push(...contextAfter)
|
||||
result.push(currentHunk)
|
||||
currentHunk = null
|
||||
// Keep the last few context lines for the next hunk
|
||||
contextBefore = contextAfter
|
||||
contextAfter = []
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!currentHunk) {
|
||||
currentHunk = { changes: [...contextBefore] }
|
||||
contextAfter = []
|
||||
} else if (contextAfter.length > 0) {
|
||||
// Add accumulated context to current hunk
|
||||
currentHunk.changes.push(...contextAfter)
|
||||
contextAfter = []
|
||||
}
|
||||
currentHunk.changes.push(change)
|
||||
}
|
||||
}
|
||||
|
||||
// Add any remaining changes
|
||||
if (currentHunk) {
|
||||
if (contextAfter.length > 0) {
|
||||
currentHunk.changes.push(...contextAfter)
|
||||
}
|
||||
result.push(currentHunk)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
async applyDiff(
|
||||
originalContent: string,
|
||||
diffContent: string,
|
||||
startLine?: number,
|
||||
endLine?: number,
|
||||
): Promise<DiffResult> {
|
||||
const parsedDiff = this.parseUnifiedDiff(diffContent)
|
||||
const originalLines = originalContent.split("\n")
|
||||
let result = [...originalLines]
|
||||
|
||||
if (!parsedDiff.hunks.length) {
|
||||
return {
|
||||
success: false,
|
||||
error: "No hunks found in diff. Please ensure your diff includes actual changes and follows the unified diff format.",
|
||||
}
|
||||
}
|
||||
|
||||
for (const hunk of parsedDiff.hunks) {
|
||||
const contextStr = prepareSearchString(hunk.changes)
|
||||
const {
|
||||
index: matchPosition,
|
||||
confidence,
|
||||
strategy,
|
||||
} = findBestMatch(contextStr, result, 0, this.confidenceThreshold)
|
||||
|
||||
if (confidence < this.confidenceThreshold) {
|
||||
console.log("Full hunk application failed, trying sub-hunks strategy")
|
||||
// Try splitting the hunk into smaller hunks
|
||||
const subHunks = this.splitHunk(hunk)
|
||||
let subHunkSuccess = true
|
||||
let subHunkResult = [...result]
|
||||
|
||||
for (const subHunk of subHunks) {
|
||||
const subContextStr = prepareSearchString(subHunk.changes)
|
||||
const subSearchResult = findBestMatch(subContextStr, subHunkResult, 0, this.confidenceThreshold)
|
||||
|
||||
if (subSearchResult.confidence >= this.confidenceThreshold) {
|
||||
const subEditResult = await applyEdit(
|
||||
subHunk,
|
||||
subHunkResult,
|
||||
subSearchResult.index,
|
||||
subSearchResult.confidence,
|
||||
this.confidenceThreshold,
|
||||
)
|
||||
if (subEditResult.confidence >= this.confidenceThreshold) {
|
||||
subHunkResult = subEditResult.result
|
||||
continue
|
||||
}
|
||||
}
|
||||
subHunkSuccess = false
|
||||
break
|
||||
}
|
||||
|
||||
if (subHunkSuccess) {
|
||||
result = subHunkResult
|
||||
continue
|
||||
}
|
||||
|
||||
// If sub-hunks also failed, return the original error
|
||||
const contextLines = hunk.changes.filter((c) => c.type === "context").length
|
||||
const totalLines = hunk.changes.length
|
||||
const contextRatio = contextLines / totalLines
|
||||
|
||||
let errorMsg = `Failed to find a matching location in the file (${Math.floor(
|
||||
confidence * 100,
|
||||
)}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n`
|
||||
errorMsg += "Debug Info:\n"
|
||||
errorMsg += `- Search Strategy Used: ${strategy}\n`
|
||||
errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor(
|
||||
contextRatio * 100,
|
||||
)}%)\n`
|
||||
errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n`
|
||||
|
||||
if (contextRatio < 0.2) {
|
||||
errorMsg += "\nPossible Issues:\n"
|
||||
errorMsg += "- Not enough context lines to uniquely identify the location\n"
|
||||
errorMsg += "- Add a few more lines of unchanged code around your changes\n"
|
||||
} else if (contextRatio > 0.5) {
|
||||
errorMsg += "\nPossible Issues:\n"
|
||||
errorMsg += "- Too many context lines may reduce search accuracy\n"
|
||||
errorMsg += "- Try to keep only 2-3 lines of context before and after changes\n"
|
||||
} else {
|
||||
errorMsg += "\nPossible Issues:\n"
|
||||
errorMsg += "- The diff may be targeting a different version of the file\n"
|
||||
errorMsg +=
|
||||
"- There may be too many changes in a single hunk, try splitting the changes into multiple hunks\n"
|
||||
}
|
||||
|
||||
if (startLine && endLine) {
|
||||
errorMsg += `\nSearch Range: lines ${startLine}-${endLine}\n`
|
||||
}
|
||||
|
||||
return { success: false, error: errorMsg }
|
||||
}
|
||||
|
||||
const editResult = await applyEdit(hunk, result, matchPosition, confidence, this.confidenceThreshold)
|
||||
if (editResult.confidence >= this.confidenceThreshold) {
|
||||
result = editResult.result
|
||||
} else {
|
||||
// Edit failure - likely due to content mismatch
|
||||
let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor(
|
||||
editResult.confidence * 100,
|
||||
)}% confidence)\n\n`
|
||||
errorMsg += "Debug Info:\n"
|
||||
errorMsg += "- The location was found but the content didn't match exactly\n"
|
||||
errorMsg += "- This usually means the file has been modified since the diff was created\n"
|
||||
errorMsg += "- Or the diff may be targeting a different version of the file\n"
|
||||
errorMsg += "\nPossible Solutions:\n"
|
||||
errorMsg += "1. Refresh your view of the file and create a new diff\n"
|
||||
errorMsg += "2. Double-check that the removed lines (-) match the current file content\n"
|
||||
errorMsg += "3. Ensure your diff targets the correct version of the file"
|
||||
|
||||
return { success: false, error: errorMsg }
|
||||
}
|
||||
}
|
||||
|
||||
return { success: true, content: result.join("\n") }
|
||||
}
|
||||
}
|
||||
408
src/core/diff/strategies/new-unified/search-strategies.ts
Normal file
408
src/core/diff/strategies/new-unified/search-strategies.ts
Normal file
@@ -0,0 +1,408 @@
|
||||
import { compareTwoStrings } from "string-similarity"
|
||||
import { closest } from "fastest-levenshtein"
|
||||
import { diff_match_patch } from "diff-match-patch"
|
||||
import { Change, Hunk } from "./types"
|
||||
|
||||
export type SearchResult = {
|
||||
index: number
|
||||
confidence: number
|
||||
strategy: string
|
||||
}
|
||||
|
||||
const LARGE_FILE_THRESHOLD = 1000 // lines
|
||||
const UNIQUE_CONTENT_BOOST = 0.05
|
||||
const DEFAULT_OVERLAP_SIZE = 3 // lines of overlap between windows
|
||||
const MAX_WINDOW_SIZE = 500 // maximum lines in a window
|
||||
|
||||
// Helper function to calculate adaptive confidence threshold based on file size
|
||||
function getAdaptiveThreshold(contentLength: number, baseThreshold: number): number {
|
||||
if (contentLength <= LARGE_FILE_THRESHOLD) {
|
||||
return baseThreshold
|
||||
}
|
||||
return Math.max(baseThreshold - 0.07, 0.8) // Reduce threshold for large files but keep minimum at 80%
|
||||
}
|
||||
|
||||
// Helper function to evaluate content uniqueness
|
||||
function evaluateContentUniqueness(searchStr: string, content: string[]): number {
|
||||
const searchLines = searchStr.split("\n")
|
||||
const uniqueLines = new Set(searchLines)
|
||||
const contentStr = content.join("\n")
|
||||
|
||||
// Calculate how many search lines are relatively unique in the content
|
||||
let uniqueCount = 0
|
||||
for (const line of uniqueLines) {
|
||||
const regex = new RegExp(line.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "g")
|
||||
const matches = contentStr.match(regex)
|
||||
if (matches && matches.length <= 2) {
|
||||
// Line appears at most twice
|
||||
uniqueCount++
|
||||
}
|
||||
}
|
||||
|
||||
return uniqueCount / uniqueLines.size
|
||||
}
|
||||
|
||||
// Helper function to prepare search string from context
|
||||
export function prepareSearchString(changes: Change[]): string {
|
||||
const lines = changes.filter((c) => c.type === "context" || c.type === "remove").map((c) => c.originalLine)
|
||||
return lines.join("\n")
|
||||
}
|
||||
|
||||
// Helper function to evaluate similarity between two texts
|
||||
export function evaluateSimilarity(original: string, modified: string): number {
|
||||
return compareTwoStrings(original, modified)
|
||||
}
|
||||
|
||||
// Helper function to validate using diff-match-patch
|
||||
export function getDMPSimilarity(original: string, modified: string): number {
|
||||
const dmp = new diff_match_patch()
|
||||
const diffs = dmp.diff_main(original, modified)
|
||||
dmp.diff_cleanupSemantic(diffs)
|
||||
const patches = dmp.patch_make(original, diffs)
|
||||
const [expectedText] = dmp.patch_apply(patches, original)
|
||||
|
||||
const similarity = evaluateSimilarity(expectedText, modified)
|
||||
return similarity
|
||||
}
|
||||
|
||||
// Helper function to validate edit results using hunk information
|
||||
export function validateEditResult(hunk: Hunk, result: string): number {
|
||||
// Build the expected text from the hunk
|
||||
const expectedText = hunk.changes
|
||||
.filter((change) => change.type === "context" || change.type === "add")
|
||||
.map((change) => (change.indent ? change.indent + change.content : change.content))
|
||||
.join("\n")
|
||||
|
||||
// Calculate similarity between the result and expected text
|
||||
const similarity = getDMPSimilarity(expectedText, result)
|
||||
|
||||
// If the result is unchanged from original, return low confidence
|
||||
const originalText = hunk.changes
|
||||
.filter((change) => change.type === "context" || change.type === "remove")
|
||||
.map((change) => (change.indent ? change.indent + change.content : change.content))
|
||||
.join("\n")
|
||||
|
||||
const originalSimilarity = getDMPSimilarity(originalText, result)
|
||||
if (originalSimilarity > 0.97 && similarity !== 1) {
|
||||
return 0.8 * similarity // Some confidence since we found the right location
|
||||
}
|
||||
|
||||
// For partial matches, scale the confidence but keep it high if we're close
|
||||
return similarity
|
||||
}
|
||||
|
||||
// Helper function to validate context lines against original content
|
||||
function validateContextLines(searchStr: string, content: string, confidenceThreshold: number): number {
|
||||
// Extract just the context lines from the search string
|
||||
const contextLines = searchStr.split("\n").filter((line) => !line.startsWith("-")) // Exclude removed lines
|
||||
|
||||
// Compare context lines with content
|
||||
const similarity = evaluateSimilarity(contextLines.join("\n"), content)
|
||||
|
||||
// Get adaptive threshold based on content size
|
||||
const threshold = getAdaptiveThreshold(content.split("\n").length, confidenceThreshold)
|
||||
|
||||
// Calculate uniqueness boost
|
||||
const uniquenessScore = evaluateContentUniqueness(searchStr, content.split("\n"))
|
||||
const uniquenessBoost = uniquenessScore * UNIQUE_CONTENT_BOOST
|
||||
|
||||
// Adjust confidence based on threshold and uniqueness
|
||||
return similarity < threshold ? similarity * 0.3 + uniquenessBoost : similarity + uniquenessBoost
|
||||
}
|
||||
|
||||
// Helper function to create overlapping windows
|
||||
function createOverlappingWindows(
|
||||
content: string[],
|
||||
searchSize: number,
|
||||
overlapSize: number = DEFAULT_OVERLAP_SIZE,
|
||||
): { window: string[]; startIndex: number }[] {
|
||||
const windows: { window: string[]; startIndex: number }[] = []
|
||||
|
||||
// Ensure minimum window size is at least searchSize
|
||||
const effectiveWindowSize = Math.max(searchSize, Math.min(searchSize * 2, MAX_WINDOW_SIZE))
|
||||
|
||||
// Ensure overlap size doesn't exceed window size
|
||||
const effectiveOverlapSize = Math.min(overlapSize, effectiveWindowSize - 1)
|
||||
|
||||
// Calculate step size, ensure it's at least 1
|
||||
const stepSize = Math.max(1, effectiveWindowSize - effectiveOverlapSize)
|
||||
|
||||
for (let i = 0; i < content.length; i += stepSize) {
|
||||
const windowContent = content.slice(i, i + effectiveWindowSize)
|
||||
if (windowContent.length >= searchSize) {
|
||||
windows.push({ window: windowContent, startIndex: i })
|
||||
}
|
||||
}
|
||||
|
||||
return windows
|
||||
}
|
||||
|
||||
// Helper function to combine overlapping matches
|
||||
function combineOverlappingMatches(
|
||||
matches: (SearchResult & { windowIndex: number })[],
|
||||
overlapSize: number = DEFAULT_OVERLAP_SIZE,
|
||||
): SearchResult[] {
|
||||
if (matches.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
// Sort matches by confidence
|
||||
matches.sort((a, b) => b.confidence - a.confidence)
|
||||
|
||||
const combinedMatches: SearchResult[] = []
|
||||
const usedIndices = new Set<number>()
|
||||
|
||||
for (const match of matches) {
|
||||
if (usedIndices.has(match.windowIndex)) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find overlapping matches
|
||||
const overlapping = matches.filter(
|
||||
(m) =>
|
||||
Math.abs(m.windowIndex - match.windowIndex) === 1 &&
|
||||
Math.abs(m.index - match.index) <= overlapSize &&
|
||||
!usedIndices.has(m.windowIndex),
|
||||
)
|
||||
|
||||
if (overlapping.length > 0) {
|
||||
// Boost confidence if we find same match in overlapping windows
|
||||
const avgConfidence =
|
||||
(match.confidence + overlapping.reduce((sum, m) => sum + m.confidence, 0)) / (overlapping.length + 1)
|
||||
const boost = Math.min(0.05 * overlapping.length, 0.1) // Max 10% boost
|
||||
|
||||
combinedMatches.push({
|
||||
index: match.index,
|
||||
confidence: Math.min(1, avgConfidence + boost),
|
||||
strategy: `${match.strategy}-overlapping`,
|
||||
})
|
||||
|
||||
usedIndices.add(match.windowIndex)
|
||||
overlapping.forEach((m) => usedIndices.add(m.windowIndex))
|
||||
} else {
|
||||
combinedMatches.push({
|
||||
index: match.index,
|
||||
confidence: match.confidence,
|
||||
strategy: match.strategy,
|
||||
})
|
||||
usedIndices.add(match.windowIndex)
|
||||
}
|
||||
}
|
||||
|
||||
return combinedMatches
|
||||
}
|
||||
|
||||
export function findExactMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const searchLines = searchStr.split("\n")
|
||||
const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length)
|
||||
const matches: (SearchResult & { windowIndex: number })[] = []
|
||||
|
||||
windows.forEach((windowData, windowIndex) => {
|
||||
const windowStr = windowData.window.join("\n")
|
||||
const exactMatch = windowStr.indexOf(searchStr)
|
||||
|
||||
if (exactMatch !== -1) {
|
||||
const matchedContent = windowData.window
|
||||
.slice(
|
||||
windowStr.slice(0, exactMatch).split("\n").length - 1,
|
||||
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length,
|
||||
)
|
||||
.join("\n")
|
||||
|
||||
const similarity = getDMPSimilarity(searchStr, matchedContent)
|
||||
const contextSimilarity = validateContextLines(searchStr, matchedContent, confidenceThreshold)
|
||||
const confidence = Math.min(similarity, contextSimilarity)
|
||||
|
||||
matches.push({
|
||||
index: startIndex + windowData.startIndex + windowStr.slice(0, exactMatch).split("\n").length - 1,
|
||||
confidence,
|
||||
strategy: "exact",
|
||||
windowIndex,
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
const combinedMatches = combineOverlappingMatches(matches)
|
||||
return combinedMatches.length > 0 ? combinedMatches[0] : { index: -1, confidence: 0, strategy: "exact" }
|
||||
}
|
||||
|
||||
// String similarity strategy
|
||||
export function findSimilarityMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const searchLines = searchStr.split("\n")
|
||||
let bestScore = 0
|
||||
let bestIndex = -1
|
||||
|
||||
for (let i = startIndex; i < content.length - searchLines.length + 1; i++) {
|
||||
const windowStr = content.slice(i, i + searchLines.length).join("\n")
|
||||
const score = compareTwoStrings(searchStr, windowStr)
|
||||
if (score > bestScore && score >= confidenceThreshold) {
|
||||
const similarity = getDMPSimilarity(searchStr, windowStr)
|
||||
const contextSimilarity = validateContextLines(searchStr, windowStr, confidenceThreshold)
|
||||
const adjustedScore = Math.min(similarity, contextSimilarity) * score
|
||||
|
||||
if (adjustedScore > bestScore) {
|
||||
bestScore = adjustedScore
|
||||
bestIndex = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
index: bestIndex,
|
||||
confidence: bestIndex !== -1 ? bestScore : 0,
|
||||
strategy: "similarity",
|
||||
}
|
||||
}
|
||||
|
||||
// Levenshtein strategy
|
||||
export function findLevenshteinMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const searchLines = searchStr.split("\n")
|
||||
const candidates = []
|
||||
|
||||
for (let i = startIndex; i < content.length - searchLines.length + 1; i++) {
|
||||
candidates.push(content.slice(i, i + searchLines.length).join("\n"))
|
||||
}
|
||||
|
||||
if (candidates.length > 0) {
|
||||
const closestMatch = closest(searchStr, candidates)
|
||||
const index = startIndex + candidates.indexOf(closestMatch)
|
||||
const similarity = getDMPSimilarity(searchStr, closestMatch)
|
||||
const contextSimilarity = validateContextLines(searchStr, closestMatch, confidenceThreshold)
|
||||
const confidence = Math.min(similarity, contextSimilarity)
|
||||
return {
|
||||
index: confidence === 0 ? -1 : index,
|
||||
confidence: index !== -1 ? confidence : 0,
|
||||
strategy: "levenshtein",
|
||||
}
|
||||
}
|
||||
|
||||
return { index: -1, confidence: 0, strategy: "levenshtein" }
|
||||
}
|
||||
|
||||
// Helper function to identify anchor lines
|
||||
function identifyAnchors(searchStr: string): { first: string | null; last: string | null } {
|
||||
const searchLines = searchStr.split("\n")
|
||||
let first: string | null = null
|
||||
let last: string | null = null
|
||||
|
||||
// Find the first non-empty line
|
||||
for (const line of searchLines) {
|
||||
if (line.trim()) {
|
||||
first = line
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Find the last non-empty line
|
||||
for (let i = searchLines.length - 1; i >= 0; i--) {
|
||||
if (searchLines[i].trim()) {
|
||||
last = searchLines[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return { first, last }
|
||||
}
|
||||
|
||||
// Anchor-based search strategy
|
||||
export function findAnchorMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const searchLines = searchStr.split("\n")
|
||||
const { first, last } = identifyAnchors(searchStr)
|
||||
|
||||
if (!first || !last) {
|
||||
return { index: -1, confidence: 0, strategy: "anchor" }
|
||||
}
|
||||
|
||||
let firstIndex = -1
|
||||
let lastIndex = -1
|
||||
|
||||
// Check if the first anchor is unique
|
||||
let firstOccurrences = 0
|
||||
for (const contentLine of content) {
|
||||
if (contentLine === first) {
|
||||
firstOccurrences++
|
||||
}
|
||||
}
|
||||
|
||||
if (firstOccurrences !== 1) {
|
||||
return { index: -1, confidence: 0, strategy: "anchor" }
|
||||
}
|
||||
|
||||
// Find the first anchor
|
||||
for (let i = startIndex; i < content.length; i++) {
|
||||
if (content[i] === first) {
|
||||
firstIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Find the last anchor
|
||||
for (let i = content.length - 1; i >= startIndex; i--) {
|
||||
if (content[i] === last) {
|
||||
lastIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if (firstIndex === -1 || lastIndex === -1 || lastIndex <= firstIndex) {
|
||||
return { index: -1, confidence: 0, strategy: "anchor" }
|
||||
}
|
||||
|
||||
// Validate the context
|
||||
const expectedContext = searchLines.slice(searchLines.indexOf(first) + 1, searchLines.indexOf(last)).join("\n")
|
||||
const actualContext = content.slice(firstIndex + 1, lastIndex).join("\n")
|
||||
const contextSimilarity = evaluateSimilarity(expectedContext, actualContext)
|
||||
|
||||
if (contextSimilarity < getAdaptiveThreshold(content.length, confidenceThreshold)) {
|
||||
return { index: -1, confidence: 0, strategy: "anchor" }
|
||||
}
|
||||
|
||||
const confidence = 1
|
||||
|
||||
return {
|
||||
index: firstIndex,
|
||||
confidence: confidence,
|
||||
strategy: "anchor",
|
||||
}
|
||||
}
|
||||
|
||||
// Main search function that tries all strategies
|
||||
export function findBestMatch(
|
||||
searchStr: string,
|
||||
content: string[],
|
||||
startIndex: number = 0,
|
||||
confidenceThreshold: number = 0.97,
|
||||
): SearchResult {
|
||||
const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch]
|
||||
|
||||
let bestResult: SearchResult = { index: -1, confidence: 0, strategy: "none" }
|
||||
|
||||
for (const strategy of strategies) {
|
||||
const result = strategy(searchStr, content, startIndex, confidenceThreshold)
|
||||
if (result.confidence > bestResult.confidence) {
|
||||
bestResult = result
|
||||
}
|
||||
}
|
||||
|
||||
return bestResult
|
||||
}
|
||||
20
src/core/diff/strategies/new-unified/types.ts
Normal file
20
src/core/diff/strategies/new-unified/types.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
export type Change = {
|
||||
type: "context" | "add" | "remove"
|
||||
content: string
|
||||
indent: string
|
||||
originalLine?: string
|
||||
}
|
||||
|
||||
export type Hunk = {
|
||||
changes: Change[]
|
||||
}
|
||||
|
||||
export type Diff = {
|
||||
hunks: Hunk[]
|
||||
}
|
||||
|
||||
export type EditResult = {
|
||||
confidence: number
|
||||
result: string[]
|
||||
strategy: string
|
||||
}
|
||||
302
src/core/diff/strategies/search-replace.ts
Normal file
302
src/core/diff/strategies/search-replace.ts
Normal file
@@ -0,0 +1,302 @@
|
||||
import { DiffStrategy, DiffResult } from "../types"
|
||||
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
|
||||
import { distance } from "fastest-levenshtein"
|
||||
|
||||
const BUFFER_LINES = 20 // Number of extra context lines to show before and after matches
|
||||
|
||||
function getSimilarity(original: string, search: string): number {
|
||||
if (search === "") {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Normalize strings by removing extra whitespace but preserve case
|
||||
const normalizeStr = (str: string) => str.replace(/\s+/g, " ").trim()
|
||||
|
||||
const normalizedOriginal = normalizeStr(original)
|
||||
const normalizedSearch = normalizeStr(search)
|
||||
|
||||
if (normalizedOriginal === normalizedSearch) {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Calculate Levenshtein distance using fastest-levenshtein's distance function
|
||||
const dist = distance(normalizedOriginal, normalizedSearch)
|
||||
|
||||
// Calculate similarity ratio (0 to 1, where 1 is an exact match)
|
||||
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length)
|
||||
return 1 - dist / maxLength
|
||||
}
|
||||
|
||||
export class SearchReplaceDiffStrategy implements DiffStrategy {
|
||||
private fuzzyThreshold: number
|
||||
private bufferLines: number
|
||||
|
||||
constructor(fuzzyThreshold?: number, bufferLines?: number) {
|
||||
// Use provided threshold or default to exact matching (1.0)
|
||||
// Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9)
|
||||
// so we use it directly here
|
||||
this.fuzzyThreshold = fuzzyThreshold ?? 1.0
|
||||
this.bufferLines = bufferLines ?? BUFFER_LINES
|
||||
}
|
||||
|
||||
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||
return `## apply_diff
|
||||
Description: Request to replace existing code using a search and replace block.
|
||||
This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with.
|
||||
The tool will maintain proper indentation and formatting while making changes.
|
||||
Only a single operation is allowed per tool use.
|
||||
The SEARCH section must exactly match existing content including whitespace and indentation.
|
||||
If you're not confident in the exact content to search for, use the read_file tool first to get the exact content.
|
||||
When applying the diffs, be extra careful to remember to change any closing brackets or other syntax that may be affected by the diff farther down in the file.
|
||||
|
||||
Parameters:
|
||||
- path: (required) The path of the file to modify (relative to the current working directory ${args.cwd})
|
||||
- diff: (required) The search/replace block defining the changes.
|
||||
- start_line: (required) The line number where the search block starts.
|
||||
- end_line: (required) The line number where the search block ends.
|
||||
|
||||
Diff format:
|
||||
\`\`\`
|
||||
<<<<<<< SEARCH
|
||||
[exact content to find including whitespace]
|
||||
=======
|
||||
[new content to replace with]
|
||||
>>>>>>> REPLACE
|
||||
\`\`\`
|
||||
|
||||
Example:
|
||||
|
||||
Original file:
|
||||
\`\`\`
|
||||
1 | def calculate_total(items):
|
||||
2 | total = 0
|
||||
3 | for item in items:
|
||||
4 | total += item
|
||||
5 | return total
|
||||
\`\`\`
|
||||
|
||||
Search/Replace content:
|
||||
\`\`\`
|
||||
<<<<<<< SEARCH
|
||||
def calculate_total(items):
|
||||
total = 0
|
||||
for item in items:
|
||||
total += item
|
||||
return total
|
||||
=======
|
||||
def calculate_total(items):
|
||||
"""Calculate total with 10% markup"""
|
||||
return sum(item * 1.1 for item in items)
|
||||
>>>>>>> REPLACE
|
||||
\`\`\`
|
||||
|
||||
Usage:
|
||||
<apply_diff>
|
||||
<path>File path here</path>
|
||||
<diff>
|
||||
Your search/replace content here
|
||||
</diff>
|
||||
<start_line>1</start_line>
|
||||
<end_line>5</end_line>
|
||||
</apply_diff>`
|
||||
}
|
||||
|
||||
async applyDiff(
|
||||
originalContent: string,
|
||||
diffContent: string,
|
||||
startLine?: number,
|
||||
endLine?: number,
|
||||
): Promise<DiffResult> {
|
||||
// Extract the search and replace blocks
|
||||
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/)
|
||||
if (!match) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`,
|
||||
}
|
||||
}
|
||||
|
||||
let [_, searchContent, replaceContent] = match
|
||||
|
||||
// Detect line ending from original content
|
||||
const lineEnding = originalContent.includes("\r\n") ? "\r\n" : "\n"
|
||||
|
||||
// Strip line numbers from search and replace content if every line starts with a line number
|
||||
if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) {
|
||||
searchContent = stripLineNumbers(searchContent)
|
||||
replaceContent = stripLineNumbers(replaceContent)
|
||||
}
|
||||
|
||||
// Split content into lines, handling both \n and \r\n
|
||||
const searchLines = searchContent === "" ? [] : searchContent.split(/\r?\n/)
|
||||
const replaceLines = replaceContent === "" ? [] : replaceContent.split(/\r?\n/)
|
||||
const originalLines = originalContent.split(/\r?\n/)
|
||||
|
||||
// Validate that empty search requires start line
|
||||
if (searchLines.length === 0 && !startLine) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that empty search requires same start and end line
|
||||
if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize search variables
|
||||
let matchIndex = -1
|
||||
let bestMatchScore = 0
|
||||
let bestMatchContent = ""
|
||||
const searchChunk = searchLines.join("\n")
|
||||
|
||||
// Determine search bounds
|
||||
let searchStartIndex = 0
|
||||
let searchEndIndex = originalLines.length
|
||||
|
||||
// Validate and handle line range if provided
|
||||
if (startLine && endLine) {
|
||||
// Convert to 0-based index
|
||||
const exactStartIndex = startLine - 1
|
||||
const exactEndIndex = endLine - 1
|
||||
|
||||
if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`,
|
||||
}
|
||||
}
|
||||
|
||||
// Try exact match first
|
||||
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join("\n")
|
||||
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||
if (similarity >= this.fuzzyThreshold) {
|
||||
matchIndex = exactStartIndex
|
||||
bestMatchScore = similarity
|
||||
bestMatchContent = originalChunk
|
||||
} else {
|
||||
// Set bounds for buffered search
|
||||
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1))
|
||||
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines)
|
||||
}
|
||||
}
|
||||
|
||||
// If no match found yet, try middle-out search within bounds
|
||||
if (matchIndex === -1) {
|
||||
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2)
|
||||
let leftIndex = midPoint
|
||||
let rightIndex = midPoint + 1
|
||||
|
||||
// Search outward from the middle within bounds
|
||||
while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) {
|
||||
// Check left side if still in range
|
||||
if (leftIndex >= searchStartIndex) {
|
||||
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join("\n")
|
||||
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||
if (similarity > bestMatchScore) {
|
||||
bestMatchScore = similarity
|
||||
matchIndex = leftIndex
|
||||
bestMatchContent = originalChunk
|
||||
}
|
||||
leftIndex--
|
||||
}
|
||||
|
||||
// Check right side if still in range
|
||||
if (rightIndex <= searchEndIndex - searchLines.length) {
|
||||
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join("\n")
|
||||
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||
if (similarity > bestMatchScore) {
|
||||
bestMatchScore = similarity
|
||||
matchIndex = rightIndex
|
||||
bestMatchContent = originalChunk
|
||||
}
|
||||
rightIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Require similarity to meet threshold
|
||||
if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) {
|
||||
const searchChunk = searchLines.join("\n")
|
||||
const originalContentSection =
|
||||
startLine !== undefined && endLine !== undefined
|
||||
? `\n\nOriginal Content:\n${addLineNumbers(
|
||||
originalLines
|
||||
.slice(
|
||||
Math.max(0, startLine - 1 - this.bufferLines),
|
||||
Math.min(originalLines.length, endLine + this.bufferLines),
|
||||
)
|
||||
.join("\n"),
|
||||
Math.max(1, startLine - this.bufferLines),
|
||||
)}`
|
||||
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join("\n"))}`
|
||||
|
||||
const bestMatchSection = bestMatchContent
|
||||
? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}`
|
||||
: `\n\nBest Match Found:\n(no match)`
|
||||
|
||||
const lineRange =
|
||||
startLine || endLine
|
||||
? ` at ${startLine ? `start: ${startLine}` : "start"} to ${endLine ? `end: ${endLine}` : "end"}`
|
||||
: ""
|
||||
return {
|
||||
success: false,
|
||||
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : "start to end"}\n- Tip: Use read_file to get the latest content of the file before attempting the diff again, as the file content may have changed\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`,
|
||||
}
|
||||
}
|
||||
|
||||
// Get the matched lines from the original content
|
||||
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length)
|
||||
|
||||
// Get the exact indentation (preserving tabs/spaces) of each line
|
||||
const originalIndents = matchedLines.map((line) => {
|
||||
const match = line.match(/^[\t ]*/)
|
||||
return match ? match[0] : ""
|
||||
})
|
||||
|
||||
// Get the exact indentation of each line in the search block
|
||||
const searchIndents = searchLines.map((line) => {
|
||||
const match = line.match(/^[\t ]*/)
|
||||
return match ? match[0] : ""
|
||||
})
|
||||
|
||||
// Apply the replacement while preserving exact indentation
|
||||
const indentedReplaceLines = replaceLines.map((line, i) => {
|
||||
// Get the matched line's exact indentation
|
||||
const matchedIndent = originalIndents[0] || ""
|
||||
|
||||
// Get the current line's indentation relative to the search content
|
||||
const currentIndentMatch = line.match(/^[\t ]*/)
|
||||
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ""
|
||||
const searchBaseIndent = searchIndents[0] || ""
|
||||
|
||||
// Calculate the relative indentation level
|
||||
const searchBaseLevel = searchBaseIndent.length
|
||||
const currentLevel = currentIndent.length
|
||||
const relativeLevel = currentLevel - searchBaseLevel
|
||||
|
||||
// If relative level is negative, remove indentation from matched indent
|
||||
// If positive, add to matched indent
|
||||
const finalIndent =
|
||||
relativeLevel < 0
|
||||
? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel))
|
||||
: matchedIndent + currentIndent.slice(searchBaseLevel)
|
||||
|
||||
return finalIndent + line.trim()
|
||||
})
|
||||
|
||||
// Construct the final content
|
||||
const beforeMatch = originalLines.slice(0, matchIndex)
|
||||
const afterMatch = originalLines.slice(matchIndex + searchLines.length)
|
||||
|
||||
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding)
|
||||
return {
|
||||
success: true,
|
||||
content: finalContent,
|
||||
}
|
||||
}
|
||||
}
|
||||
137
src/core/diff/strategies/unified.ts
Normal file
137
src/core/diff/strategies/unified.ts
Normal file
@@ -0,0 +1,137 @@
|
||||
import { applyPatch } from "diff"
|
||||
import { DiffStrategy, DiffResult } from "../types"
|
||||
|
||||
export class UnifiedDiffStrategy implements DiffStrategy {
|
||||
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||
return `## apply_diff
|
||||
Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3).
|
||||
|
||||
Parameters:
|
||||
- path: (required) The path of the file to apply the diff to (relative to the current working directory ${args.cwd})
|
||||
- diff: (required) The diff content in unified format to apply to the file.
|
||||
|
||||
Format Requirements:
|
||||
|
||||
1. Header (REQUIRED):
|
||||
\`\`\`
|
||||
--- path/to/original/file
|
||||
+++ path/to/modified/file
|
||||
\`\`\`
|
||||
- Must include both lines exactly as shown
|
||||
- Use actual file paths
|
||||
- NO timestamps after paths
|
||||
|
||||
2. Hunks:
|
||||
\`\`\`
|
||||
@@ -lineStart,lineCount +lineStart,lineCount @@
|
||||
-removed line
|
||||
+added line
|
||||
\`\`\`
|
||||
- Each hunk starts with @@ showing line numbers for changes
|
||||
- Format: @@ -originalStart,originalCount +newStart,newCount @@
|
||||
- Use - for removed/changed lines
|
||||
- Use + for new/modified lines
|
||||
- Indentation must match exactly
|
||||
|
||||
Complete Example:
|
||||
|
||||
Original file (with line numbers):
|
||||
\`\`\`
|
||||
1 | import { Logger } from '../logger';
|
||||
2 |
|
||||
3 | function calculateTotal(items: number[]): number {
|
||||
4 | return items.reduce((sum, item) => {
|
||||
5 | return sum + item;
|
||||
6 | }, 0);
|
||||
7 | }
|
||||
8 |
|
||||
9 | export { calculateTotal };
|
||||
\`\`\`
|
||||
|
||||
After applying the diff, the file would look like:
|
||||
\`\`\`
|
||||
1 | import { Logger } from '../logger';
|
||||
2 |
|
||||
3 | function calculateTotal(items: number[]): number {
|
||||
4 | const total = items.reduce((sum, item) => {
|
||||
5 | return sum + item * 1.1; // Add 10% markup
|
||||
6 | }, 0);
|
||||
7 | return Math.round(total * 100) / 100; // Round to 2 decimal places
|
||||
8 | }
|
||||
9 |
|
||||
10 | export { calculateTotal };
|
||||
\`\`\`
|
||||
|
||||
Diff to modify the file:
|
||||
\`\`\`
|
||||
--- src/utils/helper.ts
|
||||
+++ src/utils/helper.ts
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Logger } from '../logger';
|
||||
|
||||
function calculateTotal(items: number[]): number {
|
||||
- return items.reduce((sum, item) => {
|
||||
- return sum + item;
|
||||
+ const total = items.reduce((sum, item) => {
|
||||
+ return sum + item * 1.1; // Add 10% markup
|
||||
}, 0);
|
||||
+ return Math.round(total * 100) / 100; // Round to 2 decimal places
|
||||
}
|
||||
|
||||
export { calculateTotal };
|
||||
\`\`\`
|
||||
|
||||
Common Pitfalls:
|
||||
1. Missing or incorrect header lines
|
||||
2. Incorrect line numbers in @@ lines
|
||||
3. Wrong indentation in changed lines
|
||||
4. Incomplete context (missing lines that need changing)
|
||||
5. Not marking all modified lines with - and +
|
||||
|
||||
Best Practices:
|
||||
1. Replace entire code blocks:
|
||||
- Remove complete old version with - lines
|
||||
- Add complete new version with + lines
|
||||
- Include correct line numbers
|
||||
2. Moving code requires two hunks:
|
||||
- First hunk: Remove from old location
|
||||
- Second hunk: Add to new location
|
||||
3. One hunk per logical change
|
||||
4. Verify line numbers match the line numbers you have in the file
|
||||
|
||||
Usage:
|
||||
<apply_diff>
|
||||
<path>File path here</path>
|
||||
<diff>
|
||||
Your diff here
|
||||
</diff>
|
||||
</apply_diff>`
|
||||
}
|
||||
|
||||
async applyDiff(originalContent: string, diffContent: string): Promise<DiffResult> {
|
||||
try {
|
||||
const result = applyPatch(originalContent, diffContent)
|
||||
if (result === false) {
|
||||
return {
|
||||
success: false,
|
||||
error: "Failed to apply unified diff - patch rejected",
|
||||
details: {
|
||||
searchContent: diffContent,
|
||||
},
|
||||
}
|
||||
}
|
||||
return {
|
||||
success: true,
|
||||
content: result,
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Error applying unified diff: ${error.message}`,
|
||||
details: {
|
||||
searchContent: diffContent,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
36
src/core/diff/types.ts
Normal file
36
src/core/diff/types.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Interface for implementing different diff strategies
|
||||
*/
|
||||
|
||||
export type DiffResult =
|
||||
| { success: true; content: string }
|
||||
| {
|
||||
success: false
|
||||
error: string
|
||||
details?: {
|
||||
similarity?: number
|
||||
threshold?: number
|
||||
matchedRange?: { start: number; end: number }
|
||||
searchContent?: string
|
||||
bestMatch?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface DiffStrategy {
|
||||
/**
|
||||
* Get the tool description for this diff strategy
|
||||
* @param args The tool arguments including cwd and toolOptions
|
||||
* @returns The complete tool description including format requirements and examples
|
||||
*/
|
||||
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string
|
||||
|
||||
/**
|
||||
* Apply a diff to the original content
|
||||
* @param originalContent The original file content
|
||||
* @param diffContent The diff content in the strategy's format
|
||||
* @param startLine Optional line number where the search block starts. If not provided, searches the entire file.
|
||||
* @param endLine Optional line number where the search block ends. If not provided, searches the entire file.
|
||||
* @returns A DiffResult object containing either the successful result or error details
|
||||
*/
|
||||
applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult>
|
||||
}
|
||||
746
src/core/mcp/McpHub.ts
Normal file
746
src/core/mcp/McpHub.ts
Normal file
@@ -0,0 +1,746 @@
|
||||
import { Client } from "@modelcontextprotocol/sdk/client/index.js"
|
||||
import { StdioClientTransport, StdioServerParameters } from "@modelcontextprotocol/sdk/client/stdio.js"
|
||||
import {
|
||||
CallToolResultSchema,
|
||||
ListResourcesResultSchema,
|
||||
ListResourceTemplatesResultSchema,
|
||||
ListToolsResultSchema,
|
||||
ReadResourceResultSchema,
|
||||
} from "@modelcontextprotocol/sdk/types.js"
|
||||
import chokidar, { FSWatcher } from "chokidar"
|
||||
import delay from "delay"
|
||||
import deepEqual from "fast-deep-equal"
|
||||
import * as fs from "fs/promises"
|
||||
import * as path from "path"
|
||||
import * as vscode from "vscode"
|
||||
import { z } from "zod"
|
||||
|
||||
import { ClineProvider } from "../../core/webview/ClineProvider"
|
||||
import { GlobalFileNames } from "../../shared/globalFileNames"
|
||||
import {
|
||||
McpResource,
|
||||
McpResourceResponse,
|
||||
McpResourceTemplate,
|
||||
McpServer,
|
||||
McpTool,
|
||||
McpToolCallResponse,
|
||||
} from "../../shared/mcp"
|
||||
import { fileExistsAtPath } from "../../utils/fs"
|
||||
import { arePathsEqual } from "../../utils/path"
|
||||
|
||||
export type McpConnection = {
|
||||
server: McpServer
|
||||
client: Client
|
||||
transport: StdioClientTransport
|
||||
}
|
||||
|
||||
// StdioServerParameters
|
||||
const AlwaysAllowSchema = z.array(z.string()).default([])
|
||||
|
||||
export const StdioConfigSchema = z.object({
|
||||
command: z.string(),
|
||||
args: z.array(z.string()).optional(),
|
||||
env: z.record(z.string()).optional(),
|
||||
alwaysAllow: AlwaysAllowSchema.optional(),
|
||||
disabled: z.boolean().optional(),
|
||||
timeout: z.number().min(1).max(3600).optional().default(60),
|
||||
})
|
||||
|
||||
const McpSettingsSchema = z.object({
|
||||
mcpServers: z.record(StdioConfigSchema),
|
||||
})
|
||||
|
||||
export class McpHub {
|
||||
private providerRef: WeakRef<ClineProvider>
|
||||
private disposables: vscode.Disposable[] = []
|
||||
private settingsWatcher?: vscode.FileSystemWatcher
|
||||
private fileWatchers: Map<string, FSWatcher> = new Map()
|
||||
connections: McpConnection[] = []
|
||||
isConnecting: boolean = false
|
||||
|
||||
constructor(provider: ClineProvider) {
|
||||
this.providerRef = new WeakRef(provider)
|
||||
this.watchMcpSettingsFile()
|
||||
this.initializeMcpServers()
|
||||
}
|
||||
|
||||
getServers(): McpServer[] {
|
||||
// Only return enabled servers
|
||||
return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server)
|
||||
}
|
||||
|
||||
getAllServers(): McpServer[] {
|
||||
// Return all servers regardless of state
|
||||
return this.connections.map((conn) => conn.server)
|
||||
}
|
||||
|
||||
async getMcpServersPath(): Promise<string> {
|
||||
const provider = this.providerRef.deref()
|
||||
if (!provider) {
|
||||
throw new Error("Provider not available")
|
||||
}
|
||||
const mcpServersPath = await provider.ensureMcpServersDirectoryExists()
|
||||
return mcpServersPath
|
||||
}
|
||||
|
||||
async getMcpSettingsFilePath(): Promise<string> {
|
||||
const provider = this.providerRef.deref()
|
||||
if (!provider) {
|
||||
throw new Error("Provider not available")
|
||||
}
|
||||
const mcpSettingsFilePath = path.join(
|
||||
await provider.ensureSettingsDirectoryExists(),
|
||||
GlobalFileNames.mcpSettings,
|
||||
)
|
||||
const fileExists = await fileExistsAtPath(mcpSettingsFilePath)
|
||||
if (!fileExists) {
|
||||
await fs.writeFile(
|
||||
mcpSettingsFilePath,
|
||||
`{
|
||||
"mcpServers": {
|
||||
|
||||
}
|
||||
}`,
|
||||
)
|
||||
}
|
||||
return mcpSettingsFilePath
|
||||
}
|
||||
|
||||
private async watchMcpSettingsFile(): Promise<void> {
|
||||
const settingsPath = await this.getMcpSettingsFilePath()
|
||||
this.disposables.push(
|
||||
vscode.workspace.onDidSaveTextDocument(async (document) => {
|
||||
if (arePathsEqual(document.uri.fsPath, settingsPath)) {
|
||||
const content = await fs.readFile(settingsPath, "utf-8")
|
||||
const errorMessage =
|
||||
"Invalid MCP settings format. Please ensure your settings follow the correct JSON format."
|
||||
let config: any
|
||||
try {
|
||||
config = JSON.parse(content)
|
||||
} catch (error) {
|
||||
vscode.window.showErrorMessage(errorMessage)
|
||||
return
|
||||
}
|
||||
const result = McpSettingsSchema.safeParse(config)
|
||||
if (!result.success) {
|
||||
vscode.window.showErrorMessage(errorMessage)
|
||||
return
|
||||
}
|
||||
try {
|
||||
await this.updateServerConnections(result.data.mcpServers || {})
|
||||
} catch (error) {
|
||||
console.error("Failed to process MCP settings change:", error)
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
private async initializeMcpServers(): Promise<void> {
|
||||
try {
|
||||
const settingsPath = await this.getMcpSettingsFilePath()
|
||||
const content = await fs.readFile(settingsPath, "utf-8")
|
||||
const config = JSON.parse(content)
|
||||
await this.updateServerConnections(config.mcpServers || {})
|
||||
} catch (error) {
|
||||
console.error("Failed to initialize MCP servers:", error)
|
||||
}
|
||||
}
|
||||
|
||||
private async connectToServer(name: string, config: StdioServerParameters): Promise<void> {
|
||||
// Remove existing connection if it exists (should never happen, the connection should be deleted beforehand)
|
||||
this.connections = this.connections.filter((conn) => conn.server.name !== name)
|
||||
|
||||
try {
|
||||
// Each MCP server requires its own transport connection and has unique capabilities, configurations, and error handling. Having separate clients also allows proper scoping of resources/tools and independent server management like reconnection.
|
||||
const client = new Client(
|
||||
{
|
||||
name: "Roo Code",
|
||||
version: this.providerRef.deref()?.context.extension?.packageJSON?.version ?? "1.0.0",
|
||||
},
|
||||
{
|
||||
capabilities: {},
|
||||
},
|
||||
)
|
||||
|
||||
const transport = new StdioClientTransport({
|
||||
command: config.command,
|
||||
args: config.args,
|
||||
env: {
|
||||
...config.env,
|
||||
...(process.env.PATH ? { PATH: process.env.PATH } : {}),
|
||||
// ...(process.env.NODE_PATH ? { NODE_PATH: process.env.NODE_PATH } : {}),
|
||||
},
|
||||
stderr: "pipe", // necessary for stderr to be available
|
||||
})
|
||||
|
||||
transport.onerror = async (error) => {
|
||||
console.error(`Transport error for "${name}":`, error)
|
||||
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||
if (connection) {
|
||||
connection.server.status = "disconnected"
|
||||
this.appendErrorMessage(connection, error.message)
|
||||
}
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
}
|
||||
|
||||
transport.onclose = async () => {
|
||||
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||
if (connection) {
|
||||
connection.server.status = "disconnected"
|
||||
}
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
}
|
||||
|
||||
// If the config is invalid, show an error
|
||||
if (!StdioConfigSchema.safeParse(config).success) {
|
||||
console.error(`Invalid config for "${name}": missing or invalid parameters`)
|
||||
const connection: McpConnection = {
|
||||
server: {
|
||||
name,
|
||||
config: JSON.stringify(config),
|
||||
status: "disconnected",
|
||||
error: "Invalid config: missing or invalid parameters",
|
||||
},
|
||||
client,
|
||||
transport,
|
||||
}
|
||||
this.connections.push(connection)
|
||||
return
|
||||
}
|
||||
|
||||
// valid schema
|
||||
const parsedConfig = StdioConfigSchema.parse(config)
|
||||
const connection: McpConnection = {
|
||||
server: {
|
||||
name,
|
||||
config: JSON.stringify(config),
|
||||
status: "connecting",
|
||||
disabled: parsedConfig.disabled,
|
||||
},
|
||||
client,
|
||||
transport,
|
||||
}
|
||||
this.connections.push(connection)
|
||||
|
||||
// transport.stderr is only available after the process has been started. However we can't start it separately from the .connect() call because it also starts the transport. And we can't place this after the connect call since we need to capture the stderr stream before the connection is established, in order to capture errors during the connection process.
|
||||
// As a workaround, we start the transport ourselves, and then monkey-patch the start method to no-op so that .connect() doesn't try to start it again.
|
||||
await transport.start()
|
||||
const stderrStream = transport.stderr
|
||||
if (stderrStream) {
|
||||
stderrStream.on("data", async (data: Buffer) => {
|
||||
const errorOutput = data.toString()
|
||||
console.error(`Server "${name}" stderr:`, errorOutput)
|
||||
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||
if (connection) {
|
||||
// NOTE: we do not set server status to "disconnected" because stderr logs do not necessarily mean the server crashed or disconnected, it could just be informational. In fact when the server first starts up, it immediately logs "<name> server running on stdio" to stderr.
|
||||
this.appendErrorMessage(connection, errorOutput)
|
||||
// Only need to update webview right away if it's already disconnected
|
||||
if (connection.server.status === "disconnected") {
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
console.error(`No stderr stream for ${name}`)
|
||||
}
|
||||
transport.start = async () => {} // No-op now, .connect() won't fail
|
||||
|
||||
// Connect
|
||||
await client.connect(transport)
|
||||
connection.server.status = "connected"
|
||||
connection.server.error = ""
|
||||
|
||||
// Initial fetch of tools and resources
|
||||
connection.server.tools = await this.fetchToolsList(name)
|
||||
connection.server.resources = await this.fetchResourcesList(name)
|
||||
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(name)
|
||||
} catch (error) {
|
||||
// Update status with error
|
||||
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||
if (connection) {
|
||||
connection.server.status = "disconnected"
|
||||
this.appendErrorMessage(connection, error instanceof Error ? error.message : String(error))
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private appendErrorMessage(connection: McpConnection, error: string) {
|
||||
const newError = connection.server.error ? `${connection.server.error}\n${error}` : error
|
||||
connection.server.error = newError //.slice(0, 800)
|
||||
}
|
||||
|
||||
private async fetchToolsList(serverName: string): Promise<McpTool[]> {
|
||||
try {
|
||||
const response = await this.connections
|
||||
.find((conn) => conn.server.name === serverName)
|
||||
?.client.request({ method: "tools/list" }, ListToolsResultSchema)
|
||||
|
||||
// Get always allow settings
|
||||
const settingsPath = await this.getMcpSettingsFilePath()
|
||||
const content = await fs.readFile(settingsPath, "utf-8")
|
||||
const config = JSON.parse(content)
|
||||
const alwaysAllowConfig = config.mcpServers[serverName]?.alwaysAllow || []
|
||||
|
||||
// Mark tools as always allowed based on settings
|
||||
const tools = (response?.tools || []).map((tool) => ({
|
||||
...tool,
|
||||
alwaysAllow: alwaysAllowConfig.includes(tool.name),
|
||||
}))
|
||||
|
||||
console.log(`[MCP] Fetched tools for ${serverName}:`, tools)
|
||||
return tools
|
||||
} catch (error) {
|
||||
// console.error(`Failed to fetch tools for ${serverName}:`, error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
private async fetchResourcesList(serverName: string): Promise<McpResource[]> {
|
||||
try {
|
||||
const response = await this.connections
|
||||
.find((conn) => conn.server.name === serverName)
|
||||
?.client.request({ method: "resources/list" }, ListResourcesResultSchema)
|
||||
return response?.resources || []
|
||||
} catch (error) {
|
||||
// console.error(`Failed to fetch resources for ${serverName}:`, error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
private async fetchResourceTemplatesList(serverName: string): Promise<McpResourceTemplate[]> {
|
||||
try {
|
||||
const response = await this.connections
|
||||
.find((conn) => conn.server.name === serverName)
|
||||
?.client.request({ method: "resources/templates/list" }, ListResourceTemplatesResultSchema)
|
||||
return response?.resourceTemplates || []
|
||||
} catch (error) {
|
||||
// console.error(`Failed to fetch resource templates for ${serverName}:`, error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
async deleteConnection(name: string): Promise<void> {
|
||||
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||
if (connection) {
|
||||
try {
|
||||
await connection.transport.close()
|
||||
await connection.client.close()
|
||||
} catch (error) {
|
||||
console.error(`Failed to close transport for ${name}:`, error)
|
||||
}
|
||||
this.connections = this.connections.filter((conn) => conn.server.name !== name)
|
||||
}
|
||||
}
|
||||
|
||||
async updateServerConnections(newServers: Record<string, any>): Promise<void> {
|
||||
this.isConnecting = true
|
||||
this.removeAllFileWatchers()
|
||||
const currentNames = new Set(this.connections.map((conn) => conn.server.name))
|
||||
const newNames = new Set(Object.keys(newServers))
|
||||
|
||||
// Delete removed servers
|
||||
for (const name of currentNames) {
|
||||
if (!newNames.has(name)) {
|
||||
await this.deleteConnection(name)
|
||||
console.log(`Deleted MCP server: ${name}`)
|
||||
}
|
||||
}
|
||||
|
||||
// Update or add servers
|
||||
for (const [name, config] of Object.entries(newServers)) {
|
||||
const currentConnection = this.connections.find((conn) => conn.server.name === name)
|
||||
|
||||
if (!currentConnection) {
|
||||
// New server
|
||||
try {
|
||||
this.setupFileWatcher(name, config)
|
||||
await this.connectToServer(name, config)
|
||||
} catch (error) {
|
||||
console.error(`Failed to connect to new MCP server ${name}:`, error)
|
||||
}
|
||||
} else if (!deepEqual(JSON.parse(currentConnection.server.config), config)) {
|
||||
// Existing server with changed config
|
||||
try {
|
||||
this.setupFileWatcher(name, config)
|
||||
await this.deleteConnection(name)
|
||||
await this.connectToServer(name, config)
|
||||
console.log(`Reconnected MCP server with updated config: ${name}`)
|
||||
} catch (error) {
|
||||
console.error(`Failed to reconnect MCP server ${name}:`, error)
|
||||
}
|
||||
}
|
||||
// If server exists with same config, do nothing
|
||||
}
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
this.isConnecting = false
|
||||
}
|
||||
|
||||
private setupFileWatcher(name: string, config: any) {
|
||||
const filePath = config.args?.find((arg: string) => arg.includes("build/index.js"))
|
||||
if (filePath) {
|
||||
// we use chokidar instead of onDidSaveTextDocument because it doesn't require the file to be open in the editor. The settings config is better suited for onDidSave since that will be manually updated by the user or Cline (and we want to detect save events, not every file change)
|
||||
const watcher = chokidar.watch(filePath, {
|
||||
// persistent: true,
|
||||
// ignoreInitial: true,
|
||||
// awaitWriteFinish: true, // This helps with atomic writes
|
||||
})
|
||||
|
||||
watcher.on("change", () => {
|
||||
console.log(`Detected change in ${filePath}. Restarting server ${name}...`)
|
||||
this.restartConnection(name)
|
||||
})
|
||||
|
||||
this.fileWatchers.set(name, watcher)
|
||||
}
|
||||
}
|
||||
|
||||
private removeAllFileWatchers() {
|
||||
this.fileWatchers.forEach((watcher) => watcher.close())
|
||||
this.fileWatchers.clear()
|
||||
}
|
||||
|
||||
async restartConnection(serverName: string): Promise<void> {
|
||||
this.isConnecting = true
|
||||
const provider = this.providerRef.deref()
|
||||
if (!provider) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing connection and update its status
|
||||
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||
const config = connection?.server.config
|
||||
if (config) {
|
||||
vscode.window.showInformationMessage(`Restarting ${serverName} MCP server...`)
|
||||
connection.server.status = "connecting"
|
||||
connection.server.error = ""
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
await delay(500) // artificial delay to show user that server is restarting
|
||||
try {
|
||||
await this.deleteConnection(serverName)
|
||||
// Try to connect again using existing config
|
||||
await this.connectToServer(serverName, JSON.parse(config))
|
||||
vscode.window.showInformationMessage(`${serverName} MCP server connected`)
|
||||
} catch (error) {
|
||||
console.error(`Failed to restart connection for ${serverName}:`, error)
|
||||
vscode.window.showErrorMessage(`Failed to connect to ${serverName} MCP server`)
|
||||
}
|
||||
}
|
||||
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
this.isConnecting = false
|
||||
}
|
||||
|
||||
private async notifyWebviewOfServerChanges(): Promise<void> {
|
||||
// servers should always be sorted in the order they are defined in the settings file
|
||||
const settingsPath = await this.getMcpSettingsFilePath()
|
||||
const content = await fs.readFile(settingsPath, "utf-8")
|
||||
const config = JSON.parse(content)
|
||||
const serverOrder = Object.keys(config.mcpServers || {})
|
||||
await this.providerRef.deref()?.postMessageToWebview({
|
||||
type: "mcpServers",
|
||||
mcpServers: [...this.connections]
|
||||
.sort((a, b) => {
|
||||
const indexA = serverOrder.indexOf(a.server.name)
|
||||
const indexB = serverOrder.indexOf(b.server.name)
|
||||
return indexA - indexB
|
||||
})
|
||||
.map((connection) => connection.server),
|
||||
})
|
||||
}
|
||||
|
||||
public async toggleServerDisabled(serverName: string, disabled: boolean): Promise<void> {
|
||||
let settingsPath: string
|
||||
try {
|
||||
settingsPath = await this.getMcpSettingsFilePath()
|
||||
|
||||
// Ensure the settings file exists and is accessible
|
||||
try {
|
||||
await fs.access(settingsPath)
|
||||
} catch (error) {
|
||||
console.error("Settings file not accessible:", error)
|
||||
throw new Error("Settings file not accessible")
|
||||
}
|
||||
const content = await fs.readFile(settingsPath, "utf-8")
|
||||
const config = JSON.parse(content)
|
||||
|
||||
// Validate the config structure
|
||||
if (!config || typeof config !== "object") {
|
||||
throw new Error("Invalid config structure")
|
||||
}
|
||||
|
||||
if (!config.mcpServers || typeof config.mcpServers !== "object") {
|
||||
config.mcpServers = {}
|
||||
}
|
||||
|
||||
if (config.mcpServers[serverName]) {
|
||||
// Create a new server config object to ensure clean structure
|
||||
const serverConfig = {
|
||||
...config.mcpServers[serverName],
|
||||
disabled,
|
||||
}
|
||||
|
||||
// Ensure required fields exist
|
||||
if (!serverConfig.alwaysAllow) {
|
||||
serverConfig.alwaysAllow = []
|
||||
}
|
||||
|
||||
config.mcpServers[serverName] = serverConfig
|
||||
|
||||
// Write the entire config back
|
||||
const updatedConfig = {
|
||||
mcpServers: config.mcpServers,
|
||||
}
|
||||
|
||||
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
|
||||
|
||||
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||
if (connection) {
|
||||
try {
|
||||
connection.server.disabled = disabled
|
||||
|
||||
// Only refresh capabilities if connected
|
||||
if (connection.server.status === "connected") {
|
||||
connection.server.tools = await this.fetchToolsList(serverName)
|
||||
connection.server.resources = await this.fetchResourcesList(serverName)
|
||||
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(serverName)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Failed to refresh capabilities for ${serverName}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to update server disabled state:", error)
|
||||
if (error instanceof Error) {
|
||||
console.error("Error details:", error.message, error.stack)
|
||||
}
|
||||
vscode.window.showErrorMessage(
|
||||
`Failed to update server state: ${error instanceof Error ? error.message : String(error)}`,
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async updateServerTimeout(serverName: string, timeout: number): Promise<void> {
|
||||
let settingsPath: string
|
||||
try {
|
||||
settingsPath = await this.getMcpSettingsFilePath()
|
||||
|
||||
// Ensure the settings file exists and is accessible
|
||||
try {
|
||||
await fs.access(settingsPath)
|
||||
} catch (error) {
|
||||
console.error("Settings file not accessible:", error)
|
||||
throw new Error("Settings file not accessible")
|
||||
}
|
||||
const content = await fs.readFile(settingsPath, "utf-8")
|
||||
const config = JSON.parse(content)
|
||||
|
||||
// Validate the config structure
|
||||
if (!config || typeof config !== "object") {
|
||||
throw new Error("Invalid config structure")
|
||||
}
|
||||
|
||||
if (!config.mcpServers || typeof config.mcpServers !== "object") {
|
||||
config.mcpServers = {}
|
||||
}
|
||||
|
||||
if (config.mcpServers[serverName]) {
|
||||
// Create a new server config object to ensure clean structure
|
||||
const serverConfig = {
|
||||
...config.mcpServers[serverName],
|
||||
timeout,
|
||||
}
|
||||
|
||||
config.mcpServers[serverName] = serverConfig
|
||||
|
||||
// Write the entire config back
|
||||
const updatedConfig = {
|
||||
mcpServers: config.mcpServers,
|
||||
}
|
||||
|
||||
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to update server timeout:", error)
|
||||
if (error instanceof Error) {
|
||||
console.error("Error details:", error.message, error.stack)
|
||||
}
|
||||
vscode.window.showErrorMessage(
|
||||
`Failed to update server timeout: ${error instanceof Error ? error.message : String(error)}`,
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async deleteServer(serverName: string): Promise<void> {
|
||||
try {
|
||||
const settingsPath = await this.getMcpSettingsFilePath()
|
||||
|
||||
// Ensure the settings file exists and is accessible
|
||||
try {
|
||||
await fs.access(settingsPath)
|
||||
} catch (error) {
|
||||
throw new Error("Settings file not accessible")
|
||||
}
|
||||
|
||||
const content = await fs.readFile(settingsPath, "utf-8")
|
||||
const config = JSON.parse(content)
|
||||
|
||||
// Validate the config structure
|
||||
if (!config || typeof config !== "object") {
|
||||
throw new Error("Invalid config structure")
|
||||
}
|
||||
|
||||
if (!config.mcpServers || typeof config.mcpServers !== "object") {
|
||||
config.mcpServers = {}
|
||||
}
|
||||
|
||||
// Remove the server from the settings
|
||||
if (config.mcpServers[serverName]) {
|
||||
delete config.mcpServers[serverName]
|
||||
|
||||
// Write the entire config back
|
||||
const updatedConfig = {
|
||||
mcpServers: config.mcpServers,
|
||||
}
|
||||
|
||||
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
|
||||
|
||||
// Update server connections
|
||||
await this.updateServerConnections(config.mcpServers)
|
||||
|
||||
vscode.window.showInformationMessage(`Deleted MCP server: ${serverName}`)
|
||||
} else {
|
||||
vscode.window.showWarningMessage(`Server "${serverName}" not found in configuration`)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to delete MCP server:", error)
|
||||
if (error instanceof Error) {
|
||||
console.error("Error details:", error.message, error.stack)
|
||||
}
|
||||
vscode.window.showErrorMessage(
|
||||
`Failed to delete MCP server: ${error instanceof Error ? error.message : String(error)}`,
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async readResource(serverName: string, uri: string): Promise<McpResourceResponse> {
|
||||
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||
if (!connection) {
|
||||
throw new Error(`No connection found for server: ${serverName}`)
|
||||
}
|
||||
if (connection.server.disabled) {
|
||||
throw new Error(`Server "${serverName}" is disabled`)
|
||||
}
|
||||
return await connection.client.request(
|
||||
{
|
||||
method: "resources/read",
|
||||
params: {
|
||||
uri,
|
||||
},
|
||||
},
|
||||
ReadResourceResultSchema,
|
||||
)
|
||||
}
|
||||
|
||||
async callTool(
|
||||
serverName: string,
|
||||
toolName: string,
|
||||
toolArguments?: Record<string, unknown>,
|
||||
): Promise<McpToolCallResponse> {
|
||||
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||
if (!connection) {
|
||||
throw new Error(
|
||||
`No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`,
|
||||
)
|
||||
}
|
||||
if (connection.server.disabled) {
|
||||
throw new Error(`Server "${serverName}" is disabled and cannot be used`)
|
||||
}
|
||||
|
||||
let timeout: number
|
||||
try {
|
||||
const parsedConfig = StdioConfigSchema.parse(JSON.parse(connection.server.config))
|
||||
timeout = (parsedConfig.timeout ?? 60) * 1000
|
||||
} catch (error) {
|
||||
console.error("Failed to parse server config for timeout:", error)
|
||||
// Default to 60 seconds if parsing fails
|
||||
timeout = 60 * 1000
|
||||
}
|
||||
|
||||
return await connection.client.request(
|
||||
{
|
||||
method: "tools/call",
|
||||
params: {
|
||||
name: toolName,
|
||||
arguments: toolArguments,
|
||||
},
|
||||
},
|
||||
CallToolResultSchema,
|
||||
{
|
||||
timeout,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise<void> {
|
||||
try {
|
||||
const settingsPath = await this.getMcpSettingsFilePath()
|
||||
const content = await fs.readFile(settingsPath, "utf-8")
|
||||
const config = JSON.parse(content)
|
||||
|
||||
// Initialize alwaysAllow if it doesn't exist
|
||||
if (!config.mcpServers[serverName].alwaysAllow) {
|
||||
config.mcpServers[serverName].alwaysAllow = []
|
||||
}
|
||||
|
||||
const alwaysAllow = config.mcpServers[serverName].alwaysAllow
|
||||
const toolIndex = alwaysAllow.indexOf(toolName)
|
||||
|
||||
if (shouldAllow && toolIndex === -1) {
|
||||
// Add tool to always allow list
|
||||
alwaysAllow.push(toolName)
|
||||
} else if (!shouldAllow && toolIndex !== -1) {
|
||||
// Remove tool from always allow list
|
||||
alwaysAllow.splice(toolIndex, 1)
|
||||
}
|
||||
|
||||
// Write updated config back to file
|
||||
await fs.writeFile(settingsPath, JSON.stringify(config, null, 2))
|
||||
|
||||
// Update the tools list to reflect the change
|
||||
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||
if (connection) {
|
||||
connection.server.tools = await this.fetchToolsList(serverName)
|
||||
await this.notifyWebviewOfServerChanges()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to update always allow settings:", error)
|
||||
vscode.window.showErrorMessage("Failed to update always allow settings")
|
||||
throw error // Re-throw to ensure the error is properly handled
|
||||
}
|
||||
}
|
||||
|
||||
async dispose(): Promise<void> {
|
||||
this.removeAllFileWatchers()
|
||||
for (const connection of this.connections) {
|
||||
try {
|
||||
await this.deleteConnection(connection.server.name)
|
||||
} catch (error) {
|
||||
console.error(`Failed to close connection for ${connection.server.name}:`, error)
|
||||
}
|
||||
}
|
||||
this.connections = []
|
||||
if (this.settingsWatcher) {
|
||||
this.settingsWatcher.dispose()
|
||||
}
|
||||
this.disposables.forEach((d) => d.dispose())
|
||||
}
|
||||
}
|
||||
83
src/core/mcp/McpServerManager.ts
Normal file
83
src/core/mcp/McpServerManager.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import * as vscode from "vscode"
|
||||
import { ClineProvider } from "../../core/webview/ClineProvider"
|
||||
import { McpHub } from "./McpHub"
|
||||
|
||||
/**
|
||||
* Singleton manager for MCP server instances.
|
||||
* Ensures only one set of MCP servers runs across all webviews.
|
||||
*/
|
||||
export class McpServerManager {
|
||||
private static instance: McpHub | null = null
|
||||
private static readonly GLOBAL_STATE_KEY = "mcpHubInstanceId"
|
||||
private static providers: Set<ClineProvider> = new Set()
|
||||
private static initializationPromise: Promise<McpHub> | null = null
|
||||
|
||||
/**
|
||||
* Get the singleton McpHub instance.
|
||||
* Creates a new instance if one doesn't exist.
|
||||
* Thread-safe implementation using a promise-based lock.
|
||||
*/
|
||||
static async getInstance(context: vscode.ExtensionContext, provider: ClineProvider): Promise<McpHub> {
|
||||
// Register the provider
|
||||
this.providers.add(provider)
|
||||
|
||||
// If we already have an instance, return it
|
||||
if (this.instance) {
|
||||
return this.instance
|
||||
}
|
||||
|
||||
// If initialization is in progress, wait for it
|
||||
if (this.initializationPromise) {
|
||||
return this.initializationPromise
|
||||
}
|
||||
|
||||
// Create a new initialization promise
|
||||
this.initializationPromise = (async () => {
|
||||
try {
|
||||
// Double-check instance in case it was created while we were waiting
|
||||
if (!this.instance) {
|
||||
this.instance = new McpHub(provider)
|
||||
// Store a unique identifier in global state to track the primary instance
|
||||
await context.globalState.update(this.GLOBAL_STATE_KEY, Date.now().toString())
|
||||
}
|
||||
return this.instance
|
||||
} finally {
|
||||
// Clear the initialization promise after completion or error
|
||||
this.initializationPromise = null
|
||||
}
|
||||
})()
|
||||
|
||||
return this.initializationPromise
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a provider from the tracked set.
|
||||
* This is called when a webview is disposed.
|
||||
*/
|
||||
static unregisterProvider(provider: ClineProvider): void {
|
||||
this.providers.delete(provider)
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify all registered providers of server state changes.
|
||||
*/
|
||||
static notifyProviders(message: any): void {
|
||||
this.providers.forEach((provider) => {
|
||||
provider.postMessageToWebview(message).catch((error) => {
|
||||
console.error("Failed to notify provider:", error)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up the singleton instance and all its resources.
|
||||
*/
|
||||
static async cleanup(context: vscode.ExtensionContext): Promise<void> {
|
||||
if (this.instance) {
|
||||
await this.instance.dispose()
|
||||
this.instance = null
|
||||
await context.globalState.update(this.GLOBAL_STATE_KEY, undefined)
|
||||
}
|
||||
this.providers.clear()
|
||||
}
|
||||
}
|
||||
206
src/core/services/ripgrep/index.ts
Normal file
206
src/core/services/ripgrep/index.ts
Normal file
@@ -0,0 +1,206 @@
|
||||
// import * as vscode from "vscode"
|
||||
import * as childProcess from "child_process"
|
||||
import * as fs from "fs"
|
||||
import * as path from "path"
|
||||
import * as readline from "readline"
|
||||
|
||||
const isWindows = /^win/.test(process.platform)
|
||||
const binName = isWindows ? "rg.exe" : "rg"
|
||||
|
||||
interface SearchResult {
|
||||
file: string
|
||||
line: number
|
||||
column: number
|
||||
match: string
|
||||
beforeContext: string[]
|
||||
afterContext: string[]
|
||||
}
|
||||
|
||||
// Constants
|
||||
const MAX_RESULTS = 300
|
||||
const MAX_LINE_LENGTH = 500
|
||||
|
||||
/**
|
||||
* Truncates a line if it exceeds the maximum length
|
||||
* @param line The line to truncate
|
||||
* @param maxLength The maximum allowed length (defaults to MAX_LINE_LENGTH)
|
||||
* @returns The truncated line, or the original line if it's shorter than maxLength
|
||||
*/
|
||||
export function truncateLine(line: string, maxLength: number = MAX_LINE_LENGTH): string {
|
||||
return line.length > maxLength ? line.substring(0, maxLength) + " [truncated...]" : line
|
||||
}
|
||||
|
||||
async function getBinPath(): Promise<string | undefined> {
|
||||
const binPath = path.join("/opt/homebrew/bin/", binName)
|
||||
return (await pathExists(binPath)) ? binPath : undefined
|
||||
}
|
||||
|
||||
async function pathExists(path: string): Promise<boolean> {
|
||||
return new Promise((resolve) => {
|
||||
fs.access(path, (err) => {
|
||||
resolve(err === null)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async function execRipgrep(bin: string, args: string[]): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const rgProcess = childProcess.spawn(bin, args)
|
||||
// cross-platform alternative to head, which is ripgrep author's recommendation for limiting output.
|
||||
const rl = readline.createInterface({
|
||||
input: rgProcess.stdout,
|
||||
crlfDelay: Infinity, // treat \r\n as a single line break even if it's split across chunks. This ensures consistent behavior across different operating systems.
|
||||
})
|
||||
|
||||
let output = ""
|
||||
let lineCount = 0
|
||||
const maxLines = MAX_RESULTS * 5 // limiting ripgrep output with max lines since there's no other way to limit results. it's okay that we're outputting as json, since we're parsing it line by line and ignore anything that's not part of a match. This assumes each result is at most 5 lines.
|
||||
|
||||
rl.on("line", (line) => {
|
||||
if (lineCount < maxLines) {
|
||||
output += line + "\n"
|
||||
lineCount++
|
||||
} else {
|
||||
rl.close()
|
||||
rgProcess.kill()
|
||||
}
|
||||
})
|
||||
|
||||
let errorOutput = ""
|
||||
rgProcess.stderr.on("data", (data) => {
|
||||
errorOutput += data.toString()
|
||||
})
|
||||
rl.on("close", () => {
|
||||
if (errorOutput) {
|
||||
reject(new Error(`ripgrep process error: ${errorOutput}`))
|
||||
} else {
|
||||
resolve(output)
|
||||
}
|
||||
})
|
||||
rgProcess.on("error", (error) => {
|
||||
reject(new Error(`ripgrep process error: ${error.message}`))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
export async function regexSearchFiles(
|
||||
directoryPath: string,
|
||||
regex: string,
|
||||
): Promise<string> {
|
||||
const rgPath = await getBinPath()
|
||||
|
||||
if (!rgPath) {
|
||||
throw new Error("Could not find ripgrep binary")
|
||||
}
|
||||
|
||||
// 使用--glob参数排除.obsidian目录
|
||||
const args = [
|
||||
"--json",
|
||||
"-e",
|
||||
regex,
|
||||
"--glob",
|
||||
"!.obsidian/**", // 排除.obsidian目录及其所有子目录
|
||||
"--glob",
|
||||
"!.git/**",
|
||||
"--context",
|
||||
"1",
|
||||
directoryPath
|
||||
]
|
||||
|
||||
let output: string
|
||||
try {
|
||||
output = await execRipgrep(rgPath, args)
|
||||
console.log("output", output)
|
||||
} catch (error) {
|
||||
console.error("Error executing ripgrep:", error)
|
||||
return "No results found"
|
||||
}
|
||||
const results: SearchResult[] = []
|
||||
let currentResult: Partial<SearchResult> | null = null
|
||||
|
||||
output.split("\n").forEach((line) => {
|
||||
if (line) {
|
||||
try {
|
||||
const parsed = JSON.parse(line)
|
||||
if (parsed.type === "match") {
|
||||
if (currentResult) {
|
||||
results.push(currentResult as SearchResult)
|
||||
}
|
||||
|
||||
// Safety check: truncate extremely long lines to prevent excessive output
|
||||
const matchText = parsed.data.lines.text
|
||||
const truncatedMatch = truncateLine(matchText)
|
||||
|
||||
currentResult = {
|
||||
file: parsed.data.path.text,
|
||||
line: parsed.data.line_number,
|
||||
column: parsed.data.submatches[0].start,
|
||||
match: truncatedMatch,
|
||||
beforeContext: [],
|
||||
afterContext: [],
|
||||
}
|
||||
} else if (parsed.type === "context" && currentResult) {
|
||||
// Apply the same truncation logic to context lines
|
||||
const contextText = parsed.data.lines.text
|
||||
const truncatedContext = truncateLine(contextText)
|
||||
|
||||
if (parsed.data.line_number < currentResult.line!) {
|
||||
currentResult.beforeContext!.push(truncatedContext)
|
||||
} else {
|
||||
currentResult.afterContext!.push(truncatedContext)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error parsing ripgrep output:", error)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if (currentResult) {
|
||||
results.push(currentResult as SearchResult)
|
||||
}
|
||||
|
||||
console.log("results", results)
|
||||
console.log("currentResult", currentResult)
|
||||
|
||||
return formatResults(results, directoryPath)
|
||||
}
|
||||
|
||||
function formatResults(results: SearchResult[], cwd: string): string {
|
||||
const groupedResults: { [key: string]: SearchResult[] } = {}
|
||||
|
||||
let output = ""
|
||||
if (results.length >= MAX_RESULTS) {
|
||||
output += `Showing first ${MAX_RESULTS} of ${MAX_RESULTS}+ results. Use a more specific search if necessary.\n\n`
|
||||
} else {
|
||||
output += `Found ${results.length === 1 ? "1 result" : `${results.length.toLocaleString()} results`}.\n\n`
|
||||
}
|
||||
|
||||
// Group results by file name
|
||||
results.slice(0, MAX_RESULTS).forEach((result) => {
|
||||
const relativeFilePath = path.relative(cwd, result.file)
|
||||
if (!groupedResults[relativeFilePath]) {
|
||||
groupedResults[relativeFilePath] = []
|
||||
}
|
||||
groupedResults[relativeFilePath].push(result)
|
||||
})
|
||||
|
||||
for (const [filePath, fileResults] of Object.entries(groupedResults)) {
|
||||
output += `${filePath.toPosix()}\n│----\n`
|
||||
|
||||
fileResults.forEach((result, index) => {
|
||||
const allLines = [...result.beforeContext, result.match, ...result.afterContext]
|
||||
allLines.forEach((line) => {
|
||||
output += `│${line?.trimEnd() ?? ""}\n`
|
||||
})
|
||||
|
||||
if (index < fileResults.length - 1) {
|
||||
output += "│----\n"
|
||||
}
|
||||
})
|
||||
|
||||
output += "│----\n\n"
|
||||
}
|
||||
|
||||
return output.trim()
|
||||
}
|
||||
0
src/core/services/semantic/index.ts
Normal file
0
src/core/services/semantic/index.ts
Normal file
Reference in New Issue
Block a user