diff --git a/bigframes/core/log_adapter.py b/bigframes/core/log_adapter.py index 8179ffbeed..77c09437c0 100644 --- a/bigframes/core/log_adapter.py +++ b/bigframes/core/log_adapter.py @@ -174,7 +174,8 @@ def wrapper(*args, **kwargs): full_method_name = f"{base_name.lower()}-{api_method_name}" # Track directly called methods if len(_call_stack) == 0: - add_api_method(full_method_name) + session = _find_session(*args, **kwargs) + add_api_method(full_method_name, session=session) _call_stack.append(full_method_name) @@ -220,7 +221,8 @@ def wrapped(*args, **kwargs): full_property_name = f"{class_name.lower()}-{property_name.lower()}" if len(_call_stack) == 0: - add_api_method(full_property_name) + session = _find_session(*args, **kwargs) + add_api_method(full_property_name, session=session) _call_stack.append(full_property_name) try: @@ -250,25 +252,41 @@ def wrapper(func): return wrapper -def add_api_method(api_method_name): +def add_api_method(api_method_name, session=None): global _lock global _api_methods - with _lock: - # Push the method to the front of the _api_methods list - _api_methods.insert(0, api_method_name.replace("<", "").replace(">", "")) - # Keep the list length within the maximum limit (adjust MAX_LABELS_COUNT as needed) - _api_methods = _api_methods[:MAX_LABELS_COUNT] + clean_method_name = api_method_name.replace("<", "").replace(">", "") + + if session is not None and _is_session_initialized(session): + with session._api_methods_lock: + session._api_methods.insert(0, clean_method_name) + session._api_methods = session._api_methods[:MAX_LABELS_COUNT] + else: + with _lock: + # Push the method to the front of the _api_methods list + _api_methods.insert(0, clean_method_name) + # Keep the list length within the maximum limit (adjust MAX_LABELS_COUNT as needed) + _api_methods = _api_methods[:MAX_LABELS_COUNT] -def get_and_reset_api_methods(dry_run: bool = False): + +def get_and_reset_api_methods(dry_run: bool = False, session=None): global _lock + methods = [] + + if session is not None and _is_session_initialized(session): + with session._api_methods_lock: + methods.extend(session._api_methods) + if not dry_run: + session._api_methods.clear() + with _lock: - previous_api_methods = list(_api_methods) + methods.extend(_api_methods) # dry_run might not make a job resource, so only reset the log on real queries. if not dry_run: _api_methods.clear() - return previous_api_methods + return methods def _get_bq_client(*args, **kwargs): @@ -283,3 +301,36 @@ def _get_bq_client(*args, **kwargs): return kwargv._block.session.bqclient return None + + +def _is_session_initialized(session): + """Return True if fully initialized. + + Because the method logger could get called before Session.__init__ has a + chance to run, we use the globals in that case. + """ + return hasattr(session, "_api_methods_lock") and hasattr(session, "_api_methods") + + +def _find_session(*args, **kwargs): + # This function cannot import Session at the top level because Session + # imports log_adapter. + from bigframes.session import Session + + session = args[0] if args else None + if ( + session is not None + and isinstance(session, Session) + and _is_session_initialized(session) + ): + return session + + session = kwargs.get("session") + if ( + session is not None + and isinstance(session, Session) + and _is_session_initialized(session) + ): + return session + + return None diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 3cb9d2bb68..4f32514652 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -23,6 +23,7 @@ import logging import os import secrets +import threading import typing from typing import ( Any, @@ -208,6 +209,9 @@ def __init__( self._session_id: str = "session" + secrets.token_hex(3) # store table ids and delete them when the session is closed + self._api_methods: list[str] = [] + self._api_methods_lock = threading.Lock() + self._objects: list[ weakref.ReferenceType[ Union[ @@ -2160,6 +2164,7 @@ def _start_query_ml_ddl( query_with_job=True, job_retry=third_party_gcb_retry.DEFAULT_ML_JOB_RETRY, publisher=self._publisher, + session=self, ) return iterator, query_job @@ -2188,6 +2193,7 @@ def _create_object_table(self, path: str, connection: str) -> str: timeout=None, query_with_job=True, publisher=self._publisher, + session=self, ) return table diff --git a/bigframes/session/_io/bigquery/__init__.py b/bigframes/session/_io/bigquery/__init__.py index aa56dc0040..9114770224 100644 --- a/bigframes/session/_io/bigquery/__init__.py +++ b/bigframes/session/_io/bigquery/__init__.py @@ -126,6 +126,7 @@ def create_temp_table( schema: Optional[Iterable[bigquery.SchemaField]] = None, cluster_columns: Optional[list[str]] = None, kms_key: Optional[str] = None, + session=None, ) -> str: """Create an empty table with an expiration in the desired session. @@ -153,6 +154,7 @@ def create_temp_view( *, expiration: datetime.datetime, sql: str, + session=None, ) -> str: """Create an empty table with an expiration in the desired session. @@ -228,12 +230,14 @@ def format_option(key: str, value: Union[bool, str]) -> str: return f"{key}={repr(value)}" -def add_and_trim_labels(job_config): +def add_and_trim_labels(job_config, session=None): """ Add additional labels to the job configuration and trim the total number of labels to ensure they do not exceed MAX_LABELS_COUNT labels per job. """ - api_methods = log_adapter.get_and_reset_api_methods(dry_run=job_config.dry_run) + api_methods = log_adapter.get_and_reset_api_methods( + dry_run=job_config.dry_run, session=session + ) job_config.labels = create_job_configs_labels( job_configs_labels=job_config.labels, api_methods=api_methods, @@ -270,6 +274,7 @@ def start_query_with_client( metrics: Optional[bigframes.session.metrics.ExecutionMetrics], query_with_job: Literal[True], publisher: bigframes.core.events.Publisher, + session=None, ) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: ... @@ -286,6 +291,7 @@ def start_query_with_client( metrics: Optional[bigframes.session.metrics.ExecutionMetrics], query_with_job: Literal[False], publisher: bigframes.core.events.Publisher, + session=None, ) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: ... @@ -303,6 +309,7 @@ def start_query_with_client( query_with_job: Literal[True], job_retry: google.api_core.retry.Retry, publisher: bigframes.core.events.Publisher, + session=None, ) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: ... @@ -320,6 +327,7 @@ def start_query_with_client( query_with_job: Literal[False], job_retry: google.api_core.retry.Retry, publisher: bigframes.core.events.Publisher, + session=None, ) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: ... @@ -340,6 +348,7 @@ def start_query_with_client( # version 3.36.0 or later. job_retry: google.api_core.retry.Retry = third_party_gcb_retry.DEFAULT_JOB_RETRY, publisher: bigframes.core.events.Publisher, + session=None, ) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: """ Starts query job and waits for results. @@ -347,7 +356,7 @@ def start_query_with_client( # Note: Ensure no additional labels are added to job_config after this # point, as `add_and_trim_labels` ensures the label count does not # exceed MAX_LABELS_COUNT. - add_and_trim_labels(job_config) + add_and_trim_labels(job_config, session=session) try: if not query_with_job: diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index 736dbf7be1..ca19d1be86 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -323,6 +323,7 @@ def _export_gbq( iterator, job = self._run_execute_query( sql=sql, job_config=job_config, + session=array_value.session, ) has_timedelta_col = any( @@ -389,6 +390,7 @@ def _run_execute_query( sql: str, job_config: Optional[bq_job.QueryJobConfig] = None, query_with_job: bool = True, + session=None, ) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]: """ Starts BigQuery query job and waits for results. @@ -415,6 +417,7 @@ def _run_execute_query( timeout=None, query_with_job=True, publisher=self._publisher, + session=session, ) else: return bq_io.start_query_with_client( @@ -427,6 +430,7 @@ def _run_execute_query( timeout=None, query_with_job=False, publisher=self._publisher, + session=session, ) except google.api_core.exceptions.BadRequest as e: @@ -661,6 +665,7 @@ def _execute_plan_gbq( sql=compiled.sql, job_config=job_config, query_with_job=(destination_table is not None), + session=plan.session, ) # we could actually cache even when caching is not explicitly requested, but being conservative for now diff --git a/bigframes/session/direct_gbq_execution.py b/bigframes/session/direct_gbq_execution.py index 748c43e66c..3ec10bf20f 100644 --- a/bigframes/session/direct_gbq_execution.py +++ b/bigframes/session/direct_gbq_execution.py @@ -60,6 +60,7 @@ def execute( iterator, query_job = self._run_execute_query( sql=compiled.sql, + session=plan.session, ) # just immediately downlaod everything for simplicity @@ -75,6 +76,7 @@ def _run_execute_query( self, sql: str, job_config: Optional[bq_job.QueryJobConfig] = None, + session=None, ) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]: """ Starts BigQuery query job and waits for results. @@ -89,4 +91,5 @@ def _run_execute_query( metrics=None, query_with_job=False, publisher=self._publisher, + session=session, ) diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index d248cf4ff5..bf91637be4 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -1324,6 +1324,7 @@ def _start_query_with_job_optional( metrics=None, query_with_job=False, publisher=self._publisher, + session=self._session, ) return rows @@ -1350,6 +1351,7 @@ def _start_query_with_job( metrics=None, query_with_job=True, publisher=self._publisher, + session=self._session, ) return query_job