Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 26, 2025

📄 69% (0.69x) speedup for SparseVectorIndexConfig.validate_embedding_function_field in chromadb/api/types.py

⏱️ Runtime : 1.71 milliseconds 1.01 milliseconds (best of 102 runs)

📝 Explanation and details

The optimization introduces signature caching to eliminate redundant computation during validation. The key changes are:

  1. Cached protocol signature: The protocol signature signature(SparseEmbeddingFunction.__call__).parameters.keys() is computed once and cached globally, rather than being recomputed on every validation call.

  2. Tuple conversion for faster comparison: Both signatures are converted to tuples instead of comparing dict_keys objects directly, which provides faster equality comparison in Python.

  3. Lazy initialization: The protocol signature is computed only when first needed via _get_protocol_signature(), avoiding any import-time overhead.

Why this leads to speedup: The inspect.signature() function performs introspection on the method, which involves parsing the function's metadata. This is computationally expensive when done repeatedly. By caching the protocol signature (which never changes), we eliminate this repeated work. The tuple conversion also optimizes the comparison operation itself.

Test case performance patterns: The optimization shows consistent 30-40% speedups across all test cases that involve signature validation (e.g., test_valid_sparse_embedding_function: 38.9% faster, test_multiple_instances_large_scale: 90.3% faster). The most dramatic improvements occur in scenarios with multiple validations, where the caching benefit compounds. Simple cases like test_none_embedding_function show minimal impact since they bypass signature validation entirely.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 134 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from abc import abstractmethod
from inspect import signature
# Function under test (copied from above, with necessary dependencies)
from typing import Any, List, Optional, TypeVar, Union, cast

# imports
import pytest
from chromadb.api.types import SparseVectorIndexConfig
from typing_extensions import Protocol, runtime_checkable

# ===========================
# Unit Tests for the function
# ===========================

# ----------- Basic Test Cases -----------

def test_none_embedding_function():
    # Should accept None
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(None) # 478ns -> 501ns (4.59% slower)

def test_valid_sparse_embedding_function():
    # Should accept a valid callable with correct signature
    class ValidEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self, input: Documents) -> SparseVectors:
            return [{} for _ in input]
    func = ValidEmbeddingFunction()
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(func) # 30.9μs -> 22.2μs (38.9% faster)

def test_valid_lambda_function_with_correct_signature():
    # Should accept a lambda with correct signature
    func = lambda self, input: [{} for _ in input]
    # We need to bind it to an object to mimic __call__ method
    class LambdaEmbeddingFunction:
        def __init__(self):
            pass
        __call__ = func
    obj = LambdaEmbeddingFunction()
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(obj) # 24.9μs -> 17.5μs (42.5% faster)

def test_invalid_not_callable():
    # Should raise ValueError for non-callable non-None input
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field("not_callable") # 901ns -> 884ns (1.92% faster)

# ----------- Edge Test Cases -----------

def test_missing_input_parameter_in_call():
    # Should raise ValueError if __call__ signature is wrong (missing input)
    class InvalidEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self) -> SparseVectors:
            return [{}]
    func = InvalidEmbeddingFunction()
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field(func) # 28.6μs -> 20.8μs (37.7% faster)

def test_extra_parameter_in_call():
    # Should raise ValueError if __call__ signature is wrong (extra param)
    class InvalidEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self, input: Documents, extra: int) -> SparseVectors:
            return [{} for _ in input]
    func = InvalidEmbeddingFunction()
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field(func) # 28.1μs -> 21.3μs (31.7% faster)

def test_wrong_return_type():
    # Should accept function even if return type is not correct (type hints are not enforced at runtime)
    class WeirdEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self, input: Documents) -> int:
            return 42
    func = WeirdEmbeddingFunction()
    # The signature matches so it should pass, even though return type is wrong
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(func) # 23.9μs -> 17.1μs (39.7% faster)

def test_non_class_callable():
    # Should accept a callable object (not a class) with correct signature
    class CallableObj:
        def __init__(self):
            pass
        def __call__(self, input: Documents) -> SparseVectors:
            return [{} for _ in input]
    obj = CallableObj()
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(obj) # 23.6μs -> 16.9μs (39.4% faster)

def test_function_with_varargs():
    # Should raise ValueError if __call__ uses *args
    class VarArgsEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self, *args) -> SparseVectors:
            return [{}]
    func = VarArgsEmbeddingFunction()
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field(func) # 27.5μs -> 19.5μs (41.1% faster)

