Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions bigframes/core/log_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def wrapper(*args, **kwargs):
full_method_name = f"{base_name.lower()}-{api_method_name}"
# Track directly called methods
if len(_call_stack) == 0:
add_api_method(full_method_name)
session = _find_session(*args, **kwargs)
add_api_method(full_method_name, session=session)

_call_stack.append(full_method_name)

Expand Down Expand Up @@ -220,7 +221,8 @@ def wrapped(*args, **kwargs):
full_property_name = f"{class_name.lower()}-{property_name.lower()}"

if len(_call_stack) == 0:
add_api_method(full_property_name)
session = _find_session(*args, **kwargs)
add_api_method(full_property_name, session=session)

_call_stack.append(full_property_name)
try:
Expand Down Expand Up @@ -250,25 +252,41 @@ def wrapper(func):
return wrapper


def add_api_method(api_method_name):
def add_api_method(api_method_name, session=None):
global _lock
global _api_methods
with _lock:
# Push the method to the front of the _api_methods list
_api_methods.insert(0, api_method_name.replace("<", "").replace(">", ""))
# Keep the list length within the maximum limit (adjust MAX_LABELS_COUNT as needed)
_api_methods = _api_methods[:MAX_LABELS_COUNT]

clean_method_name = api_method_name.replace("<", "").replace(">", "")

if session is not None and _is_session_initialized(session):
with session._api_methods_lock:
session._api_methods.insert(0, clean_method_name)
session._api_methods = session._api_methods[:MAX_LABELS_COUNT]
else:
with _lock:
# Push the method to the front of the _api_methods list
_api_methods.insert(0, clean_method_name)
# Keep the list length within the maximum limit (adjust MAX_LABELS_COUNT as needed)
_api_methods = _api_methods[:MAX_LABELS_COUNT]

def get_and_reset_api_methods(dry_run: bool = False):

def get_and_reset_api_methods(dry_run: bool = False, session=None):
global _lock
methods = []

if session is not None and _is_session_initialized(session):
with session._api_methods_lock:
methods.extend(session._api_methods)
if not dry_run:
session._api_methods.clear()

with _lock:
previous_api_methods = list(_api_methods)
methods.extend(_api_methods)

# dry_run might not make a job resource, so only reset the log on real queries.
if not dry_run:
_api_methods.clear()
return previous_api_methods
return methods


def _get_bq_client(*args, **kwargs):
Expand All @@ -283,3 +301,36 @@ def _get_bq_client(*args, **kwargs):
return kwargv._block.session.bqclient

return None


def _is_session_initialized(session):
"""Return True if fully initialized.

Because the method logger could get called before Session.__init__ has a
chance to run, we use the globals in that case.
"""
return hasattr(session, "_api_methods_lock") and hasattr(session, "_api_methods")


def _find_session(*args, **kwargs):
# This function cannot import Session at the top level because Session
# imports log_adapter.
from bigframes.session import Session

session = args[0] if args else None
if (
session is not None
and isinstance(session, Session)
and _is_session_initialized(session)
):
return session

session = kwargs.get("session")
if (
session is not None
and isinstance(session, Session)
and _is_session_initialized(session)
):
return session

