From b7964868c775868ec3a535b072c632964f385ac5 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 24 Dec 2025 16:45:24 +0530 Subject: [PATCH 1/3] pushing the logic first --- ..._adding_blob_column_in_collection_table.py | 47 ++++++ backend/app/models/__init__.py | 4 + backend/app/models/collection/__init__.py | 14 ++ .../{collection.py => collection/request.py} | 122 ++++++++------ backend/app/models/collection/response.py | 33 ++++ .../services/collections/create_collection.py | 118 ++++--------- .../services/collections/delete_collection.py | 28 ++-- backend/app/services/collections/helpers.py | 11 -- .../collections/providers/__init__.py | 6 + .../services/collections/providers/base.py | 84 ++++++++++ .../services/collections/providers/openai.py | 156 ++++++++++++++++++ .../collections/providers/registry.py | 71 ++++++++ 12 files changed, 533 insertions(+), 161 deletions(-) create mode 100644 backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py create mode 100644 backend/app/models/collection/__init__.py rename backend/app/models/{collection.py => collection/request.py} (63%) create mode 100644 backend/app/models/collection/response.py create mode 100644 backend/app/services/collections/providers/__init__.py create mode 100644 backend/app/services/collections/providers/base.py create mode 100644 backend/app/services/collections/providers/openai.py create mode 100644 backend/app/services/collections/providers/registry.py diff --git a/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py new file mode 100644 index 00000000..8f65f055 --- /dev/null +++ b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py @@ -0,0 +1,47 @@ +"""adding blob column in collection table + +Revision ID: 041 +Revises: 040 +Create Date: 2025-12-24 11:03:44.620424 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = "041" +down_revision = "040" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection", + sa.Column( + "collection_blob", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Provider-specific knowledge base creation parameters (name, description, chunking params etc.)", + ), + ) + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service", + existing_comment="Name of the LLM service provider", + existing_nullable=False, + ) + + +def downgrade(): + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service provider", + existing_comment="Name of the LLM service", + existing_nullable=False, + ) + op.drop_column("collection", "collection_blob") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 9a351825..a23a05ae 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -8,9 +8,13 @@ from .collection import ( Collection, + CreateCollectionParams, + CreateCollectionResult, + CreationRequest, CollectionPublic, CollectionIDPublic, CollectionWithDocsPublic, + DeletionRequest, ) from .collection_job import ( CollectionActionType, diff --git a/backend/app/models/collection/__init__.py b/backend/app/models/collection/__init__.py new file mode 100644 index 00000000..e31f65bc --- /dev/null +++ b/backend/app/models/collection/__init__.py @@ -0,0 +1,14 @@ +from app.models.collection.request import ( + Collection, + CreationRequest, + DeletionRequest, + CallbackRequest, + AssistantOptions, + CreateCollectionParams, +) +from app.models.collection.response import ( + CollectionIDPublic, + CollectionPublic, + CollectionWithDocsPublic, + CreateCollectionResult, +) diff --git a/backend/app/models/collection.py b/backend/app/models/collection/request.py similarity index 63% rename from backend/app/models/collection.py rename to backend/app/models/collection/request.py index 57e5a17b..9f8e106b 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection/request.py @@ -3,13 +3,13 @@ from uuid import UUID, uuid4 from pydantic import HttpUrl, model_validator +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field, Relationship, SQLModel from app.core.util import now -from app.models.document import DocumentPublic - -from .organization import Organization -from .project import Project +from app.models.organization import Organization +from app.models.project import Project class Collection(SQLModel, table=True): @@ -30,8 +30,13 @@ class Collection(SQLModel, table=True): nullable=False, sa_column_kwargs={"comment": "Name of the LLM service"}, ) - - # Foreign keys + collection_blob: dict[str, Any] | None = Field( + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Provider-specific collection parameters (name, description, chunking params etc.)", + ) + ) organization_id: int = Field( foreign_key="organization.id", nullable=False, @@ -44,8 +49,6 @@ class Collection(SQLModel, table=True): ondelete="CASCADE", sa_column_kwargs={"comment": "Reference to the project"}, ) - - # Timestamps inserted_at: datetime = Field( default_factory=now, sa_column_kwargs={"comment": "Timestamp when the collection was created"}, @@ -64,27 +67,55 @@ class Collection(SQLModel, table=True): project: Project = Relationship(back_populates="collections") -# Request models -class DocumentOptions(SQLModel): - documents: list[UUID] = Field( - description="List of document IDs", +class DocumentInput(SQLModel): + """Document to be added to knowledge base.""" + + name: str | None = Field( + description="Display name for the document", ) - batch_size: int = Field( - default=1, - description=( - "Number of documents to send to OpenAI in a single " - "transaction. See the `file_ids` parameter in the " - "vector store [create batch](https://platform.openai.com/docs/api-reference/vector-stores-file-batches/createBatch)." - ), + id: UUID = Field( + description="Reference to uploaded file/document in Kaapi", + ) + + +class CreateCollectionParams(SQLModel): + """Request-specific parameters for knowledge base creation.""" + + name: str | None = Field( + min_length=1, + description="Name of the knowledge base to create or update", + ) + description: str | None = Field( + default=None, + description="Description of the knowledge base (required by Bedrock, optional for others)", + ) + documents: list[DocumentInput] = Field( + default_factory=list, + description="List of documents to add to the knowledge base", + ) + chunking_params: dict[str, Any] | None = Field( + default=None, + description="Chunking parameters for document processing (e.g., chunk_size, chunk_overlap)", + ) + additional_params: dict[str, Any] | None = Field( + default=None, + description="Additional provider-specific parameters", ) def model_post_init(self, __context: Any): - self.documents = list(set(self.documents)) + """Deduplicate documents by file_id.""" + seen = set() + unique_docs = [] + for doc in self.documents: + if doc.file_id not in seen: + seen.add(doc.file_id) + unique_docs.append(doc) + self.documents = unique_docs class AssistantOptions(SQLModel): # Fields to be passed along to OpenAI. They must be a subset of - # parameters accepted by the OpenAI.clien.beta.assistants.create + # parameters accepted by the OpenAI.client.beta.assistants.create # API. model: str | None = Field( default=None, @@ -139,6 +170,8 @@ def norm(x: Any) -> Any: class CallbackRequest(SQLModel): + """Optional callback configuration for async job notifications.""" + callback_url: HttpUrl | None = Field( default=None, description="URL to call to report endpoint status", @@ -153,40 +186,23 @@ class ProviderOptions(SQLModel): ) -class CreationRequest( - DocumentOptions, - ProviderOptions, - AssistantOptions, - CallbackRequest, -): - def extract_super_type(self, cls: "CreationRequest"): - for field_name in cls.model_fields.keys(): - field_value = getattr(self, field_name) - yield (field_name, field_value) - - -class DeletionRequest(CallbackRequest): - collection_id: UUID = Field(description="Collection to delete") - - -# Response models - - -class CollectionIDPublic(SQLModel): - id: UUID +class CreationRequest(AssistantOptions, ProviderOptions, CallbackRequest): + """API request for collection creation""" + collection_params: CreateCollectionParams = Field( + ..., + description="Collection creation specific parameters (name, documents, etc.)", + ) + batch_size: int = Field( + default=10, + ge=1, + le=500, + description="Number of documents to process in a single batch", + ) -class CollectionPublic(SQLModel): - id: UUID - llm_service_id: str - llm_service_name: str - project_id: int - organization_id: int - inserted_at: datetime - updated_at: datetime - deleted_at: datetime | None = None +class DeletionRequest(ProviderOptions, CallbackRequest): + """API request for collection deletion""" -class CollectionWithDocsPublic(CollectionPublic): - documents: list[DocumentPublic] | None = None + collection_id: UUID = Field(description="Collection to delete") diff --git a/backend/app/models/collection/response.py b/backend/app/models/collection/response.py new file mode 100644 index 00000000..f72c5ee7 --- /dev/null +++ b/backend/app/models/collection/response.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import Any +from uuid import UUID + +from sqlmodel import SQLModel + +from app.models.document import DocumentPublic + + +class CreateCollectionResult(SQLModel): + llm_service_id: str + llm_service_name: str + collection_blob: dict[str, Any] + + +class CollectionIDPublic(SQLModel): + id: UUID + + +class CollectionPublic(SQLModel): + id: UUID + llm_service_id: str + llm_service_name: str + project_id: int + organization_id: int + + inserted_at: datetime + updated_at: datetime + deleted_at: datetime | None = None + + +class CollectionWithDocsPublic(CollectionPublic): + documents: list[DocumentPublic] | None = None diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index ed83e4a8..1086dc71 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -6,7 +6,6 @@ from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage -from app.core.util import now from app.core.db import engine from app.crud import ( CollectionCrud, @@ -14,7 +13,6 @@ DocumentCollectionCrud, CollectionJobCrud, ) -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.models import ( CollectionJobStatus, CollectionJob, @@ -23,18 +21,11 @@ CollectionPublic, CollectionJobPublic, ) -from app.models.collection import ( - CreationRequest, - AssistantOptions, -) -from app.services.collections.helpers import ( - _backout, - batch_documents, - extract_error_message, - OPENAI_VECTOR_STORE, -) +from app.models.collection import CreationRequest +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -116,26 +107,6 @@ def build_failure_payload(collection_job: CollectionJob, error_message: str) -> ) -def _cleanup_remote_resources( - assistant, - assistant_crud, - vector_store, - vector_store_crud, -) -> None: - """Best-effort cleanup of partially created remote resources.""" - try: - if assistant is not None and assistant_crud is not None: - _backout(assistant_crud, assistant.id) - elif vector_store is not None and vector_store_crud is not None: - _backout(vector_store_crud, vector_store.id) - else: - logger.warning( - "[create_collection._backout] Skipping: no resource/crud available" - ) - except Exception: - logger.warning("[create_collection.execute_job] Backout failed") - - def _mark_job_failed( project_id: int, job_id: str, @@ -172,17 +143,15 @@ def execute_job( ) -> None: """ Worker entrypoint scheduled by start_job. - Orchestrates: job state, client/storage init, batching, vector-store upload, + Orchestrates: job state, provider init, collection creation, optional assistant creation, collection persistence, linking, callbacks, and cleanup. """ start_time = time.time() - # Keep references for potential backout/cleanup on failure - assistant = None - assistant_crud = None - vector_store = None - vector_store_crud = None + # Keeping the references for potential backout/cleanup on failure collection_job = None + result = None + provider = None try: creation_request = CreationRequest(**request) @@ -199,49 +168,32 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) storage = get_cloud_storage(session=session, project_id=project_id) - - # Batch documents for upload, and flatten for linking/metrics later document_crud = DocumentCrud(session, project_id) - docs_batches = batch_documents( - document_crud, - creation_request.documents, - creation_request.batch_size, + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, ) - flat_docs = [doc for batch in docs_batches for doc in batch] - vector_store_crud = OpenAIVectorStoreCrud(client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + result = provider.create( + collection_request=creation_request, + storage=storage, + document_crud=document_crud, + ) - # if with_assistant is true, create assistant backed by the vector store - if with_assistant: - assistant_crud = OpenAIAssistantCrud(client) + llm_service_id = result.llm_service_id + llm_service_name = result.llm_service_name + # Storing collection params (name, description, chunking_params, etc.) in DB + # for future reference and to support different providers with varying configurations + collection_blob = result.collection_blob - # Filter out None to avoid sending unset options - assistant_options = dict( - creation_request.extract_super_type(AssistantOptions) - ) - assistant_options = { - k: v for k, v in assistant_options.items() if v is not None - } - - assistant = assistant_crud.create(vector_store.id, **assistant_options) - llm_service_id = assistant.id - llm_service_name = assistant_options.get("model") or "assistant" - - logger.info( - "[execute_job] Assistant created | assistant_id=%s, vector_store_id=%s", - assistant.id, - vector_store.id, - ) - else: - # If no assistant, the collection points directly at the vector store - llm_service_id = vector_store.id - llm_service_name = OPENAI_VECTOR_STORE - logger.info( - "[execute_job] Skipping assistant creation | with_assistant=False" + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + flat_docs = document_crud.read_many_by_ids( + [doc.id for doc in creation_request.collection_params.documents] ) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} @@ -259,6 +211,7 @@ def execute_job( organization_id=organization_id, llm_service_id=llm_service_id, llm_service_name=llm_service_name, + collection_blob=collection_blob, ) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) @@ -299,12 +252,13 @@ def execute_job( exc_info=True, ) - _cleanup_remote_resources( - assistant=assistant, - assistant_crud=assistant_crud, - vector_store=vector_store, - vector_store_crud=vector_store_crud, - ) + if provider is not None and result is not None: + try: + provider.cleanup(result) + except Exception: + logger.warning( + "[create_collection.execute_job] Provider cleanup failed" + ) collection_job = _mark_job_failed( project_id=project_id, diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index ca337b79..a49fa4b1 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -6,7 +6,6 @@ from app.core.db import engine from app.crud import CollectionCrud, CollectionJobCrud -from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud from app.models import ( CollectionJobStatus, CollectionJobUpdate, @@ -15,9 +14,10 @@ CollectionIDPublic, ) from app.models.collection import DeletionRequest -from app.services.collections.helpers import extract_error_message, OPENAI_VECTOR_STORE +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -155,7 +155,6 @@ def execute_job( job_uuid = UUID(job_id) collection_job = None - client = None try: with Session(engine) as session: @@ -169,20 +168,19 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) - collection = CollectionCrud(session, project_id).read_one(collection_id) - # Identify which external service (assistant/vector store) this collection belongs to - service = (collection.llm_service_name or "").strip().lower() - is_vector = service == OPENAI_VECTOR_STORE - llm_service_id = collection.llm_service_id + provider = get_llm_provider( + session=session, + provider=deletion_request.provider, + project_id=project_id, + organization_id=organization_id, + ) - # Delete the corresponding OpenAI resource (vector store or assistant) - if is_vector: - OpenAIVectorStoreCrud(client).delete(llm_service_id) - else: - OpenAIAssistantCrud(client).delete(llm_service_id) + provider.delete( + llm_service_id=collection.llm_service_id, + llm_service_name=collection.llm_service_name, + ) with Session(engine) as session: CollectionCrud(session, project_id).delete_by_id(collection_id) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 795b04cd..6995e081 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -69,17 +69,6 @@ def batch_documents( return docs_batches -def _backout(crud, llm_service_id: str): - """Best-effort cleanup: attempt to delete the assistant by ID""" - try: - crud.delete(llm_service_id) - except OpenAIError as err: - logger.error( - f"[backout] Failed to delete resource | {{'llm_service_id': '{llm_service_id}', 'error': '{str(err)}'}}", - exc_info=True, - ) - - # Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will # eventually be removed from Kaapi. Once that happens, this function can be safely deleted - def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): diff --git a/backend/app/services/collections/providers/__init__.py b/backend/app/services/collections/providers/__init__.py new file mode 100644 index 00000000..5a9b6a55 --- /dev/null +++ b/backend/app/services/collections/providers/__init__.py @@ -0,0 +1,6 @@ +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider +from app.services.collections.providers.registry import ( + LLMProvider, + get_llm_provider, +) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py new file mode 100644 index 00000000..9fb21f3e --- /dev/null +++ b/backend/app/services/collections/providers/base.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Any + +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.models import CreationRequest, CreateCollectionResult, Collection + + +class BaseProvider(ABC): + """Abstract base class for collection providers. + + All provider implementations (OpenAI, Bedrock, etc.) must inherit from + this class and implement the required methods. + + Providers handle creation of knowledge bases (vector stores) and + optional assistant/agent creation backed by those knowledge bases. + + Attributes: + client: The provider-specific client instance + """ + + def __init__(self, client: Any): + """Initialize provider with client. + + Args: + client: Provider-specific client instance + """ + self.client = client + + @abstractmethod + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> CreateCollectionResult: + """Create collection with documents and optionally an assistant. + + Args: + collection_params: Collection parameters (name, description, chunking_params, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant/agent + assistant_options: Options for assistant creation (provider-specific) + + Returns: + CreateCollectionresult containing: + - llm_service_id: ID of the created resource (vector store or assistant) + - llm_service_name: Name of the service + - kb_blob: All collection params except documents + """ + raise NotImplementedError("Providers must implement execute method") + + @abstractmethod + def delete(self, collection: Collection) -> None: + """Delete remote resources associated with a collection. + + Called when a collection is being deleted and remote resources need to be cleaned up. + + Args: + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) + """ + raise NotImplementedError("Providers must implement delete method") + + @abstractmethod + def cleanup(self, collection_result: CreateCollectionResult) -> None: + """Clean up/rollback resources created during execute. + + Called when collection creation fails and remote resources need to be deleted. + + Args: + collection_result: The CreateCollectionresult returned from execute, containing resource IDs + """ + raise NotImplementedError("Providers must implement cleanup method") + + def get_provider_name(self) -> str: + """Get the name of the provider. + + Returns: + Provider name (e.g., "openai", "bedrock", "pinecone") + """ + return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py new file mode 100644 index 00000000..36c76af1 --- /dev/null +++ b/backend/app/services/collections/providers/openai.py @@ -0,0 +1,156 @@ +import logging +from typing import Any + +from openai import OpenAI + +from app.services.collections.providers import BaseProvider +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud +from app.services.collections.helpers import batch_documents, OPENAI_VECTOR_STORE +from app.models import CreateCollectionResult, CreationRequest, Collection + + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(BaseProvider): + """OpenAI-specific collection provider for vector stores and assistants.""" + + def __init__(self, client: OpenAI): + super().__init__(client) + self.client = client + + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> CreateCollectionResult: + """Create OpenAI vector store with documents and optionally an assistant. + + Args: + collection_params: Collection parameters (name, description, chunking_params, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant + assistant_options: Options for assistant creation (model, instructions, etc.) + + Returns: + CreateCollectionResult containing llm_service_id, llm_service_name, and collection_blob + """ + try: + collection_params = collection_request.collection_params + document_ids = [doc.id for doc in collection_params.documents] + + docs_batches = batch_documents( + document_crud, + document_ids, + collection_request.batch_size, + ) + + vector_store_crud = OpenAIVectorStoreCrud(self.client) + vector_store = vector_store_crud.create() + + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + + logger.info( + "[OpenAIProvider.execute] Vector store created | " + f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" + ) + + collection_blob = { + "name": collection_params.name, + "description": collection_params.description, + "chunking_params": collection_params.chunking_params, + "additional_params": collection_params.additional_params, + } + + # Check if we need to create an assistant (based on assistant options in request) + with_assistant = ( + collection_request.model is not None + and collection_request.instructions is not None + ) + if with_assistant: + assistant_crud = OpenAIAssistantCrud(self.client) + + assistant_options = { + "model": collection_request.model, + "instructions": collection_request.instructions, + "temperature": collection_request.temperature, + } + filtered_options = { + k: v for k, v in assistant_options.items() if v is not None + } + + assistant = assistant_crud.create(vector_store.id, **filtered_options) + + logger.info( + "[OpenAIProvider.execute] Assistant created | " + f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + ) + + return CreateCollectionResult( + llm_service_id=assistant.id, + llm_service_name=filtered_options.get("model", "assistant"), + collection_blob=collection_blob, + ) + else: + logger.info( + "[OpenAIProvider.execute] Skipping assistant creation | with_assistant=False" + ) + + return CreateCollectionResult( + llm_service_id=vector_store.id, + llm_service_name=OPENAI_VECTOR_STORE, + collection_blob=collection_blob, + ) + + except Exception as e: + logger.error( + f"[OpenAIProvider.execute] Failed to create knowledge base: {str(e)}", + exc_info=True, + ) + raise + + def delete(self, collection: Collection) -> None: + """Delete OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + + Args: + collection: Collection that has been requested to be deleted + """ + try: + if collection.llm_service_name != OPENAI_VECTOR_STORE: + OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" + ) + else: + OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" + ) + except Exception as e: + logger.error( + f"[OpenAIProvider.delete] Failed to delete resource | " + f"llm_service_id={collection.llm_service_id}, error={str(e)}", + exc_info=True, + ) + raise + + def cleanup(self, result: CreateCollectionResult) -> None: + """Clean up OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + + Args: + result: The CreateCollectionResult from execute containing resource IDs + """ + self.delete(result.llm_service_id, result.llm_service_name) diff --git a/backend/app/services/collections/providers/registry.py b/backend/app/services/collections/providers/registry.py new file mode 100644 index 00000000..10d07d45 --- /dev/null +++ b/backend/app/services/collections/providers/registry.py @@ -0,0 +1,71 @@ +import logging + +from sqlmodel import Session +from openai import OpenAI + +from app.crud import get_provider_credential +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider + + +logger = logging.getLogger(__name__) + + +class LLMProvider: + OPENAI = "openai" + # Future constants for providers: + # ANTHROPIC = "ANTHROPIC" + # GEMINI = "gemini" + + _registry: dict[str, type[BaseProvider]] = { + OPENAI: OpenAIProvider, + # Future providers: + # ANTHROPIC: BedrockProvider, + # GEMINI: GeminiProvider, + } + + @classmethod + def get(cls, name: str) -> type[BaseProvider]: + """Return the provider class for a given name.""" + provider = cls._registry.get(name) + if not provider: + raise ValueError( + f"Provider '{name}' is not supported. " + f"Supported providers: {', '.join(cls._registry.keys())}" + ) + return provider + + @classmethod + def supported_providers(cls) -> list[str]: + """Return a list of supported provider names.""" + return list(cls._registry.keys()) + + +def get_llm_provider( + session: Session, provider: str, project_id: int, organization_id: int +) -> BaseProvider: + provider_class = LLMProvider.get(provider) + + credentials = get_provider_credential( + session=session, + provider=provider, + project_id=project_id, + org_id=organization_id, + ) + + if not credentials: + raise ValueError( + f"Credentials for provider '{provider}' not configured for this project." + ) + + if provider == LLMProvider.OPENAI: + if "api_key" not in credentials: + raise ValueError("OpenAI credentials not configured for this project.") + client = OpenAI(api_key=credentials["api_key"]) + else: + logger.error( + f"[get_llm_provider] Unsupported provider type requested: {provider}" + ) + raise ValueError(f"Provider '{provider}' is not supported.") + + return provider_class(client=client) From 7648f48e8af94d316a5680200626a33acfeff774 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 24 Dec 2025 16:57:59 +0530 Subject: [PATCH 2/3] fixing a delete mistake --- backend/app/services/collections/delete_collection.py | 5 +---- backend/app/services/collections/helpers.py | 11 +++++++++++ backend/app/services/collections/providers/openai.py | 8 ++++++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index a49fa4b1..e9570964 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -177,10 +177,7 @@ def execute_job( organization_id=organization_id, ) - provider.delete( - llm_service_id=collection.llm_service_id, - llm_service_name=collection.llm_service_name, - ) + provider.delete(collection) with Session(engine) as session: CollectionCrud(session, project_id).delete_by_id(collection_id) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 6995e081..795b04cd 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -69,6 +69,17 @@ def batch_documents( return docs_batches +def _backout(crud, llm_service_id: str): + """Best-effort cleanup: attempt to delete the assistant by ID""" + try: + crud.delete(llm_service_id) + except OpenAIError as err: + logger.error( + f"[backout] Failed to delete resource | {{'llm_service_id': '{llm_service_id}', 'error': '{str(err)}'}}", + exc_info=True, + ) + + # Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will # eventually be removed from Kaapi. Once that happens, this function can be safely deleted - def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index 36c76af1..ba734d85 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -7,7 +7,11 @@ from app.crud import DocumentCrud from app.core.cloud.storage import CloudStorage from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import batch_documents, OPENAI_VECTOR_STORE +from app.services.collections.helpers import ( + batch_documents, + OPENAI_VECTOR_STORE, + _backout, +) from app.models import CreateCollectionResult, CreationRequest, Collection @@ -153,4 +157,4 @@ def cleanup(self, result: CreateCollectionResult) -> None: Args: result: The CreateCollectionResult from execute containing resource IDs """ - self.delete(result.llm_service_id, result.llm_service_name) + _backout(result.llm_service_id, result.llm_service_name) From 946e7c7bcb28c5fa932146071bac6378f6e82e69 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 26 Dec 2025 09:53:27 +0530 Subject: [PATCH 3/3] tested and fixed --- ..._adding_blob_column_in_collection_table.py | 33 ++++++++++++- backend/app/models/__init__.py | 1 + backend/app/models/collection/__init__.py | 1 + backend/app/models/collection/request.py | 47 ++++++++++++++++--- .../services/collections/create_collection.py | 9 ++-- backend/app/services/collections/helpers.py | 13 +++-- .../services/collections/providers/openai.py | 6 +-- .../collections/test_collection_info.py | 3 +- .../collections/test_collection_list.py | 3 +- backend/app/tests/utils/collection.py | 6 ++- 10 files changed, 100 insertions(+), 22 deletions(-) diff --git a/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py index 8f65f055..e9ebdd11 100644 --- a/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py +++ b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py @@ -14,23 +14,51 @@ branch_labels = None depends_on = None +provider_enum = postgresql.ENUM( + "openai", + name="providertype", + create_type=True, +) + def upgrade(): + provider_enum.create(op.get_bind(), checkfirst=True) + op.add_column( "collection", sa.Column( "collection_blob", postgresql.JSONB(astext_type=sa.Text()), nullable=True, - comment="Provider-specific knowledge base creation parameters (name, description, chunking params etc.)", + comment="Provider-specific collection parameters (name, description, chunking params etc.)", ), ) + + op.add_column( + "collection", + sa.Column( + "provider", + provider_enum, + nullable=True, + comment="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini')", + ), + ) + + op.execute("UPDATE collection SET provider = 'openai' WHERE provider IS NULL") + + op.alter_column( + "collection", + "provider", + nullable=False, + existing_type=provider_enum, + ) + op.alter_column( "collection", "llm_service_name", existing_type=sa.VARCHAR(), comment="Name of the LLM service", - existing_comment="Name of the LLM service provider", + existing_comment="Name of the LLM provider's service", existing_nullable=False, ) @@ -44,4 +72,5 @@ def downgrade(): existing_comment="Name of the LLM service", existing_nullable=False, ) + op.drop_column("collection", "provider") op.drop_column("collection", "collection_blob") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ef08fd09..95f4948c 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -15,6 +15,7 @@ CollectionIDPublic, CollectionWithDocsPublic, DeletionRequest, + ProviderType, ) from .collection_job import ( CollectionActionType, diff --git a/backend/app/models/collection/__init__.py b/backend/app/models/collection/__init__.py index e31f65bc..a83620b1 100644 --- a/backend/app/models/collection/__init__.py +++ b/backend/app/models/collection/__init__.py @@ -5,6 +5,7 @@ CallbackRequest, AssistantOptions, CreateCollectionParams, + ProviderType, ) from app.models.collection.response import ( CollectionIDPublic, diff --git a/backend/app/models/collection/request.py b/backend/app/models/collection/request.py index 9f8e106b..0d38f837 100644 --- a/backend/app/models/collection/request.py +++ b/backend/app/models/collection/request.py @@ -1,10 +1,11 @@ from datetime import datetime +from enum import Enum from typing import Any, Literal from uuid import UUID, uuid4 from pydantic import HttpUrl, model_validator import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import JSONB, ENUM from sqlmodel import Field, Relationship, SQLModel from app.core.util import now @@ -12,6 +13,16 @@ from app.models.project import Project +class ProviderType(str, Enum): + """Supported LLM providers for collections.""" + + OPENAI = "openai" + + +# BEDROCK = "bedrock" +# GEMINI = "gemini" + + class Collection(SQLModel, table=True): """Database model for Collection operations.""" @@ -20,6 +31,20 @@ class Collection(SQLModel, table=True): primary_key=True, sa_column_kwargs={"comment": "Unique identifier for the collection"}, ) + + provider: ProviderType = Field( + sa_column=sa.Column( + ENUM( + "openai", + # "bedrock", + # "gemini", + name="providertype", + create_type=False, + ), + nullable=False, + comment="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini', etc)", + ), + ) llm_service_id: str = Field( nullable=False, sa_column_kwargs={ @@ -103,12 +128,12 @@ class CreateCollectionParams(SQLModel): ) def model_post_init(self, __context: Any): - """Deduplicate documents by file_id.""" + """Deduplicate documents by document id.""" seen = set() unique_docs = [] for doc in self.documents: - if doc.file_id not in seen: - seen.add(doc.file_id) + if doc.id not in seen: + seen.add(doc.id) unique_docs.append(doc) self.documents = unique_docs @@ -181,10 +206,20 @@ class CallbackRequest(SQLModel): class ProviderOptions(SQLModel): """LLM provider configuration.""" - provider: Literal["openai"] = Field( - default="openai", description="LLM provider to use for this collection" + provider: ProviderType = Field( + default=ProviderType.OPENAI, + description="LLM provider to use for this collection", ) + @model_validator(mode="before") + def normalize_provider(cls, values: dict[str, Any]) -> dict[str, Any]: + """Normalize provider value to lowercase for case-insensitive matching.""" + if isinstance(values, dict) and "provider" in values: + provider = values["provider"] + if isinstance(provider, str): + values["provider"] = provider.lower() + return values + class CreationRequest(AssistantOptions, ProviderOptions, CallbackRequest): """API request for collection creation""" diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 1086dc71..088ebe0e 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -20,8 +20,9 @@ CollectionJobUpdate, CollectionPublic, CollectionJobPublic, + CreationRequest, + ProviderType, ) -from app.models.collection import CreationRequest from app.services.collections.helpers import extract_error_message from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job @@ -51,7 +52,6 @@ def start_job( project_id=project_id, job_id=str(collection_job_id), trace_id=trace_id, - with_assistant=with_assistant, request=request.model_dump(mode="json"), organization_id=organization_id, ) @@ -138,7 +138,6 @@ def execute_job( organization_id: int, task_id: str, job_id: str, - with_assistant: bool, task_instance, ) -> None: """ @@ -192,7 +191,7 @@ def execute_job( with Session(engine) as session: document_crud = DocumentCrud(session, project_id) - flat_docs = document_crud.read_many_by_ids( + flat_docs = document_crud.read_each( [doc.id for doc in creation_request.collection_params.documents] ) @@ -212,11 +211,11 @@ def execute_job( llm_service_id=llm_service_id, llm_service_name=llm_service_name, collection_blob=collection_blob, + provider=creation_request.provider, ) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) - # Link documents to the new collection if flat_docs: DocumentCollectionCrud(session).create(collection, flat_docs) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 795b04cd..34eaad42 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -14,8 +14,15 @@ logger = logging.getLogger(__name__) -# llm service name for when only an openai vector store is being made -OPENAI_VECTOR_STORE = "openai vector store" + +def get_service_name(provider: str) -> str: + """Get the collection service name for a provider.""" + names = { + "openai": "openai vector store", + # "bedrock": "bedrock knowledge base", + # "gemini": "gemini file search store", + } + return names.get(provider.lower(), "") def extract_error_message(err: Exception) -> str: @@ -101,4 +108,4 @@ def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): service = ( (getattr(coll, "llm_service_name", "") or "").strip().lower() if coll else "" ) - return v_crud if service == OPENAI_VECTOR_STORE else a_crud + return v_crud if service == get_service_name("openai") else a_crud diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index ba734d85..998d8fb3 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -9,7 +9,7 @@ from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.services.collections.helpers import ( batch_documents, - OPENAI_VECTOR_STORE, + get_service_name, _backout, ) from app.models import CreateCollectionResult, CreationRequest, Collection @@ -107,7 +107,7 @@ def create( return CreateCollectionResult( llm_service_id=vector_store.id, - llm_service_name=OPENAI_VECTOR_STORE, + llm_service_name=get_service_name("openai"), collection_blob=collection_blob, ) @@ -129,7 +129,7 @@ def delete(self, collection: Collection) -> None: collection: Collection that has been requested to be deleted """ try: - if collection.llm_service_name != OPENAI_VECTOR_STORE: + if collection.llm_service_name != get_service_name("openai"): OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) logger.info( f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 90f8b80c..877f9a19 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -9,6 +9,7 @@ from app.tests.utils.collection import get_collection, get_vector_store_collection from app.crud import DocumentCollectionCrud from app.models import Collection, Document +from app.services.collections.helpers import get_service_name def link_document_to_collection( @@ -163,7 +164,7 @@ def test_collection_info_vector_store_collection( payload = data["data"] assert payload["id"] == str(collection.id) - assert payload["llm_service_name"] == "openai vector store" + assert payload["llm_service_name"] == get_service_name("openai") assert payload["llm_service_id"] == collection.llm_service_id docs = payload.get("documents", []) diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py index f7507c12..6d0e512e 100644 --- a/backend/app/tests/api/routes/collections/test_collection_list.py +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -7,6 +7,7 @@ get_collection, get_vector_store_collection, ) +from app.services.collections.helpers import get_service_name def test_list_collections_returns_api_response( @@ -101,7 +102,7 @@ def test_list_collections_includes_vector_store_collection_with_fields( row = matching[0] assert row["project_id"] == project.id - assert row["llm_service_name"] == "openai vector store" + assert row["llm_service_name"] == get_service_name("openai") assert row["llm_service_id"] == collection.llm_service_id diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 429bfc8b..cfacf0c4 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -8,8 +8,10 @@ CollectionActionType, CollectionJob, CollectionJobStatus, + ProviderType, ) from app.crud import CollectionCrud, CollectionJobCrud +from app.services.collections.helpers import get_service_name class constants: @@ -43,6 +45,7 @@ def get_collection( organization_id=project.organization_id, llm_service_name=model, llm_service_id=assistant_id, + provider=ProviderType.OPENAI, ) return CollectionCrud(db, project.id).create(collection) @@ -65,8 +68,9 @@ def get_vector_store_collection( id=collection_id or uuid4(), project_id=project.id, organization_id=project.organization_id, - llm_service_name="openai vector store", + llm_service_name=get_service_name("openai"), llm_service_id=vector_store_id, + provider=ProviderType.OPENAI, ) return CollectionCrud(db, project.id).create(collection)