Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
import struct
import math

_ALLOWED_DTYPES = {
np.dtype(np.float16),
np.dtype(np.float32),
np.dtype(np.float64),
np.dtype(np.int32),
np.dtype(np.int64),
}

# Re-export types from chromadb.types
__all__ = [
"Metadata",
Expand Down Expand Up @@ -1270,11 +1278,14 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
)
if not all([isinstance(e, np.ndarray) for e in embeddings]):
raise ValueError(
"Expected each embedding in the embeddings to be a numpy array, got "
f"{list(set([type(e).__name__ for e in embeddings]))}"
)

for e in embeddings:
if not isinstance(e, np.ndarray):
raise ValueError(
"Expected each embedding in the embeddings to be a numpy array, got "
f"{list(set([type(e).__name__ for e in embeddings]))}"
)

for i, embedding in enumerate(embeddings):
if embedding.ndim == 0:
raise ValueError(
Expand All @@ -1285,13 +1296,7 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
f"Expected each embedding in the embeddings to be a 1-dimensional numpy array with at least 1 int/float value. Got a 1-dimensional numpy array with no values at pos {i}"
)

if embedding.dtype not in [
np.float16,
np.float32,
np.float64,
np.int32,
np.int64,
]:
if embedding.dtype not in _ALLOWED_DTYPES:
raise ValueError(
"Expected each value in the embedding to be a int or float, got an embedding with "
f"{embedding.dtype} - {embedding}"
Expand Down