def test_function_with_kwargs():
    # Should raise ValueError if __call__ uses **kwargs
    class KwArgsEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self, input: Documents, **kwargs) -> SparseVectors:
            return [{} for _ in input]
    func = KwArgsEmbeddingFunction()
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field(func) # 27.9μs -> 20.8μs (34.0% faster)



def test_large_number_of_documents():
    # Should handle large input without error
    class ValidEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self, input: Documents) -> SparseVectors:
            return [{} for _ in input]
    func = ValidEmbeddingFunction()
    large_docs = ["doc" + str(i) for i in range(1000)]
    # Should not raise error
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(func) # 31.1μs -> 22.3μs (39.6% faster)

def test_large_number_of_sparse_vectors_returned():
    # Should handle large output without error
    class ValidEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self, input: Documents) -> SparseVectors:
            return [{} for _ in range(1000)]
    func = ValidEmbeddingFunction()
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(func) # 25.2μs -> 18.2μs (38.4% faster)

def test_multiple_instances_large_scale():
    # Should handle multiple instances of valid embedding functions
    class ValidEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self, input: Documents) -> SparseVectors:
            return [{} for _ in input]
    for _ in range(100):
        func = ValidEmbeddingFunction()
        codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(func) # 1.07ms -> 562μs (90.3% faster)

def test_large_scale_invalid_signature():
    # Should raise ValueError for many instances with invalid signature
    class InvalidEmbeddingFunction:
        def __init__(self):
            pass
        def __call__(self) -> SparseVectors:
            return [{}]
    for _ in range(10):
        func = InvalidEmbeddingFunction()
        with pytest.raises(ValueError):
            SparseVectorIndexConfig.validate_embedding_function_field(func)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from abc import abstractmethod
from inspect import signature
from typing import Any, List, Optional, TypeVar, Union, cast

# function to test (from chromadb/api/types.py)
import numpy as np
# imports
import pytest
from chromadb.api.types import SparseVectorIndexConfig
from pydantic import BaseModel, field_validator
from typing_extensions import Protocol, runtime_checkable

# Minimal stub for SparseVector and validate_sparse_vectors, since they are referenced but not defined
SparseVector = dict  # For testing, a dict is enough

# Documents
Document = str
Documents = List[Document]
from chromadb.api.types import SparseVectorIndexConfig

# --- UNIT TESTS BEGIN HERE ---

# Helper: a valid embedding function class (minimal)
class ValidSparseEmbeddingFunction:
    def __init__(self):
        pass
    def __call__(self, input: Documents) -> List[SparseVector]:
        # returns a list of dicts (SparseVector)
        return [{} for _ in input]

# Helper: a valid embedding function class with a different signature (should fail)
class InvalidSignatureEmbeddingFunction:
    def __init__(self):
        pass
    def __call__(self, input: Documents, extra: int) -> List[SparseVector]:
        return [{} for _ in input]

# Helper: not callable
class NotCallableEmbeddingFunction:
    pass

# Helper: returns None (should fail validate_sparse_vectors)
class ReturnsNoneEmbeddingFunction:
    def __init__(self):
        pass
    def __call__(self, input: Documents) -> List[SparseVector]:
        return None

# Helper: returns wrong type (should fail validate_sparse_vectors)
class ReturnsWrongTypeEmbeddingFunction:
    def __init__(self):
        pass
    def __call__(self, input: Documents) -> List[SparseVector]:
        return "notalist"

# Helper: returns list of non-dicts (should fail validate_sparse_vectors)
class ReturnsListOfNonDictsEmbeddingFunction:
    def __init__(self):
        pass
    def __call__(self, input: Documents) -> List[SparseVector]:
        return [1, 2, 3]

# -------------------- BASIC TEST CASES --------------------

def test_none_is_accepted():
    """Test that None is accepted and returned as is."""
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(None) # 510ns -> 573ns (11.0% slower)

def test_valid_callable_is_accepted():
    """Test that a valid callable embedding function is accepted and returned as is."""
    func = ValidSparseEmbeddingFunction()
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(func); result = codeflash_output # 31.6μs -> 22.5μs (40.5% faster)

def test_non_callable_raises():
    """Test that a non-callable value raises ValueError."""
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field(123) # 894ns -> 937ns (4.59% slower)
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field("not a function") # 499ns -> 495ns (0.808% faster)
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field(NotCallableEmbeddingFunction()) # 314ns -> 328ns (4.27% slower)

def test_invalid_signature_raises():
    """Test that a callable with wrong signature raises ValueError."""
    func = InvalidSignatureEmbeddingFunction()
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field(func) # 31.5μs -> 23.5μs (34.3% faster)

# -------------------- EDGE TEST CASES --------------------




