diff --git a/rust/Cargo.lock b/rust/Cargo.lock index a00f0b1b..93960b8d 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -7,6 +7,7 @@ name = "_rustgrimp" version = "0.1.0" dependencies = [ "bimap", + "bincode", "const_format", "derive-new", "encoding_rs", @@ -52,6 +53,26 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230c5f1ca6a325a32553f8640d31ac9b49f2411e901e427570154868b46da4f7" +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bitflags" version = "2.9.4" @@ -833,12 +854,24 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "version_check" version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 4eb4e139..302b434a 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -27,6 +27,7 @@ ruff_source_file = { git = "https://github.com/astral-sh/ruff.git", tag = "v0.4. serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.137" serde_yaml = "0.9" +bincode = "2.0.1" unindent = "0.2.4" encoding_rs = "0.8.35" diff --git a/rust/src/caching.rs b/rust/src/caching.rs index cb00b417..00dcbc49 100644 --- a/rust/src/caching.rs +++ b/rust/src/caching.rs @@ -2,10 +2,36 @@ use crate::errors::{GrimpError, GrimpResult}; use crate::filesystem::get_file_system_boxed; use crate::import_scanning::{DirectImport, imports_by_module_to_py}; use crate::module_finding::Module; -use pyo3::types::PyDict; +use pyo3::types::PyAnyMethods; +use pyo3::types::{PyDict, PySet}; +use pyo3::types::{PyDictMethods, PySetMethods}; use pyo3::{Bound, PyAny, PyResult, Python, pyfunction}; use std::collections::{HashMap, HashSet}; +/// Writes the cache file containing all the imports for a given package. +/// Args: +/// - filename: str +/// - imports_by_module: dict[Module, Set[DirectImport]] +/// - file_system: The file system interface to use. (A BasicFileSystem.) +#[pyfunction] +pub fn write_cache_data_map_file<'py>( + _py: Python<'py>, + filename: &str, + imports_by_module: Bound<'py, PyDict>, + file_system: Bound<'py, PyAny>, +) -> PyResult<()> { + let mut file_system_boxed = get_file_system_boxed(&file_system)?; + + let imports_by_module_rust = imports_by_module_to_rust(imports_by_module); + + let file_contents = serialize_imports_by_module(&imports_by_module_rust) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to serialize: {}", e)))?; + + file_system_boxed.write_bytes(filename, &file_contents)?; + + Ok(()) +} + /// Reads the cache file containing all the imports for a given package. /// Args: /// - filename: str @@ -19,19 +45,70 @@ pub fn read_cache_data_map_file<'py>( ) -> PyResult> { let file_system_boxed = get_file_system_boxed(&file_system)?; - let file_contents = file_system_boxed.read(filename)?; + let file_contents = file_system_boxed.read_bytes(filename)?; - let imports_by_module = parse_json_to_map(&file_contents, filename)?; + let imports_by_module = parse_bincode_to_map(&file_contents, filename)?; Ok(imports_by_module_to_py(py, imports_by_module)) } -pub fn parse_json_to_map( - json_str: &str, +#[allow(unused_variables)] +fn imports_by_module_to_rust( + imports_by_module_py: Bound, +) -> HashMap> { + let mut imports_by_module_rust = HashMap::new(); + + for (py_key, py_value) in imports_by_module_py.iter() { + let module: Module = py_key.extract().unwrap(); + let py_set = py_value + .downcast::() + .expect("Expected value to be a Python set."); + let mut hashset: HashSet = HashSet::new(); + for element in py_set.iter() { + let direct_import: DirectImport = element + .extract() + .expect("Expected value to be DirectImport."); + hashset.insert(direct_import); + } + imports_by_module_rust.insert(module, hashset); + } + + imports_by_module_rust +} + +#[allow(unused_variables)] +fn serialize_imports_by_module( + imports_by_module: &HashMap>, +) -> Result, bincode::error::EncodeError> { + let raw_map: HashMap<&str, Vec<(&str, usize, &str)>> = imports_by_module + .iter() + .map(|(module, imports)| { + let imports_vec: Vec<(&str, usize, &str)> = imports + .iter() + .map(|import| { + ( + import.imported.as_str(), + import.line_number, + import.line_contents.as_str(), + ) + }) + .collect(); + (module.name.as_str(), imports_vec) + }) + .collect(); + + let config = bincode::config::standard(); + bincode::encode_to_vec(&raw_map, config) +} + +pub fn parse_bincode_to_map( + bytes: &[u8], filename: &str, ) -> GrimpResult>> { - let raw_map: HashMap> = serde_json::from_str(json_str) - .map_err(|_| GrimpError::CorruptCache(filename.to_string()))?; + let config = bincode::config::standard(); + let (raw_map, _): (HashMap>, usize) = + bincode::decode_from_slice(bytes, config) + .map_err(|_| GrimpError::CorruptCache(filename.to_string()))?; let mut parsed_map: HashMap> = HashMap::new(); diff --git a/rust/src/filesystem.rs b/rust/src/filesystem.rs index 4b54781c..0b5c6cdf 100644 --- a/rust/src/filesystem.rs +++ b/rust/src/filesystem.rs @@ -5,8 +5,10 @@ use regex::Regex; use std::collections::HashMap; use std::ffi::OsStr; use std::fs; +use std::fs::File; +use std::io::prelude::*; use std::path::{Path, PathBuf}; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock, Mutex}; use unindent::unindent; static ENCODING_RE: LazyLock = @@ -22,6 +24,12 @@ pub trait FileSystem: Send + Sync { fn exists(&self, file_name: &str) -> bool; fn read(&self, file_name: &str) -> PyResult; + + fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()>; + + fn read_bytes(&self, file_name: &str) -> PyResult>; + + fn write_bytes(&mut self, file_name: &str, contents: &[u8]) -> PyResult<()>; } #[derive(Clone)] @@ -129,6 +137,33 @@ impl FileSystem for RealBasicFileSystem { }) } } + + fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()> { + let file_path: PathBuf = file_name.into(); + if let Some(patent_dir) = file_path.parent() { + fs::create_dir_all(patent_dir)?; + } + let mut file = File::create(file_path)?; + file.write_all(contents.as_bytes())?; + Ok(()) + } + + fn read_bytes(&self, file_name: &str) -> PyResult> { + let path = Path::new(file_name); + fs::read(path).map_err(|e| { + PyFileNotFoundError::new_err(format!("Failed to read file {file_name}: {e}")) + }) + } + + fn write_bytes(&mut self, file_name: &str, contents: &[u8]) -> PyResult<()> { + let file_path: PathBuf = file_name.into(); + if let Some(patent_dir) = file_path.parent() { + fs::create_dir_all(patent_dir)?; + } + let mut file = File::create(file_path)?; + file.write_all(contents)?; + Ok(()) + } } #[pymethods] @@ -161,13 +196,27 @@ impl PyRealBasicFileSystem { fn read(&self, file_name: &str) -> PyResult { self.inner.read(file_name) } + + fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()> { + self.inner.write(file_name, contents) + } + + fn read_bytes(&self, file_name: &str) -> PyResult> { + self.inner.read_bytes(file_name) + } + + fn write_bytes(&mut self, file_name: &str, contents: Vec) -> PyResult<()> { + self.inner.write_bytes(file_name, &contents) + } } type FileSystemContents = HashMap; +type BinaryFileSystemContents = HashMap>; #[derive(Clone)] pub struct FakeBasicFileSystem { - contents: Box, + contents: Arc>, + binary_contents: Arc>, } // Implements BasicFileSystem (defined in grimp.application.ports.filesystem.BasicFileSystem). @@ -190,7 +239,8 @@ impl FakeBasicFileSystem { parsed_contents.extend(unindented_map); }; Ok(FakeBasicFileSystem { - contents: Box::new(parsed_contents), + contents: Arc::new(Mutex::new(parsed_contents)), + binary_contents: Arc::new(Mutex::new(HashMap::new())), }) } } @@ -232,17 +282,42 @@ impl FileSystem for FakeBasicFileSystem { /// Checks if a file or directory exists within the file system. fn exists(&self, file_name: &str) -> bool { - self.contents.contains_key(file_name) + self.contents.lock().unwrap().contains_key(file_name) + || self.binary_contents.lock().unwrap().contains_key(file_name) } fn read(&self, file_name: &str) -> PyResult { - match self.contents.get(file_name) { - Some(file_name) => Ok(file_name.clone()), + let contents = self.contents.lock().unwrap(); + match contents.get(file_name) { + Some(file_contents) => Ok(file_contents.clone()), + None => Err(PyFileNotFoundError::new_err(format!( + "No such file: {file_name}" + ))), + } + } + + #[allow(unused_variables)] + fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()> { + let mut contents_mut = self.contents.lock().unwrap(); + contents_mut.insert(file_name.to_string(), contents.to_string()); + Ok(()) + } + + fn read_bytes(&self, file_name: &str) -> PyResult> { + let binary_contents = self.binary_contents.lock().unwrap(); + match binary_contents.get(file_name) { + Some(file_contents) => Ok(file_contents.clone()), None => Err(PyFileNotFoundError::new_err(format!( "No such file: {file_name}" ))), } } + + fn write_bytes(&mut self, file_name: &str, contents: &[u8]) -> PyResult<()> { + let mut binary_contents_mut = self.binary_contents.lock().unwrap(); + binary_contents_mut.insert(file_name.to_string(), contents.to_vec()); + Ok(()) + } } #[pymethods] @@ -278,6 +353,18 @@ impl PyFakeBasicFileSystem { self.inner.read(file_name) } + fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()> { + self.inner.write(file_name, contents) + } + + fn read_bytes(&self, file_name: &str) -> PyResult> { + self.inner.read_bytes(file_name) + } + + fn write_bytes(&mut self, file_name: &str, contents: Vec) -> PyResult<()> { + self.inner.write_bytes(file_name, &contents) + } + // Temporary workaround method for Python tests. fn convert_to_basic(&self) -> PyResult { Ok(PyFakeBasicFileSystem { @@ -381,7 +468,6 @@ pub fn get_file_system_boxed<'py>( file_system: &Bound<'py, PyAny>, ) -> PyResult> { let file_system_boxed: Box; - if let Ok(py_real) = file_system.extract::>() { file_system_boxed = Box::new(py_real.inner.clone()); } else if let Ok(py_fake) = file_system.extract::>() { @@ -391,5 +477,6 @@ pub fn get_file_system_boxed<'py>( "file_system must be an instance of RealBasicFileSystem or FakeBasicFileSystem", )); } + Ok(file_system_boxed) } diff --git a/rust/src/import_scanning.rs b/rust/src/import_scanning.rs index 0719c035..045ce411 100644 --- a/rust/src/import_scanning.rs +++ b/rust/src/import_scanning.rs @@ -18,6 +18,22 @@ pub struct DirectImport { pub line_contents: String, } +impl<'py> FromPyObject<'py> for DirectImport { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let importer: String = ob.getattr("importer")?.getattr("name")?.extract()?; + let imported: String = ob.getattr("imported")?.getattr("name")?.extract()?; + let line_number: usize = ob.getattr("line_number")?.extract()?; + let line_contents: String = ob.getattr("line_contents")?.extract()?; + + Ok(DirectImport { + importer, + imported, + line_number, + line_contents, + }) + } +} + pub fn py_found_packages_to_rust(py_found_packages: &Bound<'_, PyAny>) -> HashSet { let py_set = py_found_packages .downcast::() diff --git a/rust/src/lib.rs b/rust/src/lib.rs index e0453f26..e63b501b 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -18,6 +18,9 @@ mod _rustgrimp { #[pymodule_export] use crate::caching::read_cache_data_map_file; + #[pymodule_export] + use crate::caching::write_cache_data_map_file; + #[pymodule_export] use crate::graph::GraphWrapper; diff --git a/src/grimp/adaptors/caching.py b/src/grimp/adaptors/caching.py index 13fe9167..00d97292 100644 --- a/src/grimp/adaptors/caching.py +++ b/src/grimp/adaptors/caching.py @@ -4,7 +4,7 @@ import logging from typing import Dict, List, Optional, Set, Tuple, Type -from grimp.application.ports.filesystem import AbstractFileSystem +from grimp.application.ports.filesystem import BasicFileSystem from grimp.application.ports.modulefinder import FoundPackage, ModuleFile from grimp.domain.valueobjects import DirectImport, Module @@ -77,7 +77,7 @@ def __init__(self, *args, namer: Type[CacheFileNamer], **kwargs) -> None: @classmethod def setup( cls, - file_system: AbstractFileSystem, + file_system: BasicFileSystem, found_packages: Set[FoundPackage], include_external_packages: bool, exclude_type_checking_imports: bool = False, @@ -122,22 +122,6 @@ def write( ) -> None: self._write_marker_files_if_not_already_there() # Write data file. - primitives_map: PrimitiveFormat = {} - for found_package in self.found_packages: - primitives_map_for_found_package: PrimitiveFormat = { - module_file.module.name: [ - ( - direct_import.imported.name, - direct_import.line_number, - direct_import.line_contents, - ) - for direct_import in imports_by_module[module_file.module] - ] - for module_file in found_package.module_files - } - primitives_map.update(primitives_map_for_found_package) - - serialized = json.dumps(primitives_map) data_cache_filename = self.file_system.join( self.cache_dir, self._namer.make_data_file_name( @@ -146,7 +130,12 @@ def write( exclude_type_checking_imports=self.exclude_type_checking_imports, ), ) - self.file_system.write(data_cache_filename, serialized) + rust.write_cache_data_map_file( + filename=data_cache_filename, + imports_by_module=imports_by_module, + file_system=self.file_system, + ) + logger.info(f"Wrote data cache file {data_cache_filename}.") # Write meta files. @@ -202,7 +191,7 @@ def _read_data_map_file(self) -> Dict[Module, Set[DirectImport]]: ) try: imports_by_module = rust.read_cache_data_map_file( - data_cache_filename, self.file_system.convert_to_basic() + data_cache_filename, self.file_system ) except FileNotFoundError: logger.info(f"No cache file: {data_cache_filename}.") diff --git a/src/grimp/application/ports/caching.py b/src/grimp/application/ports/caching.py index d3f58f9b..6d3aa2bd 100644 --- a/src/grimp/application/ports/caching.py +++ b/src/grimp/application/ports/caching.py @@ -3,7 +3,7 @@ from grimp.application.ports.modulefinder import FoundPackage, ModuleFile from grimp.domain.valueobjects import DirectImport, Module -from .filesystem import AbstractFileSystem +from .filesystem import BasicFileSystem class CacheMiss(Exception): @@ -13,7 +13,7 @@ class CacheMiss(Exception): class Cache: def __init__( self, - file_system: AbstractFileSystem, + file_system: BasicFileSystem, include_external_packages: bool, exclude_type_checking_imports: bool, found_packages: Set[FoundPackage], @@ -31,7 +31,7 @@ def __init__( @classmethod def setup( cls, - file_system: AbstractFileSystem, + file_system: BasicFileSystem, found_packages: Set[FoundPackage], *, include_external_packages: bool, diff --git a/src/grimp/application/ports/filesystem.py b/src/grimp/application/ports/filesystem.py index 5e966cc7..aa06e352 100644 --- a/src/grimp/application/ports/filesystem.py +++ b/src/grimp/application/ports/filesystem.py @@ -93,7 +93,7 @@ def convert_to_basic(self) -> BasicFileSystem: class BasicFileSystem(Protocol): """ - A more limited file system, used by the Rust-based scan_for_imports function. + A more limited file system. Having two different file system APIs is an interim approach, allowing us to implement BasicFileSystem in Rust without needing to implement the full range @@ -109,4 +109,6 @@ def split(self, file_name: str) -> Tuple[str, str]: ... def read(self, file_name: str) -> str: ... + def write(self, file_name: str, contents: str) -> None: ... + def exists(self, file_name: str) -> bool: ... diff --git a/src/grimp/application/usecases.py b/src/grimp/application/usecases.py index 45d581f8..ae6884b3 100644 --- a/src/grimp/application/usecases.py +++ b/src/grimp/application/usecases.py @@ -109,7 +109,7 @@ def _scan_packages( if cache_dir is not None: cache_dir_if_supplied = cache_dir if cache_dir != NotSupplied else None cache: caching.Cache = settings.CACHE_CLASS.setup( - file_system=file_system, + file_system=file_system.convert_to_basic(), found_packages=found_packages, include_external_packages=include_external_packages, exclude_type_checking_imports=exclude_type_checking_imports, diff --git a/tests/unit/adaptors/test_caching.py b/tests/unit/adaptors/test_caching.py index 749f14b3..f4cc4cf3 100644 --- a/tests/unit/adaptors/test_caching.py +++ b/tests/unit/adaptors/test_caching.py @@ -203,7 +203,7 @@ class TestCache: ] }""", }, - ) + ).convert_to_basic() MODULE_FILE_UNMODIFIED = ModuleFile( module=Module("mypackage.foo.unmodified"), mtime=SOME_MTIME ) @@ -250,7 +250,7 @@ def test_logs_missing_cache_files(self, caplog): caplog.set_level(logging.INFO, logger=Cache.__module__) Cache.setup( - file_system=FakeFileSystem(), # No cache files. + file_system=FakeFileSystem().convert_to_basic(), # No cache files. found_packages=self.FOUND_PACKAGES, namer=SimplisticFileNamer, include_external_packages=False, @@ -278,7 +278,7 @@ def test_logs_corrupt_cache_meta_file_reading(self, serialized_mtime: str, caplo }}""", ".grimp_cache/mypackage.data.json": "{}", }, - ) + ).convert_to_basic() Cache.setup( file_system=file_system, found_packages=self.FOUND_PACKAGES, @@ -306,7 +306,7 @@ def test_logs_corrupt_cache_data_file_reading(self, caplog): }}""", ".grimp_cache/mypackage.data.json": "INVALID JSON", }, - ) + ).convert_to_basic() Cache.setup( file_system=file_system, @@ -408,7 +408,7 @@ def test_raises_cache_miss_for_missing_module_from_data(self): }}""", ".grimp_cache/mypackage.data.json": """{}""", }, - ) + ).convert_to_basic() module_file = ModuleFile(module=Module("mypackage.somemodule"), mtime=self.SOME_MTIME) cache = Cache.setup( file_system=file_system, @@ -451,7 +451,7 @@ def test_raises_cache_miss_for_corrupt_meta_file(self, serialized_mtime): ] }""", }, - ) + ).convert_to_basic() cache = Cache.setup( file_system=file_system, found_packages=self.FOUND_PACKAGES, @@ -487,7 +487,7 @@ def test_raises_cache_miss_for_corrupt_data_file(self, serialized_import): ] }}""", }, - ) + ).convert_to_basic() cache = Cache.setup( file_system=file_system, found_packages=self.FOUND_PACKAGES, @@ -554,7 +554,14 @@ def test_uses_cache_multiple_packages( | expected_additional_imports ) - @pytest.mark.parametrize("cache_dir", ("/tmp/some-cache-dir", "/tmp/some-cache-dir/", None)) + @pytest.mark.parametrize( + "cache_dir", + ( + "/tmp/some-cache-dir", + "/tmp/some-cache-dir/", + None, + ), + ) @pytest.mark.parametrize( "include_external_packages, expected_data_file_name", ( @@ -566,7 +573,7 @@ def test_write_to_cache( self, include_external_packages, expected_data_file_name, cache_dir, caplog ): caplog.set_level(logging.INFO, logger=Cache.__module__) - file_system = FakeFileSystem() + file_system = FakeFileSystem().convert_to_basic() blue_one = Module(name="blue.one") blue_two = Module(name="blue.two") green_one = Module(name="green.one") @@ -673,7 +680,7 @@ def test_write_to_cache( def test_write_to_cache_adds_marker_files(self): some_cache_dir = "/tmp/some-cache-dir" - file_system = FakeFileSystem() + file_system = FakeFileSystem().convert_to_basic() cache = Cache.setup( file_system=file_system, cache_dir=some_cache_dir, diff --git a/tests/unit/adaptors/test_filesystem.py b/tests/unit/adaptors/test_filesystem.py index 68d86f45..117e63fd 100644 --- a/tests/unit/adaptors/test_filesystem.py +++ b/tests/unit/adaptors/test_filesystem.py @@ -148,6 +148,14 @@ def test_read(self, file_name, expected_contents): else: assert file_system.read(file_name) == expected_contents + def test_write(self): + some_filename, some_contents = "path/to/some-file.txt", "Some contents." + file_system = self.file_system_cls() + + file_system.write(some_filename, some_contents) + + assert file_system.read(some_filename) == some_contents + class TestFakeFileSystem(_Base): file_system_cls = FakeFileSystem