From 1eaa3e3b8215d447a80d30e1f5caccf57324673e Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 25 Nov 2025 23:17:07 +0000 Subject: [PATCH 1/2] refactor: fix date and time diff ops --- .../sqlglot/aggregations/unary_compiler.py | 66 ++++++------------- .../compile/sqlglot/aggregations/windows.py | 4 +- .../test_date_series_diff/out.sql | 13 ---- .../test_diff/diff_datetime.sql | 17 +++++ .../out.sql => test_diff/diff_timestamp.sql} | 4 +- .../aggregations/test_unary_compiler.py | 42 ++++++------ 6 files changed, 62 insertions(+), 84 deletions(-) delete mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_date_series_diff/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_datetime.sql rename tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/{test_time_series_diff/out.sql => test_diff/diff_timestamp.sql} (70%) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 171c3cc239..ec711c7fa1 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -245,27 +245,6 @@ def _cut_ops_w_intervals( return case_expr -@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp) -def _( - op: agg_ops.DateSeriesDiffOp, - column: typed_expr.TypedExpr, - window: typing.Optional[window_spec.WindowSpec] = None, -) -> sge.Expression: - if column.dtype != dtypes.DATE_DTYPE: - raise TypeError(f"Cannot perform date series diff on type {column.dtype}") - shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)] - shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window) - # Conversion factor from days to microseconds - conversion_factor = 24 * 60 * 60 * 1_000_000 - return sge.Cast( - this=sge.DateDiff( - this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY") - ) - * sge.convert(conversion_factor), - to="INT64", - ) - - @UNARY_OP_REGISTRATION.register(agg_ops.DenseRankOp) def _( op: agg_ops.DenseRankOp, @@ -327,13 +306,27 @@ def _( ) -> sge.Expression: shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)] shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window) - if column.dtype in (dtypes.BOOL_DTYPE, dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE): - if column.dtype == dtypes.BOOL_DTYPE: - return sge.NEQ(this=column.expr, expression=shifted) - else: - return sge.Sub(this=column.expr, expression=shifted) - else: - raise TypeError(f"Cannot perform diff on type {column.dtype}") + if column.dtype == dtypes.BOOL_DTYPE: + return sge.NEQ(this=column.expr, expression=shifted) + + if column.dtype in (dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE): + return sge.Sub(this=column.expr, expression=shifted) + + if column.dtype == dtypes.TIMESTAMP_DTYPE: + return sge.TimestampDiff( + this=column.expr, + expression=shifted, + unit=sge.Identifier(this="MICROSECOND"), + ) + + if column.dtype == dtypes.DATETIME_DTYPE: + return sge.DatetimeDiff( + this=column.expr, + expression=shifted, + unit=sge.Identifier(this="MICROSECOND"), + ) + + raise TypeError(f"Cannot perform diff on type {column.dtype}") @UNARY_OP_REGISTRATION.register(agg_ops.MaxOp) @@ -593,23 +586,6 @@ def _( return sge.func("IFNULL", expr, ir._literal(zero, column.dtype)) -@UNARY_OP_REGISTRATION.register(agg_ops.TimeSeriesDiffOp) -def _( - op: agg_ops.TimeSeriesDiffOp, - column: typed_expr.TypedExpr, - window: typing.Optional[window_spec.WindowSpec] = None, -) -> sge.Expression: - if column.dtype != dtypes.TIMESTAMP_DTYPE: - raise TypeError(f"Cannot perform time series diff on type {column.dtype}") - shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)] - shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window) - return sge.TimestampDiff( - this=column.expr, - expression=shifted, - unit=sge.Identifier(this="MICROSECOND"), - ) - - @UNARY_OP_REGISTRATION.register(agg_ops.VarOp) def _( op: agg_ops.VarOp, diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index b775d6666a..5ca66ee505 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -62,10 +62,10 @@ def apply_window_if_present( # This is the key change. Don't create a spec for the default window frame # if there's no ordering. This avoids generating an `ORDER BY NULL` clause. - if not window.bounds and not order: + if window.is_unbounded and not order: return sge.Window(this=value, partition_by=group_by) - if not window.bounds and not include_framing_clauses: + if window.is_unbounded and not include_framing_clauses: return sge.Window(this=value, partition_by=group_by, order=order) kind = ( diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_date_series_diff/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_date_series_diff/out.sql deleted file mode 100644 index 84c95fd010..0000000000 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_date_series_diff/out.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `date_col` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - CAST(DATE_DIFF(`date_col`, LAG(`date_col`, 1) OVER (ORDER BY `date_col` ASC NULLS LAST), DAY) * 86400000000 AS INT64) AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `diff_date` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_datetime.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_datetime.sql new file mode 100644 index 0000000000..9c279a479d --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_datetime.sql @@ -0,0 +1,17 @@ +WITH `bfcte_0` AS ( + SELECT + `datetime_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + DATETIME_DIFF( + `datetime_col`, + LAG(`datetime_col`, 1) OVER (ORDER BY `datetime_col` ASC NULLS LAST), + MICROSECOND + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `diff_datetime` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_time_series_diff/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_timestamp.sql similarity index 70% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_time_series_diff/out.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_timestamp.sql index 645f583dc1..1f8b8227b4 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_time_series_diff/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_timestamp.sql @@ -7,11 +7,11 @@ WITH `bfcte_0` AS ( *, TIMESTAMP_DIFF( `timestamp_col`, - LAG(`timestamp_col`, 1) OVER (ORDER BY `timestamp_col` ASC NULLS LAST), + LAG(`timestamp_col`, 1) OVER (ORDER BY `timestamp_col` DESC), MICROSECOND ) AS `bfcol_1` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `diff_time` + `bfcol_1` AS `diff_timestamp` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 5f7d0d7653..9c7f0caeb2 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -214,17 +214,6 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") -def test_date_series_diff(scalar_types_df: bpd.DataFrame, snapshot): - col_name = "date_col" - bf_df = scalar_types_df[[col_name]] - window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) - op = agg_exprs.UnaryAggregation( - agg_ops.DateSeriesDiffOp(periods=1), expression.deref(col_name) - ) - sql = _apply_unary_window_op(bf_df, op, window, "diff_date") - snapshot.assert_match(sql, "out.sql") - - def test_diff(scalar_types_df: bpd.DataFrame, snapshot): # Test integer int_col = "int64_col" @@ -246,6 +235,26 @@ def test_diff(scalar_types_df: bpd.DataFrame, snapshot): bool_sql = _apply_unary_window_op(bf_df_bool, bool_op, window, "diff_bool") snapshot.assert_match(bool_sql, "diff_bool.sql") + # Test date + col_name = "datetime_col" + bf_df_date = scalar_types_df[[col_name]] + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + op = agg_exprs.UnaryAggregation( + agg_ops.DiffOp(periods=1), expression.deref(col_name) + ) + sql = _apply_unary_window_op(bf_df_date, op, window, "diff_datetime") + snapshot.assert_match(sql, "diff_datetime.sql") + + # Test date + col_name = "timestamp_col" + bf_df_timestamp = scalar_types_df[[col_name]] + window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),)) + op = agg_exprs.UnaryAggregation( + agg_ops.DiffOp(periods=1), expression.deref(col_name) + ) + sql = _apply_unary_window_op(bf_df_timestamp, op, window, "diff_timestamp") + snapshot.assert_match(sql, "diff_timestamp.sql") + def test_first(scalar_types_df: bpd.DataFrame, snapshot): if sys.version_info < (3, 12): @@ -606,17 +615,6 @@ def test_sum(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql_window_partition, "window_partition_out.sql") -def test_time_series_diff(scalar_types_df: bpd.DataFrame, snapshot): - col_name = "timestamp_col" - bf_df = scalar_types_df[[col_name]] - window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) - op = agg_exprs.UnaryAggregation( - agg_ops.TimeSeriesDiffOp(periods=1), expression.deref(col_name) - ) - sql = _apply_unary_window_op(bf_df, op, window, "diff_time") - snapshot.assert_match(sql, "out.sql") - - def test_var(scalar_types_df: bpd.DataFrame, snapshot): col_names = ["int64_col", "bool_col"] bf_df = scalar_types_df[col_names] From 3916225a11c5b20ba5199f6561caf7a8479eac47 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 26 Nov 2025 19:07:24 +0000 Subject: [PATCH 2/2] address comments --- .../out.sql} | 0 .../out.sql} | 0 .../diff_int.sql => test_diff_w_int/out.sql} | 0 .../out.sql} | 0 .../aggregations/test_unary_compiler.py | 19 +++++++++++-------- 5 files changed, 11 insertions(+), 8 deletions(-) rename tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/{test_diff/diff_bool.sql => test_diff_w_bool/out.sql} (100%) rename tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/{test_diff/diff_datetime.sql => test_diff_w_datetime/out.sql} (100%) rename tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/{test_diff/diff_int.sql => test_diff_w_int/out.sql} (100%) rename tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/{test_diff/diff_timestamp.sql => test_diff_w_timestamp/out.sql} (100%) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_bool.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_bool/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_bool.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_bool/out.sql diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_datetime.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_datetime/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_datetime.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_datetime/out.sql diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_int.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_int/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_int.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_int/out.sql diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_timestamp.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_timestamp/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_timestamp.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_timestamp/out.sql diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 9c7f0caeb2..fbf631d1a0 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -214,7 +214,7 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") -def test_diff(scalar_types_df: bpd.DataFrame, snapshot): +def test_diff_w_int(scalar_types_df: bpd.DataFrame, snapshot): # Test integer int_col = "int64_col" bf_df_int = scalar_types_df[[int_col]] @@ -223,9 +223,10 @@ def test_diff(scalar_types_df: bpd.DataFrame, snapshot): agg_ops.DiffOp(periods=1), expression.deref(int_col) ) int_sql = _apply_unary_window_op(bf_df_int, int_op, window, "diff_int") - snapshot.assert_match(int_sql, "diff_int.sql") + snapshot.assert_match(int_sql, "out.sql") - # Test boolean + +def test_diff_w_bool(scalar_types_df: bpd.DataFrame, snapshot): bool_col = "bool_col" bf_df_bool = scalar_types_df[[bool_col]] window = window_spec.WindowSpec(ordering=(ordering.descending_over(bool_col),)) @@ -233,9 +234,10 @@ def test_diff(scalar_types_df: bpd.DataFrame, snapshot): agg_ops.DiffOp(periods=1), expression.deref(bool_col) ) bool_sql = _apply_unary_window_op(bf_df_bool, bool_op, window, "diff_bool") - snapshot.assert_match(bool_sql, "diff_bool.sql") + snapshot.assert_match(bool_sql, "out.sql") + - # Test date +def test_diff_w_datetime(scalar_types_df: bpd.DataFrame, snapshot): col_name = "datetime_col" bf_df_date = scalar_types_df[[col_name]] window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) @@ -243,9 +245,10 @@ def test_diff(scalar_types_df: bpd.DataFrame, snapshot): agg_ops.DiffOp(periods=1), expression.deref(col_name) ) sql = _apply_unary_window_op(bf_df_date, op, window, "diff_datetime") - snapshot.assert_match(sql, "diff_datetime.sql") + snapshot.assert_match(sql, "out.sql") - # Test date + +def test_diff_w_timestamp(scalar_types_df: bpd.DataFrame, snapshot): col_name = "timestamp_col" bf_df_timestamp = scalar_types_df[[col_name]] window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),)) @@ -253,7 +256,7 @@ def test_diff(scalar_types_df: bpd.DataFrame, snapshot): agg_ops.DiffOp(periods=1), expression.deref(col_name) ) sql = _apply_unary_window_op(bf_df_timestamp, op, window, "diff_timestamp") - snapshot.assert_match(sql, "diff_timestamp.sql") + snapshot.assert_match(sql, "out.sql") def test_first(scalar_types_df: bpd.DataFrame, snapshot):