-
Notifications
You must be signed in to change notification settings - Fork 7
Collection: making the module provider agnostic #508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| """adding blob column in collection table | ||
| Revision ID: 041 | ||
| Revises: 040 | ||
| Create Date: 2025-12-24 11:03:44.620424 | ||
| """ | ||
| from alembic import op | ||
| import sqlalchemy as sa | ||
| from sqlalchemy.dialects import postgresql | ||
|
|
||
| revision = "041" | ||
| down_revision = "040" | ||
| branch_labels = None | ||
| depends_on = None | ||
|
|
||
| provider_enum = postgresql.ENUM( | ||
| "openai", | ||
| name="providertype", | ||
| create_type=True, | ||
| ) | ||
|
|
||
|
|
||
| def upgrade(): | ||
| provider_enum.create(op.get_bind(), checkfirst=True) | ||
|
|
||
| op.add_column( | ||
| "collection", | ||
| sa.Column( | ||
| "collection_blob", | ||
| postgresql.JSONB(astext_type=sa.Text()), | ||
| nullable=True, | ||
| comment="Provider-specific collection parameters (name, description, chunking params etc.)", | ||
| ), | ||
| ) | ||
|
|
||
| op.add_column( | ||
| "collection", | ||
| sa.Column( | ||
| "provider", | ||
| provider_enum, | ||
| nullable=True, | ||
| comment="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini')", | ||
| ), | ||
| ) | ||
|
|
||
| op.execute("UPDATE collection SET provider = 'openai' WHERE provider IS NULL") | ||
|
|
||
| op.alter_column( | ||
| "collection", | ||
| "provider", | ||
| nullable=False, | ||
| existing_type=provider_enum, | ||
| ) | ||
|
|
||
| op.alter_column( | ||
| "collection", | ||
| "llm_service_name", | ||
| existing_type=sa.VARCHAR(), | ||
| comment="Name of the LLM service", | ||
| existing_comment="Name of the LLM provider's service", | ||
| existing_nullable=False, | ||
| ) | ||
|
|
||
|
|
||
| def downgrade(): | ||
| op.alter_column( | ||
| "collection", | ||
| "llm_service_name", | ||
| existing_type=sa.VARCHAR(), | ||
| comment="Name of the LLM service provider", | ||
| existing_comment="Name of the LLM service", | ||
| existing_nullable=False, | ||
| ) | ||
| op.drop_column("collection", "provider") | ||
| op.drop_column("collection", "collection_blob") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from app.models.collection.request import ( | ||
| Collection, | ||
| CreationRequest, | ||
| DeletionRequest, | ||
| CallbackRequest, | ||
| AssistantOptions, | ||
| CreateCollectionParams, | ||
| ProviderType, | ||
| ) | ||
| from app.models.collection.response import ( | ||
| CollectionIDPublic, | ||
| CollectionPublic, | ||
| CollectionWithDocsPublic, | ||
| CreateCollectionResult, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,26 @@ | ||
| from datetime import datetime | ||
| from enum import Enum | ||
| from typing import Any, Literal | ||
| from uuid import UUID, uuid4 | ||
|
|
||
| from pydantic import HttpUrl, model_validator | ||
| import sqlalchemy as sa | ||
| from sqlalchemy.dialects.postgresql import JSONB, ENUM | ||
| from sqlmodel import Field, Relationship, SQLModel | ||
|
|
||
| from app.core.util import now | ||
| from app.models.document import DocumentPublic | ||
| from app.models.organization import Organization | ||
| from app.models.project import Project | ||
|
|
||
| from .organization import Organization | ||
| from .project import Project | ||
|
|
||
| class ProviderType(str, Enum): | ||
| """Supported LLM providers for collections.""" | ||
|
|
||
| OPENAI = "openai" | ||
|
|
||
|
|
||
| # BEDROCK = "bedrock" | ||
| # GEMINI = "gemini" | ||
|
|
||
|
|
||
| class Collection(SQLModel, table=True): | ||
|
|
@@ -20,6 +31,20 @@ class Collection(SQLModel, table=True): | |
| primary_key=True, | ||
| sa_column_kwargs={"comment": "Unique identifier for the collection"}, | ||
| ) | ||
|
|
||
| provider: ProviderType = Field( | ||
| sa_column=sa.Column( | ||
| ENUM( | ||
| "openai", | ||
| # "bedrock", | ||
| # "gemini", | ||
| name="providertype", | ||
| create_type=False, | ||
| ), | ||
| nullable=False, | ||
| comment="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini', etc)", | ||
| ), | ||
| ) | ||
| llm_service_id: str = Field( | ||
| nullable=False, | ||
| sa_column_kwargs={ | ||
|
|
@@ -30,8 +55,13 @@ class Collection(SQLModel, table=True): | |
| nullable=False, | ||
| sa_column_kwargs={"comment": "Name of the LLM service"}, | ||
| ) | ||
|
|
||
| # Foreign keys | ||
| collection_blob: dict[str, Any] | None = Field( | ||
| sa_column=sa.Column( | ||
| JSONB, | ||
| nullable=True, | ||
| comment="Provider-specific collection parameters (name, description, chunking params etc.)", | ||
| ) | ||
| ) | ||
| organization_id: int = Field( | ||
| foreign_key="organization.id", | ||
| nullable=False, | ||
|
|
@@ -44,8 +74,6 @@ class Collection(SQLModel, table=True): | |
| ondelete="CASCADE", | ||
| sa_column_kwargs={"comment": "Reference to the project"}, | ||
| ) | ||
|
|
||
| # Timestamps | ||
| inserted_at: datetime = Field( | ||
| default_factory=now, | ||
| sa_column_kwargs={"comment": "Timestamp when the collection was created"}, | ||
|
|
@@ -64,27 +92,55 @@ class Collection(SQLModel, table=True): | |
| project: Project = Relationship(back_populates="collections") | ||
|
|
||
|
|
||
| # Request models | ||
| class DocumentOptions(SQLModel): | ||
| documents: list[UUID] = Field( | ||
| description="List of document IDs", | ||
| class DocumentInput(SQLModel): | ||
| """Document to be added to knowledge base.""" | ||
|
|
||
| name: str | None = Field( | ||
| description="Display name for the document", | ||
| ) | ||
| batch_size: int = Field( | ||
| default=1, | ||
| description=( | ||
| "Number of documents to send to OpenAI in a single " | ||
| "transaction. See the `file_ids` parameter in the " | ||
| "vector store [create batch](https://platform.openai.com/docs/api-reference/vector-stores-file-batches/createBatch)." | ||
| ), | ||
| id: UUID = Field( | ||
| description="Reference to uploaded file/document in Kaapi", | ||
| ) | ||
|
Comment on lines
+95
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The 🔎 Proposed fix class DocumentInput(SQLModel):
"""Document to be added to knowledge base."""
name: str | None = Field(
+ default=None,
description="Display name for the document",
)
id: UUID = Field(
description="Reference to uploaded file/document in Kaapi",
)🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| class CreateCollectionParams(SQLModel): | ||
| """Request-specific parameters for knowledge base creation.""" | ||
|
|
||
| name: str | None = Field( | ||
| min_length=1, | ||
| description="Name of the knowledge base to create or update", | ||
| ) | ||
|
Comment on lines
+109
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # Find and read the file to see the actual context
fd backend/app/models/collection/request.py --type fRepository: ProjectTech4DevAI/kaapi-backend Length of output: 57 🏁 Script executed: # Read the specific file around lines 109-112
if [ -f "backend/app/models/collection/request.py" ]; then
wc -l backend/app/models/collection/request.py
fiRepository: ProjectTech4DevAI/kaapi-backend Length of output: 119 🏁 Script executed: # Get the file structure to understand imports and class definition
head -50 backend/app/models/collection/request.pyRepository: ProjectTech4DevAI/kaapi-backend Length of output: 1386 🏁 Script executed: # Get lines 100-120 to see the name field and context
sed -n '100,120p' backend/app/models/collection/request.pyRepository: ProjectTech4DevAI/kaapi-backend Length of output: 729 🏁 Script executed: python3 << 'EOF'
from pydantic import Field
from sqlmodel import SQLModel
class CreateCollectionParams(SQLModel):
name: str | None = Field(
min_length=1,
description="Name of the knowledge base to create or update",
)
# Test 1: None value
try:
t1 = CreateCollectionParams(name=None)
print(f'Test 1 - None accepted: {t1.name}')
except Exception as e:
print(f'Test 1 - Validation error with None: {type(e).__name__}: {e}')
# Test 2: Empty string
try:
t2 = CreateCollectionParams(name="")
print(f'Test 2 - Empty string accepted: {t2.name}')
except Exception as e:
print(f'Test 2 - Validation error with empty string: {type(e).__name__}: {e}')
# Test 3: Valid string
try:
t3 = CreateCollectionParams(name="test")
print(f'Test 3 - Valid string accepted: {t3.name}')
except Exception as e:
print(f'Test 3 - Validation error with valid string: {type(e).__name__}: {e}')
# Test 4: No name provided
try:
t4 = CreateCollectionParams()
print(f'Test 4 - No name provided, defaults to: {t4.name}')
except Exception as e:
print(f'Test 4 - Validation error when name not provided: {type(e).__name__}: {e}')
EOFRepository: ProjectTech4DevAI/kaapi-backend Length of output: 161 🏁 Script executed: # Check the full CreateCollectionParams class and surrounding context
sed -n '105,145p' backend/app/models/collection/request.pyRepository: ProjectTech4DevAI/kaapi-backend Length of output: 1494 🏁 Script executed: # Check if there are imports or other usage patterns in the file
grep -n "sa_column_kwargs" backend/app/models/collection/request.py | head -20Repository: ProjectTech4DevAI/kaapi-backend Length of output: 659 Add explicit The 🤖 Prompt for AI Agents |
||
| description: str | None = Field( | ||
| default=None, | ||
| description="Description of the knowledge base (required by Bedrock, optional for others)", | ||
| ) | ||
| documents: list[DocumentInput] = Field( | ||
| default_factory=list, | ||
| description="List of documents to add to the knowledge base", | ||
| ) | ||
| chunking_params: dict[str, Any] | None = Field( | ||
| default=None, | ||
| description="Chunking parameters for document processing (e.g., chunk_size, chunk_overlap)", | ||
| ) | ||
| additional_params: dict[str, Any] | None = Field( | ||
| default=None, | ||
| description="Additional provider-specific parameters", | ||
| ) | ||
|
|
||
| def model_post_init(self, __context: Any): | ||
| self.documents = list(set(self.documents)) | ||
| """Deduplicate documents by document id.""" | ||
| seen = set() | ||
| unique_docs = [] | ||
| for doc in self.documents: | ||
| if doc.id not in seen: | ||
| seen.add(doc.id) | ||
| unique_docs.append(doc) | ||
| self.documents = unique_docs | ||
|
|
||
|
|
||
| class AssistantOptions(SQLModel): | ||
| # Fields to be passed along to OpenAI. They must be a subset of | ||
| # parameters accepted by the OpenAI.clien.beta.assistants.create | ||
| # parameters accepted by the OpenAI.client.beta.assistants.create | ||
| # API. | ||
| model: str | None = Field( | ||
| default=None, | ||
|
|
@@ -139,6 +195,8 @@ def norm(x: Any) -> Any: | |
|
|
||
|
|
||
| class CallbackRequest(SQLModel): | ||
| """Optional callback configuration for async job notifications.""" | ||
|
|
||
| callback_url: HttpUrl | None = Field( | ||
| default=None, | ||
| description="URL to call to report endpoint status", | ||
|
|
@@ -148,45 +206,38 @@ class CallbackRequest(SQLModel): | |
| class ProviderOptions(SQLModel): | ||
| """LLM provider configuration.""" | ||
|
|
||
| provider: Literal["openai"] = Field( | ||
| default="openai", description="LLM provider to use for this collection" | ||
| provider: ProviderType = Field( | ||
| default=ProviderType.OPENAI, | ||
| description="LLM provider to use for this collection", | ||
| ) | ||
|
|
||
|
|
||
| class CreationRequest( | ||
| DocumentOptions, | ||
| ProviderOptions, | ||
| AssistantOptions, | ||
| CallbackRequest, | ||
| ): | ||
| def extract_super_type(self, cls: "CreationRequest"): | ||
| for field_name in cls.model_fields.keys(): | ||
| field_value = getattr(self, field_name) | ||
| yield (field_name, field_value) | ||
|
|
||
|
|
||
| class DeletionRequest(CallbackRequest): | ||
| collection_id: UUID = Field(description="Collection to delete") | ||
|
|
||
|
|
||
| # Response models | ||
| @model_validator(mode="before") | ||
| def normalize_provider(cls, values: dict[str, Any]) -> dict[str, Any]: | ||
| """Normalize provider value to lowercase for case-insensitive matching.""" | ||
| if isinstance(values, dict) and "provider" in values: | ||
| provider = values["provider"] | ||
| if isinstance(provider, str): | ||
| values["provider"] = provider.lower() | ||
| return values | ||
|
|
||
|
|
||
| class CollectionIDPublic(SQLModel): | ||
| id: UUID | ||
| class CreationRequest(AssistantOptions, ProviderOptions, CallbackRequest): | ||
| """API request for collection creation""" | ||
|
|
||
| collection_params: CreateCollectionParams = Field( | ||
| ..., | ||
| description="Collection creation specific parameters (name, documents, etc.)", | ||
| ) | ||
| batch_size: int = Field( | ||
| default=10, | ||
| ge=1, | ||
| le=500, | ||
| description="Number of documents to process in a single batch", | ||
| ) | ||
|
|
||
| class CollectionPublic(SQLModel): | ||
| id: UUID | ||
| llm_service_id: str | ||
| llm_service_name: str | ||
| project_id: int | ||
| organization_id: int | ||
|
|
||
| inserted_at: datetime | ||
| updated_at: datetime | ||
| deleted_at: datetime | None = None | ||
| class DeletionRequest(ProviderOptions, CallbackRequest): | ||
|
|
||
| """API request for collection deletion""" | ||
|
|
||
| class CollectionWithDocsPublic(CollectionPublic): | ||
| documents: list[DocumentPublic] | None = None | ||
| collection_id: UUID = Field(description="Collection to delete") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from datetime import datetime | ||
| from typing import Any | ||
| from uuid import UUID | ||
|
|
||
| from sqlmodel import SQLModel | ||
|
|
||
| from app.models.document import DocumentPublic | ||
|
|
||
|
|
||
| class CreateCollectionResult(SQLModel): | ||
| llm_service_id: str | ||
| llm_service_name: str | ||
| collection_blob: dict[str, Any] | ||
|
|
||
|
|
||
| class CollectionIDPublic(SQLModel): | ||
| id: UUID | ||
|
|
||
|
|
||
| class CollectionPublic(SQLModel): | ||
| id: UUID | ||
| llm_service_id: str | ||
| llm_service_name: str | ||
| project_id: int | ||
| organization_id: int | ||
|
|
||
| inserted_at: datetime | ||
| updated_at: datetime | ||
| deleted_at: datetime | None = None | ||
|
|
||
|
|
||
| class CollectionWithDocsPublic(CollectionPublic): | ||
| documents: list[DocumentPublic] | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing ENUM type drop in downgrade.
The
downgrade()function drops theproviderandcollection_blobcolumns but doesn't drop theprovidertypeENUM type. This could leave orphaned types in the database after a rollback.🔎 Proposed fix
def downgrade(): op.alter_column( "collection", "llm_service_name", existing_type=sa.VARCHAR(), comment="Name of the LLM service provider", existing_comment="Name of the LLM service", existing_nullable=False, ) op.drop_column("collection", "provider") op.drop_column("collection", "collection_blob") + provider_enum.drop(op.get_bind(), checkifexists=True)🤖 Prompt for AI Agents