From ab2252d2281f5e8782408af050c217025825abe8 Mon Sep 17 00:00:00 2001 From: Dominikus Gierlach Date: Mon, 17 Nov 2025 20:41:18 +0100 Subject: [PATCH 1/2] feat: add scikit-learn based index --- pyproject.toml | 3 +- .../extensions/index/scikit_index.py | 126 ++++++++++++++++++ uv.lock | 90 +++++++++++++ 3 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 src/graph_sitter/extensions/index/scikit_index.py diff --git a/pyproject.toml b/pyproject.toml index f2c0e652f..7d23f1175 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "hatch-vcs>=0.4.0", "hatchling>=1.25.0", "pyinstrument>=5.0.0", - "pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work! + "pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work! "rich-click>=1.8.5", "python-dotenv>=1.0.1", "giturlparse", @@ -70,6 +70,7 @@ dependencies = [ "datasets", "colorlog>=6.9.0", "codegen-sdk-pink>=0.1.0", + "scikit-learn>=1.7.2", ] # renovate: datasource=python-version depName=python diff --git a/src/graph_sitter/extensions/index/scikit_index.py b/src/graph_sitter/extensions/index/scikit_index.py new file mode 100644 index 000000000..2dd89963d --- /dev/null +++ b/src/graph_sitter/extensions/index/scikit_index.py @@ -0,0 +1,126 @@ +"""File-level semantic code search index using scikit-learn.""" + +import pickle +from pathlib import Path +from typing import Any, override + +from sklearn.feature_extraction.text import TfidfVectorizer + +from graph_sitter.core.codebase import Codebase +from graph_sitter.extensions.index.code_index import CodeIndex + + +class ScikitCodeIndex(CodeIndex): + """Local code index using TF-IDF vectorization for semantic search. + + Chis CodeIndex implementation builds a local vector database with scikit, not requiring openai api access. + """ + + def __init__(self, codebase: Codebase, vectorizer: TfidfVectorizer | None = None) -> None: + super().__init__(codebase) + if vectorizer: + self.vectorizer = vectorizer + else: + self.vectorizer: TfidfVectorizer = TfidfVectorizer(stop_words="english", max_features=5000, ngram_range=(1, 2)) + self._fitted: bool = False + + @property + @override + def save_file_name(self) -> str: + return "local_index_{commit}.pkl" + + @override + def _get_embeddings(self, items: list[Any]) -> list[list[float]]: + """Get TF-IDF embeddings for content.""" + if not self._fitted: + all_items = [content for _, content in self._get_items_to_index()] + if all_items: + _ = self.vectorizer.fit(all_items) + self._fitted = True + + if not items: + return [] + + # Extract content strings from items if they are tuples + content_items = [] + for item in items: + if isinstance(item, tuple) and len(item) >= 2: + content_items.append(item[1]) # Get content from tuple + elif isinstance(item, str): + content_items.append(item) + else: + content_items.append(str(item)) + + vectors = self.vectorizer.transform(content_items) + return vectors.toarray().tolist() # pyright: ignore [reportAttributeAccessIssue] + + @override + def _get_items_to_index(self) -> list[tuple[Any, str]]: + """Get all files and their content.""" + items = [] + for file in self.codebase.files(): + try: + content = file.content + if content.strip(): # Only index non-empty files + items.append((file, content)) + # pylint: disable-next=broad-exception-caught, can't do a lot anyways here + except Exception: + continue # Skip files that can't be read + return items + + @override + def _get_changed_items(self) -> set[Any]: + """Get files that have changed since last commit.""" + if not self.commit_hash: + return set() + + changed = set() + try: + current_commit = self._get_current_commit() + if current_commit != self.commit_hash: + # For simplicity, consider all files as potentially changed + changed = set(self.codebase.files()) + # pylint: disable-next=broad-exception-caught, can't do a lot anyways here + except Exception: + pass + + return changed + + @override + def _save_index(self, path: Path) -> None: + """Save index data to disk.""" + data = { + "E": self.E, + "items": self.items, + "commit_hash": self.commit_hash, + "vectorizer": self.vectorizer, + "fitted": self._fitted, + } + with open(path, "wb") as f: + pickle.dump(data, f) + + @override + def _load_index(self, path: Path) -> None: + """Load index data from disk.""" + with open(path, "rb") as f: + data = pickle.load(f) + + self.E = data["E"] + self.items = data["items"] + self.commit_hash = data["commit_hash"] + self.vectorizer = data["vectorizer"] + self._fitted = data["fitted"] + + @override + def similarity_search(self, query: str, k: int = 5) -> list[tuple[Any, float]]: + """Find the k most similar files to a query.""" + raw_results = self._similarity_search_raw(query, k) + + results = [] + for item_str, score in raw_results: + for file in self.codebase.files(): + if str(file) == item_str: + results.append((file, score)) + break + + return results diff --git a/uv.lock b/uv.lock index 115df0d22..d87ae859d 100644 --- a/uv.lock +++ b/uv.lock @@ -1180,6 +1180,7 @@ dependencies = [ { name = "rich" }, { name = "rich-click" }, { name = "rustworkx" }, + { name = "scikit-learn" }, { name = "sentry-sdk" }, { name = "starlette" }, { name = "tabulate" }, @@ -1307,6 +1308,7 @@ requires-dist = [ { name = "rich", specifier = ">=13.7.1,<14.0.0" }, { name = "rich-click", specifier = ">=1.8.5" }, { name = "rustworkx", specifier = ">=0.15.1" }, + { name = "scikit-learn" }, { name = "sentry-sdk", specifier = "==2.41.0" }, { name = "starlette", specifier = ">=0.16.0,<1.0.0" }, { name = "tabulate", specifier = ">=0.9.0,<1.0.0" }, @@ -1758,6 +1760,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/47/3729f00f35a696e68da15d64eb9283c330e776f3b5789bac7f2c0c4df209/jiter-0.9.0-cp313-cp313t-win_amd64.whl", hash = "sha256:6f7838bc467ab7e8ef9f387bd6de195c43bad82a569c1699cb822f6609dd4cdf", size = 206867, upload-time = "2025-03-10T21:36:25.843Z" }, ] +[[package]] +name = "joblib" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077, upload-time = "2025-08-27T12:15:46.575Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, +] + [[package]] name = "jsbeautifier" version = "1.15.4" @@ -3858,6 +3869,76 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3e/79/9bdd52d2a33d468c81c1827de1b588080cb055d1d3561b194ab7bf2635b5/rustworkx-0.16.0-cp39-abi3-win_amd64.whl", hash = "sha256:905df608843c32fa45ac023687769fe13056edf7584474c801d5c50705d76e9b", size = 1953559, upload-time = "2025-01-24T01:22:06.136Z" }, ] +[[package]] +name = "scikit-learn" +version = "1.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/c2/a7855e41c9d285dfe86dc50b250978105dce513d6e459ea66a6aeb0e1e0c/scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda", size = 7193136, upload-time = "2025-09-09T08:21:29.075Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/aa/3996e2196075689afb9fce0410ebdb4a09099d7964d061d7213700204409/scikit_learn-1.7.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8d91a97fa2b706943822398ab943cde71858a50245e31bc71dba62aab1d60a96", size = 9259818, upload-time = "2025-09-09T08:20:43.19Z" }, + { url = "https://files.pythonhosted.org/packages/43/5d/779320063e88af9c4a7c2cf463ff11c21ac9c8bd730c4a294b0000b666c9/scikit_learn-1.7.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:acbc0f5fd2edd3432a22c69bed78e837c70cf896cd7993d71d51ba6708507476", size = 8636997, upload-time = "2025-09-09T08:20:45.468Z" }, + { url = "https://files.pythonhosted.org/packages/5c/d0/0c577d9325b05594fdd33aa970bf53fb673f051a45496842caee13cfd7fe/scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e5bf3d930aee75a65478df91ac1225ff89cd28e9ac7bd1196853a9229b6adb0b", size = 9478381, upload-time = "2025-09-09T08:20:47.982Z" }, + { url = "https://files.pythonhosted.org/packages/82/70/8bf44b933837ba8494ca0fc9a9ab60f1c13b062ad0197f60a56e2fc4c43e/scikit_learn-1.7.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4d6e9deed1a47aca9fe2f267ab8e8fe82ee20b4526b2c0cd9e135cea10feb44", size = 9300296, upload-time = "2025-09-09T08:20:50.366Z" }, + { url = "https://files.pythonhosted.org/packages/c6/99/ed35197a158f1fdc2fe7c3680e9c70d0128f662e1fee4ed495f4b5e13db0/scikit_learn-1.7.2-cp312-cp312-win_amd64.whl", hash = "sha256:6088aa475f0785e01bcf8529f55280a3d7d298679f50c0bb70a2364a82d0b290", size = 8731256, upload-time = "2025-09-09T08:20:52.627Z" }, + { url = "https://files.pythonhosted.org/packages/ae/93/a3038cb0293037fd335f77f31fe053b89c72f17b1c8908c576c29d953e84/scikit_learn-1.7.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b7dacaa05e5d76759fb071558a8b5130f4845166d88654a0f9bdf3eb57851b7", size = 9212382, upload-time = "2025-09-09T08:20:54.731Z" }, + { url = "https://files.pythonhosted.org/packages/40/dd/9a88879b0c1104259136146e4742026b52df8540c39fec21a6383f8292c7/scikit_learn-1.7.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:abebbd61ad9e1deed54cca45caea8ad5f79e1b93173dece40bb8e0c658dbe6fe", size = 8592042, upload-time = "2025-09-09T08:20:57.313Z" }, + { url = "https://files.pythonhosted.org/packages/46/af/c5e286471b7d10871b811b72ae794ac5fe2989c0a2df07f0ec723030f5f5/scikit_learn-1.7.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:502c18e39849c0ea1a5d681af1dbcf15f6cce601aebb657aabbfe84133c1907f", size = 9434180, upload-time = "2025-09-09T08:20:59.671Z" }, + { url = "https://files.pythonhosted.org/packages/f1/fd/df59faa53312d585023b2da27e866524ffb8faf87a68516c23896c718320/scikit_learn-1.7.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a4c328a71785382fe3fe676a9ecf2c86189249beff90bf85e22bdb7efaf9ae0", size = 9283660, upload-time = "2025-09-09T08:21:01.71Z" }, + { url = "https://files.pythonhosted.org/packages/a7/c7/03000262759d7b6f38c836ff9d512f438a70d8a8ddae68ee80de72dcfb63/scikit_learn-1.7.2-cp313-cp313-win_amd64.whl", hash = "sha256:63a9afd6f7b229aad94618c01c252ce9e6fa97918c5ca19c9a17a087d819440c", size = 8702057, upload-time = "2025-09-09T08:21:04.234Z" }, + { url = "https://files.pythonhosted.org/packages/55/87/ef5eb1f267084532c8e4aef98a28b6ffe7425acbfd64b5e2f2e066bc29b3/scikit_learn-1.7.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9acb6c5e867447b4e1390930e3944a005e2cb115922e693c08a323421a6966e8", size = 9558731, upload-time = "2025-09-09T08:21:06.381Z" }, + { url = "https://files.pythonhosted.org/packages/93/f8/6c1e3fc14b10118068d7938878a9f3f4e6d7b74a8ddb1e5bed65159ccda8/scikit_learn-1.7.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:2a41e2a0ef45063e654152ec9d8bcfc39f7afce35b08902bfe290c2498a67a6a", size = 9038852, upload-time = "2025-09-09T08:21:08.628Z" }, + { url = "https://files.pythonhosted.org/packages/83/87/066cafc896ee540c34becf95d30375fe5cbe93c3b75a0ee9aa852cd60021/scikit_learn-1.7.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98335fb98509b73385b3ab2bd0639b1f610541d3988ee675c670371d6a87aa7c", size = 9527094, upload-time = "2025-09-09T08:21:11.486Z" }, + { url = "https://files.pythonhosted.org/packages/9c/2b/4903e1ccafa1f6453b1ab78413938c8800633988c838aa0be386cbb33072/scikit_learn-1.7.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:191e5550980d45449126e23ed1d5e9e24b2c68329ee1f691a3987476e115e09c", size = 9367436, upload-time = "2025-09-09T08:21:13.602Z" }, + { url = "https://files.pythonhosted.org/packages/b5/aa/8444be3cfb10451617ff9d177b3c190288f4563e6c50ff02728be67ad094/scikit_learn-1.7.2-cp313-cp313t-win_amd64.whl", hash = "sha256:57dc4deb1d3762c75d685507fbd0bc17160144b2f2ba4ccea5dc285ab0d0e973", size = 9275749, upload-time = "2025-09-09T08:21:15.96Z" }, +] + +[[package]] +name = "scipy" +version = "1.16.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/ca/d8ace4f98322d01abcd52d381134344bf7b431eba7ed8b42bdea5a3c2ac9/scipy-1.16.3.tar.gz", hash = "sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb", size = 30597883, upload-time = "2025-10-28T17:38:54.068Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/41/5bf55c3f386b1643812f3a5674edf74b26184378ef0f3e7c7a09a7e2ca7f/scipy-1.16.3-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6", size = 36659043, upload-time = "2025-10-28T17:32:40.285Z" }, + { url = "https://files.pythonhosted.org/packages/1e/0f/65582071948cfc45d43e9870bf7ca5f0e0684e165d7c9ef4e50d783073eb/scipy-1.16.3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07", size = 28898986, upload-time = "2025-10-28T17:32:45.325Z" }, + { url = "https://files.pythonhosted.org/packages/96/5e/36bf3f0ac298187d1ceadde9051177d6a4fe4d507e8f59067dc9dd39e650/scipy-1.16.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9", size = 20889814, upload-time = "2025-10-28T17:32:49.277Z" }, + { url = "https://files.pythonhosted.org/packages/80/35/178d9d0c35394d5d5211bbff7ac4f2986c5488b59506fef9e1de13ea28d3/scipy-1.16.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686", size = 23565795, upload-time = "2025-10-28T17:32:53.337Z" }, + { url = "https://files.pythonhosted.org/packages/fa/46/d1146ff536d034d02f83c8afc3c4bab2eddb634624d6529a8512f3afc9da/scipy-1.16.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203", size = 33349476, upload-time = "2025-10-28T17:32:58.353Z" }, + { url = "https://files.pythonhosted.org/packages/79/2e/415119c9ab3e62249e18c2b082c07aff907a273741b3f8160414b0e9193c/scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1", size = 35676692, upload-time = "2025-10-28T17:33:03.88Z" }, + { url = "https://files.pythonhosted.org/packages/27/82/df26e44da78bf8d2aeaf7566082260cfa15955a5a6e96e6a29935b64132f/scipy-1.16.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe", size = 36019345, upload-time = "2025-10-28T17:33:09.773Z" }, + { url = "https://files.pythonhosted.org/packages/82/31/006cbb4b648ba379a95c87262c2855cd0d09453e500937f78b30f02fa1cd/scipy-1.16.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70", size = 38678975, upload-time = "2025-10-28T17:33:15.809Z" }, + { url = "https://files.pythonhosted.org/packages/c2/7f/acbd28c97e990b421af7d6d6cd416358c9c293fc958b8529e0bd5d2a2a19/scipy-1.16.3-cp312-cp312-win_amd64.whl", hash = "sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc", size = 38555926, upload-time = "2025-10-28T17:33:21.388Z" }, + { url = "https://files.pythonhosted.org/packages/ce/69/c5c7807fd007dad4f48e0a5f2153038dc96e8725d3345b9ee31b2b7bed46/scipy-1.16.3-cp312-cp312-win_arm64.whl", hash = "sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2", size = 25463014, upload-time = "2025-10-28T17:33:25.975Z" }, + { url = "https://files.pythonhosted.org/packages/72/f1/57e8327ab1508272029e27eeef34f2302ffc156b69e7e233e906c2a5c379/scipy-1.16.3-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c", size = 36617856, upload-time = "2025-10-28T17:33:31.375Z" }, + { url = "https://files.pythonhosted.org/packages/44/13/7e63cfba8a7452eb756306aa2fd9b37a29a323b672b964b4fdeded9a3f21/scipy-1.16.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d", size = 28874306, upload-time = "2025-10-28T17:33:36.516Z" }, + { url = "https://files.pythonhosted.org/packages/15/65/3a9400efd0228a176e6ec3454b1fa998fbbb5a8defa1672c3f65706987db/scipy-1.16.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9", size = 20865371, upload-time = "2025-10-28T17:33:42.094Z" }, + { url = "https://files.pythonhosted.org/packages/33/d7/eda09adf009a9fb81827194d4dd02d2e4bc752cef16737cc4ef065234031/scipy-1.16.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4", size = 23524877, upload-time = "2025-10-28T17:33:48.483Z" }, + { url = "https://files.pythonhosted.org/packages/7d/6b/3f911e1ebc364cb81320223a3422aab7d26c9c7973109a9cd0f27c64c6c0/scipy-1.16.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959", size = 33342103, upload-time = "2025-10-28T17:33:56.495Z" }, + { url = "https://files.pythonhosted.org/packages/21/f6/4bfb5695d8941e5c570a04d9fcd0d36bce7511b7d78e6e75c8f9791f82d0/scipy-1.16.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88", size = 35697297, upload-time = "2025-10-28T17:34:04.722Z" }, + { url = "https://files.pythonhosted.org/packages/04/e1/6496dadbc80d8d896ff72511ecfe2316b50313bfc3ebf07a3f580f08bd8c/scipy-1.16.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234", size = 36021756, upload-time = "2025-10-28T17:34:13.482Z" }, + { url = "https://files.pythonhosted.org/packages/fe/bd/a8c7799e0136b987bda3e1b23d155bcb31aec68a4a472554df5f0937eef7/scipy-1.16.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d", size = 38696566, upload-time = "2025-10-28T17:34:22.384Z" }, + { url = "https://files.pythonhosted.org/packages/cd/01/1204382461fcbfeb05b6161b594f4007e78b6eba9b375382f79153172b4d/scipy-1.16.3-cp313-cp313-win_amd64.whl", hash = "sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304", size = 38529877, upload-time = "2025-10-28T17:35:51.076Z" }, + { url = "https://files.pythonhosted.org/packages/7f/14/9d9fbcaa1260a94f4bb5b64ba9213ceb5d03cd88841fe9fd1ffd47a45b73/scipy-1.16.3-cp313-cp313-win_arm64.whl", hash = "sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2", size = 25455366, upload-time = "2025-10-28T17:35:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/e2/a3/9ec205bd49f42d45d77f1730dbad9ccf146244c1647605cf834b3a8c4f36/scipy-1.16.3-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b", size = 37027931, upload-time = "2025-10-28T17:34:31.451Z" }, + { url = "https://files.pythonhosted.org/packages/25/06/ca9fd1f3a4589cbd825b1447e5db3a8ebb969c1eaf22c8579bd286f51b6d/scipy-1.16.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079", size = 29400081, upload-time = "2025-10-28T17:34:39.087Z" }, + { url = "https://files.pythonhosted.org/packages/6a/56/933e68210d92657d93fb0e381683bc0e53a965048d7358ff5fbf9e6a1b17/scipy-1.16.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a", size = 21391244, upload-time = "2025-10-28T17:34:45.234Z" }, + { url = "https://files.pythonhosted.org/packages/a8/7e/779845db03dc1418e215726329674b40576879b91814568757ff0014ad65/scipy-1.16.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119", size = 23929753, upload-time = "2025-10-28T17:34:51.793Z" }, + { url = "https://files.pythonhosted.org/packages/4c/4b/f756cf8161d5365dcdef9e5f460ab226c068211030a175d2fc7f3f41ca64/scipy-1.16.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c", size = 33496912, upload-time = "2025-10-28T17:34:59.8Z" }, + { url = "https://files.pythonhosted.org/packages/09/b5/222b1e49a58668f23839ca1542a6322bb095ab8d6590d4f71723869a6c2c/scipy-1.16.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e", size = 35802371, upload-time = "2025-10-28T17:35:08.173Z" }, + { url = "https://files.pythonhosted.org/packages/c1/8d/5964ef68bb31829bde27611f8c9deeac13764589fe74a75390242b64ca44/scipy-1.16.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135", size = 36190477, upload-time = "2025-10-28T17:35:16.7Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f2/b31d75cb9b5fa4dd39a0a931ee9b33e7f6f36f23be5ef560bf72e0f92f32/scipy-1.16.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6", size = 38796678, upload-time = "2025-10-28T17:35:26.354Z" }, + { url = "https://files.pythonhosted.org/packages/b4/1e/b3723d8ff64ab548c38d87055483714fefe6ee20e0189b62352b5e015bb1/scipy-1.16.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc", size = 38640178, upload-time = "2025-10-28T17:35:35.304Z" }, + { url = "https://files.pythonhosted.org/packages/8e/f3/d854ff38789aca9b0cc23008d607ced9de4f7ab14fa1ca4329f86b3758ca/scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a", size = 25803246, upload-time = "2025-10-28T17:35:42.155Z" }, +] + [[package]] name = "send2trash" version = "1.8.3" @@ -4075,6 +4156,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/a5/c0b6468d3824fe3fde30dbb5e1f687b291608f9473681bbf7dabbf5a87d7/text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8", size = 78154, upload-time = "2019-08-30T21:37:03.543Z" }, ] +[[package]] +name = "threadpoolctl" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, +] + [[package]] name = "tiktoken" version = "0.9.0" From 51b935bdeb6a3c68f1ba4af27628cfc3d8b06ae6 Mon Sep 17 00:00:00 2001 From: Dominikus Gierlach Date: Mon, 17 Nov 2025 20:53:42 +0100 Subject: [PATCH 2/2] test: add scikit tests --- tests/integration/test_scikit_index.py | 137 +++++++++++++++++++++++++ uv.lock | 2 +- 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 tests/integration/test_scikit_index.py diff --git a/tests/integration/test_scikit_index.py b/tests/integration/test_scikit_index.py new file mode 100644 index 000000000..c82a52841 --- /dev/null +++ b/tests/integration/test_scikit_index.py @@ -0,0 +1,137 @@ +from pathlib import Path + +import numpy as np +import pytest + +from graph_sitter.codebase.factory.get_session import get_codebase_session +from graph_sitter.extensions.index.scikit_index import ScikitCodeIndex + + +def test_scikit_index_lifecycle(tmpdir) -> None: + # language=python + content1 = """ +def hello(): + print("Hello, world!") + +def goodbye(): + print("Goodbye, world!") +""" + + # language=python + content2 = """ +def greet(name: str): + print(f"Hi {name}!") +""" + + with get_codebase_session(tmpdir=tmpdir, files={"greetings.py": content1, "hello.py": content2}) as codebase: + # Test construction and initial indexing + index = ScikitCodeIndex(codebase=codebase) + index.create() + + # Verify initial state + assert index.E is not None + assert index.items is not None + assert len(index.items) == 2 # Both files should be indexed + assert index.commit_hash is not None + + # Test similarity search + results = index.similarity_search("greeting someone", k=2) + assert len(results) == 2 + # The greet function should be most relevant to greeting + assert any("hello.py" in file.filepath for file, _ in results) + + # Test saving + save_dir = Path(tmpdir) / ".codegen" + index.save() + assert save_dir.exists() + saved_files = list(save_dir.glob("file_index_*.pkl")) + assert len(saved_files) == 1 + + # Test loading + new_index = FileIndex(codebase) + new_index.load(saved_files[0]) + assert np.array_equal(index.E, new_index.E) + assert np.array_equal(index.items, new_index.items) + assert index.commit_hash == new_index.commit_hash + + # Test updating after file changes + # Add a new function to greetings.py + greetings_file = codebase.get_file("greetings.py") + new_content = greetings_file.content + "\n\ndef welcome():\n print('Welcome!')\n" + greetings_file.edit(new_content) + + # Update the index + index.update() + + # Verify the update + assert len(index.items) >= 2 # Should have at least the original files + + # Search for the new content + results = index.similarity_search("welcome message", k=2) + assert len(results) == 2 + # The updated greetings.py should be relevant now + assert any("greetings.py" in file.filepath for file, _ in results) + + +def test_file_index_empty_file(tmpdir) -> None: + """Test that the file index handles empty files gracefully.""" + with get_codebase_session(tmpdir=tmpdir, files={"empty.py": ""}) as codebase: + index = FileIndex(codebase) + index.create() + assert len(index.items) == 0 # Empty file should be skipped + + +def test_file_index_large_file(tmpdir) -> None: + """Test that the file index handles files larger than the token limit.""" + # Create a large file by repeating a simple function many times + large_content = "def f():\n print('test')\n\n" * 10000 + + with get_codebase_session(tmpdir=tmpdir, files={"large.py": large_content}) as codebase: + index = FileIndex(codebase) + index.create() + + # Should have multiple chunks for the large file + assert len([item for item in index.items if "large.py" in item]) > 1 + + # Test searching in large file + results = index.similarity_search("function that prints test", k=1) + assert len(results) == 1 + assert "large.py" in results[0][0].filepath + + +def test_file_index_invalid_operations(tmpdir) -> None: + """Test that the file index properly handles invalid operations.""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": "print('test')"}) as codebase: + index = FileIndex(codebase) + + # Test searching before creating index + with pytest.raises(ValueError, match="No embeddings available"): + index.similarity_search("test") + + # Test saving before creating index + with pytest.raises(ValueError, match="No embeddings to save"): + index.save() + + # Test updating before creating index + with pytest.raises(ValueError, match="No index to update"): + index.update() + + # Test loading from non-existent path + with pytest.raises(FileNotFoundError): + index.load("nonexistent.pkl") + + +def test_file_index_binary_files(tmpdir) -> None: + """Test that the file index properly handles binary files.""" + # Create a binary file + binary_content = bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) # PNG header + binary_path = Path(tmpdir) / "test.png" + binary_path.write_bytes(binary_content) + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": "print('test')", "test.png": binary_content}) as codebase: + index = FileIndex(codebase) + index.create() + + # Should only index the Python file + assert len(index.items) == 1 + assert all("test.py" in item for item in index.items) diff --git a/uv.lock b/uv.lock index d87ae859d..c3f7cbdab 100644 --- a/uv.lock +++ b/uv.lock @@ -1308,7 +1308,7 @@ requires-dist = [ { name = "rich", specifier = ">=13.7.1,<14.0.0" }, { name = "rich-click", specifier = ">=1.8.5" }, { name = "rustworkx", specifier = ">=0.15.1" }, - { name = "scikit-learn" }, + { name = "scikit-learn", specifier = ">=1.7.2" }, { name = "sentry-sdk", specifier = "==2.41.0" }, { name = "starlette", specifier = ">=0.16.0,<1.0.0" }, { name = "tabulate", specifier = ">=0.9.0,<1.0.0" },