diff --git a/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py new file mode 100644 index 00000000..e9ebdd11 --- /dev/null +++ b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py @@ -0,0 +1,76 @@ +"""adding blob column in collection table + +Revision ID: 041 +Revises: 040 +Create Date: 2025-12-24 11:03:44.620424 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = "041" +down_revision = "040" +branch_labels = None +depends_on = None + +provider_enum = postgresql.ENUM( + "openai", + name="providertype", + create_type=True, +) + + +def upgrade(): + provider_enum.create(op.get_bind(), checkfirst=True) + + op.add_column( + "collection", + sa.Column( + "collection_blob", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Provider-specific collection parameters (name, description, chunking params etc.)", + ), + ) + + op.add_column( + "collection", + sa.Column( + "provider", + provider_enum, + nullable=True, + comment="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini')", + ), + ) + + op.execute("UPDATE collection SET provider = 'openai' WHERE provider IS NULL") + + op.alter_column( + "collection", + "provider", + nullable=False, + existing_type=provider_enum, + ) + + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service", + existing_comment="Name of the LLM provider's service", + existing_nullable=False, + ) + + +def downgrade(): + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service provider", + existing_comment="Name of the LLM service", + existing_nullable=False, + ) + op.drop_column("collection", "provider") + op.drop_column("collection", "collection_blob") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ac7e89d6..95f4948c 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -8,9 +8,14 @@ from .collection import ( Collection, + CreateCollectionParams, + CreateCollectionResult, + CreationRequest, CollectionPublic, CollectionIDPublic, CollectionWithDocsPublic, + DeletionRequest, + ProviderType, ) from .collection_job import ( CollectionActionType, diff --git a/backend/app/models/collection/__init__.py b/backend/app/models/collection/__init__.py new file mode 100644 index 00000000..a83620b1 --- /dev/null +++ b/backend/app/models/collection/__init__.py @@ -0,0 +1,15 @@ +from app.models.collection.request import ( + Collection, + CreationRequest, + DeletionRequest, + CallbackRequest, + AssistantOptions, + CreateCollectionParams, + ProviderType, +) +from app.models.collection.response import ( + CollectionIDPublic, + CollectionPublic, + CollectionWithDocsPublic, + CreateCollectionResult, +) diff --git a/backend/app/models/collection.py b/backend/app/models/collection/request.py similarity index 53% rename from backend/app/models/collection.py rename to backend/app/models/collection/request.py index 57e5a17b..0d38f837 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection/request.py @@ -1,15 +1,26 @@ from datetime import datetime +from enum import Enum from typing import Any, Literal from uuid import UUID, uuid4 from pydantic import HttpUrl, model_validator +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, ENUM from sqlmodel import Field, Relationship, SQLModel from app.core.util import now -from app.models.document import DocumentPublic +from app.models.organization import Organization +from app.models.project import Project -from .organization import Organization -from .project import Project + +class ProviderType(str, Enum): + """Supported LLM providers for collections.""" + + OPENAI = "openai" + + +# BEDROCK = "bedrock" +# GEMINI = "gemini" class Collection(SQLModel, table=True): @@ -20,6 +31,20 @@ class Collection(SQLModel, table=True): primary_key=True, sa_column_kwargs={"comment": "Unique identifier for the collection"}, ) + + provider: ProviderType = Field( + sa_column=sa.Column( + ENUM( + "openai", + # "bedrock", + # "gemini", + name="providertype", + create_type=False, + ), + nullable=False, + comment="LLM provider used for this collection (e.g., 'openai', 'bedrock', 'gemini', etc)", + ), + ) llm_service_id: str = Field( nullable=False, sa_column_kwargs={ @@ -30,8 +55,13 @@ class Collection(SQLModel, table=True): nullable=False, sa_column_kwargs={"comment": "Name of the LLM service"}, ) - - # Foreign keys + collection_blob: dict[str, Any] | None = Field( + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Provider-specific collection parameters (name, description, chunking params etc.)", + ) + ) organization_id: int = Field( foreign_key="organization.id", nullable=False, @@ -44,8 +74,6 @@ class Collection(SQLModel, table=True): ondelete="CASCADE", sa_column_kwargs={"comment": "Reference to the project"}, ) - - # Timestamps inserted_at: datetime = Field( default_factory=now, sa_column_kwargs={"comment": "Timestamp when the collection was created"}, @@ -64,27 +92,55 @@ class Collection(SQLModel, table=True): project: Project = Relationship(back_populates="collections") -# Request models -class DocumentOptions(SQLModel): - documents: list[UUID] = Field( - description="List of document IDs", +class DocumentInput(SQLModel): + """Document to be added to knowledge base.""" + + name: str | None = Field( + description="Display name for the document", ) - batch_size: int = Field( - default=1, - description=( - "Number of documents to send to OpenAI in a single " - "transaction. See the `file_ids` parameter in the " - "vector store [create batch](https://platform.openai.com/docs/api-reference/vector-stores-file-batches/createBatch)." - ), + id: UUID = Field( + description="Reference to uploaded file/document in Kaapi", + ) + + +class CreateCollectionParams(SQLModel): + """Request-specific parameters for knowledge base creation.""" + + name: str | None = Field( + min_length=1, + description="Name of the knowledge base to create or update", + ) + description: str | None = Field( + default=None, + description="Description of the knowledge base (required by Bedrock, optional for others)", + ) + documents: list[DocumentInput] = Field( + default_factory=list, + description="List of documents to add to the knowledge base", + ) + chunking_params: dict[str, Any] | None = Field( + default=None, + description="Chunking parameters for document processing (e.g., chunk_size, chunk_overlap)", + ) + additional_params: dict[str, Any] | None = Field( + default=None, + description="Additional provider-specific parameters", ) def model_post_init(self, __context: Any): - self.documents = list(set(self.documents)) + """Deduplicate documents by document id.""" + seen = set() + unique_docs = [] + for doc in self.documents: + if doc.id not in seen: + seen.add(doc.id) + unique_docs.append(doc) + self.documents = unique_docs class AssistantOptions(SQLModel): # Fields to be passed along to OpenAI. They must be a subset of - # parameters accepted by the OpenAI.clien.beta.assistants.create + # parameters accepted by the OpenAI.client.beta.assistants.create # API. model: str | None = Field( default=None, @@ -139,6 +195,8 @@ def norm(x: Any) -> Any: class CallbackRequest(SQLModel): + """Optional callback configuration for async job notifications.""" + callback_url: HttpUrl | None = Field( default=None, description="URL to call to report endpoint status", @@ -148,45 +206,38 @@ class CallbackRequest(SQLModel): class ProviderOptions(SQLModel): """LLM provider configuration.""" - provider: Literal["openai"] = Field( - default="openai", description="LLM provider to use for this collection" + provider: ProviderType = Field( + default=ProviderType.OPENAI, + description="LLM provider to use for this collection", ) - -class CreationRequest( - DocumentOptions, - ProviderOptions, - AssistantOptions, - CallbackRequest, -): - def extract_super_type(self, cls: "CreationRequest"): - for field_name in cls.model_fields.keys(): - field_value = getattr(self, field_name) - yield (field_name, field_value) - - -class DeletionRequest(CallbackRequest): - collection_id: UUID = Field(description="Collection to delete") - - -# Response models + @model_validator(mode="before") + def normalize_provider(cls, values: dict[str, Any]) -> dict[str, Any]: + """Normalize provider value to lowercase for case-insensitive matching.""" + if isinstance(values, dict) and "provider" in values: + provider = values["provider"] + if isinstance(provider, str): + values["provider"] = provider.lower() + return values -class CollectionIDPublic(SQLModel): - id: UUID +class CreationRequest(AssistantOptions, ProviderOptions, CallbackRequest): + """API request for collection creation""" + collection_params: CreateCollectionParams = Field( + ..., + description="Collection creation specific parameters (name, documents, etc.)", + ) + batch_size: int = Field( + default=10, + ge=1, + le=500, + description="Number of documents to process in a single batch", + ) -class CollectionPublic(SQLModel): - id: UUID - llm_service_id: str - llm_service_name: str - project_id: int - organization_id: int - inserted_at: datetime - updated_at: datetime - deleted_at: datetime | None = None +class DeletionRequest(ProviderOptions, CallbackRequest): + """API request for collection deletion""" -class CollectionWithDocsPublic(CollectionPublic): - documents: list[DocumentPublic] | None = None + collection_id: UUID = Field(description="Collection to delete") diff --git a/backend/app/models/collection/response.py b/backend/app/models/collection/response.py new file mode 100644 index 00000000..f72c5ee7 --- /dev/null +++ b/backend/app/models/collection/response.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import Any +from uuid import UUID + +from sqlmodel import SQLModel + +from app.models.document import DocumentPublic + + +class CreateCollectionResult(SQLModel): + llm_service_id: str + llm_service_name: str + collection_blob: dict[str, Any] + + +class CollectionIDPublic(SQLModel): + id: UUID + + +class CollectionPublic(SQLModel): + id: UUID + llm_service_id: str + llm_service_name: str + project_id: int + organization_id: int + + inserted_at: datetime + updated_at: datetime + deleted_at: datetime | None = None + + +class CollectionWithDocsPublic(CollectionPublic): + documents: list[DocumentPublic] | None = None diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index ed83e4a8..088ebe0e 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -6,7 +6,6 @@ from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage -from app.core.util import now from app.core.db import engine from app.crud import ( CollectionCrud, @@ -14,7 +13,6 @@ DocumentCollectionCrud, CollectionJobCrud, ) -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.models import ( CollectionJobStatus, CollectionJob, @@ -22,19 +20,13 @@ CollectionJobUpdate, CollectionPublic, CollectionJobPublic, -) -from app.models.collection import ( CreationRequest, - AssistantOptions, -) -from app.services.collections.helpers import ( - _backout, - batch_documents, - extract_error_message, - OPENAI_VECTOR_STORE, + ProviderType, ) +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -60,7 +52,6 @@ def start_job( project_id=project_id, job_id=str(collection_job_id), trace_id=trace_id, - with_assistant=with_assistant, request=request.model_dump(mode="json"), organization_id=organization_id, ) @@ -116,26 +107,6 @@ def build_failure_payload(collection_job: CollectionJob, error_message: str) -> ) -def _cleanup_remote_resources( - assistant, - assistant_crud, - vector_store, - vector_store_crud, -) -> None: - """Best-effort cleanup of partially created remote resources.""" - try: - if assistant is not None and assistant_crud is not None: - _backout(assistant_crud, assistant.id) - elif vector_store is not None and vector_store_crud is not None: - _backout(vector_store_crud, vector_store.id) - else: - logger.warning( - "[create_collection._backout] Skipping: no resource/crud available" - ) - except Exception: - logger.warning("[create_collection.execute_job] Backout failed") - - def _mark_job_failed( project_id: int, job_id: str, @@ -167,22 +138,19 @@ def execute_job( organization_id: int, task_id: str, job_id: str, - with_assistant: bool, task_instance, ) -> None: """ Worker entrypoint scheduled by start_job. - Orchestrates: job state, client/storage init, batching, vector-store upload, + Orchestrates: job state, provider init, collection creation, optional assistant creation, collection persistence, linking, callbacks, and cleanup. """ start_time = time.time() - # Keep references for potential backout/cleanup on failure - assistant = None - assistant_crud = None - vector_store = None - vector_store_crud = None + # Keeping the references for potential backout/cleanup on failure collection_job = None + result = None + provider = None try: creation_request = CreationRequest(**request) @@ -199,49 +167,32 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) storage = get_cloud_storage(session=session, project_id=project_id) - - # Batch documents for upload, and flatten for linking/metrics later document_crud = DocumentCrud(session, project_id) - docs_batches = batch_documents( - document_crud, - creation_request.documents, - creation_request.batch_size, + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, ) - flat_docs = [doc for batch in docs_batches for doc in batch] - vector_store_crud = OpenAIVectorStoreCrud(client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + result = provider.create( + collection_request=creation_request, + storage=storage, + document_crud=document_crud, + ) - # if with_assistant is true, create assistant backed by the vector store - if with_assistant: - assistant_crud = OpenAIAssistantCrud(client) + llm_service_id = result.llm_service_id + llm_service_name = result.llm_service_name + # Storing collection params (name, description, chunking_params, etc.) in DB + # for future reference and to support different providers with varying configurations + collection_blob = result.collection_blob - # Filter out None to avoid sending unset options - assistant_options = dict( - creation_request.extract_super_type(AssistantOptions) - ) - assistant_options = { - k: v for k, v in assistant_options.items() if v is not None - } - - assistant = assistant_crud.create(vector_store.id, **assistant_options) - llm_service_id = assistant.id - llm_service_name = assistant_options.get("model") or "assistant" - - logger.info( - "[execute_job] Assistant created | assistant_id=%s, vector_store_id=%s", - assistant.id, - vector_store.id, - ) - else: - # If no assistant, the collection points directly at the vector store - llm_service_id = vector_store.id - llm_service_name = OPENAI_VECTOR_STORE - logger.info( - "[execute_job] Skipping assistant creation | with_assistant=False" + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + flat_docs = document_crud.read_each( + [doc.id for doc in creation_request.collection_params.documents] ) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} @@ -259,11 +210,12 @@ def execute_job( organization_id=organization_id, llm_service_id=llm_service_id, llm_service_name=llm_service_name, + collection_blob=collection_blob, + provider=creation_request.provider, ) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) - # Link documents to the new collection if flat_docs: DocumentCollectionCrud(session).create(collection, flat_docs) @@ -299,12 +251,13 @@ def execute_job( exc_info=True, ) - _cleanup_remote_resources( - assistant=assistant, - assistant_crud=assistant_crud, - vector_store=vector_store, - vector_store_crud=vector_store_crud, - ) + if provider is not None and result is not None: + try: + provider.cleanup(result) + except Exception: + logger.warning( + "[create_collection.execute_job] Provider cleanup failed" + ) collection_job = _mark_job_failed( project_id=project_id, diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index ca337b79..e9570964 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -6,7 +6,6 @@ from app.core.db import engine from app.crud import CollectionCrud, CollectionJobCrud -from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud from app.models import ( CollectionJobStatus, CollectionJobUpdate, @@ -15,9 +14,10 @@ CollectionIDPublic, ) from app.models.collection import DeletionRequest -from app.services.collections.helpers import extract_error_message, OPENAI_VECTOR_STORE +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -155,7 +155,6 @@ def execute_job( job_uuid = UUID(job_id) collection_job = None - client = None try: with Session(engine) as session: @@ -169,20 +168,16 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) - collection = CollectionCrud(session, project_id).read_one(collection_id) - # Identify which external service (assistant/vector store) this collection belongs to - service = (collection.llm_service_name or "").strip().lower() - is_vector = service == OPENAI_VECTOR_STORE - llm_service_id = collection.llm_service_id + provider = get_llm_provider( + session=session, + provider=deletion_request.provider, + project_id=project_id, + organization_id=organization_id, + ) - # Delete the corresponding OpenAI resource (vector store or assistant) - if is_vector: - OpenAIVectorStoreCrud(client).delete(llm_service_id) - else: - OpenAIAssistantCrud(client).delete(llm_service_id) + provider.delete(collection) with Session(engine) as session: CollectionCrud(session, project_id).delete_by_id(collection_id) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 795b04cd..34eaad42 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -14,8 +14,15 @@ logger = logging.getLogger(__name__) -# llm service name for when only an openai vector store is being made -OPENAI_VECTOR_STORE = "openai vector store" + +def get_service_name(provider: str) -> str: + """Get the collection service name for a provider.""" + names = { + "openai": "openai vector store", + # "bedrock": "bedrock knowledge base", + # "gemini": "gemini file search store", + } + return names.get(provider.lower(), "") def extract_error_message(err: Exception) -> str: @@ -101,4 +108,4 @@ def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): service = ( (getattr(coll, "llm_service_name", "") or "").strip().lower() if coll else "" ) - return v_crud if service == OPENAI_VECTOR_STORE else a_crud + return v_crud if service == get_service_name("openai") else a_crud diff --git a/backend/app/services/collections/providers/__init__.py b/backend/app/services/collections/providers/__init__.py new file mode 100644 index 00000000..5a9b6a55 --- /dev/null +++ b/backend/app/services/collections/providers/__init__.py @@ -0,0 +1,6 @@ +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider +from app.services.collections.providers.registry import ( + LLMProvider, + get_llm_provider, +) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py new file mode 100644 index 00000000..9fb21f3e --- /dev/null +++ b/backend/app/services/collections/providers/base.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Any + +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.models import CreationRequest, CreateCollectionResult, Collection + + +class BaseProvider(ABC): + """Abstract base class for collection providers. + + All provider implementations (OpenAI, Bedrock, etc.) must inherit from + this class and implement the required methods. + + Providers handle creation of knowledge bases (vector stores) and + optional assistant/agent creation backed by those knowledge bases. + + Attributes: + client: The provider-specific client instance + """ + + def __init__(self, client: Any): + """Initialize provider with client. + + Args: + client: Provider-specific client instance + """ + self.client = client + + @abstractmethod + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> CreateCollectionResult: + """Create collection with documents and optionally an assistant. + + Args: + collection_params: Collection parameters (name, description, chunking_params, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant/agent + assistant_options: Options for assistant creation (provider-specific) + + Returns: + CreateCollectionresult containing: + - llm_service_id: ID of the created resource (vector store or assistant) + - llm_service_name: Name of the service + - kb_blob: All collection params except documents + """ + raise NotImplementedError("Providers must implement execute method") + + @abstractmethod + def delete(self, collection: Collection) -> None: + """Delete remote resources associated with a collection. + + Called when a collection is being deleted and remote resources need to be cleaned up. + + Args: + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) + """ + raise NotImplementedError("Providers must implement delete method") + + @abstractmethod + def cleanup(self, collection_result: CreateCollectionResult) -> None: + """Clean up/rollback resources created during execute. + + Called when collection creation fails and remote resources need to be deleted. + + Args: + collection_result: The CreateCollectionresult returned from execute, containing resource IDs + """ + raise NotImplementedError("Providers must implement cleanup method") + + def get_provider_name(self) -> str: + """Get the name of the provider. + + Returns: + Provider name (e.g., "openai", "bedrock", "pinecone") + """ + return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py new file mode 100644 index 00000000..998d8fb3 --- /dev/null +++ b/backend/app/services/collections/providers/openai.py @@ -0,0 +1,160 @@ +import logging +from typing import Any + +from openai import OpenAI + +from app.services.collections.providers import BaseProvider +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud +from app.services.collections.helpers import ( + batch_documents, + get_service_name, + _backout, +) +from app.models import CreateCollectionResult, CreationRequest, Collection + + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(BaseProvider): + """OpenAI-specific collection provider for vector stores and assistants.""" + + def __init__(self, client: OpenAI): + super().__init__(client) + self.client = client + + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> CreateCollectionResult: + """Create OpenAI vector store with documents and optionally an assistant. + + Args: + collection_params: Collection parameters (name, description, chunking_params, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant + assistant_options: Options for assistant creation (model, instructions, etc.) + + Returns: + CreateCollectionResult containing llm_service_id, llm_service_name, and collection_blob + """ + try: + collection_params = collection_request.collection_params + document_ids = [doc.id for doc in collection_params.documents] + + docs_batches = batch_documents( + document_crud, + document_ids, + collection_request.batch_size, + ) + + vector_store_crud = OpenAIVectorStoreCrud(self.client) + vector_store = vector_store_crud.create() + + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + + logger.info( + "[OpenAIProvider.execute] Vector store created | " + f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" + ) + + collection_blob = { + "name": collection_params.name, + "description": collection_params.description, + "chunking_params": collection_params.chunking_params, + "additional_params": collection_params.additional_params, + } + + # Check if we need to create an assistant (based on assistant options in request) + with_assistant = ( + collection_request.model is not None + and collection_request.instructions is not None + ) + if with_assistant: + assistant_crud = OpenAIAssistantCrud(self.client) + + assistant_options = { + "model": collection_request.model, + "instructions": collection_request.instructions, + "temperature": collection_request.temperature, + } + filtered_options = { + k: v for k, v in assistant_options.items() if v is not None + } + + assistant = assistant_crud.create(vector_store.id, **filtered_options) + + logger.info( + "[OpenAIProvider.execute] Assistant created | " + f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + ) + + return CreateCollectionResult( + llm_service_id=assistant.id, + llm_service_name=filtered_options.get("model", "assistant"), + collection_blob=collection_blob, + ) + else: + logger.info( + "[OpenAIProvider.execute] Skipping assistant creation | with_assistant=False" + ) + + return CreateCollectionResult( + llm_service_id=vector_store.id, + llm_service_name=get_service_name("openai"), + collection_blob=collection_blob, + ) + + except Exception as e: + logger.error( + f"[OpenAIProvider.execute] Failed to create knowledge base: {str(e)}", + exc_info=True, + ) + raise + + def delete(self, collection: Collection) -> None: + """Delete OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + + Args: + collection: Collection that has been requested to be deleted + """ + try: + if collection.llm_service_name != get_service_name("openai"): + OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" + ) + else: + OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" + ) + except Exception as e: + logger.error( + f"[OpenAIProvider.delete] Failed to delete resource | " + f"llm_service_id={collection.llm_service_id}, error={str(e)}", + exc_info=True, + ) + raise + + def cleanup(self, result: CreateCollectionResult) -> None: + """Clean up OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + + Args: + result: The CreateCollectionResult from execute containing resource IDs + """ + _backout(result.llm_service_id, result.llm_service_name) diff --git a/backend/app/services/collections/providers/registry.py b/backend/app/services/collections/providers/registry.py new file mode 100644 index 00000000..10d07d45 --- /dev/null +++ b/backend/app/services/collections/providers/registry.py @@ -0,0 +1,71 @@ +import logging + +from sqlmodel import Session +from openai import OpenAI + +from app.crud import get_provider_credential +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider + + +logger = logging.getLogger(__name__) + + +class LLMProvider: + OPENAI = "openai" + # Future constants for providers: + # ANTHROPIC = "ANTHROPIC" + # GEMINI = "gemini" + + _registry: dict[str, type[BaseProvider]] = { + OPENAI: OpenAIProvider, + # Future providers: + # ANTHROPIC: BedrockProvider, + # GEMINI: GeminiProvider, + } + + @classmethod + def get(cls, name: str) -> type[BaseProvider]: + """Return the provider class for a given name.""" + provider = cls._registry.get(name) + if not provider: + raise ValueError( + f"Provider '{name}' is not supported. " + f"Supported providers: {', '.join(cls._registry.keys())}" + ) + return provider + + @classmethod + def supported_providers(cls) -> list[str]: + """Return a list of supported provider names.""" + return list(cls._registry.keys()) + + +def get_llm_provider( + session: Session, provider: str, project_id: int, organization_id: int +) -> BaseProvider: + provider_class = LLMProvider.get(provider) + + credentials = get_provider_credential( + session=session, + provider=provider, + project_id=project_id, + org_id=organization_id, + ) + + if not credentials: + raise ValueError( + f"Credentials for provider '{provider}' not configured for this project." + ) + + if provider == LLMProvider.OPENAI: + if "api_key" not in credentials: + raise ValueError("OpenAI credentials not configured for this project.") + client = OpenAI(api_key=credentials["api_key"]) + else: + logger.error( + f"[get_llm_provider] Unsupported provider type requested: {provider}" + ) + raise ValueError(f"Provider '{provider}' is not supported.") + + return provider_class(client=client) diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 90f8b80c..877f9a19 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -9,6 +9,7 @@ from app.tests.utils.collection import get_collection, get_vector_store_collection from app.crud import DocumentCollectionCrud from app.models import Collection, Document +from app.services.collections.helpers import get_service_name def link_document_to_collection( @@ -163,7 +164,7 @@ def test_collection_info_vector_store_collection( payload = data["data"] assert payload["id"] == str(collection.id) - assert payload["llm_service_name"] == "openai vector store" + assert payload["llm_service_name"] == get_service_name("openai") assert payload["llm_service_id"] == collection.llm_service_id docs = payload.get("documents", []) diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py index f7507c12..6d0e512e 100644 --- a/backend/app/tests/api/routes/collections/test_collection_list.py +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -7,6 +7,7 @@ get_collection, get_vector_store_collection, ) +from app.services.collections.helpers import get_service_name def test_list_collections_returns_api_response( @@ -101,7 +102,7 @@ def test_list_collections_includes_vector_store_collection_with_fields( row = matching[0] assert row["project_id"] == project.id - assert row["llm_service_name"] == "openai vector store" + assert row["llm_service_name"] == get_service_name("openai") assert row["llm_service_id"] == collection.llm_service_id diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 429bfc8b..cfacf0c4 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -8,8 +8,10 @@ CollectionActionType, CollectionJob, CollectionJobStatus, + ProviderType, ) from app.crud import CollectionCrud, CollectionJobCrud +from app.services.collections.helpers import get_service_name class constants: @@ -43,6 +45,7 @@ def get_collection( organization_id=project.organization_id, llm_service_name=model, llm_service_id=assistant_id, + provider=ProviderType.OPENAI, ) return CollectionCrud(db, project.id).create(collection) @@ -65,8 +68,9 @@ def get_vector_store_collection( id=collection_id or uuid4(), project_id=project.id, organization_id=project.organization_id, - llm_service_name="openai vector store", + llm_service_name=get_service_name("openai"), llm_service_id=vector_store_id, + provider=ProviderType.OPENAI, ) return CollectionCrud(db, project.id).create(collection)