Skip to content

Commit 0094eea

Browse files
DeanChensjcopybara-github
authored andcommitted
feat!: Migrate DatabaseSessionService to use JSON serialization schema
Also provide a command line tool `adk migrate session` for DB migration Addresses #3605 Addresses #3681 To verify: ``` # Start one postgres DB docker run --name my-postgres -d -e POSTGRES_DB=agent -e POSTGRES_USER=agent -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v pgvolume:/var/lib/postgresql/data -p 5532:5432 postgres # Connect to an old version of ADK and produce some query data adk web --session_service_uri=postgresql://agent:agent@localhost:5532/agent # Check out to the latest branch and restart ADK web # You should see error log ask you to migrate the DB # Start a new DB docker run --name migration-test-db \ -d \ --rm \ -e POSTGRES_DB=agent \ -e POSTGRES_USER=agent \ -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v migration_test_vol:/var/lib/postgresql/data -p 5533:5432 postgres # DB Migration adk migrate session \ --source_db_url="postgresql://agent:agent@localhost:5532/agent" \ --dest_db_url="postgresql://agent:agent@localhost:5533/agent" # Run ADK web with the new DB adk web --session_service_uri=postgresql+asyncpg://agent:agent@localhost:5533/agent # You should see the data from old DB is migrated ``` Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 837341139
1 parent 786aaed commit 0094eea

File tree

9 files changed

+939
-342
lines changed

9 files changed

+939
-342
lines changed

src/google/adk/cli/cli_tools_click.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from . import cli_deploy
3737
from .. import version
3838
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
39+
from ..sessions.migration import migration_runner
3940
from .cli import run_cli
4041
from .fast_api import get_fast_api_app
4142
from .utils import envs
@@ -1485,6 +1486,41 @@ def cli_deploy_cloud_run(
14851486
click.secho(f"Deploy failed: {e}", fg="red", err=True)
14861487

14871488

1489+
@main.group()
1490+
def migrate():
1491+
"""Migrate ADK database schemas."""
1492+
pass
1493+
1494+
1495+
@migrate.command("session", cls=HelpfulCommand)
1496+
@click.option(
1497+
"--source_db_url",
1498+
required=True,
1499+
help="SQLAlchemy URL of source database.",
1500+
)
1501+
@click.option(
1502+
"--dest_db_url",
1503+
required=True,
1504+
help="SQLAlchemy URL of destination database.",
1505+
)
1506+
@click.option(
1507+
"--log_level",
1508+
type=LOG_LEVELS,
1509+
default="INFO",
1510+
help="Optional. Set the logging level",
1511+
)
1512+
def cli_migrate_session(
1513+
*, source_db_url: str, dest_db_url: str, log_level: str
1514+
):
1515+
"""Migrates a session database to the latest schema version."""
1516+
logs.setup_adk_logger(getattr(logging, log_level.upper()))
1517+
try:
1518+
migration_runner.upgrade(source_db_url, dest_db_url)
1519+
click.secho("Migration check and upgrade process finished.", fg="green")
1520+
except Exception as e:
1521+
click.secho(f"Migration failed: {e}", fg="red", err=True)
1522+
1523+
14881524
@deploy.command("agent_engine")
14891525
@click.option(
14901526
"--api_key",

src/google/adk/sessions/database_session_service.py

Lines changed: 62 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,16 @@
1919
from datetime import timezone
2020
import json
2121
import logging
22-
import pickle
2322
from typing import Any
2423
from typing import Optional
2524
import uuid
2625

27-
from google.genai import types
28-
from sqlalchemy import Boolean
2926
from sqlalchemy import delete
3027
from sqlalchemy import Dialect
3128
from sqlalchemy import event
3229
from sqlalchemy import ForeignKeyConstraint
3330
from sqlalchemy import func
31+
from sqlalchemy import inspect
3432
from sqlalchemy import select
3533
from sqlalchemy import Text
3634
from sqlalchemy.dialects import mysql
@@ -41,14 +39,11 @@
4139
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
4240
from sqlalchemy.ext.asyncio import create_async_engine
4341
from sqlalchemy.ext.mutable import MutableDict
44-
from sqlalchemy.inspection import inspect
4542
from sqlalchemy.orm import DeclarativeBase
4643
from sqlalchemy.orm import Mapped
4744
from sqlalchemy.orm import mapped_column
4845
from sqlalchemy.orm import relationship
49-
from sqlalchemy.schema import MetaData
5046
from sqlalchemy.types import DateTime
51-
from sqlalchemy.types import PickleType
5247
from sqlalchemy.types import String
5348
from sqlalchemy.types import TypeDecorator
5449
from typing_extensions import override
@@ -57,10 +52,10 @@
5752
from . import _session_util
5853
from ..errors.already_exists_error import AlreadyExistsError
5954
from ..events.event import Event
60-
from ..events.event_actions import EventActions
6155
from .base_session_service import BaseSessionService
6256
from .base_session_service import GetSessionConfig
6357
from .base_session_service import ListSessionsResponse
58+
from .migration import _schema_check
6459
from .session import Session
6560
from .state import State
6661

@@ -111,41 +106,22 @@ def load_dialect_impl(self, dialect):
111106
return self.impl
112107

113108

114-
class DynamicPickleType(TypeDecorator):
115-
"""Represents a type that can be pickled."""
116-
117-
impl = PickleType
118-
119-
def load_dialect_impl(self, dialect):
120-
if dialect.name == "mysql":
121-
return dialect.type_descriptor(mysql.LONGBLOB)
122-
if dialect.name == "spanner+spanner":
123-
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType
124-
125-
return dialect.type_descriptor(SpannerPickleType)
126-
return self.impl
127-
128-
def process_bind_param(self, value, dialect):
129-
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
130-
if value is not None:
131-
if dialect.name in ("spanner+spanner", "mysql"):
132-
return pickle.dumps(value)
133-
return value
134-
135-
def process_result_value(self, value, dialect):
136-
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
137-
if value is not None:
138-
if dialect.name in ("spanner+spanner", "mysql"):
139-
return pickle.loads(value)
140-
return value
141-
142-
143109
class Base(DeclarativeBase):
144110
"""Base class for database tables."""
145111

146112
pass
147113

148114

115+
class StorageMetadata(Base):
116+
"""Represents internal metadata stored in the database."""
117+
118+
__tablename__ = "adk_internal_metadata"
119+
key: Mapped[str] = mapped_column(
120+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
121+
)
122+
value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
123+
124+
149125
class StorageSession(Base):
150126
"""Represents a session stored in the database."""
151127

@@ -237,46 +213,10 @@ class StorageEvent(Base):
237213
)
238214

239215
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
240-
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
241-
actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType)
242-
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
243-
Text, nullable=True
244-
)
245-
branch: Mapped[str] = mapped_column(
246-
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
247-
)
248216
timestamp: Mapped[PreciseTimestamp] = mapped_column(
249217
PreciseTimestamp, default=func.now()
250218
)
251-
252-
# === Fields from llm_response.py ===
253-
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
254-
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
255-
DynamicJSON, nullable=True
256-
)
257-
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
258-
DynamicJSON, nullable=True
259-
)
260-
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
261-
DynamicJSON, nullable=True
262-
)
263-
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
264-
DynamicJSON, nullable=True
265-
)
266-
267-
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
268-
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
269-
error_code: Mapped[str] = mapped_column(
270-
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
271-
)
272-
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
273-
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
274-
input_transcription: Mapped[dict[str, Any]] = mapped_column(
275-
DynamicJSON, nullable=True
276-
)
277-
output_transcription: Mapped[dict[str, Any]] = mapped_column(
278-
DynamicJSON, nullable=True
279-
)
219+
event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
280220

