diff --git a/select.go b/select.go index 872be1c..a6e00ef 100644 --- a/select.go +++ b/select.go @@ -85,6 +85,14 @@ func (s *SelectStatement) Where(where AsExpr) *SelectStatement { } func (s *SelectStatement) AndWhere(where AsExpr) *SelectStatement { + return andOrWhere(where, s, "AND") +} + +func (s *SelectStatement) OrWhere(where AsExpr) *SelectStatement { + return andOrWhere(where, s, "OR") +} + +func andOrWhere(where AsExpr, s *SelectStatement, operator string) *SelectStatement { c := s.clone() if c.where == nil { @@ -92,10 +100,10 @@ func (s *SelectStatement) AndWhere(where AsExpr) *SelectStatement { } else { expr := c.where - if bexpr, ok := expr.(*BooleanOperatorExpr); ok && bexpr.operator == "AND" { - c.where = BooleanOperator("AND", append(bexpr.elements[:], where)...) + if bexpr, ok := expr.(*BooleanOperatorExpr); ok && bexpr.operator == operator { + c.where = BooleanOperator(operator, append(bexpr.elements[:], where)...) } else { - c.where = BooleanOperator("AND", expr, where) + c.where = BooleanOperator(operator, expr, where) } } diff --git a/sqlbuilder_test.go b/sqlbuilder_test.go index 2f6530b..d6b8613 100644 --- a/sqlbuilder_test.go +++ b/sqlbuilder_test.go @@ -78,6 +78,10 @@ func TestSelect(t *testing.T) { AliasColumn(InfixOperator("+", Literal("1"), Literal("2"), Literal("3")), "Five"), ).Where( In(tbl.C("PartNo"), s.BindAllAsExpr(1000, 1001, 1002)...), + ).AndWhere( + In(tbl.C("Grade"), s.BindAllAsExpr("A", "B")...), + ).OrWhere( + In(tbl.C("Finish"), s.BindAllAsExpr("HeapsGood", "TopNotch")...), ).OrderBy( OrderAsc(tbl.C("Type")), OrderAsc(tbl.C("Product")), @@ -94,8 +98,8 @@ func TestSelect(t *testing.T) { qs, qv, err := s.F(q.AsStatement).ToSQL() a.NoError(err) - a.Equal("SELECT p.PartNo, p.Type, p.Product, p.Grade, p.Coating, p.Finish, p.Thickness, p.Width, p.Length, p.Dim1, p.Dim2, p.ClassFBR, p.ClassFME, p.ClassFSY, @p3 AS ClassFHO, p.SLOB, dbo.productAvailablePlusALTOAmountOnHand(p.PartNo, @p1) AS OnHandAmount, dbo.productAvailableWeightOnHand(p.PartNo, @p1) AS OnHandWeight, dbo.productReservedAmount(p.PartNo, @p1) AS ReservedAmount, dbo.productALTOAmount(p.PartNo, @p1) AS ALTO, dbo.getPartAvgCost(p.PartNo) AS AverageCost, dbo.productOnOrderAmount(p.PartNo, @p1) AS OnOrderAmount, dbo.productOnOrderWeight(p.PartNo, @p1) AS OnOrderWeight, dbo.getListPriceGivenAvgCost(@p2, p.PartNo, dbo.getPartAvgCost(p.PartNo)) AS MinimumPrice, dbo.customerLastPrice(p.PartNo, @p2) AS CustomerLastPrice, CONVERT(VARCHAR(23), dbo.customerLastSoldDate(p.PartNo, @p2), 126) AS CustomerLastSoldDate, (1 + 2 + 3) AS Five FROM tblproducts p WHERE p.PartNo IN (@p4, @p5, @p6) ORDER BY p.Type ASC, p.Product ASC, p.Grade ASC, p.Coating ASC, p.Finish ASC, p.Thickness ASC, p.Width ASC, p.Dim1 ASC, p.Dim2 ASC, p.Length ASC OFFSET @p7 ROWS FETCH NEXT @p8 ROWS ONLY", qs) - a.Equal([]interface{}{"REGION_1", "CUSTOMER_1", "D", 1000, 1001, 1002, 30, 30}, qv) + a.Equal("SELECT p.PartNo, p.Type, p.Product, p.Grade, p.Coating, p.Finish, p.Thickness, p.Width, p.Length, p.Dim1, p.Dim2, p.ClassFBR, p.ClassFME, p.ClassFSY, @p3 AS ClassFHO, p.SLOB, dbo.productAvailablePlusALTOAmountOnHand(p.PartNo, @p1) AS OnHandAmount, dbo.productAvailableWeightOnHand(p.PartNo, @p1) AS OnHandWeight, dbo.productReservedAmount(p.PartNo, @p1) AS ReservedAmount, dbo.productALTOAmount(p.PartNo, @p1) AS ALTO, dbo.getPartAvgCost(p.PartNo) AS AverageCost, dbo.productOnOrderAmount(p.PartNo, @p1) AS OnOrderAmount, dbo.productOnOrderWeight(p.PartNo, @p1) AS OnOrderWeight, dbo.getListPriceGivenAvgCost(@p2, p.PartNo, dbo.getPartAvgCost(p.PartNo)) AS MinimumPrice, dbo.customerLastPrice(p.PartNo, @p2) AS CustomerLastPrice, CONVERT(VARCHAR(23), dbo.customerLastSoldDate(p.PartNo, @p2), 126) AS CustomerLastSoldDate, (1 + 2 + 3) AS Five FROM tblproducts p WHERE ((p.PartNo IN (@p4, @p5, @p6) AND p.Grade IN (@p7, @p8)) OR p.Finish IN (@p9, @p10)) ORDER BY p.Type ASC, p.Product ASC, p.Grade ASC, p.Coating ASC, p.Finish ASC, p.Thickness ASC, p.Width ASC, p.Dim1 ASC, p.Dim2 ASC, p.Length ASC OFFSET @p11 ROWS FETCH NEXT @p12 ROWS ONLY", qs) + a.Equal([]interface{}{"REGION_1", "CUSTOMER_1", "D", 1000, 1001, 1002, "A", "B", "HeapsGood", "TopNotch", 30, 30}, qv) } func TestJoin(t *testing.T) {