diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 17ca55a83b9..9d3d7deaae1 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -36,6 +36,12 @@ VectorQueryResult, ) +_SEGMENT_SCOPE_FAST_MAP = { + chroma_pb.SegmentScope.VECTOR: SegmentScope.VECTOR, + chroma_pb.SegmentScope.METADATA: SegmentScope.METADATA, + chroma_pb.SegmentScope.RECORD: SegmentScope.RECORD, +} + class ProjectionRecord(TypedDict): id: str @@ -158,18 +164,16 @@ def from_proto_submit( def from_proto_segment(segment: chroma_pb.Segment) -> Segment: + has_metadata = segment.HasField("metadata") + file_paths = {name: list(paths.paths) for name, paths in segment.file_paths.items()} + return Segment( id=UUID(hex=segment.id), type=segment.type, - scope=from_proto_segment_scope(segment.scope), + scope=_from_proto_segment_scope_fast(segment.scope), collection=UUID(hex=segment.collection), - metadata=from_proto_metadata(segment.metadata) - if segment.HasField("metadata") - else None, - file_paths={ - name: [path for path in paths.paths] - for name, paths in segment.file_paths.items() - }, + metadata=from_proto_metadata(segment.metadata) if has_metadata else None, + file_paths=file_paths, ) @@ -686,3 +690,12 @@ def from_proto_knn_batch_result( [from_proto_knn_projection_record(record) for record in result.records] for result in results.results ] + + +def _from_proto_segment_scope_fast( + segment_scope: chroma_pb.SegmentScope, +) -> SegmentScope: + try: + return _SEGMENT_SCOPE_FAST_MAP[segment_scope] + except KeyError: + raise RuntimeError(f"Unknown segment scope {segment_scope}")