281221
storage_session: Mapped[StorageSession] = relationship(
282222
"StorageSession",
@@ -291,102 +231,27 @@ class StorageEvent(Base):
291231
),
292232
)
293233

294-
@property
295-
def long_running_tool_ids(self) -> set[str]:
296-
return (
297-
set(json.loads(self.long_running_tool_ids_json))
298-
if self.long_running_tool_ids_json
299-
else set()
300-
)
301-
302-
@long_running_tool_ids.setter
303-
def long_running_tool_ids(self, value: set[str]):
304-
if value is None:
305-
self.long_running_tool_ids_json = None
306-
else:
307-
self.long_running_tool_ids_json = json.dumps(list(value))
308-
309234
@classmethod
310235
def from_event(cls, session: Session, event: Event) -> StorageEvent:
311-
storage_event = StorageEvent(
236+
"""Creates a StorageEvent from an Event."""
237+
return StorageEvent(
312238
id=event.id,
313239
invocation_id=event.invocation_id,
314-
author=event.author,
315-
branch=event.branch,
316-
actions=event.actions,
317240
session_id=session.id,
318241
app_name=session.app_name,
319242
user_id=session.user_id,
320243
timestamp=datetime.fromtimestamp(event.timestamp),
321-
long_running_tool_ids=event.long_running_tool_ids,
322-
partial=event.partial,
323-
turn_complete=event.turn_complete,
324-
error_code=event.error_code,
325-
error_message=event.error_message,
326-
interrupted=event.interrupted,
244+
event_data=event.model_dump(exclude_none=True, mode="json"),
327245
)
328-
if event.content:
329-
storage_event.content = event.content.model_dump(
330-
exclude_none=True, mode="json"
331-
)
332-
if event.grounding_metadata:
333-
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
334-
exclude_none=True, mode="json"
335-
)
336-
if event.custom_metadata:
337-
storage_event.custom_metadata = event.custom_metadata
338-
if event.usage_metadata:
339-
storage_event.usage_metadata = event.usage_metadata.model_dump(
340-
exclude_none=True, mode="json"
341-
)
342-
if event.citation_metadata:
343-
storage_event.citation_metadata = event.citation_metadata.model_dump(
344-
exclude_none=True, mode="json"
345-
)
346-
if event.input_transcription:
347-
storage_event.input_transcription = event.input_transcription.model_dump(
348-
exclude_none=True, mode="json"
349-
)
350-
if event.output_transcription:
351-
storage_event.output_transcription = (
352-
event.output_transcription.model_dump(exclude_none=True, mode="json")
353-
)
354-
return storage_event
355246

