From caf84f6b4264b9548dd35ba20f33512b17e132ce Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 17 Dec 2025 13:32:57 +0000 Subject: [PATCH 1/2] refactor: Handle special float values and None consistently in sqlglot _literal --- bigframes/core/compile/sqlglot/sqlglot_ir.py | 45 ++++++++++--------- .../out.sql | 23 ++++++++++ .../compile/sqlglot/test_compile_readlocal.py | 18 ++++++++ 3 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 0d568b098b..fbcdbfb9d0 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -21,6 +21,7 @@ from google.cloud import bigquery import numpy as np +import pandas as pd import pyarrow as pa import sqlglot as sg import sqlglot.dialects.bigquery @@ -28,7 +29,7 @@ from bigframes import dtypes from bigframes.core import guid, local_data, schema, utils -from bigframes.core.compile.sqlglot.expressions import typed_expr +from bigframes.core.compile.sqlglot.expressions import constants, typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt # shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0. @@ -639,11 +640,27 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None if sqlglot_type is None: - if value is not None: - raise ValueError("Cannot infer SQLGlot type from None dtype.") + if not pd.isna(value): + raise ValueError(f"Cannot infer SQLGlot type from None dtype: {value}") return sge.Null() - if value is None: + if dtypes.is_struct_like(dtype): + items = [ + _literal(value=value[field_name], dtype=field_dtype).as_( + field_name, quoted=True + ) + for field_name, field_dtype in dtypes.get_struct_fields(dtype).items() + ] + return sge.Struct.from_arg_list(items) + elif dtypes.is_array_like(dtype): + value_type = dtypes.get_array_inner_type(dtype) + values = sge.Array( + expressions=[_literal(value=v, dtype=value_type) for v in value] + ) + return values if len(value) > 0 else _cast(values, sqlglot_type) + elif dtype == dtypes.JSON_DTYPE: + return sge.ParseJSON(this=sge.convert(str(value))) + elif pd.isna(value): return _cast(sge.Null(), sqlglot_type) elif dtype == dtypes.BYTES_DTYPE: return _cast(str(value), sqlglot_type) @@ -658,24 +675,12 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: elif dtypes.is_geo_like(dtype): wkt = value if isinstance(value, str) else to_wkt(value) return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt)) - elif dtype == dtypes.JSON_DTYPE: - return sge.ParseJSON(this=sge.convert(str(value))) elif dtype == dtypes.TIMEDELTA_DTYPE: return sge.convert(utils.timedelta_to_micros(value)) - elif dtypes.is_struct_like(dtype): - items = [ - _literal(value=value[field_name], dtype=field_dtype).as_( - field_name, quoted=True - ) - for field_name, field_dtype in dtypes.get_struct_fields(dtype).items() - ] - return sge.Struct.from_arg_list(items) - elif dtypes.is_array_like(dtype): - value_type = dtypes.get_array_inner_type(dtype) - values = sge.Array( - expressions=[_literal(value=v, dtype=value_type) for v in value] - ) - return values if len(value) > 0 else _cast(values, sqlglot_type) + elif dtype == dtypes.FLOAT_DTYPE: + if np.isinf(value): + return constants._INF if value > 0 else constants._NEG_INF + return sge.convert(value) else: if isinstance(value, np.generic): value = value.item() diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql new file mode 100644 index 0000000000..d4aac40af8 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql @@ -0,0 +1,23 @@ +WITH `bfcte_0` AS ( + SELECT + * + FROM UNNEST(ARRAY>[STRUCT( + CAST(NULL AS FLOAT64), + CAST('Infinity' AS FLOAT64), + CAST('-Infinity' AS FLOAT64), + CAST(NULL AS FLOAT64), + CAST(NULL AS FLOAT64), + TRUE, + 0 + ), STRUCT(1.0, 1.0, 1.0, 1.0, 10.0, CAST(NULL AS BOOLEAN), 1), STRUCT(2.0, 2.0, 2.0, 2.0, 20.0, FALSE, 2)]) +) +SELECT + `bfcol_0` AS `col_none`, + `bfcol_1` AS `col_inf`, + `bfcol_2` AS `col_neginf`, + `bfcol_3` AS `col_nan`, + `bfcol_4` AS `col_int_none`, + `bfcol_5` AS `col_bool_none` +FROM `bfcte_0` +ORDER BY + `bfcol_6` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py index 7307fd9b4e..6ff0c1f790 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py +++ b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pandas as pd import pytest @@ -58,3 +59,20 @@ def test_compile_readlocal_w_json_df( ): bf_df = bpd.DataFrame(json_pandas_df, session=compiler_session_w_json_types) snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_compile_readlocal_w_special_values( + compiler_session: bigframes.Session, snapshot +): + df = pd.DataFrame( + { + "col_none": [None, 1, 2], + "col_inf": [np.inf, 1.0, 2.0], + "col_neginf": [-np.inf, 1.0, 2.0], + "col_nan": [np.nan, 1.0, 2.0], + "col_int_none": [None, 10, 20], + "col_bool_none": [True, None, False], + } + ) + bf_df = bpd.DataFrame(df, session=compiler_session) + snapshot.assert_match(bf_df.sql, "out.sql") From 60d3483e40e832ac602e932cb0fdf04b4453795e Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 18 Dec 2025 02:49:27 +0000 Subject: [PATCH 2/2] fix test_sql_scalar_for_array_series --- bigframes/core/compile/sqlglot/sqlglot_ir.py | 6 ++++-- .../out.sql | 16 +++++++++------- .../compile/sqlglot/test_compile_readlocal.py | 9 +++++++-- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index fbcdbfb9d0..cbc601ea63 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -644,6 +644,8 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: raise ValueError(f"Cannot infer SQLGlot type from None dtype: {value}") return sge.Null() + if value is None: + return _cast(sge.Null(), sqlglot_type) if dtypes.is_struct_like(dtype): items = [ _literal(value=value[field_name], dtype=field_dtype).as_( @@ -658,10 +660,10 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: expressions=[_literal(value=v, dtype=value_type) for v in value] ) return values if len(value) > 0 else _cast(values, sqlglot_type) - elif dtype == dtypes.JSON_DTYPE: - return sge.ParseJSON(this=sge.convert(str(value))) elif pd.isna(value): return _cast(sge.Null(), sqlglot_type) + elif dtype == dtypes.JSON_DTYPE: + return sge.ParseJSON(this=sge.convert(str(value))) elif dtype == dtypes.BYTES_DTYPE: return _cast(str(value), sqlglot_type) elif dtypes.is_time_like(dtype): diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql index d4aac40af8..ba5e0c8f1c 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql @@ -1,23 +1,25 @@ WITH `bfcte_0` AS ( SELECT * - FROM UNNEST(ARRAY>[STRUCT( + FROM UNNEST(ARRAY, `bfcol_5` STRUCT, `bfcol_6` ARRAY, `bfcol_7` INT64>>[STRUCT( CAST(NULL AS FLOAT64), CAST('Infinity' AS FLOAT64), CAST('-Infinity' AS FLOAT64), CAST(NULL AS FLOAT64), - CAST(NULL AS FLOAT64), - TRUE, + CAST(NULL AS STRUCT), + STRUCT(CAST(NULL AS INT64) AS `foo`), + ARRAY[], 0 - ), STRUCT(1.0, 1.0, 1.0, 1.0, 10.0, CAST(NULL AS BOOLEAN), 1), STRUCT(2.0, 2.0, 2.0, 2.0, 20.0, FALSE, 2)]) + ), STRUCT(1.0, 1.0, 1.0, 1.0, STRUCT(1 AS `foo`), STRUCT(1 AS `foo`), [1, 2], 1), STRUCT(2.0, 2.0, 2.0, 2.0, STRUCT(2 AS `foo`), STRUCT(2 AS `foo`), [3, 4], 2)]) ) SELECT `bfcol_0` AS `col_none`, `bfcol_1` AS `col_inf`, `bfcol_2` AS `col_neginf`, `bfcol_3` AS `col_nan`, - `bfcol_4` AS `col_int_none`, - `bfcol_5` AS `col_bool_none` + `bfcol_4` AS `col_struct_none`, + `bfcol_5` AS `col_struct_w_none`, + `bfcol_6` AS `col_list_none` FROM `bfcte_0` ORDER BY - `bfcol_6` ASC NULLS LAST \ No newline at end of file + `bfcol_7` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py index 6ff0c1f790..c5fabd99e6 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py +++ b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys + import numpy as np import pandas as pd import pytest @@ -64,14 +66,17 @@ def test_compile_readlocal_w_json_df( def test_compile_readlocal_w_special_values( compiler_session: bigframes.Session, snapshot ): + if sys.version_info < (3, 12): + pytest.skip("Skipping test due to inconsistent SQL formatting") df = pd.DataFrame( { "col_none": [None, 1, 2], "col_inf": [np.inf, 1.0, 2.0], "col_neginf": [-np.inf, 1.0, 2.0], "col_nan": [np.nan, 1.0, 2.0], - "col_int_none": [None, 10, 20], - "col_bool_none": [True, None, False], + "col_struct_none": [None, {"foo": 1}, {"foo": 2}], + "col_struct_w_none": [{"foo": None}, {"foo": 1}, {"foo": 2}], + "col_list_none": [None, [1, 2], [3, 4]], } ) bf_df = bpd.DataFrame(df, session=compiler_session)