return None
6 changes: 6 additions & 0 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import os
import secrets
import threading
import typing
from typing import (
Any,
Expand Down Expand Up @@ -208,6 +209,9 @@ def __init__(
self._session_id: str = "session" + secrets.token_hex(3)
# store table ids and delete them when the session is closed

self._api_methods: list[str] = []
self._api_methods_lock = threading.Lock()

self._objects: list[
weakref.ReferenceType[
Union[
Expand Down Expand Up @@ -2160,6 +2164,7 @@ def _start_query_ml_ddl(
query_with_job=True,
job_retry=third_party_gcb_retry.DEFAULT_ML_JOB_RETRY,
publisher=self._publisher,
session=self,
)
return iterator, query_job

Expand Down Expand Up @@ -2188,6 +2193,7 @@ def _create_object_table(self, path: str, connection: str) -> str:
timeout=None,
query_with_job=True,
publisher=self._publisher,
session=self,
)

return table
Expand Down
15 changes: 12 additions & 3 deletions bigframes/session/_io/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def create_temp_table(
schema: Optional[Iterable[bigquery.SchemaField]] = None,
cluster_columns: Optional[list[str]] = None,
kms_key: Optional[str] = None,
session=None,
) -> str:
"""Create an empty table with an expiration in the desired session.

Expand Down Expand Up @@ -153,6 +154,7 @@ def create_temp_view(
*,
expiration: datetime.datetime,
sql: str,
session=None,
) -> str:
"""Create an empty table with an expiration in the desired session.

Expand Down Expand Up @@ -228,12 +230,14 @@ def format_option(key: str, value: Union[bool, str]) -> str:
return f"{key}={repr(value)}"


def add_and_trim_labels(job_config):
def add_and_trim_labels(job_config, session=None):
"""
Add additional labels to the job configuration and trim the total number of labels
to ensure they do not exceed MAX_LABELS_COUNT labels per job.
"""
api_methods = log_adapter.get_and_reset_api_methods(dry_run=job_config.dry_run)
api_methods = log_adapter.get_and_reset_api_methods(
dry_run=job_config.dry_run, session=session
)
job_config.labels = create_job_configs_labels(
job_configs_labels=job_config.labels,
api_methods=api_methods,
Expand Down Expand Up @@ -270,6 +274,7 @@ def start_query_with_client(
metrics: Optional[bigframes.session.metrics.ExecutionMetrics],
query_with_job: Literal[True],
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
...

Expand All @@ -286,6 +291,7 @@ def start_query_with_client(
metrics: Optional[bigframes.session.metrics.ExecutionMetrics],
query_with_job: Literal[False],
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]:
...

Expand All @@ -303,6 +309,7 @@ def start_query_with_client(
query_with_job: Literal[True],
job_retry: google.api_core.retry.Retry,
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
...

Expand All @@ -320,6 +327,7 @@ def start_query_with_client(
query_with_job: Literal[False],
job_retry: google.api_core.retry.Retry,
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]:
...

Expand All @@ -340,14 +348,15 @@ def start_query_with_client(
# version 3.36.0 or later.
job_retry: google.api_core.retry.Retry = third_party_gcb_retry.DEFAULT_JOB_RETRY,
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]:
"""
Starts query job and waits for results.
"""
# Note: Ensure no additional labels are added to job_config after this
# point, as `add_and_trim_labels` ensures the label count does not
# exceed MAX_LABELS_COUNT.
add_and_trim_labels(job_config)
add_and_trim_labels(job_config, session=session)

try:
if not query_with_job:
Expand Down
5 changes: 5 additions & 0 deletions bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def _export_gbq(
iterator, job = self._run_execute_query(
sql=sql,
job_config=job_config,
session=array_value.session,
)

has_timedelta_col = any(
Expand Down Expand Up @@ -389,6 +390,7 @@ def _run_execute_query(
sql: str,
job_config: Optional[bq_job.QueryJobConfig] = None,
query_with_job: bool = True,
session=None,
) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]:
"""
Starts BigQuery query job and waits for results.
Expand All @@ -415,6 +417,7 @@ def _run_execute_query(
timeout=None,
query_with_job=True,
publisher=self._publisher,
session=session,
)
else:
return bq_io.start_query_with_client(
Expand All @@ -427,6 +430,7 @@ def _run_execute_query(
timeout=None,
query_with_job=False,
publisher=self._publisher,
session=session,
)

except google.api_core.exceptions.BadRequest as e:
Expand Down Expand Up @@ -661,6 +665,7 @@ def _execute_plan_gbq(
sql=compiled.sql,
job_config=job_config,
query_with_job=(destination_table is not None),
session=plan.session,
)

# we could actually cache even when caching is not explicitly requested, but being conservative for now
Expand Down
3 changes: 3 additions & 0 deletions bigframes/session/direct_gbq_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def execute(

iterator, query_job = self._run_execute_query(
sql=compiled.sql,
session=plan.session,
)

# just immediately downlaod everything for simplicity
Expand All @@ -75,6 +76,7 @@ def _run_execute_query(
self,
sql: str,
job_config: Optional[bq_job.QueryJobConfig] = None,
session=None,
) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]:
"""
Starts BigQuery query job and waits for results.
Expand All @@ -89,4 +91,5 @@ def _run_execute_query(
metrics=None,
query_with_job=False,
publisher=self._publisher,
session=session,
)
2 changes: 2 additions & 0 deletions bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,7 @@ def _start_query_with_job_optional(
metrics=None,
query_with_job=False,
publisher=self._publisher,
session=self._session,
)
return rows

Expand All @@ -1350,6 +1351,7 @@ def _start_query_with_job(
metrics=None,
query_with_job=True,
publisher=self._publisher,
session=self._session,
)
return query_job

Expand Down