356247
def to_event(self) -> Event:
357-
return Event(
358-
id=self.id,
359-
invocation_id=self.invocation_id,
360-
author=self.author,
361-
branch=self.branch,
362-
# This is needed as previous ADK version pickled actions might not have
363-
# value defined in the current version of the EventActions model.
364-
actions=EventActions().model_copy(update=self.actions.model_dump()),
365-
timestamp=self.timestamp.timestamp(),
366-
long_running_tool_ids=self.long_running_tool_ids,
367-
partial=self.partial,
368-
turn_complete=self.turn_complete,
369-
error_code=self.error_code,
370-
error_message=self.error_message,
371-
interrupted=self.interrupted,
372-
custom_metadata=self.custom_metadata,
373-
content=_session_util.decode_model(self.content, types.Content),
374-
grounding_metadata=_session_util.decode_model(
375-
self.grounding_metadata, types.GroundingMetadata
376-
),
377-
usage_metadata=_session_util.decode_model(
378-
self.usage_metadata, types.GenerateContentResponseUsageMetadata
379-
),
380-
citation_metadata=_session_util.decode_model(
381-
self.citation_metadata, types.CitationMetadata
382-
),
383-
input_transcription=_session_util.decode_model(
384-
self.input_transcription, types.Transcription
385-
),
386-
output_transcription=_session_util.decode_model(
387-
self.output_transcription, types.Transcription
388-
),
389-
)
248+
"""Converts the StorageEvent to an Event."""
249+
return Event.model_validate({
250+
**self.event_data,
251+
"id": self.id,
252+
"invocation_id": self.invocation_id,
253+
"timestamp": self.timestamp.timestamp(),
254+
})
390255

391256

392257
class StorageAppState(Base):
@@ -463,7 +328,6 @@ def __init__(self, db_url: str, **kwargs: Any):
463328
logger.info("Local timezone: %s", local_timezone)
464329

465330
self.db_engine: AsyncEngine = db_engine
466-
self.metadata: MetaData = MetaData()
467331

468332
# DB session factory method
469333
self.database_session_factory: async_sessionmaker[
@@ -483,10 +347,46 @@ async def _ensure_tables_created(self):
483347
async with self._table_creation_lock:
484348
# Double-check after acquiring the lock
485349
if not self._tables_created:
350+
# Check schema version BEFORE creating tables.
351+
# This prevents creating metadata table on a v0.1 DB.
352+
async with self.database_session_factory() as sql_session:
353+
version, is_v01 = await sql_session.run_sync(
354+
_schema_check.get_version_and_v01_status_sync
355+
)
356+
357+
if is_v01:
358+
raise RuntimeError(
359+
"Database schema appears to be v0.1, but"
360+
f" {_schema_check.CURRENT_SCHEMA_VERSION} is required. Please"
361+
" migrate the database using 'adk migrate session'."
362+
)
363+
elif version and version < _schema_check.CURRENT_SCHEMA_VERSION:
364+
raise RuntimeError(
365+
f"Database schema version is {version}, but current version is"
366+
f" {_schema_check.CURRENT_SCHEMA_VERSION}. Please migrate"
367+
" the database to the latest version using 'adk migrate"
368+
" session'."
369+
)
370+
486371
async with self.db_engine.begin() as conn:
487372
# Uncomment to recreate DB every time
488373
# await conn.run_sync(Base.metadata.drop_all)
489374
await conn.run_sync(Base.metadata.create_all)
375+
376+
# If we are here, DB is either new or >= current version.
377+
# If new or without metadata row, stamp it as current version.
378+
async with self.database_session_factory() as sql_session:
379+
metadata = await sql_session.get(
380+
StorageMetadata, _schema_check.SCHEMA_VERSION_KEY
381+
)
382+
if not metadata:
383+
sql_session.add(
384+
StorageMetadata(
385+
key=_schema_check.SCHEMA_VERSION_KEY,
386+
value=_schema_check.CURRENT_SCHEMA_VERSION,
387+
)
388+
)
389+
await sql_session.commit()
490390
self._tables_created = True
491391

492392
@override
@@ -723,7 +623,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
723623
storage_session.state = storage_session.state | session_state_delta
724624

725625
if storage_session._dialect_name == "sqlite":
726-
update_time = datetime.utcfromtimestamp(event.timestamp)
626+
update_time = datetime.fromtimestamp(
627+
event.timestamp, timezone.utc
628+
).replace(tzinfo=None)
727629
else:
728630
update_time = datetime.fromtimestamp(event.timestamp)
729631
storage_session.update_time = update_time

0 commit comments

Comments
 (0)