1919from datetime import timezone
2020import json
2121import logging
22- import pickle
2322from typing import Any
2423from typing import Optional
2524import uuid
2625
27- from google .genai import types
28- from sqlalchemy import Boolean
2926from sqlalchemy import delete
3027from sqlalchemy import Dialect
3128from sqlalchemy import event
3229from sqlalchemy import ForeignKeyConstraint
3330from sqlalchemy import func
31+ from sqlalchemy import inspect
3432from sqlalchemy import select
3533from sqlalchemy import Text
3634from sqlalchemy .dialects import mysql
4139from sqlalchemy .ext .asyncio import AsyncSession as DatabaseSessionFactory
4240from sqlalchemy .ext .asyncio import create_async_engine
4341from sqlalchemy .ext .mutable import MutableDict
44- from sqlalchemy .inspection import inspect
4542from sqlalchemy .orm import DeclarativeBase
4643from sqlalchemy .orm import Mapped
4744from sqlalchemy .orm import mapped_column
4845from sqlalchemy .orm import relationship
49- from sqlalchemy .schema import MetaData
5046from sqlalchemy .types import DateTime
51- from sqlalchemy .types import PickleType
5247from sqlalchemy .types import String
5348from sqlalchemy .types import TypeDecorator
5449from typing_extensions import override
5752from . import _session_util
5853from ..errors .already_exists_error import AlreadyExistsError
5954from ..events .event import Event
60- from ..events .event_actions import EventActions
6155from .base_session_service import BaseSessionService
6256from .base_session_service import GetSessionConfig
6357from .base_session_service import ListSessionsResponse
58+ from .migration import _schema_check
6459from .session import Session
6560from .state import State
6661
@@ -111,41 +106,22 @@ def load_dialect_impl(self, dialect):
111106 return self .impl
112107
113108
114- class DynamicPickleType (TypeDecorator ):
115- """Represents a type that can be pickled."""
116-
117- impl = PickleType
118-
119- def load_dialect_impl (self , dialect ):
120- if dialect .name == "mysql" :
121- return dialect .type_descriptor (mysql .LONGBLOB )
122- if dialect .name == "spanner+spanner" :
123- from google .cloud .sqlalchemy_spanner .sqlalchemy_spanner import SpannerPickleType
124-
125- return dialect .type_descriptor (SpannerPickleType )
126- return self .impl
127-
128- def process_bind_param (self , value , dialect ):
129- """Ensures the pickled value is a bytes object before passing it to the database dialect."""
130- if value is not None :
131- if dialect .name in ("spanner+spanner" , "mysql" ):
132- return pickle .dumps (value )
133- return value
134-
135- def process_result_value (self , value , dialect ):
136- """Ensures the raw bytes from the database are unpickled back into a Python object."""
137- if value is not None :
138- if dialect .name in ("spanner+spanner" , "mysql" ):
139- return pickle .loads (value )
140- return value
141-
142-
143109class Base (DeclarativeBase ):
144110 """Base class for database tables."""
145111
146112 pass
147113
148114
115+ class StorageMetadata (Base ):
116+ """Represents internal metadata stored in the database."""
117+
118+ __tablename__ = "adk_internal_metadata"
119+ key : Mapped [str ] = mapped_column (
120+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
121+ )
122+ value : Mapped [str ] = mapped_column (String (DEFAULT_MAX_VARCHAR_LENGTH ))
123+
124+
149125class StorageSession (Base ):
150126 """Represents a session stored in the database."""
151127
@@ -237,46 +213,10 @@ class StorageEvent(Base):
237213 )
238214
239215 invocation_id : Mapped [str ] = mapped_column (String (DEFAULT_MAX_VARCHAR_LENGTH ))
240- author : Mapped [str ] = mapped_column (String (DEFAULT_MAX_VARCHAR_LENGTH ))
241- actions : Mapped [MutableDict [str , Any ]] = mapped_column (DynamicPickleType )
242- long_running_tool_ids_json : Mapped [Optional [str ]] = mapped_column (
243- Text , nullable = True
244- )
245- branch : Mapped [str ] = mapped_column (
246- String (DEFAULT_MAX_VARCHAR_LENGTH ), nullable = True
247- )
248216 timestamp : Mapped [PreciseTimestamp ] = mapped_column (
249217 PreciseTimestamp , default = func .now ()
250218 )
251-
252- # === Fields from llm_response.py ===
253- content : Mapped [dict [str , Any ]] = mapped_column (DynamicJSON , nullable = True )
254- grounding_metadata : Mapped [dict [str , Any ]] = mapped_column (
255- DynamicJSON , nullable = True
256- )
257- custom_metadata : Mapped [dict [str , Any ]] = mapped_column (
258- DynamicJSON , nullable = True
259- )
260- usage_metadata : Mapped [dict [str , Any ]] = mapped_column (
261- DynamicJSON , nullable = True
262- )
263- citation_metadata : Mapped [dict [str , Any ]] = mapped_column (
264- DynamicJSON , nullable = True
265- )
266-
267- partial : Mapped [bool ] = mapped_column (Boolean , nullable = True )
268- turn_complete : Mapped [bool ] = mapped_column (Boolean , nullable = True )
269- error_code : Mapped [str ] = mapped_column (
270- String (DEFAULT_MAX_VARCHAR_LENGTH ), nullable = True
271- )
272- error_message : Mapped [str ] = mapped_column (String (1024 ), nullable = True )
273- interrupted : Mapped [bool ] = mapped_column (Boolean , nullable = True )
274- input_transcription : Mapped [dict [str , Any ]] = mapped_column (
275- DynamicJSON , nullable = True
276- )
277- output_transcription : Mapped [dict [str , Any ]] = mapped_column (
278- DynamicJSON , nullable = True
279- )
219+ event_data : Mapped [dict [str , Any ]] = mapped_column (DynamicJSON )
280220
281221 storage_session : Mapped [StorageSession ] = relationship (
282222 "StorageSession" ,
@@ -291,102 +231,27 @@ class StorageEvent(Base):
291231 ),
292232 )
293233
294- @property
295- def long_running_tool_ids (self ) -> set [str ]:
296- return (
297- set (json .loads (self .long_running_tool_ids_json ))
298- if self .long_running_tool_ids_json
299- else set ()
300- )
301-
302- @long_running_tool_ids .setter
303- def long_running_tool_ids (self , value : set [str ]):
304- if value is None :
305- self .long_running_tool_ids_json = None
306- else :
307- self .long_running_tool_ids_json = json .dumps (list (value ))
308-
309234 @classmethod
310235 def from_event (cls , session : Session , event : Event ) -> StorageEvent :
311- storage_event = StorageEvent (
236+ """Creates a StorageEvent from an Event."""
237+ return StorageEvent (
312238 id = event .id ,
313239 invocation_id = event .invocation_id ,
314- author = event .author ,
315- branch = event .branch ,
316- actions = event .actions ,
317240 session_id = session .id ,
318241 app_name = session .app_name ,
319242 user_id = session .user_id ,
320243 timestamp = datetime .fromtimestamp (event .timestamp ),
321- long_running_tool_ids = event .long_running_tool_ids ,
322- partial = event .partial ,
323- turn_complete = event .turn_complete ,
324- error_code = event .error_code ,
325- error_message = event .error_message ,
326- interrupted = event .interrupted ,
244+ event_data = event .model_dump (exclude_none = True , mode = "json" ),
327245 )
328- if event .content :
329- storage_event .content = event .content .model_dump (
330- exclude_none = True , mode = "json"
331- )
332- if event .grounding_metadata :
333- storage_event .grounding_metadata = event .grounding_metadata .model_dump (
334- exclude_none = True , mode = "json"
335- )
336- if event .custom_metadata :
337- storage_event .custom_metadata = event .custom_metadata
338- if event .usage_metadata :
339- storage_event .usage_metadata = event .usage_metadata .model_dump (
340- exclude_none = True , mode = "json"
341- )
342- if event .citation_metadata :
343- storage_event .citation_metadata = event .citation_metadata .model_dump (
344- exclude_none = True , mode = "json"
345- )
346- if event .input_transcription :
347- storage_event .input_transcription = event .input_transcription .model_dump (
348- exclude_none = True , mode = "json"
349- )
350- if event .output_transcription :
351- storage_event .output_transcription = (
352- event .output_transcription .model_dump (exclude_none = True , mode = "json" )
353- )
354- return storage_event
355246
356247 def to_event (self ) -> Event :
357- return Event (
358- id = self .id ,
359- invocation_id = self .invocation_id ,
360- author = self .author ,
361- branch = self .branch ,
362- # This is needed as previous ADK version pickled actions might not have
363- # value defined in the current version of the EventActions model.
364- actions = EventActions ().model_copy (update = self .actions .model_dump ()),
365- timestamp = self .timestamp .timestamp (),
366- long_running_tool_ids = self .long_running_tool_ids ,
367- partial = self .partial ,
368- turn_complete = self .turn_complete ,
369- error_code = self .error_code ,
370- error_message = self .error_message ,
371- interrupted = self .interrupted ,
372- custom_metadata = self .custom_metadata ,
373- content = _session_util .decode_model (self .content , types .Content ),
374- grounding_metadata = _session_util .decode_model (
375- self .grounding_metadata , types .GroundingMetadata
376- ),
377- usage_metadata = _session_util .decode_model (
378- self .usage_metadata , types .GenerateContentResponseUsageMetadata
379- ),
380- citation_metadata = _session_util .decode_model (
381- self .citation_metadata , types .CitationMetadata
382- ),
383- input_transcription = _session_util .decode_model (
384- self .input_transcription , types .Transcription
385- ),
386- output_transcription = _session_util .decode_model (
387- self .output_transcription , types .Transcription
388- ),
389- )
248+ """Converts the StorageEvent to an Event."""
249+ return Event .model_validate ({
250+ ** self .event_data ,
251+ "id" : self .id ,
252+ "invocation_id" : self .invocation_id ,
253+ "timestamp" : self .timestamp .timestamp (),
254+ })
390255
391256
392257class StorageAppState (Base ):
@@ -463,7 +328,6 @@ def __init__(self, db_url: str, **kwargs: Any):
463328 logger .info ("Local timezone: %s" , local_timezone )
464329
465330 self .db_engine : AsyncEngine = db_engine
466- self .metadata : MetaData = MetaData ()
467331
468332 # DB session factory method
469333 self .database_session_factory : async_sessionmaker [
@@ -483,10 +347,46 @@ async def _ensure_tables_created(self):
483347 async with self ._table_creation_lock :
484348 # Double-check after acquiring the lock
485349 if not self ._tables_created :
350+ # Check schema version BEFORE creating tables.
351+ # This prevents creating metadata table on a v0.1 DB.
352+ async with self .database_session_factory () as sql_session :
353+ version , is_v01 = await sql_session .run_sync (
354+ _schema_check .get_version_and_v01_status_sync
355+ )
356+
357+ if is_v01 :
358+ raise RuntimeError (
359+ "Database schema appears to be v0.1, but"
360+ f" { _schema_check .CURRENT_SCHEMA_VERSION } is required. Please"
361+ " migrate the database using 'adk migrate session'."
362+ )
363+ elif version and version < _schema_check .CURRENT_SCHEMA_VERSION :
364+ raise RuntimeError (
365+ f"Database schema version is { version } , but current version is"
366+ f" { _schema_check .CURRENT_SCHEMA_VERSION } . Please migrate"
367+ " the database to the latest version using 'adk migrate"
368+ " session'."
369+ )
370+
486371 async with self .db_engine .begin () as conn :
487372 # Uncomment to recreate DB every time
488373 # await conn.run_sync(Base.metadata.drop_all)
489374 await conn .run_sync (Base .metadata .create_all )
375+
376+ # If we are here, DB is either new or >= current version.
377+ # If new or without metadata row, stamp it as current version.
378+ async with self .database_session_factory () as sql_session :
379+ metadata = await sql_session .get (
380+ StorageMetadata , _schema_check .SCHEMA_VERSION_KEY
381+ )
382+ if not metadata :
383+ sql_session .add (
384+ StorageMetadata (
385+ key = _schema_check .SCHEMA_VERSION_KEY ,
386+ value = _schema_check .CURRENT_SCHEMA_VERSION ,
387+ )
388+ )
389+ await sql_session .commit ()
490390 self ._tables_created = True
491391
492392 @override
@@ -723,7 +623,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
723623 storage_session .state = storage_session .state | session_state_delta
724624
725625 if storage_session ._dialect_name == "sqlite" :
726- update_time = datetime .utcfromtimestamp (event .timestamp )
626+ update_time = datetime .fromtimestamp (
627+ event .timestamp , timezone .utc
628+ ).replace (tzinfo = None )
727629 else :
728630 update_time = datetime .fromtimestamp (event .timestamp )
729631 storage_session .update_time = update_time
0 commit comments