Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ jobs:

- name: Run CDC regression tests
run: go test -count=1 -v ./tests/integration -run 'CDC'

- name: Run crash recovery origin-filter test
run: go test -count=1 -v ./tests/integration -run 'TestTableDiffOnlyOriginWithUntil'
83 changes: 80 additions & 3 deletions db/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ package queries
import (
"bytes"
"context"
"errors"
"fmt"
"regexp"
"strings"
Expand Down Expand Up @@ -81,6 +82,7 @@ func GeneratePkeyOffsetsQuery(
tableSampleMethod string,
samplePercent float64,
ntileCount int,
filter string,
) (string, error) {
if len(keyColumns) == 0 {
return "", fmt.Errorf("keyColumns cannot be empty")
Expand Down Expand Up @@ -162,6 +164,8 @@ func GeneratePkeyOffsetsQuery(
"RangeStartColumns": strings.Join(rangeStarts, ",\n "),
"RangeEndColumns": strings.Join(rangeEnds, ",\n "),
"RangeOutputColumns": strings.Join(selectOutputCols, ",\n "),
"HasFilter": strings.TrimSpace(filter) != "",
"Filter": strings.TrimSpace(filter),
}

return RenderSQL(SQLTemplates.GetPkeyOffsets, data)
Expand Down Expand Up @@ -498,7 +502,7 @@ func InsertBlockRangesBatchComposite(ctx context.Context, db DBQuerier, mtreeTab
}

func GetPkeyOffsets(ctx context.Context, db DBQuerier, schema, table string, keyColumns []string, tableSampleMethod string, samplePercent float64, ntileCount int) ([]types.PkeyOffset, error) {
sql, err := GeneratePkeyOffsetsQuery(schema, table, keyColumns, tableSampleMethod, samplePercent, ntileCount)
sql, err := GeneratePkeyOffsetsQuery(schema, table, keyColumns, tableSampleMethod, samplePercent, ntileCount, "")
if err != nil {
return nil, fmt.Errorf("failed to generate GetPkeyOffsets SQL: %w", err)
}
Expand Down Expand Up @@ -536,7 +540,7 @@ func GetPkeyOffsets(ctx context.Context, db DBQuerier, schema, table string, key
return offsets, nil
}

func BlockHashSQL(schema, table string, primaryKeyCols []string, mode string, includeLower, includeUpper bool) (string, error) {
func BlockHashSQL(schema, table string, primaryKeyCols []string, mode string, includeLower, includeUpper bool, filter string) (string, error) {
if len(primaryKeyCols) == 0 {
return "", fmt.Errorf("primaryKeyCols cannot be empty")
}
Expand Down Expand Up @@ -610,6 +614,10 @@ func BlockHashSQL(schema, table string, primaryKeyCols []string, mode string, in
whereParts = append(whereParts, upperExpr)
}

if trimmed := strings.TrimSpace(filter); trimmed != "" {
whereParts = append(whereParts, fmt.Sprintf("(%s)", trimmed))
}

if len(whereParts) == 0 {
whereParts = append(whereParts, "TRUE")
}
Expand Down Expand Up @@ -800,6 +808,64 @@ func GetSpockNodeAndSubInfo(ctx context.Context, db DBQuerier) ([]types.SpockNod
return infos, nil
}

func GetSpockNodeNames(ctx context.Context, db DBQuerier) (map[string]string, error) {
sql, err := RenderSQL(SQLTemplates.GetSpockNodeNames, nil)
if err != nil {
return nil, err
}

rows, err := db.Query(ctx, sql)
if err != nil {
return nil, err
}
defer rows.Close()

names := make(map[string]string)
for rows.Next() {
var id, name string
if err := rows.Scan(&id, &name); err != nil {
return nil, err
}
names[id] = name
}

if err := rows.Err(); err != nil {
return nil, err
}

return names, nil
}

func GetSpockOriginLSNForNode(ctx context.Context, db DBQuerier, failedNode, survivor string) (*string, error) {
sql, err := RenderSQL(SQLTemplates.GetSpockOriginLSNForNode, nil)
if err != nil {
return nil, err
}
var lsn *string
if err := db.QueryRow(ctx, sql, failedNode, survivor).Scan(&lsn); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to fetch spock origin lsn: %w", err)
}
return lsn, nil
}

func GetSpockSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string) (*string, error) {
sql, err := RenderSQL(SQLTemplates.GetSpockSlotLSNForNode, nil)
if err != nil {
return nil, err
}
var lsn *string
if err := db.QueryRow(ctx, sql, failedNode).Scan(&lsn); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to fetch spock slot lsn: %w", err)
}
return lsn, nil
}

func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetInfo, error) {
sql, err := RenderSQL(SQLTemplates.SpockRepSetInfo, nil)
if err != nil {
Expand Down Expand Up @@ -831,6 +897,17 @@ func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetI
return infos, nil
}

func EnsurePgcrypto(ctx context.Context, db DBQuerier) error {
sql, err := RenderSQL(SQLTemplates.EnsurePgcrypto, nil)
if err != nil {
return fmt.Errorf("failed to render ensure-pgcrypto SQL: %w", err)
}
if _, err := db.Exec(ctx, sql); err != nil {
return fmt.Errorf("failed to ensure pgcrypto extension: %w", err)
}
return nil
}

func CheckSchemaExists(ctx context.Context, db DBQuerier, schema string) (bool, error) {
sql, err := RenderSQL(SQLTemplates.CheckSchemaExists, nil)
if err != nil {
Expand Down Expand Up @@ -1075,7 +1152,7 @@ func ComputeLeafHashes(ctx context.Context, db DBQuerier, schema, table string,
hasLower := len(start) > 0 && !sliceAllNil(start)
hasUpper := len(end) > 0 && !sliceAllNil(end)

sql, err := BlockHashSQL(schema, table, key, "MTREE_LEAF_HASH", hasLower, hasUpper)
sql, err := BlockHashSQL(schema, table, key, "MTREE_LEAF_HASH", hasLower, hasUpper, "")
if err != nil {
return nil, err
}
Expand Down
43 changes: 42 additions & 1 deletion db/queries/queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) {
tableSampleMethod string
samplePercent float64
ntileCount int
filter string
wantQueryContains []string
wantErr bool
}{
Expand All @@ -149,6 +150,7 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) {
tableSampleMethod: "BERNOULLI",
samplePercent: 10,
ntileCount: 100,
filter: "",
wantQueryContains: []string{
`FROM "public"."users"`,
`TABLESAMPLE BERNOULLI(10)`,
Expand All @@ -166,6 +168,7 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) {
tableSampleMethod: "SYSTEM",
samplePercent: 5.5,
ntileCount: 50,
filter: "",
wantQueryContains: []string{
`FROM "myschema"."orders"`,
`TABLESAMPLE SYSTEM(5.5)`,
Expand Down Expand Up @@ -215,8 +218,25 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) {
tableSampleMethod: "BERNOULLI",
samplePercent: 10,
ntileCount: 100,
filter: "",
wantErr: true, // Assuming empty key columns is an invalid input causing SanitiseIdentifier to err
},
{
name: "valid inputs - with filter clause",
schema: "public",
table: "users",
keyColumns: []string{"id"},
tableSampleMethod: "SYSTEM_ROWS",
samplePercent: 1000,
ntileCount: 10,
filter: `status = 'active'`,
wantQueryContains: []string{
`FROM "public"."users"`,
`WHERE status = 'active'`,
`TABLESAMPLE SYSTEM_ROWS(1000)`,
},
wantErr: false,
},
}

for _, tt := range tests {
Expand All @@ -228,6 +248,7 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) {
tt.tableSampleMethod,
tt.samplePercent,
tt.ntileCount,
tt.filter,
)

if (err != nil) != tt.wantErr {
Expand Down Expand Up @@ -258,6 +279,7 @@ func TestBlockHashSQL(t *testing.T) {
primaryKeyCols []string
includeLower bool
includeUpper bool
filter string
wantQueryContains []string
wantErr bool
}{
Expand All @@ -268,6 +290,7 @@ func TestBlockHashSQL(t *testing.T) {
primaryKeyCols: []string{"event_id"},
includeLower: true,
includeUpper: true,
filter: "",
wantQueryContains: []string{
`FROM "public"."events" AS _tbl_`,
`ORDER BY "event_id"`,
Expand All @@ -283,6 +306,7 @@ func TestBlockHashSQL(t *testing.T) {
primaryKeyCols: []string{"order_id", "item_seq"},
includeLower: true,
includeUpper: true,
filter: "",
wantQueryContains: []string{
`FROM "commerce"."line_items" AS _tbl_`,
`ORDER BY "order_id", "item_seq"`,
Expand All @@ -296,34 +320,51 @@ func TestBlockHashSQL(t *testing.T) {
schema: "bad-schema!",
table: "events",
primaryKeyCols: []string{"event_id"},
filter: "",
wantErr: true,
},
{
name: "invalid table identifier",
schema: "public",
table: "events 123",
primaryKeyCols: []string{"event_id"},
filter: "",
wantErr: true,
},
{
name: "invalid primary key column identifier",
schema: "public",
table: "events",
primaryKeyCols: []string{"event-id"},
filter: "",
wantErr: true,
},
{
name: "empty primary key columns",
schema: "public",
table: "events",
primaryKeyCols: []string{},
filter: "",
wantErr: true, // We need this to error out here
},
{
name: "valid inputs - with filter",
schema: "public",
table: "events",
primaryKeyCols: []string{"event_id"},
includeLower: false,
includeUpper: false,
filter: "status = 'live'",
wantQueryContains: []string{
`WHERE (status = 'live')`,
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, err := BlockHashSQL(tt.schema, tt.table, tt.primaryKeyCols, "TD_BLOCK_HASH" /* mode */, tt.includeLower, tt.includeUpper)
query, err := BlockHashSQL(tt.schema, tt.table, tt.primaryKeyCols, "TD_BLOCK_HASH" /* mode */, tt.includeLower, tt.includeUpper, tt.filter)

if (err != nil) != tt.wantErr {
t.Errorf("BlockHashSQL() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
44 changes: 44 additions & 0 deletions db/queries/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type Templates struct {
CheckUserPrivileges *template.Template
SpockNodeAndSubInfo *template.Template
SpockRepSetInfo *template.Template
EnsurePgcrypto *template.Template
GetSpockNodeNames *template.Template
CheckSchemaExists *template.Template
GetTablesInSchema *template.Template
GetViewsInSchema *template.Template
Expand Down Expand Up @@ -116,6 +118,8 @@ type Templates struct {
AlterPublicationDropTable *template.Template
DeleteMetadata *template.Template
RemoveTableFromCDCMetadata *template.Template
GetSpockOriginLSNForNode *template.Template
GetSpockSlotLSNForNode *template.Template
}

var SQLTemplates = Templates{
Expand Down Expand Up @@ -554,6 +558,16 @@ var SQLTemplates = Templates{
ORDER BY
set_name;
`)),
EnsurePgcrypto: template.Must(template.New("ensurePgcrypto").Parse(`
CREATE EXTENSION IF NOT EXISTS pgcrypto;
`)),
GetSpockNodeNames: template.Must(template.New("getSpockNodeNames").Parse(`
SELECT
node_id::text,
node_name
FROM
spock.node;
`)),
CheckSchemaExists: template.Must(template.New("checkSchemaExists").Parse(
`SELECT EXISTS (SELECT 1 FROM pg_namespace WHERE nspname = $1);`,
)),
Expand Down Expand Up @@ -613,6 +627,10 @@ var SQLTemplates = Templates{
FROM
{{.SchemaIdent}}.{{.TableIdent}}
TABLESAMPLE {{.TableSampleMethod}}({{.SamplePercent}})
{{- if .HasFilter }}
WHERE
{{.Filter}}
{{- end }}
ORDER BY
{{.KeyColumnsOrder}}
),
Expand All @@ -621,6 +639,10 @@ var SQLTemplates = Templates{
{{.KeyColumnsSelect}}
FROM
{{.SchemaIdent}}.{{.TableIdent}}
{{- if .HasFilter }}
WHERE
{{.Filter}}
{{- end }}
ORDER BY
{{.KeyColumnsOrder}}
LIMIT 1
Expand All @@ -630,6 +652,10 @@ var SQLTemplates = Templates{
{{.KeyColumnsSelect}}
FROM
{{.SchemaIdent}}.{{.TableIdent}}
{{- if .HasFilter }}
WHERE
{{.Filter}}
{{- end }}
ORDER BY
{{.KeyColumnsOrderDesc}}
LIMIT 1
Expand Down Expand Up @@ -1499,4 +1525,22 @@ var SQLTemplates = Templates{
CreateSchema: template.Must(template.New("createSchema").Parse(`
CREATE SCHEMA IF NOT EXISTS {{.SchemaName}}
`)),
GetSpockOriginLSNForNode: template.Must(template.New("getSpockOriginLSNForNode").Parse(`
SELECT ros.remote_lsn::text
FROM pg_catalog.pg_replication_origin_status ros
JOIN pg_catalog.pg_replication_origin ro ON ro.roident = ros.local_id
WHERE ro.roname LIKE 'spk_%_' || $1 || '_sub_' || $1 || '_' || $2
AND ros.remote_lsn IS NOT NULL
LIMIT 1
`)),
GetSpockSlotLSNForNode: template.Must(template.New("getSpockSlotLSNForNode").Parse(`
SELECT rs.confirmed_flush_lsn::text
FROM pg_catalog.pg_replication_slots rs
JOIN spock.subscription s ON rs.slot_name = s.sub_slot_name
JOIN spock.node o ON o.node_id = s.sub_origin
WHERE o.node_name = $1
AND rs.confirmed_flush_lsn IS NOT NULL
ORDER BY rs.confirmed_flush_lsn DESC
LIMIT 1
`)),
}
Loading
Loading