def test_empty_list_returned_is_accepted():
    """Test that a callable returning an empty list is accepted."""
    class EmptyListEmbeddingFunction:
        def __init__(self): pass
        def __call__(self, input: Documents) -> List[SparseVector]:
            return []
    func = EmptyListEmbeddingFunction()
    SparseVectorIndexConfig.validate_embedding_function_field(func) # 31.0μs -> 22.6μs (37.2% faster)

def test_empty_input_list_is_accepted():
    """Test that a callable can handle an empty input list."""
    func = ValidSparseEmbeddingFunction()
    SparseVectorIndexConfig.validate_embedding_function_field(func) # 24.8μs -> 17.9μs (38.6% faster)
    result = func([])

def test_callable_with_nonstandard_classname():
    """Test that a callable with a weird class name is still accepted if signature is correct."""
    WeirdName = type("WeirdName", (), {
        "__init__": lambda self: None,
        "__call__": lambda self, input: [{} for _ in input]
    })
    func = WeirdName()
    codeflash_output = SparseVectorIndexConfig.validate_embedding_function_field(func); result = codeflash_output # 24.3μs -> 17.0μs (42.7% faster)

# -------------------- LARGE SCALE TEST CASES --------------------

def test_large_input_list_performance():
    """Test that a valid embedding function handles a large input list efficiently."""
    func = ValidSparseEmbeddingFunction()
    SparseVectorIndexConfig.validate_embedding_function_field(func) # 23.9μs -> 16.9μs (40.9% faster)
    input_data = [f"doc{i}" for i in range(1000)]  # 1000 elements, within limit
    result = func(input_data)
    for item in result:
        pass

def test_large_output_list_of_dicts():
    """Test that a callable returning a large list of dicts is accepted."""
    class LargeOutputEmbeddingFunction:
        def __init__(self): pass
        def __call__(self, input: Documents) -> List[SparseVector]:
            # ignore input, return 1000 dicts
            return [{} for _ in range(1000)]
    func = LargeOutputEmbeddingFunction()
    SparseVectorIndexConfig.validate_embedding_function_field(func) # 25.4μs -> 19.4μs (31.2% faster)
    result = func(["doc"])  # input doesn't matter
    for item in result:
        pass



def test_callable_with_kwargs_is_invalid():
    """Test that a callable with **kwargs in signature is rejected."""
    class KwArgsEmbeddingFunction:
        def __init__(self): pass
        def __call__(self, input, **kwargs) -> List[SparseVector]:
            return [{}]
    func = KwArgsEmbeddingFunction()
    with pytest.raises(ValueError):
        SparseVectorIndexConfig.validate_embedding_function_field(func) # 35.6μs -> 26.7μs (33.4% faster)


#------------------------------------------------
from chromadb.api.types import SparseVectorIndexConfig
import pytest

def test_SparseVectorIndexConfig_validate_embedding_function_field():
    with pytest.raises(ValueError, match='embedding_function\\ must\\ be\\ a\\ callable\\ SparseEmbeddingFunction\\ or\\ None'):
        SparseVectorIndexConfig.validate_embedding_function_field(SparseVectorIndexConfig, 0)

To edit these changes git checkout codeflash/optimize-SparseVectorIndexConfig.validate_embedding_function_field-mh7dusnp and push.

Codeflash

The optimization introduces **signature caching** to eliminate redundant computation during validation. The key changes are:

1. **Cached protocol signature**: The protocol signature `signature(SparseEmbeddingFunction.__call__).parameters.keys()` is computed once and cached globally, rather than being recomputed on every validation call.

2. **Tuple conversion for faster comparison**: Both signatures are converted to tuples instead of comparing `dict_keys` objects directly, which provides faster equality comparison in Python.

3. **Lazy initialization**: The protocol signature is computed only when first needed via `_get_protocol_signature()`, avoiding any import-time overhead.

**Why this leads to speedup**: The `inspect.signature()` function performs introspection on the method, which involves parsing the function's metadata. This is computationally expensive when done repeatedly. By caching the protocol signature (which never changes), we eliminate this repeated work. The tuple conversion also optimizes the comparison operation itself.

**Test case performance patterns**: The optimization shows consistent 30-40% speedups across all test cases that involve signature validation (e.g., `test_valid_sparse_embedding_function`: 38.9% faster, `test_multiple_instances_large_scale`: 90.3% faster). The most dramatic improvements occur in scenarios with multiple validations, where the caching benefit compounds. Simple cases like `test_none_embedding_function` show minimal impact since they bypass signature validation entirely.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 26, 2025 07:24
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash labels Oct 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant