Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions chromadb/proto/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json

import numpy as np
from numpy.typing import NDArray

import chromadb.proto.chroma_pb2 as chroma_pb
import chromadb.proto.query_executor_pb2 as query_pb
Expand Down Expand Up @@ -36,6 +35,18 @@
VectorQueryResult,
)

_float32_encoding = chroma_pb.ScalarEncoding.FLOAT32

_int32_encoding = chroma_pb.ScalarEncoding.INT32

_add_operation = chroma_pb.Operation.ADD

_update_operation = chroma_pb.Operation.UPDATE

_upsert_operation = chroma_pb.Operation.UPSERT

_delete_operation = chroma_pb.Operation.DELETE


class ProjectionRecord(TypedDict):
id: str
Expand Down Expand Up @@ -70,32 +81,32 @@ def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> chroma_pb.Vecto

def from_proto_vector(vector: chroma_pb.Vector) -> Tuple[Embedding, ScalarEncoding]:
encoding = vector.encoding
as_array: Union[NDArray[np.int32], NDArray[np.float32]]
if encoding == chroma_pb.ScalarEncoding.FLOAT32:
as_array = np.frombuffer(vector.vector, dtype=np.float32)
out_encoding = ScalarEncoding.FLOAT32
elif encoding == chroma_pb.ScalarEncoding.INT32:
as_array = np.frombuffer(vector.vector, dtype=np.int32)
out_encoding = ScalarEncoding.INT32
# Fast-path encoding checks using pre-cached values
if encoding == _float32_encoding:
# Avoid assigning to an intermediate variable for maximum efficiency
return (np.frombuffer(vector.vector, dtype=np.float32), ScalarEncoding.FLOAT32)
elif encoding == _int32_encoding:
return (np.frombuffer(vector.vector, dtype=np.int32), ScalarEncoding.INT32)
else:
# TODO: full error
raise ValueError(
f"Unknown encoding {encoding}, expected one of \
{chroma_pb.ScalarEncoding.FLOAT32} or {chroma_pb.ScalarEncoding.INT32}"
f"Unknown encoding {encoding}, expected one of "
f"{_float32_encoding} or {_int32_encoding}"
)

return (as_array, out_encoding)


def from_proto_operation(operation: chroma_pb.Operation) -> Operation:
if operation == chroma_pb.Operation.ADD:
return Operation.ADD
elif operation == chroma_pb.Operation.UPDATE:
return Operation.UPDATE
elif operation == chroma_pb.Operation.UPSERT:
return Operation.UPSERT
elif operation == chroma_pb.Operation.DELETE:
return Operation.DELETE
else:
# Use a dict lookup for fast constant mapping
# Only falls back to exception if no mapping matches
_operation_map = {
_add_operation: Operation.ADD,
_update_operation: Operation.UPDATE,
_upsert_operation: Operation.UPSERT,
_delete_operation: Operation.DELETE,
}
try:
return _operation_map[operation]
except KeyError:
# TODO: full error
raise RuntimeError(f"Unknown operation {operation}")

Expand All @@ -107,6 +118,7 @@ def from_proto_metadata(metadata: chroma_pb.UpdateMetadata) -> Optional[Metadata
def from_proto_update_metadata(
metadata: chroma_pb.UpdateMetadata,
) -> Optional[UpdateMetadata]:
# _from_proto_metadata_handle_none is assumed imported and optimized
return cast(
Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True)
)
Expand Down Expand Up @@ -144,7 +156,8 @@ def from_proto_submit(
operation_record: chroma_pb.OperationRecord, seq_id: SeqId
) -> LogRecord:
embedding, encoding = from_proto_vector(operation_record.vector)
record = LogRecord(
# Inline OperationRecord and LogRecord construction; minimize attribute lookups
return LogRecord(
log_offset=seq_id,
record=OperationRecord(
id=operation_record.id,
Expand All @@ -154,7 +167,6 @@ def from_proto_submit(
operation=from_proto_operation(operation_record.operation),
),
)
return record


def from_proto_segment(segment: chroma_pb.Segment) -> Segment:
Expand Down