diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 5ca66ee505..d1a68b2ef7 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -19,7 +19,9 @@ from bigframes.core import utils, window_spec import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler +import bigframes.core.expression as ex import bigframes.core.ordering as ordering_spec +import bigframes.dtypes as dtypes def apply_window_if_present( @@ -52,10 +54,7 @@ def apply_window_if_present( order = sge.Order(expressions=order_by) group_by = ( - [ - scalar_compiler.scalar_op_compiler.compile_expression(key) - for key in window.grouping_keys - ] + [_compile_group_by_key(key) for key in window.grouping_keys] if window.grouping_keys else None ) @@ -164,3 +163,18 @@ def _get_window_bounds( side = "PRECEDING" if value < 0 else "FOLLOWING" return sge.convert(abs(value)), side + + +def _compile_group_by_key(key: ex.Expression) -> sge.Expression: + expr = scalar_compiler.scalar_op_compiler.compile_expression(key) + # The group_by keys has been rewritten by bind_schema_to_node + assert isinstance(key, ex.ResolvedDerefOp) + + # Some types need to be converted to another type to enable groupby + if key.dtype == dtypes.FLOAT_DTYPE: + expr = sge.Cast(this=expr, to="STRING") + elif key.dtype == dtypes.GEO_DTYPE: + expr = sge.Cast(this=expr, to="BYTES") + elif key.dtype == dtypes.JSON_DTYPE: + expr = sge.func("TO_JSON_STRING", expr) + return expr diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_windows.py b/tests/unit/core/compile/sqlglot/aggregations/test_windows.py index f1a3eced9a..af347f4aa3 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_windows.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_windows.py @@ -18,12 +18,14 @@ import pytest import sqlglot.expressions as sge +from bigframes import dtypes from bigframes.core import window_spec from bigframes.core.compile.sqlglot.aggregations.windows import ( apply_window_if_present, get_window_order_by, ) import bigframes.core.expression as ex +import bigframes.core.identifiers as ids import bigframes.core.ordering as ordering @@ -82,16 +84,37 @@ def test_apply_window_if_present_row_bounded_no_ordering_raises(self): ), ) - def test_apply_window_if_present_unbounded_grouping_no_ordering(self): + def test_apply_window_if_present_grouping_no_ordering(self): result = apply_window_if_present( sge.Var(this="value"), window_spec.WindowSpec( - grouping_keys=(ex.deref("col1"),), + grouping_keys=( + ex.ResolvedDerefOp( + ids.ColumnId("col1"), + dtype=dtypes.STRING_DTYPE, + is_nullable=True, + ), + ex.ResolvedDerefOp( + ids.ColumnId("col2"), + dtype=dtypes.FLOAT_DTYPE, + is_nullable=True, + ), + ex.ResolvedDerefOp( + ids.ColumnId("col3"), + dtype=dtypes.JSON_DTYPE, + is_nullable=True, + ), + ex.ResolvedDerefOp( + ids.ColumnId("col4"), + dtype=dtypes.GEO_DTYPE, + is_nullable=True, + ), + ), ), ) self.assertEqual( result.sql(dialect="bigquery"), - "value OVER (PARTITION BY `col1`)", + "value OVER (PARTITION BY `col1`, CAST(`col2` AS STRING), TO_JSON_STRING(`col3`), CAST(`col4` AS BYTES))", ) def test_apply_window_if_present_range_bounded(self): @@ -126,8 +149,22 @@ def test_apply_window_if_present_all_params(self): result = apply_window_if_present( sge.Var(this="value"), window_spec.WindowSpec( - grouping_keys=(ex.deref("col1"),), - ordering=(ordering.OrderingExpression(ex.deref("col2")),), + grouping_keys=( + ex.ResolvedDerefOp( + ids.ColumnId("col1"), + dtype=dtypes.STRING_DTYPE, + is_nullable=True, + ), + ), + ordering=( + ordering.OrderingExpression( + ex.ResolvedDerefOp( + ids.ColumnId("col2"), + dtype=dtypes.STRING_DTYPE, + is_nullable=True, + ) + ), + ), bounds=window_spec.RowsWindowBounds(start=-1, end=0), ), )