diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 578bff4dd..e63f52b77 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -25,6 +25,7 @@ from codeflash.cli_cmds.cli_common import apologize_and_exit from codeflash.cli_cmds.console import console, logger from codeflash.cli_cmds.extension import install_vscode_extension +from codeflash.code_utils.code_utils import validate_relative_directory_path from codeflash.code_utils.compat import LF from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.env_utils import check_formatter_installed, get_codeflash_api_key @@ -349,20 +350,32 @@ def collect_setup_info() -> CLISetupInfo: console.print(custom_panel) console.print() - custom_questions = [ - inquirer.Path( - "custom_path", - message="Enter the path to your module directory", - path_type=inquirer.Path.DIRECTORY, - exists=True, - ) - ] + # Retry loop for custom module root path + module_root = None + while module_root is None: + custom_questions = [ + inquirer.Path( + "custom_path", + message="Enter the path to your module directory", + path_type=inquirer.Path.DIRECTORY, + exists=True, + ) + ] - custom_answers = inquirer.prompt(custom_questions, theme=CodeflashTheme()) - if custom_answers: - module_root = Path(custom_answers["custom_path"]) - else: - apologize_and_exit() + custom_answers = inquirer.prompt(custom_questions, theme=CodeflashTheme()) + if not custom_answers: + apologize_and_exit() + return None # unreachable but satisfies type checker + + custom_path_str = str(custom_answers["custom_path"]) + # Validate the path is safe + is_valid, error_msg = validate_relative_directory_path(custom_path_str) + if not is_valid: + click.echo(f"❌ Invalid path: {error_msg}") + click.echo("Please enter a valid relative directory path.") + console.print() # Add spacing before retry + continue # Retry the prompt + module_root = Path(custom_path_str) else: module_root = module_root_answer ph("cli-project-root-provided") @@ -420,20 +433,32 @@ def collect_setup_info() -> CLISetupInfo: console.print(custom_tests_panel) console.print() - custom_tests_questions = [ - inquirer.Path( - "custom_tests_path", - message="Enter the path to your tests directory", - path_type=inquirer.Path.DIRECTORY, - exists=True, - ) - ] + # Retry loop for custom tests root path + tests_root = None + while tests_root is None: + custom_tests_questions = [ + inquirer.Path( + "custom_tests_path", + message="Enter the path to your tests directory", + path_type=inquirer.Path.DIRECTORY, + exists=True, + ) + ] - custom_tests_answers = inquirer.prompt(custom_tests_questions, theme=CodeflashTheme()) - if custom_tests_answers: - tests_root = Path(curdir) / Path(custom_tests_answers["custom_tests_path"]) - else: - apologize_and_exit() + custom_tests_answers = inquirer.prompt(custom_tests_questions, theme=CodeflashTheme()) + if not custom_tests_answers: + apologize_and_exit() + return None # unreachable but satisfies type checker + + custom_tests_path_str = str(custom_tests_answers["custom_tests_path"]) + # Validate the path is safe + is_valid, error_msg = validate_relative_directory_path(custom_tests_path_str) + if not is_valid: + click.echo(f"❌ Invalid path: {error_msg}") + click.echo("Please enter a valid relative directory path.") + console.print() # Add spacing before retry + continue # Retry the prompt + tests_root = Path(curdir) / Path(custom_tests_path_str) else: tests_root = Path(curdir) / Path(cast("str", tests_root_answer)) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 37e0dd94e..5682bdf42 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -19,6 +19,10 @@ from codeflash.code_utils.config_parser import find_pyproject_toml, get_all_closest_config_files from codeflash.lsp.helpers import is_LSP_enabled +_INVALID_CHARS_NT = {"<", ">", ":", '"', "|", "?", "*"} + +_INVALID_CHARS_UNIX = {"\0"} + ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE) BLACKLIST_ADDOPTS = ("--benchmark", "--sugar", "--codespeed", "--cov", "--profile", "--junitxml", "-n") @@ -376,3 +380,51 @@ def extract_unique_errors(pytest_output: str) -> set[str]: unique_errors.add(error_message) return unique_errors + + +def validate_relative_directory_path(path: str) -> tuple[bool, str]: + """Validate that a path is a safe relative directory path. + + Prevents path traversal attacks and invalid paths. + Works cross-platform (Windows, Linux, macOS). + + Args: + path: The path string to validate + + Returns: + tuple[bool, str]: (is_valid, error_message) + - is_valid: True if path is valid, False otherwise + - error_message: Empty string if valid, error description if invalid + + """ + if not path or not path.strip(): + return False, "Path cannot be empty" + + # Normalize whitespace + path = path.strip() + + # Check for path traversal attempts (cross-platform) + # Normalize path separators for checking + normalized = path.replace("\\", "/") + if ".." in normalized: + return False, "Path cannot contain '..'. Use a relative path like 'tests' or 'src/app' instead" + + # Check for absolute paths, invalid characters, and validate path format + error_msg = "" + if Path(path).is_absolute(): + error_msg = "Path must be relative, not absolute" + elif os.name == "nt": # Windows + if any(char in _INVALID_CHARS_NT for char in path): + error_msg = "Path contains invalid characters for this operating system" + elif "\0" in path: # Unix-like + error_msg = "Path contains invalid characters for this operating system" + else: + # Validate using pathlib to ensure it's a valid path structure + try: + Path(path) + except (ValueError, OSError) as e: + error_msg = f"Invalid path format: {e!s}" + + if error_msg: + return False, error_msg + return True, "" diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 69a69f113..8be2c3b03 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -23,6 +23,7 @@ get_valid_subdirs, is_valid_pyproject_toml, ) +from codeflash.code_utils.code_utils import validate_relative_directory_path from codeflash.code_utils.git_utils import git_root_dir from codeflash.code_utils.git_worktree_utils import create_worktree_snapshot_commit from codeflash.code_utils.shell_utils import save_api_key_to_rc @@ -184,10 +185,47 @@ def write_config(params: WriteConfigParams) -> dict[str, any]: # the client provided a config path but it doesn't exist create_empty_pyproject_toml(cfg_file) + # Handle both dict and object access for config + def get_config_value(key: str, default: str = "") -> str: + if isinstance(cfg, dict): + return cfg.get(key, default) + return getattr(cfg, key, default) + + tests_root = get_config_value("tests_root", "") + # Validate tests_root path format and safety + if tests_root: + is_valid, error_msg = validate_relative_directory_path(tests_root) + if not is_valid: + return { + "status": "error", + "message": f"Invalid 'tests_root': {error_msg}", + "field_errors": {"tests_root": error_msg}, + } + # Validate tests_root directory exists if provided + base_dir = cfg_file.parent if cfg_file else Path.cwd() + tests_root_path = (base_dir / tests_root).resolve() + if not tests_root_path.exists() or not tests_root_path.is_dir(): + return { + "status": "error", + "message": f"Invalid 'tests_root': directory does not exist at {tests_root_path}", + "field_errors": {"tests_root": f"Directory does not exist at {tests_root_path}"}, + } + + # Validate module_root path format and safety + module_root = get_config_value("module_root", "") + if module_root: + is_valid, error_msg = validate_relative_directory_path(module_root) + if not is_valid: + return { + "status": "error", + "message": f"Invalid 'module_root': {error_msg}", + "field_errors": {"module_root": error_msg}, + } + setup_info = VsCodeSetupInfo( - module_root=getattr(cfg, "module_root", ""), - tests_root=getattr(cfg, "tests_root", ""), - formatter=get_formatter_cmds(getattr(cfg, "formatter_cmds", "disabled")), + module_root=module_root, + tests_root=tests_root, + formatter=get_formatter_cmds(get_config_value("formatter_cmds", "disabled")), ) devnull_writer = open(os.devnull, "w") # noqa