diff --git a/bigframes/core/compile/sqlglot/expressions/array_ops.py b/bigframes/core/compile/sqlglot/expressions/array_ops.py index 2758178beb..f7b96d0418 100644 --- a/bigframes/core/compile/sqlglot/expressions/array_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/array_ops.py @@ -30,9 +30,12 @@ @register_unary_op(ops.ArrayIndexOp, pass_op=True) def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression: + if expr.dtype == dtypes.STRING_DTYPE: + return _string_index(expr, op) + return sge.Bracket( this=expr.expr, - expressions=[sge.Literal.number(op.index)], + expressions=[sge.convert(op.index)], safe=True, offset=False, ) @@ -115,3 +118,16 @@ def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression: if typed_expr.dtype == dtypes.BOOL_DTYPE: return sge.Cast(this=typed_expr.expr, to="INT64") return typed_expr.expr + + +def _string_index(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression: + sub_str = sge.Substring( + this=expr.expr, + start=sge.convert(op.index + 1), + length=sge.convert(1), + ) + return sge.If( + this=sge.NEQ(this=sub_str, expression=sge.convert("")), + true=sub_str, + false=sge.Null(), + ) diff --git a/bigframes/core/compile/sqlglot/expressions/string_ops.py b/bigframes/core/compile/sqlglot/expressions/string_ops.py index 3e19a2fe33..3f0578f843 100644 --- a/bigframes/core/compile/sqlglot/expressions/string_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/string_ops.py @@ -15,6 +15,7 @@ from __future__ import annotations import functools +import typing import sqlglot.expressions as sge @@ -29,7 +30,7 @@ @register_unary_op(ops.capitalize_op) def _(expr: TypedExpr) -> sge.Expression: - return sge.Initcap(this=expr.expr) + return sge.Initcap(this=expr.expr, expression=sge.convert("")) @register_unary_op(ops.StrContainsOp, pass_op=True) @@ -44,9 +45,17 @@ def _(expr: TypedExpr, op: ops.StrContainsRegexOp) -> sge.Expression: @register_unary_op(ops.StrExtractOp, pass_op=True) def _(expr: TypedExpr, op: ops.StrExtractOp) -> sge.Expression: - return sge.RegexpExtract( - this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n) - ) + # Cannot use BigQuery's REGEXP_EXTRACT function, which only allows one + # capturing group. + pat_expr = sge.convert(op.pat) + if op.n != 0: + pat_expr = sge.func("CONCAT", sge.convert(".*?"), pat_expr, sge.convert(".*")) + else: + pat_expr = sge.func("CONCAT", sge.convert(".*?("), pat_expr, sge.convert(").*")) + + rex_replace = sge.func("REGEXP_REPLACE", expr.expr, pat_expr, sge.convert(r"\1")) + rex_contains = sge.func("REGEXP_CONTAINS", expr.expr, sge.convert(op.pat)) + return sge.If(this=rex_contains, true=rex_replace, false=sge.null()) @register_unary_op(ops.StrFindOp, pass_op=True) @@ -75,47 +84,43 @@ def _(expr: TypedExpr, op: ops.StrFindOp) -> sge.Expression: @register_unary_op(ops.StrLstripOp, pass_op=True) def _(expr: TypedExpr, op: ops.StrLstripOp) -> sge.Expression: - return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT") + return sge.func("LTRIM", expr.expr, sge.convert(op.to_strip)) + + +@register_unary_op(ops.StrRstripOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression: + return sge.func("RTRIM", expr.expr, sge.convert(op.to_strip)) @register_unary_op(ops.StrPadOp, pass_op=True) def _(expr: TypedExpr, op: ops.StrPadOp) -> sge.Expression: - pad_length = sge.func( - "GREATEST", sge.Length(this=expr.expr), sge.convert(op.length) - ) + expr_length = sge.Length(this=expr.expr) + fillchar = sge.convert(op.fillchar) + pad_length = sge.func("GREATEST", expr_length, sge.convert(op.length)) + if op.side == "left": - return sge.func( - "LPAD", - expr.expr, - pad_length, - sge.convert(op.fillchar), - ) + return sge.func("LPAD", expr.expr, pad_length, fillchar) elif op.side == "right": - return sge.func( - "RPAD", - expr.expr, - pad_length, - sge.convert(op.fillchar), - ) + return sge.func("RPAD", expr.expr, pad_length, fillchar) else: # side == both - lpad_amount = sge.Cast( - this=sge.func( - "SAFE_DIVIDE", - sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)), - sge.convert(2), - ), - to="INT64", - ) + sge.Length(this=expr.expr) + lpad_amount = ( + sge.Cast( + this=sge.Floor( + this=sge.func( + "SAFE_DIVIDE", + sge.Sub(this=pad_length, expression=expr_length), + sge.convert(2), + ) + ), + to="INT64", + ) + + expr_length + ) return sge.func( "RPAD", - sge.func( - "LPAD", - expr.expr, - lpad_amount, - sge.convert(op.fillchar), - ), + sge.func("LPAD", expr.expr, lpad_amount, fillchar), pad_length, - sge.convert(op.fillchar), + fillchar, ) @@ -224,11 +229,6 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.func("REVERSE", expr.expr) -@register_unary_op(ops.StrRstripOp, pass_op=True) -def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression: - return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="RIGHT") - - @register_unary_op(ops.StartsWithOp, pass_op=True) def _(expr: TypedExpr, op: ops.StartsWithOp) -> sge.Expression: if not op.pat: @@ -253,26 +253,78 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression: @register_unary_op(ops.StrGetOp, pass_op=True) def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression: - return sge.Substring( + sub_str = sge.Substring( this=expr.expr, start=sge.convert(op.i + 1), length=sge.convert(1), ) + return sge.If( + this=sge.NEQ(this=sub_str, expression=sge.convert("")), + true=sub_str, + false=sge.Null(), + ) + @register_unary_op(ops.StrSliceOp, pass_op=True) def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression: - start = op.start + 1 if op.start is not None else None - if op.end is None: - length = None - elif op.start is None: - length = op.end + column_length = sge.Length(this=expr.expr) + if op.start is None: + start = 0 else: - length = op.end - op.start + start = op.start + + start_expr = sge.convert(start) if start < 0 else sge.convert(start + 1) + length_expr: typing.Optional[sge.Expression] + if op.end is None: + length_expr = None + elif op.end < 0: + if start < 0: + start_expr = sge.Greatest( + expressions=[ + sge.convert(1), + column_length + sge.convert(start + 1), + ] + ) + length_expr = sge.Greatest( + expressions=[ + sge.convert(0), + column_length + sge.convert(op.end), + ] + ) - sge.Greatest( + expressions=[ + sge.convert(0), + column_length + sge.convert(start), + ] + ) + else: + length_expr = sge.Greatest( + expressions=[ + sge.convert(0), + column_length + sge.convert(op.end - start), + ] + ) + else: # op.end >= 0 + if start < 0: + start_expr = sge.Greatest( + expressions=[ + sge.convert(1), + column_length + sge.convert(start + 1), + ] + ) + length_expr = sge.convert(op.end) - sge.Greatest( + expressions=[ + sge.convert(0), + column_length + sge.convert(start), + ] + ) + else: + length_expr = sge.convert(op.end - start) + return sge.Substring( this=expr.expr, - start=sge.convert(start) if start is not None else None, - length=sge.convert(length) if length is not None else None, + start=start_expr, + length=length_expr, ) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_capitalize/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_capitalize/out.sql index b429007ffc..dd1f1473f4 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_capitalize/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_capitalize/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - INITCAP(`string_col`) AS `bfcol_1` + INITCAP(`string_col`, '') AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_lstrip/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_lstrip/out.sql index ebe4c39bbf..1b73ee3258 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_lstrip/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_lstrip/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - TRIM(`string_col`, ' ') AS `bfcol_1` + LTRIM(`string_col`, ' ') AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_rstrip/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_rstrip/out.sql index ebe4c39bbf..72bdbba29f 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_rstrip/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_rstrip/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - TRIM(`string_col`, ' ') AS `bfcol_1` + RTRIM(`string_col`, ' ') AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_extract/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_extract/out.sql index 3e59f617ac..ad02f6b223 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_extract/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_extract/out.sql @@ -5,7 +5,11 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - REGEXP_EXTRACT(`string_col`, '([a-z]*)') AS `bfcol_1` + IF( + REGEXP_CONTAINS(`string_col`, '([a-z]*)'), + REGEXP_REPLACE(`string_col`, CONCAT('.*?', '([a-z]*)', '.*'), '\\1'), + NULL + ) AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_get/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_get/out.sql index b2a08e0e9d..f868b73032 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_get/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_get/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - SUBSTRING(`string_col`, 2, 1) AS `bfcol_1` + IF(SUBSTRING(`string_col`, 2, 1) <> '', SUBSTRING(`string_col`, 2, 1), NULL) AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_pad/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_pad/out.sql index 5f157bc5cb..2bb6042fe9 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_pad/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_pad/out.sql @@ -10,7 +10,7 @@ WITH `bfcte_0` AS ( RPAD( LPAD( `string_col`, - CAST(SAFE_DIVIDE(GREATEST(LENGTH(`string_col`), 10) - LENGTH(`string_col`), 2) AS INT64) + LENGTH(`string_col`), + CAST(FLOOR(SAFE_DIVIDE(GREATEST(LENGTH(`string_col`), 10) - LENGTH(`string_col`), 2)) AS INT64) + LENGTH(`string_col`), '-' ), GREATEST(LENGTH(`string_col`), 10),