diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 0d568b098b..cbc601ea63 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -21,6 +21,7 @@ from google.cloud import bigquery import numpy as np +import pandas as pd import pyarrow as pa import sqlglot as sg import sqlglot.dialects.bigquery @@ -28,7 +29,7 @@ from bigframes import dtypes from bigframes.core import guid, local_data, schema, utils -from bigframes.core.compile.sqlglot.expressions import typed_expr +from bigframes.core.compile.sqlglot.expressions import constants, typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt # shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0. @@ -639,12 +640,30 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None if sqlglot_type is None: - if value is not None: - raise ValueError("Cannot infer SQLGlot type from None dtype.") + if not pd.isna(value): + raise ValueError(f"Cannot infer SQLGlot type from None dtype: {value}") return sge.Null() if value is None: return _cast(sge.Null(), sqlglot_type) + if dtypes.is_struct_like(dtype): + items = [ + _literal(value=value[field_name], dtype=field_dtype).as_( + field_name, quoted=True + ) + for field_name, field_dtype in dtypes.get_struct_fields(dtype).items() + ] + return sge.Struct.from_arg_list(items) + elif dtypes.is_array_like(dtype): + value_type = dtypes.get_array_inner_type(dtype) + values = sge.Array( + expressions=[_literal(value=v, dtype=value_type) for v in value] + ) + return values if len(value) > 0 else _cast(values, sqlglot_type) + elif pd.isna(value): + return _cast(sge.Null(), sqlglot_type) + elif dtype == dtypes.JSON_DTYPE: + return sge.ParseJSON(this=sge.convert(str(value))) elif dtype == dtypes.BYTES_DTYPE: return _cast(str(value), sqlglot_type) elif dtypes.is_time_like(dtype): @@ -658,24 +677,12 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: elif dtypes.is_geo_like(dtype): wkt = value if isinstance(value, str) else to_wkt(value) return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt)) - elif dtype == dtypes.JSON_DTYPE: - return sge.ParseJSON(this=sge.convert(str(value))) elif dtype == dtypes.TIMEDELTA_DTYPE: return sge.convert(utils.timedelta_to_micros(value)) - elif dtypes.is_struct_like(dtype): - items = [ - _literal(value=value[field_name], dtype=field_dtype).as_( - field_name, quoted=True - ) - for field_name, field_dtype in dtypes.get_struct_fields(dtype).items() - ] - return sge.Struct.from_arg_list(items) - elif dtypes.is_array_like(dtype): - value_type = dtypes.get_array_inner_type(dtype) - values = sge.Array( - expressions=[_literal(value=v, dtype=value_type) for v in value] - ) - return values if len(value) > 0 else _cast(values, sqlglot_type) + elif dtype == dtypes.FLOAT_DTYPE: + if np.isinf(value): + return constants._INF if value > 0 else constants._NEG_INF + return sge.convert(value) else: if isinstance(value, np.generic): value = value.item() diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql new file mode 100644 index 0000000000..ba5e0c8f1c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_special_values/out.sql @@ -0,0 +1,25 @@ +WITH `bfcte_0` AS ( + SELECT + * + FROM UNNEST(ARRAY, `bfcol_5` STRUCT, `bfcol_6` ARRAY, `bfcol_7` INT64>>[STRUCT( + CAST(NULL AS FLOAT64), + CAST('Infinity' AS FLOAT64), + CAST('-Infinity' AS FLOAT64), + CAST(NULL AS FLOAT64), + CAST(NULL AS STRUCT), + STRUCT(CAST(NULL AS INT64) AS `foo`), + ARRAY[], + 0 + ), STRUCT(1.0, 1.0, 1.0, 1.0, STRUCT(1 AS `foo`), STRUCT(1 AS `foo`), [1, 2], 1), STRUCT(2.0, 2.0, 2.0, 2.0, STRUCT(2 AS `foo`), STRUCT(2 AS `foo`), [3, 4], 2)]) +) +SELECT + `bfcol_0` AS `col_none`, + `bfcol_1` AS `col_inf`, + `bfcol_2` AS `col_neginf`, + `bfcol_3` AS `col_nan`, + `bfcol_4` AS `col_struct_none`, + `bfcol_5` AS `col_struct_w_none`, + `bfcol_6` AS `col_list_none` +FROM `bfcte_0` +ORDER BY + `bfcol_7` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py index 7307fd9b4e..c5fabd99e6 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py +++ b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys + +import numpy as np import pandas as pd import pytest @@ -58,3 +61,23 @@ def test_compile_readlocal_w_json_df( ): bf_df = bpd.DataFrame(json_pandas_df, session=compiler_session_w_json_types) snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_compile_readlocal_w_special_values( + compiler_session: bigframes.Session, snapshot +): + if sys.version_info < (3, 12): + pytest.skip("Skipping test due to inconsistent SQL formatting") + df = pd.DataFrame( + { + "col_none": [None, 1, 2], + "col_inf": [np.inf, 1.0, 2.0], + "col_neginf": [-np.inf, 1.0, 2.0], + "col_nan": [np.nan, 1.0, 2.0], + "col_struct_none": [None, {"foo": 1}, {"foo": 2}], + "col_struct_w_none": [{"foo": None}, {"foo": 1}, {"foo": 2}], + "col_list_none": [None, [1, 2], [3, 4]], + } + ) + bf_df = bpd.DataFrame(df, session=compiler_session) + snapshot.assert_match(bf_df.sql, "out.sql")