Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ All configuration in `pyproject.toml` under `[tool.codeflash]`:
module-root = "codeflash" # Source code location
tests-root = "tests" # Test directory
benchmarks-root = "tests/benchmarks" # Benchmark tests
test-framework = "pytest" # Always pytest
formatter-cmds = [ # Auto-formatting commands
"uvx ruff check --exit-zero --fix $file",
"uvx ruff format $file",
Expand Down Expand Up @@ -186,6 +185,12 @@ uv run pytest -m "not ci_skip"
# Initialize CodeFlash in a project
uv run codeflash init

# Install VSCode extension
uv run codeflash vscode-install

# Initialize GitHub Actions workflow
uv run codeflash init-actions

# Optimize entire codebase
uv run codeflash --all

Expand All @@ -209,6 +214,12 @@ uv run codeflash --benchmark --file target_file.py

# Use replay tests for debugging
uv run codeflash --replay-test tests/specific_test.py

# Create PR (skip for local testing)
uv run codeflash --no-pr --file target_file.py

# Use worktree isolation
uv run codeflash --worktree --file target_file.py
```

## Development Guidelines
Expand All @@ -218,7 +229,7 @@ uv run codeflash --replay-test tests/specific_test.py
- Strict mypy type checking enabled
- Pre-commit hooks enforce code quality
- Line length: 120 characters
- Python 3.10+ syntax
- Python 3.9+ syntax (requires-python = ">=3.9")

### Testing Strategy
- Primary test framework: pytest
Expand All @@ -230,10 +241,11 @@ uv run codeflash --replay-test tests/specific_test.py
- Test isolation via custom pytest plugin

### Key Dependencies
- **Core**: `libcst`, `jedi`, `gitpython`, `pydantic`
- **Testing**: `pytest`, `coverage`, `crosshair-tool`
- **Performance**: `line_profiler`, `timeout-decorator`
- **UI**: `rich`, `inquirer`, `click`
- **Core**: `libcst`, `jedi`, `gitpython`, `pydantic`, `unidiff`, `tomlkit`
- **Testing**: `pytest`, `pytest-timeout`, `pytest-asyncio`, `coverage`, `crosshair-tool`, `parameterized`, `junitparser`
- **Performance**: `line_profiler`, `codeflash-benchmark`, `dill`
- **UI**: `rich`, `inquirer`, `click`, `humanize`
- **Language Server**: `pygls`
- **AI**: Custom API client for LLM interactions

### Data Models & Types
Expand Down
19 changes: 16 additions & 3 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,9 @@ class CandidateEvaluationContext:
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
is_correct: dict[str, bool] = Field(default_factory=dict)
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
# Maps AST normalized code to: optimization_id, shorter_source_code, and diff_len (for deduplication)
ast_code_to_id: dict = Field(default_factory=dict)
# Stores final code strings for candidates (may differ from original if shorter/longer versions were found)
optimizations_post: dict[str, str] = Field(default_factory=dict)
valid_optimizations: list = Field(default_factory=list)

Expand All @@ -377,9 +379,14 @@ def record_line_profiler_result(self, optimization_id: str, result: str) -> None
def handle_duplicate_candidate(
self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext
) -> None:
"""Handle a candidate that has been seen before."""
"""Handle a candidate that has been seen before.

When we encounter duplicate candidates (same AST-normalized code), we reuse the previous
evaluation results instead of re-running tests.
"""
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]

# Update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes
# Copy results from the previous evaluation
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id]
self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id]
Expand All @@ -398,14 +405,20 @@ def handle_duplicate_candidate(

# Update to shorter code if this candidate has a shorter diff
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
if (
new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]
): # new candidate has a shorter diff than the previously encountered one
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len

def register_new_candidate(
self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext
) -> None:
"""Register a new candidate that hasn't been seen before."""
"""Register a new candidate that hasn't been seen before.

Maps AST normalized code to: diff len, unnormalized code, and optimization ID.
Tracks the shortest unnormalized code for each unique AST structure.
"""
self.ast_code_to_id[normalized_code] = {
"optimization_id": candidate.optimization_id,
"shorter_source_code": candidate.source_code,
Expand Down
13 changes: 10 additions & 3 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,11 @@ def select_best_optimization(
exp_type: str,
function_references: str,
) -> BestOptimization | None:
"""Select the best optimization from valid candidates."""
"""Select the best optimization from valid candidates.

Reassigns the shorter code versions for candidates with the same AST structure,
then ranks them to determine the best optimization.
"""
if not eval_ctx.valid_optimizations:
return None

Expand Down Expand Up @@ -645,11 +649,12 @@ def select_best_optimization(
else:
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
# TODO: better way to resolve conflicts with same min ranking
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking}
min_key = min(overall_ranking, key=overall_ranking.get)
elif len(optimization_ids) == 1:
min_key = 0
else:
min_key = 0 # only one candidate in valid optimizations
else: # 0 candidates - shouldn't happen, but defensive check
return None

return valid_candidates_with_shorter_code[min_key]
Expand Down Expand Up @@ -726,6 +731,7 @@ def process_single_candidate(
return None

# Check for duplicate candidates
# Check if this code has been evaluated before by checking the AST normalized code string
normalized_code = normalize_code(candidate.source_code.flat.strip())
if normalized_code in eval_ctx.ast_code_to_id:
logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.")
Expand Down Expand Up @@ -754,6 +760,7 @@ def process_single_candidate(
eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain)

# Check if this is a successful optimization
# For async functions, prioritize throughput metrics over runtime
is_successful_opt = speedup_critic(
candidate_result,
original_code_baseline.runtime,
Expand Down
7 changes: 3 additions & 4 deletions codeflash/verification/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_behavioral_tests(
f"--codeflash_seconds={pytest_target_runtime_seconds}",
]
if pytest_timeout is not None:
common_pytest_args.append(f"--timeout={pytest_timeout}")
common_pytest_args.insert(1, f"--timeout={pytest_timeout}")

result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
Expand Down Expand Up @@ -164,7 +164,7 @@ def run_line_profile_tests(
f"--codeflash_seconds={pytest_target_runtime_seconds}",
]
if pytest_timeout is not None:
pytest_args.append(f"--timeout={pytest_timeout}")
pytest_args.insert(1, f"--timeout={pytest_timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
Expand Down Expand Up @@ -214,8 +214,7 @@ def run_benchmarking_tests(
f"--codeflash_seconds={pytest_target_runtime_seconds}",
]
if pytest_timeout is not None:
pytest_args.append(f"--timeout={pytest_timeout}")

pytest_args.insert(1, f"--timeout={pytest_timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
Expand Down
Loading