diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a1a9296..0c6a177 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,3 +36,6 @@ jobs: - name: Run CDC regression tests run: go test -count=1 -v ./tests/integration -run 'CDC' + + - name: Run crash recovery origin-filter test + run: go test -count=1 -v ./tests/integration -run 'TestTableDiffOnlyOriginWithUntil' diff --git a/db/queries/queries.go b/db/queries/queries.go index 2be6dbb..d4e612c 100644 --- a/db/queries/queries.go +++ b/db/queries/queries.go @@ -14,6 +14,7 @@ package queries import ( "bytes" "context" + "errors" "fmt" "regexp" "strings" @@ -81,6 +82,7 @@ func GeneratePkeyOffsetsQuery( tableSampleMethod string, samplePercent float64, ntileCount int, + filter string, ) (string, error) { if len(keyColumns) == 0 { return "", fmt.Errorf("keyColumns cannot be empty") @@ -162,6 +164,8 @@ func GeneratePkeyOffsetsQuery( "RangeStartColumns": strings.Join(rangeStarts, ",\n "), "RangeEndColumns": strings.Join(rangeEnds, ",\n "), "RangeOutputColumns": strings.Join(selectOutputCols, ",\n "), + "HasFilter": strings.TrimSpace(filter) != "", + "Filter": strings.TrimSpace(filter), } return RenderSQL(SQLTemplates.GetPkeyOffsets, data) @@ -498,7 +502,7 @@ func InsertBlockRangesBatchComposite(ctx context.Context, db DBQuerier, mtreeTab } func GetPkeyOffsets(ctx context.Context, db DBQuerier, schema, table string, keyColumns []string, tableSampleMethod string, samplePercent float64, ntileCount int) ([]types.PkeyOffset, error) { - sql, err := GeneratePkeyOffsetsQuery(schema, table, keyColumns, tableSampleMethod, samplePercent, ntileCount) + sql, err := GeneratePkeyOffsetsQuery(schema, table, keyColumns, tableSampleMethod, samplePercent, ntileCount, "") if err != nil { return nil, fmt.Errorf("failed to generate GetPkeyOffsets SQL: %w", err) } @@ -536,7 +540,7 @@ func GetPkeyOffsets(ctx context.Context, db DBQuerier, schema, table string, key return offsets, nil } -func BlockHashSQL(schema, table string, primaryKeyCols []string, mode string, includeLower, includeUpper bool) (string, error) { +func BlockHashSQL(schema, table string, primaryKeyCols []string, mode string, includeLower, includeUpper bool, filter string) (string, error) { if len(primaryKeyCols) == 0 { return "", fmt.Errorf("primaryKeyCols cannot be empty") } @@ -610,6 +614,10 @@ func BlockHashSQL(schema, table string, primaryKeyCols []string, mode string, in whereParts = append(whereParts, upperExpr) } + if trimmed := strings.TrimSpace(filter); trimmed != "" { + whereParts = append(whereParts, fmt.Sprintf("(%s)", trimmed)) + } + if len(whereParts) == 0 { whereParts = append(whereParts, "TRUE") } @@ -800,6 +808,64 @@ func GetSpockNodeAndSubInfo(ctx context.Context, db DBQuerier) ([]types.SpockNod return infos, nil } +func GetSpockNodeNames(ctx context.Context, db DBQuerier) (map[string]string, error) { + sql, err := RenderSQL(SQLTemplates.GetSpockNodeNames, nil) + if err != nil { + return nil, err + } + + rows, err := db.Query(ctx, sql) + if err != nil { + return nil, err + } + defer rows.Close() + + names := make(map[string]string) + for rows.Next() { + var id, name string + if err := rows.Scan(&id, &name); err != nil { + return nil, err + } + names[id] = name + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return names, nil +} + +func GetSpockOriginLSNForNode(ctx context.Context, db DBQuerier, failedNode, survivor string) (*string, error) { + sql, err := RenderSQL(SQLTemplates.GetSpockOriginLSNForNode, nil) + if err != nil { + return nil, err + } + var lsn *string + if err := db.QueryRow(ctx, sql, failedNode, survivor).Scan(&lsn); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to fetch spock origin lsn: %w", err) + } + return lsn, nil +} + +func GetSpockSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string) (*string, error) { + sql, err := RenderSQL(SQLTemplates.GetSpockSlotLSNForNode, nil) + if err != nil { + return nil, err + } + var lsn *string + if err := db.QueryRow(ctx, sql, failedNode).Scan(&lsn); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to fetch spock slot lsn: %w", err) + } + return lsn, nil +} + func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetInfo, error) { sql, err := RenderSQL(SQLTemplates.SpockRepSetInfo, nil) if err != nil { @@ -831,6 +897,17 @@ func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetI return infos, nil } +func EnsurePgcrypto(ctx context.Context, db DBQuerier) error { + sql, err := RenderSQL(SQLTemplates.EnsurePgcrypto, nil) + if err != nil { + return fmt.Errorf("failed to render ensure-pgcrypto SQL: %w", err) + } + if _, err := db.Exec(ctx, sql); err != nil { + return fmt.Errorf("failed to ensure pgcrypto extension: %w", err) + } + return nil +} + func CheckSchemaExists(ctx context.Context, db DBQuerier, schema string) (bool, error) { sql, err := RenderSQL(SQLTemplates.CheckSchemaExists, nil) if err != nil { @@ -1075,7 +1152,7 @@ func ComputeLeafHashes(ctx context.Context, db DBQuerier, schema, table string, hasLower := len(start) > 0 && !sliceAllNil(start) hasUpper := len(end) > 0 && !sliceAllNil(end) - sql, err := BlockHashSQL(schema, table, key, "MTREE_LEAF_HASH", hasLower, hasUpper) + sql, err := BlockHashSQL(schema, table, key, "MTREE_LEAF_HASH", hasLower, hasUpper, "") if err != nil { return nil, err } diff --git a/db/queries/queries_test.go b/db/queries/queries_test.go index 3ffa74a..0c76d07 100644 --- a/db/queries/queries_test.go +++ b/db/queries/queries_test.go @@ -138,6 +138,7 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) { tableSampleMethod string samplePercent float64 ntileCount int + filter string wantQueryContains []string wantErr bool }{ @@ -149,6 +150,7 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) { tableSampleMethod: "BERNOULLI", samplePercent: 10, ntileCount: 100, + filter: "", wantQueryContains: []string{ `FROM "public"."users"`, `TABLESAMPLE BERNOULLI(10)`, @@ -166,6 +168,7 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) { tableSampleMethod: "SYSTEM", samplePercent: 5.5, ntileCount: 50, + filter: "", wantQueryContains: []string{ `FROM "myschema"."orders"`, `TABLESAMPLE SYSTEM(5.5)`, @@ -215,8 +218,25 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) { tableSampleMethod: "BERNOULLI", samplePercent: 10, ntileCount: 100, + filter: "", wantErr: true, // Assuming empty key columns is an invalid input causing SanitiseIdentifier to err }, + { + name: "valid inputs - with filter clause", + schema: "public", + table: "users", + keyColumns: []string{"id"}, + tableSampleMethod: "SYSTEM_ROWS", + samplePercent: 1000, + ntileCount: 10, + filter: `status = 'active'`, + wantQueryContains: []string{ + `FROM "public"."users"`, + `WHERE status = 'active'`, + `TABLESAMPLE SYSTEM_ROWS(1000)`, + }, + wantErr: false, + }, } for _, tt := range tests { @@ -228,6 +248,7 @@ func TestGeneratePkeyOffsetsQuery(t *testing.T) { tt.tableSampleMethod, tt.samplePercent, tt.ntileCount, + tt.filter, ) if (err != nil) != tt.wantErr { @@ -258,6 +279,7 @@ func TestBlockHashSQL(t *testing.T) { primaryKeyCols []string includeLower bool includeUpper bool + filter string wantQueryContains []string wantErr bool }{ @@ -268,6 +290,7 @@ func TestBlockHashSQL(t *testing.T) { primaryKeyCols: []string{"event_id"}, includeLower: true, includeUpper: true, + filter: "", wantQueryContains: []string{ `FROM "public"."events" AS _tbl_`, `ORDER BY "event_id"`, @@ -283,6 +306,7 @@ func TestBlockHashSQL(t *testing.T) { primaryKeyCols: []string{"order_id", "item_seq"}, includeLower: true, includeUpper: true, + filter: "", wantQueryContains: []string{ `FROM "commerce"."line_items" AS _tbl_`, `ORDER BY "order_id", "item_seq"`, @@ -296,6 +320,7 @@ func TestBlockHashSQL(t *testing.T) { schema: "bad-schema!", table: "events", primaryKeyCols: []string{"event_id"}, + filter: "", wantErr: true, }, { @@ -303,6 +328,7 @@ func TestBlockHashSQL(t *testing.T) { schema: "public", table: "events 123", primaryKeyCols: []string{"event_id"}, + filter: "", wantErr: true, }, { @@ -310,6 +336,7 @@ func TestBlockHashSQL(t *testing.T) { schema: "public", table: "events", primaryKeyCols: []string{"event-id"}, + filter: "", wantErr: true, }, { @@ -317,13 +344,27 @@ func TestBlockHashSQL(t *testing.T) { schema: "public", table: "events", primaryKeyCols: []string{}, + filter: "", wantErr: true, // We need this to error out here }, + { + name: "valid inputs - with filter", + schema: "public", + table: "events", + primaryKeyCols: []string{"event_id"}, + includeLower: false, + includeUpper: false, + filter: "status = 'live'", + wantQueryContains: []string{ + `WHERE (status = 'live')`, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - query, err := BlockHashSQL(tt.schema, tt.table, tt.primaryKeyCols, "TD_BLOCK_HASH" /* mode */, tt.includeLower, tt.includeUpper) + query, err := BlockHashSQL(tt.schema, tt.table, tt.primaryKeyCols, "TD_BLOCK_HASH" /* mode */, tt.includeLower, tt.includeUpper, tt.filter) if (err != nil) != tt.wantErr { t.Errorf("BlockHashSQL() error = %v, wantErr %v", err, tt.wantErr) diff --git a/db/queries/templates.go b/db/queries/templates.go index 2dfc345..bf7e4ad 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -21,6 +21,8 @@ type Templates struct { CheckUserPrivileges *template.Template SpockNodeAndSubInfo *template.Template SpockRepSetInfo *template.Template + EnsurePgcrypto *template.Template + GetSpockNodeNames *template.Template CheckSchemaExists *template.Template GetTablesInSchema *template.Template GetViewsInSchema *template.Template @@ -116,6 +118,8 @@ type Templates struct { AlterPublicationDropTable *template.Template DeleteMetadata *template.Template RemoveTableFromCDCMetadata *template.Template + GetSpockOriginLSNForNode *template.Template + GetSpockSlotLSNForNode *template.Template } var SQLTemplates = Templates{ @@ -554,6 +558,16 @@ var SQLTemplates = Templates{ ORDER BY set_name; `)), + EnsurePgcrypto: template.Must(template.New("ensurePgcrypto").Parse(` + CREATE EXTENSION IF NOT EXISTS pgcrypto; + `)), + GetSpockNodeNames: template.Must(template.New("getSpockNodeNames").Parse(` + SELECT + node_id::text, + node_name + FROM + spock.node; + `)), CheckSchemaExists: template.Must(template.New("checkSchemaExists").Parse( `SELECT EXISTS (SELECT 1 FROM pg_namespace WHERE nspname = $1);`, )), @@ -613,6 +627,10 @@ var SQLTemplates = Templates{ FROM {{.SchemaIdent}}.{{.TableIdent}} TABLESAMPLE {{.TableSampleMethod}}({{.SamplePercent}}) + {{- if .HasFilter }} + WHERE + {{.Filter}} + {{- end }} ORDER BY {{.KeyColumnsOrder}} ), @@ -621,6 +639,10 @@ var SQLTemplates = Templates{ {{.KeyColumnsSelect}} FROM {{.SchemaIdent}}.{{.TableIdent}} + {{- if .HasFilter }} + WHERE + {{.Filter}} + {{- end }} ORDER BY {{.KeyColumnsOrder}} LIMIT 1 @@ -630,6 +652,10 @@ var SQLTemplates = Templates{ {{.KeyColumnsSelect}} FROM {{.SchemaIdent}}.{{.TableIdent}} + {{- if .HasFilter }} + WHERE + {{.Filter}} + {{- end }} ORDER BY {{.KeyColumnsOrderDesc}} LIMIT 1 @@ -1499,4 +1525,22 @@ var SQLTemplates = Templates{ CreateSchema: template.Must(template.New("createSchema").Parse(` CREATE SCHEMA IF NOT EXISTS {{.SchemaName}} `)), + GetSpockOriginLSNForNode: template.Must(template.New("getSpockOriginLSNForNode").Parse(` + SELECT ros.remote_lsn::text + FROM pg_catalog.pg_replication_origin_status ros + JOIN pg_catalog.pg_replication_origin ro ON ro.roident = ros.local_id + WHERE ro.roname LIKE 'spk_%_' || $1 || '_sub_' || $1 || '_' || $2 + AND ros.remote_lsn IS NOT NULL + LIMIT 1 + `)), + GetSpockSlotLSNForNode: template.Must(template.New("getSpockSlotLSNForNode").Parse(` + SELECT rs.confirmed_flush_lsn::text + FROM pg_catalog.pg_replication_slots rs + JOIN spock.subscription s ON rs.slot_name = s.sub_slot_name + JOIN spock.node o ON o.node_id = s.sub_origin + WHERE o.node_name = $1 + AND rs.confirmed_flush_lsn IS NOT NULL + ORDER BY rs.confirmed_flush_lsn DESC + LIMIT 1 + `)), } diff --git a/docs/commands/diff/table-diff.md b/docs/commands/diff/table-diff.md index cab1ea0..a7720ec 100644 --- a/docs/commands/diff/table-diff.md +++ b/docs/commands/diff/table-diff.md @@ -22,6 +22,8 @@ This command compares the data in the specified table across nodes in a cluster | `--output ` | `-o` | Report format. Default `json`. When `html`, both JSON and HTML files share the same timestamped prefix. | | `--nodes ` | `-n` | Comma-separated node list or `all`. Up to three-way diffs are supported. | | `--table-filter ` | `-F` | Optional SQL `WHERE` clause applied on every node before hashing. | +| `--only-origin ` | | Limit the diff to rows whose `node_origin` matches this Spock node id or name (useful for failed-node recovery). | +| `--until ` | | Optional commit timestamp fence (RFC3339) applied with `--only-origin` and `--table-filter`; excludes newer rows. | | `--override-block-size` | `-B` | Skip block-size safety checks defined in `ace.yaml`. | | `--quiet` | `-q` | Suppress progress output. Results still write to the diff file. | | `--debug` | `-v` | Enable verbose logging. | @@ -74,9 +76,19 @@ ACE optimises comparisons with multiprocessing and block hashing: ### Tuning tips 1. Tune `--block-size` and `--concurrency-factor` for your hardware and data profile. 2. Use `--table-filter`/`-F` to narrow scope on very large tables. + - Filters are applied inline (no temporary views) and recorded in the diff summary as both the raw filter and the effective filter (which also includes `--only-origin`/`--until` if set). 3. Prefer `--output html` when you’ll manually review diffs. 4. Use `--override-block-size` sparingly; the guardrails in `ace.yaml` prevent allocations that can overwhelm memory. +### Recovery-focused options + +Use `--only-origin` when you need a diff scoped to transactions from a failed node; combine with `--until` to fence at a known commit timestamp. The diff summary records: +- `only_origin` (raw id), `only_origin_resolved` (node name if known), `origin_only` (bool) +- `until` (commit timestamp fence) +- `table_filter` (raw) and `effective_filter` (combined predicates) + +When running `table-repair` on an origin-only diff, you must pass `--recovery-mode` (see `table-repair` docs) or the repair will refuse to run. + ### Scheduling runs Add `--schedule --every=` to keep the diff running on a loop. The command performs an initial comparison immediately, then repeats after each interval until you stop the process: diff --git a/docs/commands/repair/advanced-repair-examples.md b/docs/commands/repair/advanced-repair-examples.md new file mode 100644 index 0000000..b5cf829 --- /dev/null +++ b/docs/commands/repair/advanced-repair-examples.md @@ -0,0 +1,218 @@ +# Advanced repair examples + +This page shows practical repair-file snippets you can adapt. All examples assume `version: 1` at the top and a `tables:` section; only the relevant table entry is shown. + +## 1) Classic source-of-truth per batch + +Take most rows from `n1`, but EU rows from `n2`: +```yaml +tables: + public.accounts: + default_action: { type: keep_n1 } + rules: + - name: eu_from_n2 + diff_type: [row_mismatch, missing_on_n2] + when: "n1.region = 'eu'" + action: { type: keep_n2 } +``` + +## 2) Insert-only or upsert-only without global flags + +Insert missing rows into `n2`; skip updates/deletes: +```yaml +tables: + public.orders: + default_action: { type: skip } + rules: + - name: insert_missing_n2 + diff_type: [missing_on_n2] + action: + type: apply_from + from: n1 + mode: insert +``` + +Upsert (insert or update) into `n2`, but never delete: +```yaml +tables: + public.orders: + default_action: { type: skip } + rules: + - name: upsert_into_n2 + diff_type: [row_mismatch, missing_on_n2] + action: + type: apply_from + from: n1 + mode: upsert +``` + +## 3) Bidirectional convergence for mismatches only + +Copy rows both ways when they differ; leave single-sided misses untouched: +```yaml +tables: + public.features: + default_action: { type: skip } + rules: + - name: converge_mismatches + diff_type: [row_mismatch] + action: { type: bidirectional } +``` + +## 4) Coalesce (fix-nulls style) with helpers + +Fill NULLs using non-NULL values preferring `n1`, then `n2`: +```yaml +tables: + public.customers: + rules: + - name: coalesce_contact + columns_changed: [email, phone] + action: + type: custom + helpers: + coalesce_priority: [n1, n2] +``` + +## 5) Pick freshest based on a timestamp + +Keep the row with the newer `updated_at`; tie-break to `n1`: +```yaml +tables: + public.inventory: + rules: + - name: pick_newer + diff_type: [row_mismatch] + action: + type: custom + helpers: + pick_freshest: + key: updated_at + tie: n1 +``` + +## 5b) Use Spock commit metadata + +Pick the row with the newer replication commit timestamp (Spock metadata), else tie to n1: +```yaml +tables: + public.inventory: + rules: + - name: pick_newer_commit + diff_type: [row_mismatch] + action: + type: custom + helpers: + pick_freshest: + key: commit_ts # from _spock_metadata_ + tie: n1 +``` + +## 5c) Use Spock node origin (route by producer) + +Treat rows produced on node `n3` as authoritative; otherwise fall back to n1: +```yaml +tables: + public.inventory: + default_action: { type: keep_n1 } + rules: + - name: prefer_n3_origin + diff_type: [row_mismatch, missing_on_n2] + when: "n1.node_origin = 'n3'" + action: { type: keep_n1 } + - name: prefer_n3_origin_missing + diff_type: [missing_on_n1] + when: "n2.node_origin = 'n3'" + action: { type: keep_n2 } +``` + +## 5d) Split by origin for batch decisions + +Use n2 when the origin is n2, otherwise use n1: +```yaml +tables: + public.orders: + rules: + - name: origin_n2 + diff_type: [row_mismatch, missing_on_n1, missing_on_n2] + when: "n1.node_origin = 'n2' OR n2.node_origin = 'n2'" + action: { type: keep_n2 } + - name: default_to_n1 + action: { type: keep_n1 } +``` + +## 6) Custom row per PK + +Pin a specific PK to a hand-crafted row: +```yaml +tables: + public.products: + row_overrides: + - name: fix_widget42 + pk: { id: 42 } + action: + type: custom + custom_row: + id: 42 + status: "retired" + notes: "manual override" +``` + +## 7) Mixed SOT by ranges + +PK 1–100 from `n1`, 101–200 from `n2`, everything else from `n1`: +```yaml +tables: + public.accounts: + default_action: { type: keep_n1 } + rules: + - name: range_101_200_n2 + pk_in: + - range: { from: 101, to: 200 } + action: { type: keep_n2 } +``` + +## 8) Delete stray rows on a target + +Delete rows present only on `n2`: +```yaml +tables: + public.logs: + rules: + - name: delete_extras_n2 + diff_type: [missing_on_n1] + action: { type: delete } +``` + +## 9) Combine predicates + +Only take `n2` for VIPs in EU where status changed: +```yaml +tables: + public.users: + rules: + - name: vip_eu_from_n2 + diff_type: [row_mismatch] + columns_changed: [status] + when: "n1.region = 'eu' AND n1.tier = 'vip'" + action: { type: keep_n2 } +``` + +## 10) Coalesce with templating + +Build a row with an explicit status and templated columns: +```yaml +tables: + public.tasks: + rules: + - name: coalesce_with_template + diff_type: [row_mismatch, missing_on_n2] + action: + type: custom + custom_row: + id: "{{n1.id}}" + status: "active" + title: "{{n2.title}}" + helpers: + coalesce_priority: [n2, n1] +``` diff --git a/docs/commands/repair/advanced-repair.md b/docs/commands/repair/advanced-repair.md new file mode 100644 index 0000000..4d9cca7 --- /dev/null +++ b/docs/commands/repair/advanced-repair.md @@ -0,0 +1,79 @@ +# Advanced repair (repair files) + +`table-repair` can run from a repair file instead of a single source-of-truth flag. The repair file is a versioned YAML/JSON document that describes per-table defaults, ordered rules, and explicit row overrides. This lets you pick different sources of truth (or custom rows) per batch, and keep a reproducible plan on disk. + +## File structure and precedence + +- Top level: `version`, optional `default_action`, `tables` map (`schema.table` keys). +- Per table: `default_action`, ordered `rules`, and `row_overrides`. +- Precedence: `row_overrides` (exact PK) > first matching `rule` > `table.default_action` > global `default_action`. Rules must declare at least one selector to avoid match‑all accidents. + +### Selectors +- `pk_in`: list of PK values, and/or ranges (`from`/`to`) for simple PKs; composite PKs use ordered tuples. +- `diff_type`: one or more of `row_mismatch`, `missing_on_n1`, `missing_on_n2`, `deleted_on_n1`, `deleted_on_n2`. +- `columns_changed`: match if any listed column differs. +- `when`: predicate over `n1.` / `n2.`; supports `= != < <= > >=`, `IN (...)`, `IS [NOT] NULL`, `AND/OR/NOT`, parentheses, string/number/bool/null literals. + +### Actions (verbs) +- `keep_n1`, `keep_n2`: copy that side. +- `apply_from` with `from: n1|n2` and `mode: replace|upsert|insert` (default `replace`). `insert` is only valid on missing rows. +- `bidirectional`: copy both ways (only meaningful on mismatches/missing rows). +- `custom`: provide `custom_row` (with optional `{{n1.col}}` / `{{n2.col}}` templating) and/or helpers: + - `helpers.coalesce_priority: [n1, n2]` fills remaining columns from the first non‑NULL source. + - `helpers.pick_freshest: { key: updated_at, tie: n1|n2 }` fills remaining columns from the fresher side (numeric, string, or RFC3339 timestamps). +- `skip`: do nothing. +- `delete`: remove the row (from the side(s) where it exists). +- Spock metadata in plans: you can reference `commit_ts` and `node_origin` in `when` predicates and `custom_row` templates (e.g., `{{n1.commit_ts}}`, `when: "n2.node_origin = 'n3'"`). These fields are injected during diff collection (via `pg_xact_commit_timestamp(xmin)` and `spock.xact_commit_timestamp_origin(xmin)`), not stored in your table. + +### Compatibility checks (fail fast) +- `keep_n1`/`apply_from n1` require a row on n1; `keep_n2`/`apply_from n2` require a row on n2. +- `apply_from mode: insert` is invalid on `row_mismatch`. +- `diff_type` is validated; incompatible action/diff combos are rejected at parse time and at execution. +- `custom` requires `custom_row` or `helpers`. + +## Running with a repair file + +```bash +./ace table-repair \ + --diff-file=public_customers_diffs-20251210.json \ + --repair-file=repair.yaml \ + --dry-run # optional +``` + +Notes: +- When `--repair-file` is provided, `--source-of-truth` is not required. +- Dry-run and reports include rule usage counts per node. +- `--fix-nulls` and `--bidirectional` cannot be combined with a repair file (they’re separate modes). + +## Minimal skeleton + +```yaml +version: 1 +default_action: + type: skip +tables: + public.customers: + default_action: + type: keep_n1 + rules: + - name: eu_to_n2 + diff_type: [row_mismatch, missing_on_n2] + when: "n1.region = 'eu'" + action: + type: keep_n2 + - name: coalesce_contact + columns_changed: [email, phone] + action: + type: custom + helpers: + coalesce_priority: [n1, n2] + row_overrides: + - name: fix_customer_42 + pk: { id: 42 } + action: + type: custom + custom_row: + id: 42 + status: "vip" + email: "{{n1.email}}" +``` diff --git a/docs/commands/repair/index.md b/docs/commands/repair/index.md index 98e6b54..9fe4d0d 100644 --- a/docs/commands/repair/index.md +++ b/docs/commands/repair/index.md @@ -10,6 +10,7 @@ The `table-repair` command fixes data inconsistencies identified by `table-diff` - **NULL-only drift**: Use `--fix-nulls` to cross-fill NULL columns without a single source-of-truth. - **Network partition repair**: Re‑align nodes after a partition. - **Temporary node outage**: Catch a lagging node up. +- **Catastrophic node failure recovery**: Use origin-scoped diffs plus recovery-mode repair to reconcile survivors when a node fails mid-replication. See [Using ACE for catastrophic node failure recovery](../../using-ace-for-catastrophic-node-failure-recovery.md). **Safety & audit features** @@ -17,6 +18,7 @@ The `table-repair` command fixes data inconsistencies identified by `table-diff` - **Report generation**: write a detailed audit of actions taken. - **Upsert‑only**: prevent deletions on divergent nodes. - **Transaction safety**: changes are atomic; partial failures are rolled back. +- **Advanced repair plans**: drive repairs from a versioned YAML/JSON file with per-table rules, overrides, and custom rows. See [Advanced repair](advanced-repair.md) and [Examples](advanced-repair-examples.md). **Helpful Tips** diff --git a/docs/commands/repair/table-repair.md b/docs/commands/repair/table-repair.md index 2bf5b08..fd99093 100644 --- a/docs/commands/repair/table-repair.md +++ b/docs/commands/repair/table-repair.md @@ -25,12 +25,20 @@ Performs repairs on tables of divergent nodes based on the diff report generated | `--generate-report` | `-g` | Write a JSON repair report to `reports//repair_report_.json` | `false` | | `--insert-only` | `-i` | Only insert missing rows; skip updates/deletes | `false` | | `--upsert-only` | `-P` | Insert or update rows; skip deletes | `false` | +| `--repair-file ` | `-p` | Path to an advanced repair plan (YAML/JSON). Overrides `--source-of-truth` and uses rule-based actions. | | | `--fix-nulls` | `-X` | Fill NULL columns on each node using non-NULL values from its peers (no source-of-truth needed) | `false` | | `--bidirectional` | `-Z` | Perform insert-only repairs in both directions | `false` | | `--fire-triggers` | `-t` | Execute triggers (otherwise runs with `session_replication_role='replica'`) | `false` | +| `--recovery-mode` | | Enable recovery-mode repair when the diff was generated with `--only-origin`; can auto-select a source of truth using Spock LSNs | `false` | | `--quiet` | `-q` | Suppress non-essential logging | `false` | | `--debug` | `-v` | Enable verbose logging | `false` | +### Advanced repair plans +- Use `--repair-file` to drive repairs from a plan. Source-of-truth becomes optional; the plan sets per-row decisions. +- Mutually exclusive with `--bidirectional` and `--fix-nulls` (those are separate modes). +- Dry-run and reports include rule usage counts per node. +- See [Advanced repair](advanced-repair.md) for grammar and [Examples](advanced-repair-examples.md) for recipes. + ## Example ```sh @@ -42,6 +50,13 @@ Performs repairs on tables of divergent nodes based on the diff report generated Diff reports share the same prefix generated by `table-diff` (for example `public_customers_large_diffs-20250718134542.json`). When you request a dry run or report, ACE also writes JSON summaries under `reports//repair_report_.json` (or `dry_run_report_<...>.json`). +### Recovery-mode behavior + +- If the diff file indicates `only_origin`, `table-repair` refuses to run unless `--recovery-mode` is set. +- In recovery-mode, if no `--source-of-truth` is provided, ACE probes surviving nodes for the failed node’s Spock origin LSN (preferred) and slot LSN (fallback) and picks the highest. Ties or missing LSNs require you to provide `--source-of-truth`. +- The chosen source (auto or explicit) is recorded in the repair report along with the LSN probes. +- Recovery-mode still accepts `--repair-file`; the plan is applied after the source of truth is determined. If no repair file is provided, ACE performs a standard repair with the recovery-mode source selection. + ## Sample Output ``` diff --git a/docs/using-ace-for-cnf-recovery.md b/docs/using-ace-for-cnf-recovery.md new file mode 100644 index 0000000..d7f58b4 --- /dev/null +++ b/docs/using-ace-for-cnf-recovery.md @@ -0,0 +1,29 @@ +# Using ACE for Catastrophic Node Failure Recovery + +Catastrophic node failures (CNFs) leave a cluster with one node abruptly down mid‑replication. The failed node’s transactions may be partially replicated; survivors can drift. ACE helps you scope and repair the drift by focusing on the failed node’s origin and an agreed cutoff, then repairing from the best survivor. + +## What you need +- Spock metadata available on survivors (node names, origins, slots). +- A cutoff for the failed node’s commits (timestamp/LSN) so you can ignore churn after failure. +- Optional: a repair plan (YAML/JSON) if you need rule‑based actions (upsert‑only, coalescing, etc.). + +## Workflow +1) **Capture an origin-scoped diff** + Run `table-diff` on survivors with `--only-origin ` and optionally `--until ` to fence at the last known commit from the failed node. You can combine a `--table-filter` to limit scope. The diff summary records `only_origin`, resolved node name, `until`, raw/effective filters, and an `origin_only` flag. + +2) **Choose source of truth (SoT)** + In `table-repair --recovery-mode`, ACE: + - Refuses origin-only diffs unless recovery-mode is set. + - If `--source-of-truth` is not provided, probes survivors for Spock origin LSN (preferred) and slot LSN (fallback) for the failed node and picks the highest. Ties or missing LSNs require an explicit SoT. + - Records the SoT choice and LSN probes in the repair report. + +3) **Repair** + Run `table-repair --recovery-mode ... --diff-file= [--repair-file plan.yaml]`. A repair plan is optional; if provided, it is applied after SoT selection. + +4) **Validate** + Re-run `table-diff` (optionally without `--only-origin`) to confirm survivors converge. Repeat per table or filter chunk if needed. + +## Notes and cautions +- If origin/slot LSNs are absent on survivors, auto-selection will fail; provide `--source-of-truth` explicitly. +- The `until` fence should reflect the failed node’s last trusted commit to avoid including churn after failure. +- Advanced plans are allowed in recovery-mode; use them for upsert-only/coalesce patterns instead of default delete/update behavior. diff --git a/internal/server/handler.go b/internal/api/http/handler.go similarity index 97% rename from internal/server/handler.go rename to internal/api/http/handler.go index d223d9e..3f176cb 100644 --- a/internal/server/handler.go +++ b/internal/api/http/handler.go @@ -9,7 +9,9 @@ import ( "strings" "time" - "github.com/pgedge/ace/internal/core" + "github.com/pgedge/ace/internal/consistency/diff" + "github.com/pgedge/ace/internal/consistency/mtree" + "github.com/pgedge/ace/internal/consistency/repair" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/taskstore" @@ -43,6 +45,7 @@ type tableRepairRequest struct { DBName string `json:"dbname"` Nodes []string `json:"nodes"` DiffFile string `json:"diff_file"` + RepairPlan string `json:"repair_plan"` SourceOfTruth string `json:"source_of_truth"` Quiet bool `json:"quiet"` DryRun bool `json:"dry_run"` @@ -198,7 +201,7 @@ func (s *APIServer) handleTableDiff(w http.ResponseWriter, r *http.Request) { return } - task := core.NewTableDiffTask() + task := diff.NewTableDiffTask() task.ClusterName = cluster task.QualifiedTableName = tableName task.DBName = strings.TrimSpace(req.DBName) @@ -344,7 +347,7 @@ func (s *APIServer) handleTableRerun(w http.ResponseWriter, r *http.Request) { return } - task := core.NewTableDiffTask() + task := diff.NewTableDiffTask() task.Mode = "rerun" task.ClusterName = cluster task.DiffFilePath = diffFile @@ -415,12 +418,13 @@ func (s *APIServer) handleTableRepair(w http.ResponseWriter, r *http.Request) { return } - task := core.NewTableRepairTask() + task := repair.NewTableRepairTask() task.ClusterName = cluster task.QualifiedTableName = tableName task.DBName = strings.TrimSpace(req.DBName) task.Nodes = s.resolveNodes(req.Nodes) task.DiffFilePath = diffFile + task.RepairPlanPath = strings.TrimSpace(req.RepairPlan) task.SourceOfTruth = strings.TrimSpace(req.SourceOfTruth) task.QuietMode = req.Quiet task.DryRun = req.DryRun @@ -487,7 +491,7 @@ func (s *APIServer) handleSpockDiff(w http.ResponseWriter, r *http.Request) { return } - task := core.NewSpockDiffTask() + task := diff.NewSpockDiffTask() task.ClusterName = cluster task.DBName = strings.TrimSpace(req.DBName) task.Nodes = s.resolveNodes(req.Nodes) @@ -555,7 +559,7 @@ func (s *APIServer) handleSchemaDiff(w http.ResponseWriter, r *http.Request) { return } - task := core.NewSchemaDiffTask() + task := diff.NewSchemaDiffTask() task.ClusterName = cluster task.SchemaName = schema task.DBName = strings.TrimSpace(req.DBName) @@ -636,7 +640,7 @@ func (s *APIServer) handleRepsetDiff(w http.ResponseWriter, r *http.Request) { return } - task := core.NewRepsetDiffTask() + task := diff.NewRepsetDiffTask() task.ClusterName = cluster task.RepsetName = repset task.DBName = strings.TrimSpace(req.DBName) @@ -670,7 +674,7 @@ func (s *APIServer) handleRepsetDiff(w http.ResponseWriter, r *http.Request) { if err := s.enqueueTask(task.TaskID, func(ctx context.Context) error { task.Ctx = ctx - return core.RepsetDiff(task) + return diff.RepsetDiff(task) }); err != nil { logger.Error("failed to enqueue repset-diff task: %v", err) writeError(w, http.StatusInternalServerError, "unable to enqueue task") @@ -724,7 +728,7 @@ func (s *APIServer) handleMtreeInit(w http.ResponseWriter, r *http.Request) { return } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = cluster task.DBName = strings.TrimSpace(req.DBName) task.Nodes = s.resolveNodes(req.Nodes) @@ -782,7 +786,7 @@ func (s *APIServer) handleMtreeTeardown(w http.ResponseWriter, r *http.Request) return } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = cluster task.DBName = strings.TrimSpace(req.DBName) task.Nodes = s.resolveNodes(req.Nodes) @@ -845,7 +849,7 @@ func (s *APIServer) handleMtreeTeardownTable(w http.ResponseWriter, r *http.Requ return } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = cluster task.QualifiedTableName = table task.DBName = strings.TrimSpace(req.DBName) @@ -909,7 +913,7 @@ func (s *APIServer) handleMtreeBuild(w http.ResponseWriter, r *http.Request) { return } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = cluster task.QualifiedTableName = table task.DBName = strings.TrimSpace(req.DBName) @@ -989,7 +993,7 @@ func (s *APIServer) handleMtreeUpdate(w http.ResponseWriter, r *http.Request) { return } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = cluster task.QualifiedTableName = table task.DBName = strings.TrimSpace(req.DBName) @@ -1064,7 +1068,7 @@ func (s *APIServer) handleMtreeDiff(w http.ResponseWriter, r *http.Request) { return } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = cluster task.QualifiedTableName = table task.DBName = strings.TrimSpace(req.DBName) diff --git a/internal/server/server.go b/internal/api/http/server.go similarity index 100% rename from internal/server/server.go rename to internal/api/http/server.go diff --git a/internal/server/validator.go b/internal/api/http/validator.go similarity index 100% rename from internal/server/validator.go rename to internal/api/http/validator.go diff --git a/internal/server/validator_test.go b/internal/api/http/validator_test.go similarity index 100% rename from internal/server/validator_test.go rename to internal/api/http/validator_test.go diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 4c53e86..a820b22 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -26,10 +26,12 @@ import ( "github.com/charmbracelet/log" "github.com/google/uuid" - "github.com/pgedge/ace/internal/cdc" - "github.com/pgedge/ace/internal/core" - "github.com/pgedge/ace/internal/scheduler" - "github.com/pgedge/ace/internal/server" + "github.com/pgedge/ace/internal/api/http" + "github.com/pgedge/ace/internal/consistency/diff" + "github.com/pgedge/ace/internal/consistency/mtree" + "github.com/pgedge/ace/internal/consistency/repair" + "github.com/pgedge/ace/internal/infra/cdc" + "github.com/pgedge/ace/internal/jobs" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/logger" "github.com/urfave/cli/v2" @@ -132,6 +134,21 @@ func SetupCLI() *cli.App { Usage: "Where clause expression to use while diffing tables", Value: "", }, + &cli.StringFlag{ + Name: "only-origin", + Usage: "Restrict diff to rows whose node_origin matches this Spock node id or name", + Value: "", + }, + &cli.StringFlag{ + Name: "until", + Usage: "Optional commit timestamp upper bound (RFC3339) for rows to include", + Value: "", + }, + &cli.BoolFlag{ + Name: "ensure-pgcrypto", + Usage: "Ensure pgcrypto extension is installed on each node before diffing", + Value: false, + }, &cli.BoolFlag{ Name: "schedule", Aliases: []string{"S"}, @@ -155,6 +172,11 @@ func SetupCLI() *cli.App { Usage: "Path to the diff file (required)", Required: true, }, + &cli.StringFlag{ + Name: "repair-file", + Aliases: []string{"p"}, + Usage: "Path to the advanced repair file (YAML/JSON); skips source-of-truth requirement", + }, &cli.StringFlag{ Name: "source-of-truth", Aliases: []string{"r"}, @@ -196,6 +218,11 @@ func SetupCLI() *cli.App { Usage: "Whether to perform repairs in both directions. Can be used only with the insert-only option", Value: false, }, + &cli.BoolFlag{ + Name: "recovery-mode", + Usage: "Enable recovery-mode repair using origin-only diffs", + Value: false, + }, &cli.BoolFlag{ Name: "fix-nulls", Aliases: []string{"X"}, @@ -841,7 +868,7 @@ func TableDiffCLI(ctx *cli.Context) error { return fmt.Errorf("invalid block size '%s': %w", blockSizeStr, err) } - task := core.NewTableDiffTask() + task := diff.NewTableDiffTask() task.ClusterName = clusterName task.QualifiedTableName = positional[0] task.DBName = ctx.String("dbname") @@ -850,9 +877,12 @@ func TableDiffCLI(ctx *cli.Context) error { task.CompareUnitSize = ctx.Int("compare-unit-size") task.Output = strings.ToLower(ctx.String("output")) task.Nodes = ctx.String("nodes") + task.EnsurePgcrypto = ctx.Bool("ensure-pgcrypto") scheduleEnabled := ctx.Bool("schedule") scheduleEvery := ctx.String("every") task.TableFilter = ctx.String("table-filter") + task.OnlyOrigin = ctx.String("only-origin") + task.Until = ctx.String("until") task.QuietMode = ctx.Bool("quiet") task.OverrideBlockSize = ctx.Bool("override-block-size") task.Ctx = context.Background() @@ -907,7 +937,7 @@ func MtreeInitCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = clusterName task.DBName = ctx.String("dbname") task.Nodes = ctx.String("nodes") @@ -928,7 +958,7 @@ func MtreeListenCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = clusterName task.DBName = ctx.String("dbname") task.Nodes = ctx.String("nodes") @@ -976,7 +1006,7 @@ func MtreeTeardownCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = clusterName task.DBName = ctx.String("dbname") task.Nodes = ctx.String("nodes") @@ -997,7 +1027,7 @@ func MtreeTeardownTableCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = clusterName task.QualifiedTableName = positional[0] task.DBName = ctx.String("dbname") @@ -1025,7 +1055,7 @@ func MtreeBuildCLI(ctx *cli.Context) error { return fmt.Errorf("invalid block size '%s': %w", blockSizeStr, err) } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = clusterName task.QualifiedTableName = positional[0] task.DBName = ctx.String("dbname") @@ -1062,7 +1092,7 @@ func MtreeUpdateCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = clusterName task.QualifiedTableName = positional[0] task.DBName = ctx.String("dbname") @@ -1094,7 +1124,7 @@ func MtreeDiffCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = clusterName task.QualifiedTableName = positional[0] task.DBName = ctx.String("dbname") @@ -1127,7 +1157,7 @@ func TableRerunCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewTableDiffTask() + task := diff.NewTableDiffTask() task.TaskID = uuid.NewString() task.Mode = "rerun" task.ClusterName = clusterName @@ -1150,10 +1180,11 @@ func TableRepairCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewTableRepairTask() + task := repair.NewTableRepairTask() task.ClusterName = clusterName task.QualifiedTableName = positional[0] task.DiffFilePath = ctx.String("diff-file") + task.RepairPlanPath = ctx.String("repair-file") task.DBName = ctx.String("dbname") task.Nodes = ctx.String("nodes") task.SourceOfTruth = ctx.String("source-of-truth") @@ -1167,6 +1198,7 @@ func TableRepairCLI(ctx *cli.Context) error { task.FixNulls = ctx.Bool("fix-nulls") task.Bidirectional = ctx.Bool("bidirectional") task.GenerateReport = ctx.Bool("generate-report") + task.RecoveryMode = ctx.Bool("recovery-mode") if err := task.ValidateAndPrepare(); err != nil { return fmt.Errorf("validation failed: %w", err) @@ -1185,7 +1217,7 @@ func SpockDiffCLI(ctx *cli.Context) error { if err != nil { return err } - task := core.NewSpockDiffTask() + task := diff.NewSpockDiffTask() task.ClusterName = clusterName task.DBName = ctx.String("dbname") task.Nodes = ctx.String("nodes") @@ -1218,7 +1250,7 @@ func SchemaDiffCLI(ctx *cli.Context) error { return fmt.Errorf("invalid block size '%s': %w", blockSizeStr, err) } - task := core.NewSchemaDiffTask() + task := diff.NewSchemaDiffTask() task.ClusterName = clusterName task.SchemaName = positional[0] task.DBName = ctx.String("dbname") @@ -1300,7 +1332,7 @@ func RepsetDiffCLI(ctx *cli.Context) error { scheduleEnabled := ctx.Bool("schedule") scheduleEvery := ctx.String("every") - task := core.NewRepsetDiffTask() + task := diff.NewRepsetDiffTask() task.ClusterName = clusterName task.RepsetName = positional[0] task.DBName = ctx.String("dbname") @@ -1324,7 +1356,7 @@ func RepsetDiffCLI(ctx *cli.Context) error { if err := task.RunChecks(true); err != nil { return fmt.Errorf("checks failed: %w", err) } - if err := core.RepsetDiff(task); err != nil { + if err := diff.RepsetDiff(task); err != nil { return fmt.Errorf("error during repset diff: %w", err) } return nil @@ -1347,7 +1379,7 @@ func RepsetDiffCLI(ctx *cli.Context) error { if err := runTask.RunChecks(true); err != nil { return fmt.Errorf("checks failed: %w", err) } - if err := core.RepsetDiff(runTask); err != nil { + if err := diff.RepsetDiff(runTask); err != nil { return fmt.Errorf("execution failed: %w", err) } return nil diff --git a/internal/core/repset_diff.go b/internal/consistency/diff/repset_diff.go similarity index 99% rename from internal/core/repset_diff.go rename to internal/consistency/diff/repset_diff.go index d417b8a..eea2938 100644 --- a/internal/core/repset_diff.go +++ b/internal/consistency/diff/repset_diff.go @@ -9,7 +9,7 @@ // // /////////////////////////////////////////////////////////////////////////// -package core +package diff import ( "bufio" @@ -24,7 +24,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" + "github.com/pgedge/ace/internal/infra/db" utils "github.com/pgedge/ace/pkg/common" "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/taskstore" diff --git a/internal/core/schema_diff.go b/internal/consistency/diff/schema_diff.go similarity index 99% rename from internal/core/schema_diff.go rename to internal/consistency/diff/schema_diff.go index f7943f7..b6c7982 100644 --- a/internal/core/schema_diff.go +++ b/internal/consistency/diff/schema_diff.go @@ -9,7 +9,7 @@ // // /////////////////////////////////////////////////////////////////////////// -package core +package diff import ( "context" @@ -23,7 +23,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" + "github.com/pgedge/ace/internal/infra/db" utils "github.com/pgedge/ace/pkg/common" "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/taskstore" diff --git a/internal/core/spock_diff.go b/internal/consistency/diff/spock_diff.go similarity index 99% rename from internal/core/spock_diff.go rename to internal/consistency/diff/spock_diff.go index 6d365a8..cb918bb 100644 --- a/internal/core/spock_diff.go +++ b/internal/consistency/diff/spock_diff.go @@ -9,7 +9,7 @@ // // /////////////////////////////////////////////////////////////////////////// -package core +package diff import ( "context" @@ -25,7 +25,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" + "github.com/pgedge/ace/internal/infra/db" utils "github.com/pgedge/ace/pkg/common" "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/taskstore" diff --git a/internal/core/table_diff.go b/internal/consistency/diff/table_diff.go similarity index 88% rename from internal/core/table_diff.go rename to internal/consistency/diff/table_diff.go index a52a889..4577c34 100644 --- a/internal/core/table_diff.go +++ b/internal/consistency/diff/table_diff.go @@ -9,13 +9,14 @@ // // /////////////////////////////////////////////////////////////////////////// -package core +package diff import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "maps" "math" @@ -34,7 +35,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" + "github.com/pgedge/ace/internal/infra/db" utils "github.com/pgedge/ace/pkg/common" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/logger" @@ -72,6 +73,11 @@ type TableDiffTask struct { TableFilter string QuietMode bool + OnlyOrigin string + Until string + + EffectiveFilter string + Mode string OverrideBlockSize bool @@ -80,6 +86,8 @@ type TableDiffTask struct { InvokeMethod string ClientRole string + EnsurePgcrypto bool + DiffSummary map[string]string SkipDBUpdate bool @@ -92,6 +100,8 @@ type TableDiffTask struct { blockHashSQLCache map[hashBoundsKey]string blockHashSQLMu sync.Mutex + SpockNodeNames map[string]string + CompareUnitSize int MaxDiffRows int64 @@ -104,6 +114,9 @@ type TableDiffTask struct { totalDiffRows atomic.Int64 diffLimitTriggered atomic.Bool + resolvedOnlyOrigin string + untilTime *time.Time + Ctx context.Context } @@ -164,6 +177,101 @@ func (t *TableDiffTask) incrementDiffRowsLocked(delta int) bool { return false } +func (t *TableDiffTask) loadSpockNodeNames() error { + if t.SpockNodeNames != nil { + return nil + } + + var firstPool *pgxpool.Pool + for _, pool := range t.Pools { + firstPool = pool + break + } + + if firstPool == nil { + t.SpockNodeNames = make(map[string]string) + return fmt.Errorf("no connection pool available to load spock node names") + } + + names, err := queries.GetSpockNodeNames(t.Ctx, firstPool) + if err != nil { + t.SpockNodeNames = make(map[string]string) + return err + } + + t.SpockNodeNames = names + return nil +} + +func (t *TableDiffTask) resolveOnlyOrigin() error { + if strings.TrimSpace(t.OnlyOrigin) == "" { + return nil + } + if len(t.SpockNodeNames) == 0 { + return fmt.Errorf("unable to resolve --only-origin: spock node names not available") + } + + orig := strings.TrimSpace(t.OnlyOrigin) + // direct match on id + if _, ok := t.SpockNodeNames[orig]; ok { + t.resolvedOnlyOrigin = orig + return nil + } + + // match on name + for id, name := range t.SpockNodeNames { + if name == orig { + t.resolvedOnlyOrigin = id + return nil + } + } + + return fmt.Errorf("unable to resolve only-origin %q to a spock node id", t.OnlyOrigin) +} + +func (t *TableDiffTask) buildEffectiveFilter() (string, error) { + if t.untilTime == nil && strings.TrimSpace(t.Until) != "" { + parsed, err := time.Parse(time.RFC3339, strings.TrimSpace(t.Until)) + if err != nil { + return "", fmt.Errorf("invalid value for --until (expected RFC3339 timestamp): %w", err) + } + t.untilTime = &parsed + } + + var parts []string + trimmed := strings.TrimSpace(t.TableFilter) + if trimmed != "" { + parts = append(parts, fmt.Sprintf("(%s)", trimmed)) + } + + if t.resolvedOnlyOrigin != "" { + escaped := strings.ReplaceAll(t.resolvedOnlyOrigin, "'", "''") + parts = append(parts, fmt.Sprintf("(to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' = '%s')", escaped)) + } + + if t.untilTime != nil { + parts = append(parts, fmt.Sprintf("(pg_xact_commit_timestamp(xmin) <= '%s'::timestamptz)", t.untilTime.Format(time.RFC3339))) + } + + if len(parts) == 0 { + return "", nil + } + return strings.Join(parts, " AND "), nil +} + +func (t *TableDiffTask) withSpockMetadata(row map[string]any) map[string]any { + row["node_origin"] = utils.TranslateNodeOrigin(row["node_origin"], t.SpockNodeNames) + return utils.AddSpockMetadata(row) +} + +func isPgcryptoMissing(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "digest(") || strings.Contains(msg, "function digest") || strings.Contains(msg, "pgcrypto") +} + type RecursiveDiffTask struct { Node1Name string Node2Name string @@ -187,6 +295,63 @@ type hashBoundsKey struct { hasUpper bool } +func extractPlanRowEstimate(planJSON []byte) (int64, error) { + var plans []map[string]any + if err := json.Unmarshal(planJSON, &plans); err != nil { + return 0, fmt.Errorf("failed to parse EXPLAIN output: %w", err) + } + if len(plans) == 0 { + return 0, fmt.Errorf("EXPLAIN returned no plan") + } + + rootPlan, ok := plans[0]["Plan"].(map[string]any) + if !ok { + return 0, fmt.Errorf("unexpected EXPLAIN plan format") + } + if rowsVal, ok := rootPlan["Plan Rows"]; ok { + if rowsFloat, ok := rowsVal.(float64); ok { + return int64(rowsFloat), nil + } + } + return 0, fmt.Errorf("plan rows not found in EXPLAIN output") +} + +func (t *TableDiffTask) estimateRowCount(pool *pgxpool.Pool, nodeName string) (int64, error) { + schemaIdent := pgx.Identifier{t.Schema}.Sanitize() + tableIdent := pgx.Identifier{t.Table}.Sanitize() + + query := fmt.Sprintf("EXPLAIN (FORMAT JSON) SELECT 1 FROM %s.%s", schemaIdent, tableIdent) + if strings.TrimSpace(t.EffectiveFilter) != "" { + query = fmt.Sprintf("%s WHERE %s", query, t.EffectiveFilter) + } + + var planJSON []byte + if err := pool.QueryRow(t.Ctx, query).Scan(&planJSON); err != nil { + return 0, fmt.Errorf("failed to estimate row count on node %s: %w", nodeName, err) + } + + return extractPlanRowEstimate(planJSON) +} + +func (t *TableDiffTask) ensureFilterHasRows(pool *pgxpool.Pool, nodeName string) error { + if strings.TrimSpace(t.EffectiveFilter) == "" { + return nil + } + + schemaIdent := pgx.Identifier{t.Schema}.Sanitize() + tableIdent := pgx.Identifier{t.Table}.Sanitize() + sql := fmt.Sprintf("SELECT 1 FROM %s.%s WHERE %s LIMIT 1", schemaIdent, tableIdent, t.EffectiveFilter) + + var one int + if err := pool.QueryRow(t.Ctx, sql).Scan(&one); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("table filter produced no rows") + } + return fmt.Errorf("failed to validate table filter on node %s: %w", nodeName, err) + } + return nil +} + type RangeResults map[string]HashResult func (t *TableDiffTask) getBlockHashSQL(hasLower, hasUpper bool) (string, error) { @@ -203,7 +368,7 @@ func (t *TableDiffTask) getBlockHashSQL(hasLower, hasUpper bool) (string, error) return sql, nil } - query, err := queries.BlockHashSQL(t.Schema, t.Table, t.Key, "TD_BLOCK_HASH" /* mode */, hasLower, hasUpper) + query, err := queries.BlockHashSQL(t.Schema, t.Table, t.Key, "TD_BLOCK_HASH" /* mode */, hasLower, hasUpper, t.EffectiveFilter) if err != nil { return "", err } @@ -304,6 +469,10 @@ func (t *TableDiffTask) fetchRows(nodeName string, r Range) ([]types.OrderedMap, var conditions []string paramIndex := 1 + if strings.TrimSpace(t.EffectiveFilter) != "" { + conditions = append(conditions, fmt.Sprintf("(%s)", t.EffectiveFilter)) + } + if r.Start != nil { startVal := r.Start if len(t.Key) == 1 { @@ -435,6 +604,8 @@ func (t *TableDiffTask) fetchRows(nodeName string, r Range) ([]types.OrderedMap, } else { processedVal = nil } + case time.Time: + processedVal = v case string: processedVal = v case int8, int16, int32, int64, int, @@ -539,6 +710,14 @@ func (t *TableDiffTask) Validate() error { return fmt.Errorf("table-diff currently supports only json and html output formats") } + if trimmed := strings.TrimSpace(t.Until); trimmed != "" { + parsed, err := time.Parse(time.RFC3339, trimmed) + if err != nil { + return fmt.Errorf("invalid value for --until (expected RFC3339 timestamp): %w", err) + } + t.untilTime = &parsed + } + nodeList, err := utils.ParseNodes(t.Nodes) if err != nil { return fmt.Errorf("nodes should be a comma-separated list of nodenames. E.g., nodes=\"n1,n2\". Error: %w", err) @@ -637,12 +816,6 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) (err error) { if t.BaseTable == "" { t.BaseTable = table } - var filteredViewName string - if t.TableFilter != "" { - filteredViewName = buildFilteredViewName(t.TaskID, table) - t.FilteredViewName = filteredViewName - } - for _, nodeInfo := range t.ClusterNodes { hostname, _ := nodeInfo["Name"].(string) hostIP, _ := nodeInfo["PublicIP"].(string) @@ -716,30 +889,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) (err error) { hostMap[hostIP+":"+port] = hostname if t.TableFilter != "" { - sanitisedViewName := pgx.Identifier{filteredViewName}.Sanitize() - sanitisedSchema := pgx.Identifier{schema}.Sanitize() - sanitisedTable := pgx.Identifier{table}.Sanitize() - viewSQL := fmt.Sprintf("CREATE MATERIALIZED VIEW IF NOT EXISTS %s AS SELECT * FROM %s.%s WHERE %s", - sanitisedViewName, sanitisedSchema, sanitisedTable, t.TableFilter) - - _, err = conn.Exec(t.Ctx, viewSQL) - if err != nil { - return fmt.Errorf("failed to create filtered view: %w", err) - } - t.FilteredViewCreated = true - - hasRowsSQL := fmt.Sprintf("SELECT EXISTS (SELECT 1 FROM %s) AS has_rows", sanitisedViewName) - var hasRows bool - err = conn.QueryRow(t.Ctx, hasRowsSQL).Scan(&hasRows) - if err != nil { - return fmt.Errorf("failed to check if view has rows: %w", err) - } - - if !hasRows { - return fmt.Errorf("table filter produced no rows") - } - - t.FilteredViewCreated = true + logger.Info("Applying table filter for diff: %s", t.TableFilter) } } @@ -789,10 +939,6 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) (err error) { return err } - if t.TableFilter != "" { - t.Table = filteredViewName - } - return nil } @@ -851,6 +997,9 @@ func (t *TableDiffTask) CloneForSchedule(ctx context.Context) *TableDiffTask { cloned.InvokeMethod = t.InvokeMethod cloned.CompareUnitSize = t.CompareUnitSize cloned.MaxDiffRows = t.MaxDiffRows + cloned.EnsurePgcrypto = t.EnsurePgcrypto + cloned.OnlyOrigin = t.OnlyOrigin + cloned.Until = t.Until cloned.Ctx = ctx return cloned } @@ -1058,35 +1207,41 @@ func (t *TableDiffTask) ExecuteTask() (err error) { } t.Pools = pools + if t.EnsurePgcrypto { + for name, pool := range t.Pools { + if err := queries.EnsurePgcrypto(t.Ctx, pool); err != nil { + return fmt.Errorf("failed to ensure pgcrypto on node %s: %w", name, err) + } + } + } + + if err := t.loadSpockNodeNames(); err != nil { + logger.Warn("table-diff: unable to load spock node names; using raw node_origin values: %v", err) + } + + if err := t.resolveOnlyOrigin(); err != nil { + return err + } + effectiveFilter, err := t.buildEffectiveFilter() + if err != nil { + return err + } + t.EffectiveFilter = effectiveFilter + t.DiffSummary["effective_filter"] = effectiveFilter + if _, err = t.getBlockHashSQL(true, true); err != nil { return fmt.Errorf("failed to build block-hash SQL: %w", err) } var maxCount int64 var maxNode string - var totalEstimatedRowsAcrossNodes int64 for name, pool := range pools { - var count int64 - // TODO: Estimates cannot be used on views. But we can't run a count(*) - // on millions of rows either. Need to find a better way to do this. - if t.TableFilter == "" { - count, err = queries.GetRowCountEstimate(t.Ctx, pool, t.Schema, t.Table) - if err != nil { - return fmt.Errorf("failed to render estimate row count query: %w", err) - } - } else { - sanitisedSchema := pgx.Identifier{t.Schema}.Sanitize() - sanitisedTable := pgx.Identifier{t.Table}.Sanitize() - countQuerySQL := fmt.Sprintf("SELECT COUNT(*) FROM %s.%s", sanitisedSchema, sanitisedTable) - logger.Debug("[%s] Executing count query for filtered table: %s", name, countQuerySQL) - err = pool.QueryRow(t.Ctx, countQuerySQL).Scan(&count) - if err != nil { - return fmt.Errorf("failed to get row count for %s.%s on node %s (query: %s): %w", t.Schema, t.Table, name, countQuerySQL, err) - } + count, err := t.estimateRowCount(pool, name) + if err != nil { + return fmt.Errorf("failed to estimate row count for %s on node %s: %w", t.QualifiedTableName, name, err) } - totalEstimatedRowsAcrossNodes += int64(count) logger.Debug("Table contains %d rows (estimated) on %s", count, name) if count > maxCount { maxCount = count @@ -1096,6 +1251,9 @@ func (t *TableDiffTask) ExecuteTask() (err error) { if maxNode == "" { return fmt.Errorf("unable to determine node with highest row count (or any row counts)") } + if err := t.ensureFilterHasRows(pools[maxNode], maxNode); err != nil { + return err + } t.DiffResult = types.DiffOutput{ NodeDiffs: make(map[string]types.DiffByNodePair), @@ -1103,6 +1261,7 @@ func (t *TableDiffTask) ExecuteTask() (err error) { Schema: t.Schema, Table: t.BaseTable, TableFilter: t.TableFilter, + EffectiveFilter: t.EffectiveFilter, Nodes: t.NodeList, BlockSize: t.BlockSize, CompareUnitSize: t.CompareUnitSize, @@ -1111,6 +1270,22 @@ func (t *TableDiffTask) ExecuteTask() (err error) { StartTime: startTime.Format(time.RFC3339), TotalRowsChecked: int64(maxCount), DiffRowsCount: make(map[string]int), + OnlyOrigin: t.resolvedOnlyOrigin, + OnlyOriginResolved: func() string { + if t.resolvedOnlyOrigin != "" && t.SpockNodeNames != nil { + if name, ok := t.SpockNodeNames[t.resolvedOnlyOrigin]; ok { + return name + } + } + return "" + }(), + Until: func() string { + if t.untilTime != nil { + return t.untilTime.Format(time.RFC3339) + } + return strings.TrimSpace(t.Until) + }(), + OriginOnly: t.resolvedOnlyOrigin != "", }, } @@ -1152,7 +1327,7 @@ func (t *TableDiffTask) ExecuteTask() (err error) { ntileCount = 1 } - querySQL, err := queries.GeneratePkeyOffsetsQuery(t.Schema, t.Table, t.Key, sampleMethod, samplePercent, ntileCount) + querySQL, err := queries.GeneratePkeyOffsetsQuery(t.Schema, t.Table, t.Key, sampleMethod, samplePercent, ntileCount, t.EffectiveFilter) logger.Debug("Generated offsets query: %s", querySQL) if err != nil { return fmt.Errorf("failed to generate offsets query: %w", err) @@ -1435,7 +1610,11 @@ func (t *TableDiffTask) hashRange( if err != nil { duration := time.Since(startTime) logger.Debug("[%s] ERROR after %v for range Start=%v, End=%v (using query: '%s', args: %v): %v", node, duration, r.Start, r.End, query, args, err) - return "", fmt.Errorf("BlockHash query failed for %s range %v-%v: %w", node, r.Start, r.End, err) + baseErr := fmt.Errorf("BlockHash query failed for %s range %v-%v: %w", node, r.Start, r.End, err) + if isPgcryptoMissing(err) { + return "", fmt.Errorf("%w; pgcrypto extension not installed. Re-run with --ensure-pgcrypto or install via CREATE EXTENSION pgcrypto", baseErr) + } + return "", baseErr } duration := time.Since(startTime) @@ -1709,7 +1888,7 @@ func (t *TableDiffTask) recursiveDiff( break } rowAsMap := utils.OrderedMapToMap(row) - rowWithMeta := utils.AddSpockMetadata(rowAsMap) + rowWithMeta := t.withSpockMetadata(rowAsMap) rowAsOrderedMap := utils.MapToOrderedMap(rowWithMeta, t.Cols) t.DiffResult.NodeDiffs[pairKey].Rows[node1Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node1Name], rowAsOrderedMap) currentDiffRowsForPair++ @@ -1726,7 +1905,7 @@ func (t *TableDiffTask) recursiveDiff( break } rowAsMap := utils.OrderedMapToMap(row) - rowWithMeta := utils.AddSpockMetadata(rowAsMap) + rowWithMeta := t.withSpockMetadata(rowAsMap) rowAsOrderedMap := utils.MapToOrderedMap(rowWithMeta, t.Cols) t.DiffResult.NodeDiffs[pairKey].Rows[node2Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node2Name], rowAsOrderedMap) currentDiffRowsForPair++ @@ -1744,12 +1923,12 @@ func (t *TableDiffTask) recursiveDiff( break } node1DataAsMap := utils.OrderedMapToMap(modRow.Node1Data) - node1DataWithMeta := utils.AddSpockMetadata(node1DataAsMap) + node1DataWithMeta := t.withSpockMetadata(node1DataAsMap) node1DataAsOrderedMap := utils.MapToOrderedMap(node1DataWithMeta, t.Cols) t.DiffResult.NodeDiffs[pairKey].Rows[node1Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node1Name], node1DataAsOrderedMap) node2DataAsMap := utils.OrderedMapToMap(modRow.Node2Data) - node2DataWithMeta := utils.AddSpockMetadata(node2DataAsMap) + node2DataWithMeta := t.withSpockMetadata(node2DataAsMap) node2DataAsOrderedMap := utils.MapToOrderedMap(node2DataWithMeta, t.Cols) t.DiffResult.NodeDiffs[pairKey].Rows[node2Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node2Name], node2DataAsOrderedMap) currentDiffRowsForPair++ diff --git a/internal/core/table_rerun.go b/internal/consistency/diff/table_rerun.go similarity index 97% rename from internal/core/table_rerun.go rename to internal/consistency/diff/table_rerun.go index 739e0c7..ffa936c 100644 --- a/internal/core/table_rerun.go +++ b/internal/consistency/diff/table_rerun.go @@ -9,7 +9,7 @@ // // /////////////////////////////////////////////////////////////////////////// -package core +package diff import ( "context" @@ -26,7 +26,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" + "github.com/pgedge/ace/internal/infra/db" utils "github.com/pgedge/ace/pkg/common" "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/types" @@ -79,6 +79,10 @@ func (t *TableDiffTask) ExecuteRerunTask() error { } }() + if err := t.loadSpockNodeNames(); err != nil { + logger.Warn("table-diff rerun: unable to load spock node names; using raw node_origin values: %v", err) + } + // Collect all unique primary keys from the original diff report allPkeys, err := t.collectPkeysFromDiff() if err != nil { @@ -370,12 +374,12 @@ func (t *TableDiffTask) reCompareDiffs(fetchedRowsByNode map[string]map[string]t persistentDiffCount++ if nowOnNode1 { rowAsMap := utils.OrderedMapToMap(newRow1) - rowWithMeta := utils.AddSpockMetadata(rowAsMap) + rowWithMeta := t.withSpockMetadata(rowAsMap) newDiffsForPair.Rows[node1] = append(newDiffsForPair.Rows[node1], utils.MapToOrderedMap(rowWithMeta, t.Cols)) } if nowOnNode2 { rowAsMap := utils.OrderedMapToMap(newRow2) - rowWithMeta := utils.AddSpockMetadata(rowAsMap) + rowWithMeta := t.withSpockMetadata(rowAsMap) newDiffsForPair.Rows[node2] = append(newDiffsForPair.Rows[node2], utils.MapToOrderedMap(rowWithMeta, t.Cols)) } } diff --git a/internal/core/merkle.go b/internal/consistency/mtree/merkle.go similarity index 97% rename from internal/core/merkle.go rename to internal/consistency/mtree/merkle.go index ffb617f..3e58ac1 100644 --- a/internal/core/merkle.go +++ b/internal/consistency/mtree/merkle.go @@ -9,7 +9,7 @@ // // /////////////////////////////////////////////////////////////////////////// -package core +package mtree import ( "context" @@ -36,8 +36,8 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" - "github.com/pgedge/ace/internal/cdc" + "github.com/pgedge/ace/internal/infra/cdc" + "github.com/pgedge/ace/internal/infra/db" utils "github.com/pgedge/ace/pkg/common" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/logger" @@ -88,6 +88,7 @@ type MerkleTreeTask struct { diffMutex sync.Mutex diffRowKeySets map[string]map[string]map[string]struct{} StartTime time.Time + SpockNodeNames map[string]string Ctx context.Context } @@ -489,6 +490,35 @@ func (m *MerkleTreeTask) processWorkItem(work CompareRangesWorkItem, pool1, pool return nil } +func (m *MerkleTreeTask) loadSpockNodeNames() error { + if m.SpockNodeNames != nil { + return nil + } + + var lastErr error + for _, nodeInfo := range m.ClusterNodes { + pool, err := auth.GetClusterNodeConnection(m.Ctx, nodeInfo, m.connOpts()) + if err != nil { + lastErr = err + continue + } + names, err := queries.GetSpockNodeNames(m.Ctx, pool) + pool.Close() + if err != nil { + lastErr = err + continue + } + m.SpockNodeNames = names + return nil + } + + m.SpockNodeNames = make(map[string]string) + if lastErr != nil { + return lastErr + } + return fmt.Errorf("no nodes available to load spock node names") +} + func (m *MerkleTreeTask) appendDiffs(nodePairKey string, work CompareRangesWorkItem, pr1, pr2 []types.OrderedMap) error { diffResult, err := utils.CompareRowSets(pr1, pr2, m.Key, m.Cols) if err != nil { @@ -557,6 +587,11 @@ func (m *MerkleTreeTask) addRowToDiff(nodePairKey, nodeName string, row types.Or m.diffRowKeySets = make(map[string]map[string]map[string]struct{}) } + rowMap := utils.OrderedMapToMap(row) + rowMap["node_origin"] = utils.TranslateNodeOrigin(rowMap["node_origin"], m.SpockNodeNames) + rowWithMeta := utils.AddSpockMetadata(rowMap) + orderedRow := utils.MapToOrderedMap(rowWithMeta, m.Cols) + pairSet, ok := m.diffRowKeySets[nodePairKey] if !ok { pairSet = make(map[string]map[string]struct{}) @@ -569,7 +604,7 @@ func (m *MerkleTreeTask) addRowToDiff(nodePairKey, nodeName string, row types.Or pairSet[nodeName] = nodeSet } - key, err := m.buildRowKey(row) + key, err := m.buildRowKey(orderedRow) if err != nil { return false, err } @@ -579,7 +614,7 @@ func (m *MerkleTreeTask) addRowToDiff(nodePairKey, nodeName string, row types.Or } nodeSet[key] = struct{}{} - m.DiffResult.NodeDiffs[nodePairKey].Rows[nodeName] = append(m.DiffResult.NodeDiffs[nodePairKey].Rows[nodeName], row) + m.DiffResult.NodeDiffs[nodePairKey].Rows[nodeName] = append(m.DiffResult.NodeDiffs[nodePairKey].Rows[nodeName], orderedRow) return true, nil } @@ -663,7 +698,8 @@ func buildFetchRowsSQLSimple(tableName, pk string, orderBy string, keys []any) ( args[i] = keys[i] } where := fmt.Sprintf("%s IN (%s)", pgx.Identifier{pk}.Sanitize(), strings.Join(placeholders, ",")) - q := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY %s", tableName, where, orderBy) + selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" + q := fmt.Sprintf("SELECT %s FROM %s WHERE %s ORDER BY %s", selectCols, tableName, where, orderBy) return q, args } @@ -685,7 +721,8 @@ func buildFetchRowsSQLComposite(tableName string, pk []string, orderBy string, k tuples = append(tuples, fmt.Sprintf("(%s)", strings.Join(ph, ","))) } where := fmt.Sprintf("( %s ) IN ( %s )", strings.Join(tupleCols, ","), strings.Join(tuples, ",")) - q := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY %s", tableName, where, orderBy) + selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" + q := fmt.Sprintf("SELECT %s FROM %s WHERE %s ORDER BY %s", selectCols, tableName, where, orderBy) return q, args } @@ -1130,7 +1167,7 @@ func (m *MerkleTreeTask) RunChecks(skipValidation bool) error { } defer pool.Close() - if _, err := pool.Exec(m.Ctx, "CREATE EXTENSION IF NOT EXISTS pgcrypto;"); err != nil { + if err := queries.EnsurePgcrypto(m.Ctx, pool); err != nil { return fmt.Errorf("failed to ensure pgcrypto is installed on %s: %w", nodeInfo["Name"], err) } tx, err := pool.Begin(m.Ctx) @@ -1261,7 +1298,7 @@ func (m *MerkleTreeTask) BuildMtree() (err error) { keyColumns := m.Key - offsetsQuery, err := queries.GeneratePkeyOffsetsQuery(m.Schema, m.Table, keyColumns, sampleMethod, samplePercent, numBlocks) + offsetsQuery, err := queries.GeneratePkeyOffsetsQuery(m.Schema, m.Table, keyColumns, sampleMethod, samplePercent, numBlocks, "") if err != nil { return fmt.Errorf("failed to generate pkey offsets query: %w", err) } @@ -1789,6 +1826,9 @@ func (m *MerkleTreeTask) DiffMtree() (err error) { if err = m.UpdateMtree(true); err != nil { return fmt.Errorf("failed to update merkle tree before diff: %w", err) } + if err := m.loadSpockNodeNames(); err != nil { + logger.Warn("mtree diff: unable to load spock node names; using raw node_origin values: %v", err) + } nodePairs := getNodePairs(m.ClusterNodes) mtreeTableIdentifier := pgx.Identifier{aceSchema(), fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table)} mtreeTableName := mtreeTableIdentifier.Sanitize() diff --git a/internal/consistency/repair/executor.go b/internal/consistency/repair/executor.go new file mode 100644 index 0000000..be42a75 --- /dev/null +++ b/internal/consistency/repair/executor.go @@ -0,0 +1,760 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2025, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package repair + +import ( + "fmt" + "reflect" + "strings" + "time" + + "github.com/pgedge/ace/internal/consistency/repair/plan" + utils "github.com/pgedge/ace/pkg/common" + "github.com/pgedge/ace/pkg/types" +) + +type planDiffRow struct { + pkStr string + pkMap map[string]any + node1Name string + node2Name string + n1Row map[string]any + n2Row map[string]any + n1Meta map[string]any + n2Meta map[string]any + diffType string // missing_on_n1, missing_on_n2, row_mismatch + columnsChanged []string // only set for mismatches +} + +func CalculatePlanRepairSets(task *TableRepairTask) (map[string]map[string]map[string]any, map[string]map[string]map[string]any, error) { + if task.RepairPlan == nil { + return nil, nil, fmt.Errorf("repair plan is nil") + } + + tableKey := fmt.Sprintf("%s.%s", task.Schema, task.Table) + tablePlan, ok := task.RepairPlan.Tables[tableKey] + if !ok { + return nil, nil, fmt.Errorf("repair plan does not contain table %s", tableKey) + } + + fullRowsToUpsert := make(map[string]map[string]map[string]any) // nodeName -> pkey -> rowData + fullRowsToDelete := make(map[string]map[string]map[string]any) // nodeName -> pkey -> rowData + task.planRuleMatches = make(map[string]map[string]string) + + for pairKey, diffs := range task.RawDiffs.NodeDiffs { + nodes := strings.Split(pairKey, "/") + if len(nodes) != 2 { + continue + } + node1Name, node2Name := nodes[0], nodes[1] + diffRows, err := buildPlanDiffRows(task, node1Name, node2Name, diffs) + if err != nil { + return nil, nil, err + } + + for _, d := range diffRows { + action, ruleName, err := resolvePlanAction(task, tablePlan, d) + if err != nil { + return nil, nil, err + } + if action == nil { + return nil, nil, fmt.Errorf("no action resolved for row %s", d.pkStr) + } + if err := ensureActionCompatible(d, action); err != nil { + return nil, nil, err + } + if err := applyPlanAction(d, action, ruleName, fullRowsToUpsert, fullRowsToDelete, task.planRuleMatches); err != nil { + return nil, nil, err + } + } + } + + return fullRowsToUpsert, fullRowsToDelete, nil +} + +func buildPlanDiffRows(task *TableRepairTask, node1Name, node2Name string, diffs types.DiffByNodePair) ([]planDiffRow, error) { + node1Rows := diffs.Rows[node1Name] + node2Rows := diffs.Rows[node2Name] + + node1RowsByPKey := make(map[string]map[string]any) + node1PkMaps := make(map[string]map[string]any) + node1MetaByPKey := make(map[string]map[string]any) + for _, row := range node1Rows { + pkeyStr, err := utils.StringifyOrderedMapKey(row, task.Key) + if err != nil { + return nil, fmt.Errorf("error stringifying pkey for row on %s: %w", node1Name, err) + } + raw := utils.OrderedMapToMap(row) + meta := extractSpockMeta(raw) + cleanRow := utils.StripSpockMetadata(raw) + node1RowsByPKey[pkeyStr] = cleanRow + node1MetaByPKey[pkeyStr] = meta + pkMap, err := extractPkMap(task.Key, cleanRow) + if err != nil { + return nil, fmt.Errorf("node %s row %s: %w", node1Name, pkeyStr, err) + } + node1PkMaps[pkeyStr] = pkMap + } + + node2RowsByPKey := make(map[string]map[string]any) + node2PkMaps := make(map[string]map[string]any) + node2MetaByPKey := make(map[string]map[string]any) + for _, row := range node2Rows { + pkeyStr, err := utils.StringifyOrderedMapKey(row, task.Key) + if err != nil { + return nil, fmt.Errorf("error stringifying pkey for row on %s: %w", node2Name, err) + } + raw := utils.OrderedMapToMap(row) + meta := extractSpockMeta(raw) + cleanRow := utils.StripSpockMetadata(raw) + node2RowsByPKey[pkeyStr] = cleanRow + node2MetaByPKey[pkeyStr] = meta + pkMap, err := extractPkMap(task.Key, cleanRow) + if err != nil { + return nil, fmt.Errorf("node %s row %s: %w", node2Name, pkeyStr, err) + } + node2PkMaps[pkeyStr] = pkMap + } + + seen := make(map[string]struct{}) + var diffRows []planDiffRow + + for pkStr, row := range node1RowsByPKey { + diffRows = append(diffRows, planDiffRow{ + pkStr: pkStr, + pkMap: node1PkMaps[pkStr], + node1Name: node1Name, + node2Name: node2Name, + n1Row: row, + n2Row: node2RowsByPKey[pkStr], + n1Meta: node1MetaByPKey[pkStr], + n2Meta: node2MetaByPKey[pkStr], + diffType: func() string { + if _, ok := node2RowsByPKey[pkStr]; ok { + return "row_mismatch" + } + return "missing_on_n2" + }(), + columnsChanged: computeChangedColumns(row, node2RowsByPKey[pkStr]), + }) + seen[pkStr] = struct{}{} + } + + for pkStr, row := range node2RowsByPKey { + if _, ok := seen[pkStr]; ok { + continue + } + diffRows = append(diffRows, planDiffRow{ + pkStr: pkStr, + pkMap: node2PkMaps[pkStr], + node1Name: node1Name, + node2Name: node2Name, + n1Row: node1RowsByPKey[pkStr], + n2Row: row, + n1Meta: node1MetaByPKey[pkStr], + n2Meta: node2MetaByPKey[pkStr], + diffType: "missing_on_n1", + columnsChanged: computeChangedColumns(node1RowsByPKey[pkStr], row), + }) + } + + return diffRows, nil +} + +func extractPkMap(keys []string, row map[string]any) (map[string]any, error) { + pk := make(map[string]any) + for _, k := range keys { + val, ok := row[k] + if !ok { + return nil, fmt.Errorf("missing primary key column %s", k) + } + pk[k] = val + } + return pk, nil +} + +func computeChangedColumns(n1, n2 map[string]any) []string { + if n1 == nil || n2 == nil { + return nil + } + colSet := make(map[string]struct{}) + for k := range n1 { + colSet[k] = struct{}{} + } + for k := range n2 { + colSet[k] = struct{}{} + } + + var changed []string + for col := range colSet { + if !reflect.DeepEqual(n1[col], n2[col]) { + changed = append(changed, col) + } + } + return changed +} + +func resolvePlanAction(task *TableRepairTask, tablePlan planner.RepairTablePlan, row planDiffRow) (*planner.RepairPlanAction, string, error) { + // Row overrides + for idx, override := range tablePlan.RowOverrides { + if pkMatchesOverride(task.Key, row.pkMap, override.PK) { + name := override.Name + if strings.TrimSpace(name) == "" { + name = fmt.Sprintf("row_override_%d", idx) + } + if err := validateActionCompatibility(&override.Action, row); err != nil { + return nil, "", err + } + return &override.Action, name, nil + } + } + + // Rules in order + for idx, rule := range tablePlan.Rules { + if !matchPKIn(rule.PKIn, row.pkMap, task.Key, task.SimplePrimaryKey) { + continue + } + if len(rule.DiffTypes) > 0 && !containsString(rule.DiffTypes, row.diffType) { + continue + } + if len(rule.ColumnsChanged) > 0 && !columnsIntersect(rule.ColumnsChanged, row.columnsChanged) { + continue + } + if rule.CompiledWhen() != nil { + ok, err := rule.CompiledWhen().Eval(func(source, column string) (any, bool) { + return lookupValue(source, column, row) + }) + if err != nil { + return nil, "", err + } + if !ok { + continue + } + } + if err := validateActionCompatibility(&rule.Action, row); err != nil { + return nil, "", err + } + name := rule.Name + if strings.TrimSpace(name) == "" { + name = fmt.Sprintf("rule_%d", idx) + } + return &rule.Action, name, nil + } + + if tablePlan.DefaultAction != nil { + if err := validateActionCompatibility(tablePlan.DefaultAction, row); err != nil { + return nil, "", err + } + return tablePlan.DefaultAction, "table_default", nil + } + if task.RepairPlan.DefaultAction != nil { + if err := validateActionCompatibility(task.RepairPlan.DefaultAction, row); err != nil { + return nil, "", err + } + return task.RepairPlan.DefaultAction, "global_default", nil + } + + return nil, "", fmt.Errorf("no default action configured in repair plan") +} + +func pkMatchesOverride(pkOrder []string, rowPk map[string]any, overridePk map[string]any) bool { + if len(overridePk) != len(pkOrder) { + return false + } + for _, col := range pkOrder { + if !reflect.DeepEqual(rowPk[col], overridePk[col]) { + return false + } + } + return true +} + +func matchPKIn(matchers []planner.RepairPKMatcher, rowPk map[string]any, pkOrder []string, simple bool) bool { + if len(matchers) == 0 { + return true + } + + for _, m := range matchers { + if simple { + val := rowPk[pkOrder[0]] + for _, eq := range m.Equals { + if reflect.DeepEqual(eq, val) { + return true + } + } + if m.Range != nil { + if compareRange(val, m.Range.From, m.Range.To) { + return true + } + } + } else { + // Composite PK: support equals with tuple order + for _, eq := range m.Equals { + tuple, ok := eq.([]any) + if !ok || len(tuple) != len(pkOrder) { + continue + } + all := true + for i, col := range pkOrder { + if !reflect.DeepEqual(rowPk[col], tuple[i]) { + all = false + break + } + } + if all { + return true + } + } + } + } + return false +} + +func compareRange(val, from, to any) bool { + ln, lok := asFloat(val) + fn, fok := asFloat(from) + tn, tok := asFloat(to) + if lok && fok && tok { + return ln >= fn && ln <= tn + } + + ls, lsok := val.(string) + fs, fsok := from.(string) + ts, tsok := to.(string) + if lsok && fsok && tsok { + return ls >= fs && ls <= ts + } + return false +} + +func asFloat(v any) (float64, bool) { + switch n := v.(type) { + case int: + return float64(n), true + case int8: + return float64(n), true + case int16: + return float64(n), true + case int32: + return float64(n), true + case int64: + return float64(n), true + case uint: + return float64(n), true + case uint8: + return float64(n), true + case uint16: + return float64(n), true + case uint32: + return float64(n), true + case uint64: + return float64(n), true + case float32: + return float64(n), true + case float64: + return n, true + default: + return 0, false + } +} + +func columnsIntersect(a, b []string) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + set := make(map[string]struct{}) + for _, col := range b { + set[col] = struct{}{} + } + for _, col := range a { + if _, ok := set[col]; ok { + return true + } + } + return false +} + +func containsString(list []string, target string) bool { + for _, v := range list { + if v == target { + return true + } + } + return false +} + +func applyPlanAction(row planDiffRow, action *planner.RepairPlanAction, ruleName string, upserts, deletes map[string]map[string]map[string]any, matches map[string]map[string]string) error { + switch action.Type { + case planner.RepairActionKeepN1: + if row.n1Row == nil { + return fmt.Errorf("keep_n1 specified but row missing on %s (pk %s)", row.node1Name, row.pkStr) + } + addRow(upserts, row.node2Name, row.pkStr, copyMap(row.n1Row)) + recordMatch(matches, row.node2Name, row.pkStr, ruleName) + case planner.RepairActionKeepN2: + if row.n2Row == nil { + return fmt.Errorf("keep_n2 specified but row missing on %s (pk %s)", row.node2Name, row.pkStr) + } + addRow(upserts, row.node1Name, row.pkStr, copyMap(row.n2Row)) + recordMatch(matches, row.node1Name, row.pkStr, ruleName) + case planner.RepairActionApplyFrom: + from := strings.ToLower(strings.TrimSpace(action.From)) + mode := action.Mode + if mode == "" { + mode = planner.RepairApplyModeReplace + } + switch from { + case "n1": + if row.n1Row == nil { + return fmt.Errorf("apply_from n1 specified but row missing on %s (pk %s)", row.node1Name, row.pkStr) + } + if mode == planner.RepairApplyModeInsert && row.n2Row != nil { + return nil + } + addRow(upserts, row.node2Name, row.pkStr, copyMap(row.n1Row)) + recordMatch(matches, row.node2Name, row.pkStr, ruleName) + case "n2": + if row.n2Row == nil { + return fmt.Errorf("apply_from n2 specified but row missing on %s (pk %s)", row.node2Name, row.pkStr) + } + if mode == planner.RepairApplyModeInsert && row.n1Row != nil { + return nil + } + addRow(upserts, row.node1Name, row.pkStr, copyMap(row.n2Row)) + recordMatch(matches, row.node1Name, row.pkStr, ruleName) + default: + return fmt.Errorf("apply_from requires from to be n1 or n2 (pk %s)", row.pkStr) + } + case planner.RepairActionBidirectional: + if row.n1Row != nil { + addRow(upserts, row.node2Name, row.pkStr, copyMap(row.n1Row)) + recordMatch(matches, row.node2Name, row.pkStr, ruleName) + } + if row.n2Row != nil { + addRow(upserts, row.node1Name, row.pkStr, copyMap(row.n2Row)) + recordMatch(matches, row.node1Name, row.pkStr, ruleName) + } + case planner.RepairActionCustom: + customRow, err := buildCustomRow(action, row) + if err != nil { + return err + } + addRow(upserts, row.node1Name, row.pkStr, copyMap(customRow)) + addRow(upserts, row.node2Name, row.pkStr, copyMap(customRow)) + recordMatch(matches, row.node1Name, row.pkStr, ruleName) + recordMatch(matches, row.node2Name, row.pkStr, ruleName) + case planner.RepairActionSkip: + return nil + case planner.RepairActionDelete: + if row.n1Row != nil { + addRow(deletes, row.node1Name, row.pkStr, row.n1Row) + recordMatch(matches, row.node1Name, row.pkStr, ruleName) + } + if row.n2Row != nil { + addRow(deletes, row.node2Name, row.pkStr, row.n2Row) + recordMatch(matches, row.node2Name, row.pkStr, ruleName) + } + default: + return fmt.Errorf("action %s not supported in execution", action.Type) + } + return nil +} + +func addRow(target map[string]map[string]map[string]any, node, pk string, row map[string]any) { + if target[node] == nil { + target[node] = make(map[string]map[string]any) + } + target[node][pk] = row +} + +func recordMatch(matches map[string]map[string]string, node, pk, rule string) { + if matches == nil { + return + } + if matches[node] == nil { + matches[node] = make(map[string]string) + } + matches[node][pk] = rule +} + +func copyMap(src map[string]any) map[string]any { + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func extractSpockMeta(row map[string]any) map[string]any { + meta := make(map[string]any) + if rawMeta, ok := row["_spock_metadata_"].(map[string]any); ok { + for k, v := range rawMeta { + meta[k] = v + } + } + if val, ok := row["commit_ts"]; ok { + meta["commit_ts"] = val + } + if val, ok := row["node_origin"]; ok { + meta["node_origin"] = val + } + return meta +} + +func lookupValue(source, column string, row planDiffRow) (any, bool) { + switch source { + case "n1": + if row.n1Row != nil { + if v, ok := row.n1Row[column]; ok { + return v, true + } + } + if row.n1Meta != nil { + if v, ok := row.n1Meta[column]; ok { + return v, true + } + } + case "n2": + if row.n2Row != nil { + if v, ok := row.n2Row[column]; ok { + return v, true + } + } + if row.n2Meta != nil { + if v, ok := row.n2Meta[column]; ok { + return v, true + } + } + } + return nil, false +} + +func lookupTemplate(inner string, row planDiffRow) (any, bool) { + parts := strings.Split(inner, ".") + if len(parts) != 2 { + return nil, false + } + source := strings.ToLower(strings.TrimSpace(parts[0])) + col := strings.TrimSpace(parts[1]) + return lookupValue(source, col, row) +} + +func validateActionCompatibility(action *planner.RepairPlanAction, row planDiffRow) error { + switch action.Type { + case planner.RepairActionKeepN1: + if row.n1Row == nil { + return fmt.Errorf("keep_n1 used on row missing from n1 (pk %s)", row.pkStr) + } + case planner.RepairActionKeepN2: + if row.n2Row == nil { + return fmt.Errorf("keep_n2 used on row missing from n2 (pk %s)", row.pkStr) + } + case planner.RepairActionApplyFrom: + mode := action.Mode + if mode == planner.RepairApplyModeInsert && row.diffType == "row_mismatch" { + return fmt.Errorf("apply_from insert mode cannot be used on mismatched rows (pk %s)", row.pkStr) + } + if strings.TrimSpace(action.From) == "n1" && row.n1Row == nil { + return fmt.Errorf("apply_from n1 used on row missing from n1 (pk %s)", row.pkStr) + } + if strings.TrimSpace(action.From) == "n2" && row.n2Row == nil { + return fmt.Errorf("apply_from n2 used on row missing from n2 (pk %s)", row.pkStr) + } + case planner.RepairActionCustom: + // allow; buildCustomRow will enforce availability + case planner.RepairActionBidirectional, planner.RepairActionSkip, planner.RepairActionDelete: + return nil + default: + return fmt.Errorf("unsupported action type %s", action.Type) + } + return nil +} + +func buildCustomRow(action *planner.RepairPlanAction, row planDiffRow) (map[string]any, error) { + if action.CustomRow == nil && action.CustomHelpers == nil { + return nil, fmt.Errorf("custom action missing custom_row or helpers (pk %s)", row.pkStr) + } + + result := make(map[string]any) + // Start from explicit custom_row with template substitutions. + for k, v := range action.CustomRow { + result[k] = substituteValue(v, row) + } + + // Coalesce helper fills missing columns from priority order. + if action.CustomHelpers != nil && len(action.CustomHelpers.CoalescePriority) > 0 { + cols := unionKeys(row.n1Row, row.n2Row) + for col := range cols { + if _, exists := result[col]; exists { + continue + } + for _, src := range action.CustomHelpers.CoalescePriority { + switch src { + case "n1": + if row.n1Row != nil { + if val, ok := row.n1Row[col]; ok && val != nil { + result[col] = val + break + } + } + case "n2": + if row.n2Row != nil { + if val, ok := row.n2Row[col]; ok && val != nil { + result[col] = val + break + } + } + } + } + } + } + + // pick_freshest helper backfills remaining columns from the chosen side. + if action.CustomHelpers != nil && action.CustomHelpers.PickFreshest != nil { + pick := action.CustomHelpers.PickFreshest + winner := freshestSide(pick.Key, pick.Tie, row) + var source map[string]any + switch winner { + case "n1": + source = row.n1Row + case "n2": + source = row.n2Row + } + if source != nil { + for col, val := range source { + if _, exists := result[col]; !exists { + result[col] = val + } + } + } + } + + // Ensure PK columns are present using existing PK map. + for k, v := range row.pkMap { + if _, exists := result[k]; !exists { + result[k] = v + } + } + + return result, nil +} + +func substituteValue(val any, row planDiffRow) any { + s, ok := val.(string) + if !ok { + return val + } + str := strings.TrimSpace(s) + if strings.HasPrefix(str, "{{") && strings.HasSuffix(str, "}}") { + inner := strings.TrimSpace(strings.TrimSuffix(strings.TrimPrefix(str, "{{"), "}}")) + if val, ok := lookupTemplate(inner, row); ok { + return val + } + } + return val +} + +func unionKeys(maps ...map[string]any) map[string]struct{} { + out := make(map[string]struct{}) + for _, m := range maps { + for k := range m { + out[k] = struct{}{} + } + } + return out +} + +func freshestSide(key string, tie string, row planDiffRow) string { + val1 := row.n1Row + val2 := row.n2Row + var k1, k2 any + if val1 != nil { + k1 = val1[key] + } + if val2 != nil { + k2 = val2[key] + } + n1Ok := k1 != nil + n2Ok := k2 != nil + if !n1Ok && n2Ok { + return "n2" + } + if n1Ok && !n2Ok { + return "n1" + } + if !n1Ok && !n2Ok { + return strings.TrimSpace(strings.ToLower(tie)) + } + + if f1, ok := asFloat(k1); ok { + if f2, ok2 := asFloat(k2); ok2 { + if f1 > f2 { + return "n1" + } else if f2 > f1 { + return "n2" + } + return strings.TrimSpace(strings.ToLower(tie)) + } + } + + if s1, ok := k1.(string); ok { + if s2, ok2 := k2.(string); ok2 { + if s1 > s2 { + return "n1" + } else if s2 > s1 { + return "n2" + } + return strings.TrimSpace(strings.ToLower(tie)) + } + } + + // Try parsing timestamps if values are strings + if s1, ok := k1.(string); ok { + if s2, ok2 := k2.(string); ok2 { + if t1, err1 := time.Parse(time.RFC3339, s1); err1 == nil { + if t2, err2 := time.Parse(time.RFC3339, s2); err2 == nil { + if t1.After(t2) { + return "n1" + } else if t2.After(t1) { + return "n2" + } + return strings.TrimSpace(strings.ToLower(tie)) + } + } + } + } + + return strings.TrimSpace(strings.ToLower(tie)) +} + +func ensureActionCompatible(row planDiffRow, action *planner.RepairPlanAction) error { + switch row.diffType { + case "missing_on_n2": + if action.Type == planner.RepairActionKeepN2 || (action.Type == planner.RepairActionApplyFrom && strings.TrimSpace(strings.ToLower(action.From)) == "n2") { + return fmt.Errorf("action %s requires n2 row but diff is missing_on_n2 (pk %s)", action.Type, row.pkStr) + } + case "missing_on_n1": + if action.Type == planner.RepairActionKeepN1 || (action.Type == planner.RepairActionApplyFrom && strings.TrimSpace(strings.ToLower(action.From)) == "n1") { + return fmt.Errorf("action %s requires n1 row but diff is missing_on_n1 (pk %s)", action.Type, row.pkStr) + } + } + + if action.Type == planner.RepairActionApplyFrom && action.Mode == planner.RepairApplyModeInsert && row.diffType == "row_mismatch" { + return fmt.Errorf("apply_from insert mode cannot be used on mismatched rows (pk %s)", row.pkStr) + } + + if action.Type == planner.RepairActionCustom && action.CustomRow == nil && action.CustomHelpers == nil { + return fmt.Errorf("custom action missing custom_row/helpers (pk %s)", row.pkStr) + } + + return nil +} diff --git a/internal/consistency/repair/plan/parser/parser.go b/internal/consistency/repair/plan/parser/parser.go new file mode 100644 index 0000000..399c8a0 --- /dev/null +++ b/internal/consistency/repair/plan/parser/parser.go @@ -0,0 +1,702 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2025, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package parser + +import ( + "fmt" + "strconv" + "strings" + "unicode" +) + +// WhenExpr represents a compiled predicate over diff row values (n1./n2.). +type WhenExpr struct { + root exprNode +} + +// ValueProvider returns the value for a given source ("n1" or "n2") and column. +type ValueProvider func(source, column string) (any, bool) + +// Eval evaluates the predicate against the provided values. +func (w *WhenExpr) Eval(provider ValueProvider) (bool, error) { + if w == nil || w.root == nil { + return true, nil + } + val, err := w.root.eval(provider) + if err != nil { + return false, err + } + b, ok := val.(bool) + if !ok { + return false, fmt.Errorf("when expression did not produce a boolean result") + } + return b, nil +} + +// CompileWhenExpression parses a restricted predicate language: +// - Identifiers: n1.col, n2.col +// - Literals: strings ('x'), numbers, booleans, NULL +// - Operators: = != < <= > >=, IN (...), IS [NOT] NULL, AND, OR, NOT, parentheses. +func CompileWhenExpression(expr string) (*WhenExpr, error) { + lex := newLexer(expr) + tokens, err := lex.lex() + if err != nil { + return nil, err + } + p := &parser{tokens: tokens} + root, err := p.parseExpression() + if err != nil { + return nil, err + } + return &WhenExpr{root: root}, nil +} + +type tokenType int + +const ( + tokEOF tokenType = iota + tokIdent + tokString + tokNumber + tokBool + tokNull + tokAnd + tokOr + tokNot + tokIn + tokIs + tokLParen + tokRParen + tokComma + tokEq + tokNeq + tokLt + tokLte + tokGt + tokGte +) + +type token struct { + typ tokenType + lit string + pos int + value any +} + +type lexer struct { + input string + pos int +} + +func newLexer(input string) *lexer { + return &lexer{input: input} +} + +func (l *lexer) lex() ([]token, error) { + var toks []token + for { + l.skipSpace() + if l.pos >= len(l.input) { + toks = append(toks, token{typ: tokEOF, pos: l.pos}) + return toks, nil + } + ch := l.input[l.pos] + switch ch { + case '(': + toks = append(toks, token{typ: tokLParen, lit: "(", pos: l.pos}) + l.pos++ + case ')': + toks = append(toks, token{typ: tokRParen, lit: ")", pos: l.pos}) + l.pos++ + case ',': + toks = append(toks, token{typ: tokComma, lit: ",", pos: l.pos}) + l.pos++ + case '=': + toks = append(toks, token{typ: tokEq, lit: "=", pos: l.pos}) + l.pos++ + case '!': + if l.peek("=") { + toks = append(toks, token{typ: tokNeq, lit: "!=", pos: l.pos}) + l.pos += 2 + } else { + return nil, fmt.Errorf("unexpected '!' at pos %d", l.pos) + } + case '<': + if l.peek("=") { + toks = append(toks, token{typ: tokLte, lit: "<=", pos: l.pos}) + l.pos += 2 + } else { + toks = append(toks, token{typ: tokLt, lit: "<", pos: l.pos}) + l.pos++ + } + case '>': + if l.peek("=") { + toks = append(toks, token{typ: tokGte, lit: ">=", pos: l.pos}) + l.pos += 2 + } else { + toks = append(toks, token{typ: tokGt, lit: ">", pos: l.pos}) + l.pos++ + } + case '\'': + tok, err := l.scanString() + if err != nil { + return nil, err + } + toks = append(toks, tok) + default: + if unicode.IsDigit(rune(ch)) || ch == '-' { + tok, err := l.scanNumber() + if err != nil { + return nil, err + } + toks = append(toks, tok) + continue + } + if unicode.IsLetter(rune(ch)) || ch == '_' { + tok, err := l.scanIdent() + if err != nil { + return nil, err + } + toks = append(toks, tok) + continue + } + return nil, fmt.Errorf("unexpected character '%c' at pos %d", ch, l.pos) + } + } +} + +func (l *lexer) peek(next string) bool { + return strings.HasPrefix(l.input[l.pos:], next) +} + +func (l *lexer) skipSpace() { + for l.pos < len(l.input) && unicode.IsSpace(rune(l.input[l.pos])) { + l.pos++ + } +} + +func (l *lexer) scanString() (token, error) { + start := l.pos + l.pos++ // skip opening quote + var sb strings.Builder + for l.pos < len(l.input) { + ch := l.input[l.pos] + if ch == '\'' { + if l.pos+1 < len(l.input) && l.input[l.pos+1] == '\'' { + sb.WriteByte('\'') + l.pos += 2 + continue + } + l.pos++ + return token{typ: tokString, lit: l.input[start:l.pos], pos: start, value: sb.String()}, nil + } + sb.WriteByte(ch) + l.pos++ + } + return token{}, fmt.Errorf("unterminated string starting at pos %d", start) +} + +func (l *lexer) scanNumber() (token, error) { + start := l.pos + hasDot := false + if l.input[l.pos] == '-' { + l.pos++ + } + for l.pos < len(l.input) { + ch := l.input[l.pos] + if ch == '.' { + if hasDot { + break + } + hasDot = true + l.pos++ + continue + } + if !unicode.IsDigit(rune(ch)) { + break + } + l.pos++ + } + text := l.input[start:l.pos] + num, err := strconv.ParseFloat(text, 64) + if err != nil { + return token{}, fmt.Errorf("invalid number %q at pos %d", text, start) + } + return token{typ: tokNumber, lit: text, pos: start, value: num}, nil +} + +func (l *lexer) scanIdent() (token, error) { + start := l.pos + for l.pos < len(l.input) { + ch := l.input[l.pos] + if unicode.IsLetter(rune(ch)) || unicode.IsDigit(rune(ch)) || ch == '_' || ch == '.' { + l.pos++ + continue + } + break + } + text := l.input[start:l.pos] + low := strings.ToLower(text) + switch low { + case "and": + return token{typ: tokAnd, lit: text, pos: start}, nil + case "or": + return token{typ: tokOr, lit: text, pos: start}, nil + case "not": + return token{typ: tokNot, lit: text, pos: start}, nil + case "in": + return token{typ: tokIn, lit: text, pos: start}, nil + case "is": + return token{typ: tokIs, lit: text, pos: start}, nil + case "true", "false": + val := low == "true" + return token{typ: tokBool, lit: text, pos: start, value: val}, nil + case "null": + return token{typ: tokNull, lit: text, pos: start, value: nil}, nil + default: + return token{typ: tokIdent, lit: text, pos: start}, nil + } +} + +type parser struct { + tokens []token + pos int +} + +func (p *parser) current() token { + if p.pos >= len(p.tokens) { + return token{typ: tokEOF, pos: p.pos} + } + return p.tokens[p.pos] +} + +func (p *parser) advance() { + if p.pos < len(p.tokens) { + p.pos++ + } +} + +func (p *parser) expect(tt tokenType) (token, error) { + tok := p.current() + if tok.typ != tt { + return tok, fmt.Errorf("unexpected token %q at pos %d, expected %v", tok.lit, tok.pos, tt) + } + p.advance() + return tok, nil +} + +func (p *parser) parseExpression() (exprNode, error) { + node, err := p.parseOr() + if err != nil { + return nil, err + } + if p.current().typ != tokEOF { + return nil, fmt.Errorf("unexpected trailing token %q at pos %d", p.current().lit, p.current().pos) + } + return node, nil +} + +func (p *parser) parseOr() (exprNode, error) { + left, err := p.parseAnd() + if err != nil { + return nil, err + } + for p.current().typ == tokOr { + p.advance() + right, err := p.parseAnd() + if err != nil { + return nil, err + } + left = &logicalNode{op: tokOr, left: left, right: right} + } + return left, nil +} + +func (p *parser) parseAnd() (exprNode, error) { + left, err := p.parseUnary() + if err != nil { + return nil, err + } + for p.current().typ == tokAnd { + p.advance() + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = &logicalNode{op: tokAnd, left: left, right: right} + } + return left, nil +} + +func (p *parser) parseUnary() (exprNode, error) { + if p.current().typ == tokNot { + p.advance() + child, err := p.parseUnary() + if err != nil { + return nil, err + } + return ¬Node{child: child}, nil + } + return p.parsePrimary() +} + +func (p *parser) parsePrimary() (exprNode, error) { + if p.current().typ == tokLParen { + p.advance() + node, err := p.parseOr() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRParen); err != nil { + return nil, err + } + return node, nil + } + return p.parseComparison() +} + +func (p *parser) parseComparison() (exprNode, error) { + left, err := p.parseValue() + if err != nil { + return nil, err + } + + tok := p.current() + switch tok.typ { + case tokEq, tokNeq, tokLt, tokLte, tokGt, tokGte: + p.advance() + right, err := p.parseValue() + if err != nil { + return nil, err + } + return &comparisonNode{op: tok.typ, left: left, right: right}, nil + case tokIn: + p.advance() + if _, err := p.expect(tokLParen); err != nil { + return nil, err + } + var values []exprNode + for { + val, err := p.parseValue() + if err != nil { + return nil, err + } + values = append(values, val) + if p.current().typ == tokComma { + p.advance() + continue + } + if p.current().typ == tokRParen { + p.advance() + break + } + return nil, fmt.Errorf("expected ',' or ')' at pos %d", p.current().pos) + } + return &inNode{value: left, set: values}, nil + case tokIs: + p.advance() + negate := false + if p.current().typ == tokNot { + negate = true + p.advance() + } + if _, err := p.expect(tokNull); err != nil { + return nil, err + } + return &isNullNode{value: left, negate: negate}, nil + default: + return nil, fmt.Errorf("expected comparison operator after value at pos %d", tok.pos) + } +} + +func (p *parser) parseValue() (exprNode, error) { + tok := p.current() + switch tok.typ { + case tokIdent: + p.advance() + source, col, err := splitIdent(tok.lit) + if err != nil { + return nil, fmt.Errorf("invalid identifier %q at pos %d: %w", tok.lit, tok.pos, err) + } + return &identNode{source: source, column: col}, nil + case tokString, tokNumber, tokBool, tokNull: + p.advance() + return &literalNode{value: tok.value}, nil + case tokLParen: + // allow nested parentheses via primary path + return p.parsePrimary() + default: + return nil, fmt.Errorf("unexpected token %q at pos %d", tok.lit, tok.pos) + } +} + +func splitIdent(lit string) (string, string, error) { + parts := strings.Split(lit, ".") + if len(parts) != 2 { + return "", "", fmt.Errorf("identifier must be n1. or n2.") + } + source := strings.ToLower(strings.TrimSpace(parts[0])) + if source != "n1" && source != "n2" { + return "", "", fmt.Errorf("identifier prefix must be n1 or n2") + } + col := strings.TrimSpace(parts[1]) + if col == "" { + return "", "", fmt.Errorf("column name cannot be empty") + } + return source, col, nil +} + +// AST nodes and evaluation +type exprNode interface { + eval(ValueProvider) (any, error) +} + +type identNode struct { + source string + column string +} + +func (n *identNode) eval(provider ValueProvider) (any, error) { + if provider == nil { + return nil, fmt.Errorf("no value provider available") + } + val, _ := provider(n.source, n.column) + return val, nil +} + +type literalNode struct { + value any +} + +func (n *literalNode) eval(_ ValueProvider) (any, error) { + return n.value, nil +} + +type logicalNode struct { + op tokenType + left exprNode + right exprNode +} + +func (n *logicalNode) eval(provider ValueProvider) (any, error) { + lv, err := n.left.eval(provider) + if err != nil { + return nil, err + } + lb, ok := lv.(bool) + if !ok { + return nil, fmt.Errorf("left side of logical op is not boolean") + } + if n.op == tokAnd && !lb { + return false, nil + } + if n.op == tokOr && lb { + return true, nil + } + + rv, err := n.right.eval(provider) + if err != nil { + return nil, err + } + rb, ok := rv.(bool) + if !ok { + return nil, fmt.Errorf("right side of logical op is not boolean") + } + + switch n.op { + case tokAnd: + return lb && rb, nil + case tokOr: + return lb || rb, nil + default: + return nil, fmt.Errorf("unknown logical operator") + } +} + +type notNode struct { + child exprNode +} + +func (n *notNode) eval(provider ValueProvider) (any, error) { + val, err := n.child.eval(provider) + if err != nil { + return nil, err + } + b, ok := val.(bool) + if !ok { + return nil, fmt.Errorf("NOT operand is not boolean") + } + return !b, nil +} + +type comparisonNode struct { + op tokenType + left exprNode + right exprNode +} + +func (n *comparisonNode) eval(provider ValueProvider) (any, error) { + lv, err := n.left.eval(provider) + if err != nil { + return nil, err + } + rv, err := n.right.eval(provider) + if err != nil { + return nil, err + } + return compareValues(n.op, lv, rv) +} + +type inNode struct { + value exprNode + set []exprNode +} + +func (n *inNode) eval(provider ValueProvider) (any, error) { + val, err := n.value.eval(provider) + if err != nil { + return nil, err + } + for _, el := range n.set { + ev, err := el.eval(provider) + if err != nil { + return nil, err + } + ok, err := compareValues(tokEq, val, ev) + if err != nil { + return nil, err + } + if ok { + return true, nil + } + } + return false, nil +} + +type isNullNode struct { + value exprNode + negate bool +} + +func (n *isNullNode) eval(provider ValueProvider) (any, error) { + val, err := n.value.eval(provider) + if err != nil { + return nil, err + } + isNull := val == nil + if n.negate { + return !isNull, nil + } + return isNull, nil +} + +func compareValues(op tokenType, left any, right any) (bool, error) { + // nil handling + if left == nil || right == nil { + switch op { + case tokEq: + return left == nil && right == nil, nil + case tokNeq: + return !(left == nil && right == nil), nil + default: + return false, fmt.Errorf("cannot compare nil with operator %v", op) + } + } + + ln, lok := asNumber(left) + rn, rok := asNumber(right) + if lok && rok { + switch op { + case tokEq: + return ln == rn, nil + case tokNeq: + return ln != rn, nil + case tokLt: + return ln < rn, nil + case tokLte: + return ln <= rn, nil + case tokGt: + return ln > rn, nil + case tokGte: + return ln >= rn, nil + default: + return false, fmt.Errorf("unsupported numeric operator %v", op) + } + } + + ls, lsok := left.(string) + rs, rsok := right.(string) + if lsok && rsok { + switch op { + case tokEq: + return ls == rs, nil + case tokNeq: + return ls != rs, nil + case tokLt: + return ls < rs, nil + case tokLte: + return ls <= rs, nil + case tokGt: + return ls > rs, nil + case tokGte: + return ls >= rs, nil + default: + return false, fmt.Errorf("unsupported string operator %v", op) + } + } + + lb, lbok := left.(bool) + rb, rbok := right.(bool) + if lbok && rbok { + switch op { + case tokEq: + return lb == rb, nil + case tokNeq: + return lb != rb, nil + default: + return false, fmt.Errorf("unsupported boolean operator %v", op) + } + } + + return false, fmt.Errorf("cannot compare values of different or unsupported types (%T vs %T)", left, right) +} + +func asNumber(v any) (float64, bool) { + switch n := v.(type) { + case int: + return float64(n), true + case int8: + return float64(n), true + case int16: + return float64(n), true + case int32: + return float64(n), true + case int64: + return float64(n), true + case uint: + return float64(n), true + case uint8: + return float64(n), true + case uint16: + return float64(n), true + case uint32: + return float64(n), true + case uint64: + return float64(n), true + case float32: + return float64(n), true + case float64: + return n, true + default: + return 0, false + } +} diff --git a/internal/consistency/repair/plan/planner.go b/internal/consistency/repair/plan/planner.go new file mode 100644 index 0000000..655a4b4 --- /dev/null +++ b/internal/consistency/repair/plan/planner.go @@ -0,0 +1,351 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2025, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package planner + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/pgedge/ace/internal/consistency/repair/plan/parser" + "gopkg.in/yaml.v3" +) + +// Advanced repair-file schema notes: +// - Versioned YAML/JSON document describing fallbacks, ordered rules, and explicit row overrides. +// - Precedence: row_overrides > first matching rule > table.default_action > global default_action. +// - Selectors: pk_in (values or ranges for simple PKs), diff_type, columns_changed, when (restricted expr over n1./n2.). +// - Actions: keep_n1/keep_n2, apply_from {from: n1|n2, mode: replace|upsert|insert}, bidirectional, custom {custom_row or helpers}, +// skip, delete (optional). +// - Rules must declare at least one selector to avoid match-all accidents; tables can inherit the global default. +// +// This scaffolding parses and validates the file; execution wiring comes next. +const RepairPlanSchemaVersion = 1 + +type RepairActionType string + +const ( + RepairActionKeepN1 RepairActionType = "keep_n1" + RepairActionKeepN2 RepairActionType = "keep_n2" + RepairActionApplyFrom RepairActionType = "apply_from" + RepairActionBidirectional RepairActionType = "bidirectional" + RepairActionCustom RepairActionType = "custom" + RepairActionSkip RepairActionType = "skip" + RepairActionDelete RepairActionType = "delete" +) + +type RepairApplyMode string + +const ( + RepairApplyModeReplace RepairApplyMode = "replace" + RepairApplyModeUpsert RepairApplyMode = "upsert" + RepairApplyModeInsert RepairApplyMode = "insert" +) + +// RepairPlanFile represents a versioned repair-file describing defaults, rules, and per-row overrides. +type RepairPlanFile struct { + Version int `json:"version" yaml:"version"` + DefaultAction *RepairPlanAction `json:"default_action,omitempty" yaml:"default_action,omitempty"` + Tables map[string]RepairTablePlan `json:"tables,omitempty" yaml:"tables,omitempty"` // key: qualified table name + Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty"` + Notes string `json:"notes,omitempty" yaml:"notes,omitempty"` +} + +type RepairTablePlan struct { + DefaultAction *RepairPlanAction `json:"default_action,omitempty" yaml:"default_action,omitempty"` + Rules []RepairRule `json:"rules,omitempty" yaml:"rules,omitempty"` + RowOverrides []RepairRowOverride `json:"row_overrides,omitempty" yaml:"row_overrides,omitempty"` +} + +// RepairRule is evaluated in order; the first matching rule wins. +type RepairRule struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + PKIn []RepairPKMatcher `json:"pk_in,omitempty" yaml:"pk_in,omitempty"` + DiffTypes []string `json:"diff_type,omitempty" yaml:"diff_type,omitempty"` + ColumnsChanged []string `json:"columns_changed,omitempty" yaml:"columns_changed,omitempty"` + When string `json:"when,omitempty" yaml:"when,omitempty"` // restricted expression over n1./n2. values + Action RepairPlanAction `json:"action" yaml:"action"` + + compiledWhen *parser.WhenExpr `json:"-" yaml:"-"` +} + +// RepairRowOverride matches an exact PK (map for composite keys) and wins before any rule. +type RepairRowOverride struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + PK map[string]any `json:"pk" yaml:"pk"` + Action RepairPlanAction `json:"action" yaml:"action"` +} + +// RepairPKMatcher supports exact PKs and ranges (range only for simple PKs). +type RepairPKMatcher struct { + Equals []any `json:"equals,omitempty" yaml:"equals,omitempty"` // []any for simple PK, []any slice for composite PK values + Range *RepairPKRange `json:"range,omitempty" yaml:"range,omitempty"` +} + +type RepairPKRange struct { + From any `json:"from" yaml:"from"` + To any `json:"to" yaml:"to"` +} + +// RepairPlanAction captures the small, memorable action set. +type RepairPlanAction struct { + Type RepairActionType `json:"type" yaml:"type"` + From string `json:"from,omitempty" yaml:"from,omitempty"` // n1 or n2 for apply_from + Mode RepairApplyMode `json:"mode,omitempty" yaml:"mode,omitempty"` // replace|upsert|insert for apply_from + CustomRow map[string]any `json:"custom_row,omitempty" yaml:"custom_row,omitempty"` + CustomHelpers *CustomHelperSpec `json:"helpers,omitempty" yaml:"helpers,omitempty"` +} + +type CustomHelperSpec struct { + CoalescePriority []string `json:"coalesce_priority,omitempty" yaml:"coalesce_priority,omitempty"` // e.g., ["n1","n2"] + PickFreshest *PickFreshestSpec `json:"pick_freshest,omitempty" yaml:"pick_freshest,omitempty"` +} + +type PickFreshestSpec struct { + Key string `json:"key" yaml:"key"` // column name used for freshness comparison + Tie string `json:"tie,omitempty" yaml:"tie"` // n1 or n2 +} + +// LoadRepairPlanFile parses YAML or JSON into a RepairPlanFile and performs lightweight validation. +func LoadRepairPlanFile(planPath string) (*RepairPlanFile, error) { + raw, err := os.ReadFile(planPath) + if err != nil { + return nil, fmt.Errorf("read repair plan %s: %w", planPath, err) + } + + plan := &RepairPlanFile{} + + // Try YAML first (superset), then JSON for clearer errors when input is strictly JSON. + yamlErr := yaml.Unmarshal(raw, plan) + if yamlErr != nil { + if jsonErr := json.Unmarshal(raw, plan); jsonErr != nil { + return nil, fmt.Errorf("parse repair plan %s as yaml (%v) or json (%v)", planPath, yamlErr, jsonErr) + } + } + + if plan.Version == 0 { + plan.Version = RepairPlanSchemaVersion + } + + if err := plan.Validate(); err != nil { + return nil, fmt.Errorf("invalid repair plan %s: %w", planPath, err) + } + + return plan, nil +} + +// Validate performs static checks. It intentionally keeps defaults permissive; execution-time checks will enforce column existence. +func (plan *RepairPlanFile) Validate() error { + if plan == nil { + return fmt.Errorf("repair plan is nil") + } + if plan.Version <= 0 { + return fmt.Errorf("version must be greater than zero") + } + if plan.DefaultAction != nil { + if err := validateRepairAction(plan.DefaultAction, "global", "default_action"); err != nil { + return err + } + } + if len(plan.Tables) == 0 { + return fmt.Errorf("at least one table entry is required") + } + + validDiffTypes := map[string]struct{}{ + "row_mismatch": {}, + "missing_on_n1": {}, + "missing_on_n2": {}, + "deleted_on_n1": {}, + "deleted_on_n2": {}, + } + + for tableName, tablePlan := range plan.Tables { + if tablePlan.DefaultAction != nil { + if err := validateRepairAction(tablePlan.DefaultAction, tableName, "table.default_action"); err != nil { + return err + } + } + + for i := range tablePlan.RowOverrides { + override := &tablePlan.RowOverrides[i] + if len(override.PK) == 0 { + return fmt.Errorf("table %s row_override[%d] must provide pk map", tableName, i) + } + if err := validateRepairAction(&override.Action, tableName, fmt.Sprintf("row_override[%d]", i)); err != nil { + return err + } + } + + for i := range tablePlan.Rules { + rule := &tablePlan.Rules[i] + if !ruleHasSelector(rule) { + return fmt.Errorf("table %s rule[%d]%s must specify at least one selector (pk_in, diff_type, columns_changed, when)", + tableName, i, ruleLabel(rule)) + } + if err := validatePKMatchers(rule.PKIn, tableName, ruleLabel(rule)); err != nil { + return err + } + if err := validateDiffTypes(rule.DiffTypes, validDiffTypes, tableName, ruleLabel(rule)); err != nil { + return err + } + if err := validateRepairAction(&rule.Action, tableName, fmt.Sprintf("rule%s", ruleLabel(rule))); err != nil { + return err + } + if len(rule.DiffTypes) > 0 { + if err := validateActionCompatibility(&rule.Action, rule.DiffTypes, tableName, ruleLabel(rule)); err != nil { + return err + } + } + if strings.TrimSpace(rule.When) != "" { + expr, err := parser.CompileWhenExpression(rule.When) + if err != nil { + return fmt.Errorf("table %s rule%s: invalid when expression: %w", tableName, ruleLabel(rule), err) + } + rule.compiledWhen = expr + } + } + } + + return nil +} + +func ruleLabel(rule *RepairRule) string { + if rule == nil || strings.TrimSpace(rule.Name) == "" { + return "" + } + return fmt.Sprintf(" (%s)", strings.TrimSpace(rule.Name)) +} + +func ruleHasSelector(rule *RepairRule) bool { + if rule == nil { + return false + } + return len(rule.PKIn) > 0 || + len(rule.DiffTypes) > 0 || + len(rule.ColumnsChanged) > 0 || + strings.TrimSpace(rule.When) != "" +} + +// CompiledWhen returns the parsed predicate, if present. +func (r *RepairRule) CompiledWhen() *parser.WhenExpr { + return r.compiledWhen +} + +func validatePKMatchers(matchers []RepairPKMatcher, tableName, label string) error { + for i := range matchers { + matcher := matchers[i] + if len(matcher.Equals) == 0 && matcher.Range == nil { + return fmt.Errorf("table %s rule%s pk_in[%d] must specify equals or range", tableName, label, i) + } + if matcher.Range != nil { + if matcher.Range.From == nil || matcher.Range.To == nil { + return fmt.Errorf("table %s rule%s pk_in[%d] range requires both from and to", tableName, label, i) + } + } + } + return nil +} + +func validateDiffTypes(diffTypes []string, allowed map[string]struct{}, tableName, label string) error { + for i, dt := range diffTypes { + dt = strings.TrimSpace(dt) + if dt == "" { + return fmt.Errorf("table %s rule%s diff_type[%d] cannot be empty", tableName, label, i) + } + if _, ok := allowed[dt]; !ok { + return fmt.Errorf("table %s rule%s diff_type[%d] must be one of %v", tableName, label, i, keys(allowed)) + } + } + return nil +} + +func keys(m map[string]struct{}) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} + +func validateRepairAction(action *RepairPlanAction, tableName, location string) error { + if action == nil { + return fmt.Errorf("table %s %s action cannot be nil", tableName, location) + } + + switch action.Type { + case RepairActionKeepN1, RepairActionKeepN2, RepairActionBidirectional, RepairActionSkip, RepairActionDelete: + return nil + case RepairActionApplyFrom: + if action.From == "" { + return fmt.Errorf("table %s %s: apply_from requires 'from' (n1 or n2)", tableName, location) + } + if action.Mode != "" && + action.Mode != RepairApplyModeReplace && + action.Mode != RepairApplyModeUpsert && + action.Mode != RepairApplyModeInsert { + return fmt.Errorf("table %s %s: apply_from mode must be one of replace|upsert|insert", tableName, location) + } + return nil + case RepairActionCustom: + if len(action.CustomRow) == 0 && action.CustomHelpers == nil { + return fmt.Errorf("table %s %s: custom requires custom_row or helpers", tableName, location) + } + if action.CustomHelpers != nil { + if len(action.CustomHelpers.CoalescePriority) > 0 { + for _, src := range action.CustomHelpers.CoalescePriority { + if src != "n1" && src != "n2" { + return fmt.Errorf("table %s %s: coalesce_priority entries must be n1 or n2", tableName, location) + } + } + } + if action.CustomHelpers.PickFreshest != nil { + if strings.TrimSpace(action.CustomHelpers.PickFreshest.Key) == "" { + return fmt.Errorf("table %s %s: pick_freshest requires key", tableName, location) + } + if tie := strings.TrimSpace(action.CustomHelpers.PickFreshest.Tie); tie != "" && tie != "n1" && tie != "n2" { + return fmt.Errorf("table %s %s: pick_freshest.tie must be n1 or n2", tableName, location) + } + } + } + return nil + default: + return fmt.Errorf("table %s %s: unsupported action type %q", tableName, location, action.Type) + } +} + +func validateActionCompatibility(action *RepairPlanAction, diffTypes []string, tableName, label string) error { + for _, dt := range diffTypes { + normalized := strings.TrimSpace(strings.ToLower(dt)) + switch normalized { + case "missing_on_n1", "deleted_on_n1": + if action.Type == RepairActionKeepN1 { + return fmt.Errorf("table %s rule%s: keep_n1 incompatible with diff_type %s", tableName, label, dt) + } + if action.Type == RepairActionApplyFrom && strings.TrimSpace(strings.ToLower(action.From)) == "n1" { + return fmt.Errorf("table %s rule%s: apply_from n1 incompatible with diff_type %s", tableName, label, dt) + } + case "missing_on_n2", "deleted_on_n2": + if action.Type == RepairActionKeepN2 { + return fmt.Errorf("table %s rule%s: keep_n2 incompatible with diff_type %s", tableName, label, dt) + } + if action.Type == RepairActionApplyFrom && strings.TrimSpace(strings.ToLower(action.From)) == "n2" { + return fmt.Errorf("table %s rule%s: apply_from n2 incompatible with diff_type %s", tableName, label, dt) + } + case "row_mismatch": + if action.Type == RepairActionApplyFrom && action.Mode == RepairApplyModeInsert { + return fmt.Errorf("table %s rule%s: apply_from insert incompatible with diff_type %s", tableName, label, dt) + } + } + } + return nil +} diff --git a/internal/core/table_repair.go b/internal/consistency/repair/table_repair.go similarity index 88% rename from internal/core/table_repair.go rename to internal/consistency/repair/table_repair.go index e566f94..2037971 100644 --- a/internal/core/table_repair.go +++ b/internal/consistency/repair/table_repair.go @@ -9,7 +9,7 @@ // // /////////////////////////////////////////////////////////////////////////// -package core +package repair import ( "context" @@ -25,10 +25,12 @@ import ( "time" "github.com/google/uuid" + "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" + "github.com/pgedge/ace/internal/consistency/repair/plan" + "github.com/pgedge/ace/internal/infra/db" utils "github.com/pgedge/ace/pkg/common" "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/taskstore" @@ -56,6 +58,9 @@ type TableRepairTask struct { DiffFilePath string SourceOfTruth string + RepairPlanPath string + RepairPlan *planner.RepairPlanFile + QuietMode bool DryRun bool InsertOnly bool @@ -64,6 +69,7 @@ type TableRepairTask struct { GenerateReport bool FixNulls bool // TBD Bidirectional bool + RecoveryMode bool InvokeMethod string // TBD ClientRole string // TBD @@ -78,6 +84,12 @@ type TableRepairTask struct { RawDiffs types.DiffOutput report *RepairReport + planRuleMatches map[string]map[string]string // populated when using repair plans + + autoSelectedSourceOfTruth string + autoSelectionFailedNode string + autoSelectionDetails map[string]map[string]string + Ctx context.Context } @@ -149,6 +161,8 @@ func (tr *TableRepairTask) checkRepairOptionsCompatibility() error { {tr.FixNulls && tr.InsertOnly, "insert_only and fix_nulls cannot be used together"}, {tr.FixNulls && tr.UpsertOnly, "upsert_only and fix_nulls cannot be used together"}, {tr.InsertOnly && tr.UpsertOnly, "insert_only and upsert_only cannot be used together"}, + {strings.TrimSpace(tr.RepairPlanPath) != "" && tr.FixNulls, "repair-file and fix_nulls cannot be used together"}, + {strings.TrimSpace(tr.RepairPlanPath) != "" && tr.Bidirectional, "repair-file and bidirectional cannot be used together"}, } for _, rule := range incompatibleOptions { @@ -160,6 +174,14 @@ func (tr *TableRepairTask) checkRepairOptionsCompatibility() error { } func (tr *TableRepairTask) checkIfSourceOfTruthIsNeeded() bool { + // Advanced repair plans can encode SOT choices per rule, so skip mandatory SoT when a plan is supplied. + if strings.TrimSpace(tr.RepairPlanPath) != "" || tr.RepairPlan != nil { + return false + } + if tr.RecoveryMode { + // in recovery mode we'll auto-select if missing + return false + } casesNotNeeded := []bool{ tr.FixNulls, tr.Bidirectional && tr.InsertOnly, @@ -267,6 +289,15 @@ func (t *TableRepairTask) ValidateAndPrepare() error { if t.RawDiffs.Summary.TableFilter != "" { logger.Info("Diff file was generated with table filter: %s", t.RawDiffs.Summary.TableFilter) } + if strings.TrimSpace(t.RawDiffs.Summary.OnlyOrigin) != "" && !t.RecoveryMode { + return fmt.Errorf("diff file indicates origin-only comparison; re-run table-repair with --recovery-mode or provide an explicit source_of_truth") + } + + if strings.TrimSpace(t.RepairPlanPath) != "" { + if err := t.loadRepairPlan(strings.TrimSpace(t.RepairPlanPath)); err != nil { + return err + } + } if t.RawDiffs.NodeDiffs == nil { return fmt.Errorf("invalid diff file format: missing 'diffs' field or it's not a map") @@ -298,6 +329,25 @@ func (t *TableRepairTask) ValidateAndPrepare() error { t.ClusterNodes = clusterNodes + if strings.TrimSpace(t.RawDiffs.Summary.OnlyOrigin) != "" && t.RecoveryMode && t.SourceOfTruth == "" { + failedNode := strings.TrimSpace(t.RawDiffs.Summary.OnlyOriginResolved) + if failedNode == "" { + failedNode = strings.TrimSpace(t.RawDiffs.Summary.OnlyOrigin) + } + if failedNode == "" { + return fmt.Errorf("recovery-mode requires failed node information in diff summary") + } + selected, details, err := t.autoSelectSourceOfTruth(failedNode, involvedNodeNames) + if err != nil { + return err + } + t.autoSelectedSourceOfTruth = selected + t.autoSelectionFailedNode = failedNode + t.autoSelectionDetails = details + t.SourceOfTruth = selected + logger.Info("table-repair: recovery-mode selected %s as source_of_truth (failed node: %s)", selected, failedNode) + } + // Repair needs these privileges. Perhaps we can pare this down depending // on the repair options, but for now we'll keep it as is. requiredPrivileges := types.UserPrivileges{ @@ -414,6 +464,21 @@ func (t *TableRepairTask) ValidateAndPrepare() error { return nil } +func (t *TableRepairTask) loadRepairPlan(planPath string) error { + plan, err := planner.LoadRepairPlanFile(planPath) + if err != nil { + return fmt.Errorf("load repair plan: %w", err) + } + + tableKey := fmt.Sprintf("%s.%s", t.Schema, t.Table) + if _, ok := plan.Tables[tableKey]; !ok { + return fmt.Errorf("repair plan %s does not include table %s", planPath, tableKey) + } + + t.RepairPlan = plan + return nil +} + func (t *TableRepairTask) initialiseReport() *RepairReport { report := &RepairReport{ Changes: make(map[string]any), @@ -430,18 +495,28 @@ func (t *TableRepairTask) initialiseReport() *RepairReport { report.Timestamp = now.Format("2006-01-02 15:04:05") + fmt.Sprintf(".%03d", now.Nanosecond()/1e6) report.SuppliedArgs = map[string]any{ - "cluster_name": t.ClusterName, - "diff_file_path": t.DiffFilePath, - "source_of_truth": t.SourceOfTruth, - "table_name": t.QualifiedTableName, - "dbname": t.DBName, - "dry_run": t.DryRun, - "quiet": t.QuietMode, - "insert_only": t.InsertOnly, - "upsert_only": t.UpsertOnly, - "fire_triggers": t.FireTriggers, - "generate_report": t.GenerateReport, - "bidirectional": t.Bidirectional, + "cluster_name": t.ClusterName, + "diff_file_path": t.DiffFilePath, + "repair_plan_path": t.RepairPlanPath, + "source_of_truth": t.SourceOfTruth, + "table_name": t.QualifiedTableName, + "dbname": t.DBName, + "dry_run": t.DryRun, + "quiet": t.QuietMode, + "insert_only": t.InsertOnly, + "upsert_only": t.UpsertOnly, + "fire_triggers": t.FireTriggers, + "generate_report": t.GenerateReport, + "bidirectional": t.Bidirectional, + "recovery_mode": t.RecoveryMode, + } + + if t.autoSelectedSourceOfTruth != "" { + report.Changes["auto_source_of_truth"] = map[string]any{ + "selected": t.autoSelectedSourceOfTruth, + "failed_node": t.autoSelectionFailedNode, + "lsn_probe": t.autoSelectionDetails, + } } dbInfoForReport := t.Database @@ -516,15 +591,16 @@ func (t *TableRepairTask) Run(skipValidation bool) (err error) { } ctx := map[string]any{ - "qualified_table": t.QualifiedTableName, - "diff_file": t.DiffFilePath, - "source_of_truth": t.SourceOfTruth, - "dry_run": t.DryRun, - "insert_only": t.InsertOnly, - "upsert_only": t.UpsertOnly, - "fire_triggers": t.FireTriggers, - "bidirectional": t.Bidirectional, - "generate_report": t.GenerateReport, + "qualified_table": t.QualifiedTableName, + "diff_file": t.DiffFilePath, + "repair_plan_path": t.RepairPlanPath, + "source_of_truth": t.SourceOfTruth, + "dry_run": t.DryRun, + "insert_only": t.InsertOnly, + "upsert_only": t.UpsertOnly, + "fire_triggers": t.FireTriggers, + "bidirectional": t.Bidirectional, + "generate_report": t.GenerateReport, } record := taskstore.Record{ @@ -1194,6 +1270,8 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { * most users are on spock 4.0 or above. */ + skipDeletes := (t.UpsertOnly || t.InsertOnly) && t.RepairPlan == nil + for nodeName := range divergentNodes { logger.Info("Processing repairs for divergent node: %s", nodeName) divergentPool, ok := t.Pools[nodeName] @@ -1265,7 +1343,7 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { // TODO: DROP PRIVILEGES HERE! // Process deletes first - if !t.UpsertOnly && !t.InsertOnly { + if !skipDeletes { nodeDeletes := fullDeletes[nodeName] if len(nodeDeletes) > 0 { deletedCount, err := executeDeletes(t.Ctx, tx, t, nodeDeletes) @@ -1287,6 +1365,9 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { rows = append(rows, row) } t.report.Changes[nodeName].(map[string]any)["deleted_rows"] = rows + if t.RepairPlan != nil && len(t.planRuleMatches[nodeName]) > 0 { + t.report.Changes[nodeName].(map[string]any)["rule_matches"] = t.planRuleMatches[nodeName] + } } } } @@ -1335,11 +1416,10 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { for _, row := range nodeUpserts { rows = append(rows, row) } - changeType := "upserted_rows" - if t.InsertOnly { - changeType = "inserted_rows" + t.report.Changes[nodeName].(map[string]any)["upserted_rows"] = rows + if t.RepairPlan != nil && len(t.planRuleMatches[nodeName]) > 0 { + t.report.Changes[nodeName].(map[string]any)["rule_matches"] = t.planRuleMatches[nodeName] } - t.report.Changes[nodeName].(map[string]any)[changeType] = rows } } @@ -1956,7 +2036,11 @@ func getDryRunOutput(task *TableRepairTask) (string, error) { } if len(fullUpserts) == 0 && len(fullDeletes) == 0 { - sb.WriteString(" All nodes are in sync with the source of truth. No repairs needed.\n") + if task.RepairPlan != nil { + sb.WriteString(" All nodes are in sync according to the repair plan. No repairs needed.\n") + } else { + sb.WriteString(" All nodes are in sync with the source of truth. No repairs needed.\n") + } } else { // To ensure a consistent output order for nodes var nodeNames []string @@ -1994,16 +2078,21 @@ func getDryRunOutput(task *TableRepairTask) (string, error) { for _, row := range deletes { rows = append(rows, row) } - if !task.UpsertOnly && !task.InsertOnly { + if task.RepairPlan != nil { + nodeChanges["would_delete"] = rows + } else if !task.UpsertOnly && !task.InsertOnly { nodeChanges["would_delete"] = rows } else { nodeChanges["skipped_deletes"] = rows } } + if task.RepairPlan != nil && len(task.planRuleMatches[nodeName]) > 0 { + nodeChanges["rule_matches"] = task.planRuleMatches[nodeName] + } task.report.Changes[nodeName] = nodeChanges } - if !task.UpsertOnly && !task.InsertOnly { + if task.RepairPlan != nil || (!task.UpsertOnly && !task.InsertOnly) { sb.WriteString(fmt.Sprintf(" Node %s: Would attempt to UPSERT %d rows and DELETE %d rows.\n", nodeName, len(upserts), len(deletes))) } else if task.InsertOnly { sb.WriteString(fmt.Sprintf(" Node %s: Would attempt to INSERT %d rows.\n", nodeName, len(upserts))) @@ -2016,6 +2105,17 @@ func getDryRunOutput(task *TableRepairTask) (string, error) { sb.WriteString(fmt.Sprintf(" Additionally, %d rows exist on %s that are not on %s (deletes skipped).\n", len(deletes), nodeName, task.SourceOfTruth)) } } + if task.RepairPlan != nil && len(task.planRuleMatches[nodeName]) > 0 { + ruleCounts := make(map[string]int) + for _, ruleName := range task.planRuleMatches[nodeName] { + ruleCounts[ruleName]++ + } + var parts []string + for rule, count := range ruleCounts { + parts = append(parts, fmt.Sprintf("%s=%d", rule, count)) + } + sb.WriteString(fmt.Sprintf(" Rule usage: %s\n", strings.Join(parts, ", "))) + } } } } @@ -2024,6 +2124,13 @@ func getDryRunOutput(task *TableRepairTask) (string, error) { } func calculateRepairSets(task *TableRepairTask) (map[string]map[string]map[string]any, map[string]map[string]map[string]any, error) { + if task.RepairPlan != nil { + return CalculatePlanRepairSets(task) + } + return calculateRepairSetsWithSourceOfTruth(task) +} + +func calculateRepairSetsWithSourceOfTruth(task *TableRepairTask) (map[string]map[string]map[string]any, map[string]map[string]map[string]any, error) { fullRowsToUpsert := make(map[string]map[string]map[string]any) // nodeName -> string(pkey) -> rowData fullRowsToDelete := make(map[string]map[string]map[string]any) // nodeName -> string(pkey) -> rowData @@ -2085,3 +2192,106 @@ func calculateRepairSets(task *TableRepairTask) (map[string]map[string]map[strin } return fullRowsToUpsert, fullRowsToDelete, nil } + +func (t *TableRepairTask) fetchLSNsForNode(pool *pgxpool.Pool, failedNode, survivor string) (originLSN *uint64, slotLSN *uint64, err error) { + var originStr *string + originStr, err = queries.GetSpockOriginLSNForNode(t.Ctx, pool, failedNode, survivor) + if err != nil { + return nil, nil, fmt.Errorf("failed to fetch origin lsn on %s: %w", survivor, err) + } + if originStr != nil { + if val, parseErr := pglogrepl.ParseLSN(*originStr); parseErr == nil { + tmp := uint64(val) + originLSN = &tmp + } + } + + var slotStr *string + slotStr, err = queries.GetSpockSlotLSNForNode(t.Ctx, pool, failedNode) + if err != nil { + return originLSN, nil, fmt.Errorf("failed to fetch slot lsn on %s: %w", survivor, err) + } + if slotStr != nil { + if val, parseErr := pglogrepl.ParseLSN(*slotStr); parseErr == nil { + tmp := uint64(val) + slotLSN = &tmp + } + } + + return originLSN, slotLSN, nil +} + +func formatLSN(val *uint64) string { + if val == nil { + return "" + } + return pglogrepl.LSN(*val).String() +} + +func (t *TableRepairTask) autoSelectSourceOfTruth(failedNode string, involved map[string]bool) (string, map[string]map[string]string, error) { + lsnDetails := make(map[string]map[string]string) + + type candidate struct { + node string + val uint64 + valType string + } + var best *candidate + + for _, nodeInfo := range t.ClusterNodes { + nodeName, _ := nodeInfo["Name"].(string) + if nodeName == "" || nodeName == failedNode { + continue + } + if len(involved) > 0 && !involved[nodeName] { + continue + } + + pool, err := auth.GetClusterNodeConnection(t.Ctx, nodeInfo, t.connOpts()) + if err != nil { + logger.Warn("recovery-mode: failed to connect to %s for LSN probe: %v", nodeName, err) + continue + } + originLSN, slotLSN, err := t.fetchLSNsForNode(pool, failedNode, nodeName) + if err != nil { + logger.Warn("recovery-mode: failed to fetch LSNs on %s: %v", nodeName, err) + pool.Close() + continue + } + + if t.Pools[nodeName] == nil { + t.Pools[nodeName] = pool + } else { + pool.Close() + } + + lsnDetails[nodeName] = map[string]string{ + "origin_lsn": formatLSN(originLSN), + "slot_lsn": formatLSN(slotLSN), + } + + var candidateVal uint64 + var candidateType string + if originLSN != nil { + candidateVal = *originLSN + candidateType = "origin" + } else if slotLSN != nil { + candidateVal = *slotLSN + candidateType = "slot" + } else { + continue + } + + if best == nil || candidateVal > best.val { + best = &candidate{node: nodeName, val: candidateVal, valType: candidateType} + } else if candidateVal == best.val { + return "", lsnDetails, fmt.Errorf("nodes %s and %s have identical %s LSNs; specify source_of_truth explicitly", best.node, nodeName, candidateType) + } + } + + if best == nil { + return "", lsnDetails, fmt.Errorf("unable to determine source_of_truth in recovery-mode: no LSNs available for failed node %s", failedNode) + } + + return best.node, lsnDetails, nil +} diff --git a/internal/cdc/listen.go b/internal/infra/cdc/listen.go similarity index 99% rename from internal/cdc/listen.go rename to internal/infra/cdc/listen.go index bad98db..cf5a2be 100644 --- a/internal/cdc/listen.go +++ b/internal/infra/cdc/listen.go @@ -24,7 +24,7 @@ import ( "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" + "github.com/pgedge/ace/internal/infra/db" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/logger" ) diff --git a/internal/cdc/setup.go b/internal/infra/cdc/setup.go similarity index 98% rename from internal/cdc/setup.go rename to internal/infra/cdc/setup.go index 225e849..997c41f 100644 --- a/internal/cdc/setup.go +++ b/internal/infra/cdc/setup.go @@ -17,7 +17,7 @@ import ( "github.com/jackc/pglogrepl" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/auth" + "github.com/pgedge/ace/internal/infra/db" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/logger" ) diff --git a/internal/auth/auth.go b/internal/infra/db/auth.go similarity index 100% rename from internal/auth/auth.go rename to internal/infra/db/auth.go diff --git a/internal/scheduler/config.go b/internal/jobs/config.go similarity index 98% rename from internal/scheduler/config.go rename to internal/jobs/config.go index 6d38b4d..488c798 100644 --- a/internal/scheduler/config.go +++ b/internal/jobs/config.go @@ -7,7 +7,7 @@ import ( "strings" "time" - "github.com/pgedge/ace/internal/core" + "github.com/pgedge/ace/internal/consistency/diff" "github.com/pgedge/ace/pkg/config" ) @@ -92,7 +92,7 @@ func buildTableDiffJob(cfg *config.Config, def config.JobDef, spec scheduleSpec) return Job{}, fmt.Errorf("table_name is required for table-diff jobs") } - base := core.NewTableDiffTask() + base := diff.NewTableDiffTask() base.ClusterName = selectCluster(cfg, def.ClusterName) base.QualifiedTableName = def.TableName base.DBName = stringArg(def.Args, "dbname") @@ -160,7 +160,7 @@ func buildSchemaDiffJob(cfg *config.Config, def config.JobDef, spec scheduleSpec return Job{}, fmt.Errorf("schema_name is required for schema-diff jobs") } - base := core.NewSchemaDiffTask() + base := diff.NewSchemaDiffTask() base.ClusterName = selectCluster(cfg, def.ClusterName) base.SchemaName = def.SchemaName base.DBName = stringArg(def.Args, "dbname") @@ -232,7 +232,7 @@ func buildRepsetDiffJob(cfg *config.Config, def config.JobDef, spec scheduleSpec return Job{}, fmt.Errorf("repset_name is required for repset-diff jobs") } - base := core.NewRepsetDiffTask() + base := diff.NewRepsetDiffTask() base.ClusterName = selectCluster(cfg, def.ClusterName) base.RepsetName = def.RepsetName base.DBName = stringArg(def.Args, "dbname") @@ -286,7 +286,7 @@ func buildRepsetDiffJob(cfg *config.Config, def config.JobDef, spec scheduleSpec if err := runTask.RunChecks(true); err != nil { return fmt.Errorf("checks failed: %w", err) } - if err := core.RepsetDiff(runTask); err != nil { + if err := diff.RepsetDiff(runTask); err != nil { return fmt.Errorf("execution failed: %w", err) } return nil diff --git a/internal/scheduler/scheduler.go b/internal/jobs/scheduler.go similarity index 100% rename from internal/scheduler/scheduler.go rename to internal/jobs/scheduler.go diff --git a/mkdocs.yml b/mkdocs.yml index 66058d3..ec75b74 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,7 @@ nav: - Scheduling ACE Runs: scheduling.md - The ACE API: api.md - Improving ACE Performance: performance.md + - Using ACE for Failure Recovery: docs/using-ace-for-cnf-recovery.md - Architecture and Design: - ACE table-diff Architecture: design/table_diff.md - Merkle Tree Architecture: design/merkle.md diff --git a/pkg/common/utils.go b/pkg/common/utils.go index d88ac97..4fb6d23 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -833,18 +833,37 @@ func AddSpockMetadata(row map[string]any) map[string]any { return nil } metadata := make(map[string]any) - if commitTs, ok := row["commit_ts"]; ok { - metadata["commit_ts"] = commitTs - delete(row, "commit_ts") - } if nodeOrigin, ok := row["node_origin"]; ok { metadata["node_origin"] = nodeOrigin delete(row, "node_origin") } + if commitTs, ok := row["commit_ts"]; ok { + metadata["commit_ts"] = commitTs + delete(row, "commit_ts") + } row["_spock_metadata_"] = metadata return row } +func TranslateNodeOrigin(raw any, nodeNames map[string]string) any { + if raw == nil { + return nil + } + origin := strings.TrimSpace(fmt.Sprintf("%v", raw)) + if origin == "" { + return nil + } + if origin == "0" { + return "local" + } + if nodeNames != nil { + if name, ok := nodeNames[origin]; ok { + return name + } + } + return raw +} + func StripSpockMetadata(row map[string]any) map[string]any { newRow := make(map[string]any) for k, v := range row { diff --git a/pkg/types/types.go b/pkg/types/types.go index 993ac2a..d638d75 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -113,12 +113,17 @@ type DiffSummary struct { StartTime string `json:"start_time"` EndTime string `json:"end_time"` TimeTaken string `json:"time_taken"` - DiffRowsCount map[string]int `json:"diff_rows_count"` // Key: "nodeA/nodeB", Value: count of differing rows + DiffRowsCount map[string]int `json:"diff_rows_count"` // Key: "nodeA/nodeB", Value: count of differing rows DiffRowLimitReached bool `json:"diff_row_limit_reached"` TotalRowsChecked int64 `json:"total_rows_checked"` // Estimated InitialRangesCount int `json:"initial_ranges_count"` MismatchedRangesCount int `json:"mismatched_ranges_count"` PrimaryKey []string `json:"primary_key"` + EffectiveFilter string `json:"effective_filter,omitempty"` + OnlyOrigin string `json:"only_origin,omitempty"` + OnlyOriginResolved string `json:"only_origin_resolved,omitempty"` + Until string `json:"until,omitempty"` + OriginOnly bool `json:"origin_only,omitempty"` } type KVPair struct { diff --git a/tests/integration/cdc_busy_table_test.go b/tests/integration/cdc_busy_table_test.go index 43005b8..e644888 100644 --- a/tests/integration/cdc_busy_table_test.go +++ b/tests/integration/cdc_busy_table_test.go @@ -22,7 +22,7 @@ import ( "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/cdc" + "github.com/pgedge/ace/internal/infra/cdc" "github.com/pgedge/ace/pkg/config" "github.com/stretchr/testify/require" ) diff --git a/tests/integration/crash_recovery_test.go b/tests/integration/crash_recovery_test.go new file mode 100644 index 0000000..5d958f6 --- /dev/null +++ b/tests/integration/crash_recovery_test.go @@ -0,0 +1,227 @@ +package integration + +import ( + "context" + "fmt" + "log" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/pgedge/ace/internal/consistency/diff" + "github.com/pgedge/ace/internal/consistency/repair" + "github.com/stretchr/testify/require" +) + +func TestTableDiffOnlyOriginWithUntil(t *testing.T) { + ctx := context.Background() + + // Light netem delay on n3 to mimic slower replication + if err := addNetemDelay(ctx, serviceN3, "200ms"); err != nil { + t.Logf("warning: failed to add netem delay on %s: %v (continuing)", serviceN3, err) + } + defer removeNetemDelay(ctx, serviceN3) + + tableName := "crash_customers" + qualified := fmt.Sprintf("%s.%s", testSchema, tableName) + + // Create table across nodes and add to default repset + for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool, pgCluster.Node3Pool} { + nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + createSQL := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (id INT PRIMARY KEY, payload TEXT);`, qualified) + if _, err := pool.Exec(ctx, createSQL); err != nil { + t.Fatalf("failed to create table on %s: %v", nodeName, err) + } + addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualified) + if _, err := pool.Exec(ctx, addToRepSetSQL); err != nil { + t.Fatalf("failed to add table to repset on %s: %v", nodeName, err) + } + } + t.Cleanup(func() { + for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool, pgCluster.Node3Pool} { + pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE;", qualified)) + } + }) + + // Insert a few rows on n3; record last commit ts (used for reference) + insertedIDs := []int{1001, 1002, 1003} + var lastCommitTs time.Time + for _, id := range insertedIDs { + _, err := pgCluster.Node3Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, payload) VALUES ($1, $2)", qualified), id, fmt.Sprintf("from n3 %d", id)) + if err != nil { + t.Fatalf("insert on n3 failed: %v", err) + } + // capture commit timestamp for this row + var ts time.Time + if err := pgCluster.Node3Pool.QueryRow(ctx, fmt.Sprintf("SELECT pg_xact_commit_timestamp(xmin) FROM %s WHERE id=$1", qualified), id).Scan(&ts); err != nil { + t.Fatalf("failed to read commit ts for id %d: %v", id, err) + } + if ts.After(lastCommitTs) { + lastCommitTs = ts + } + } + + // Allow replication to catch on n1 + assertEventually(t, 30*time.Second, func() error { + var count int + if err := pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = ANY($1)", qualified), insertedIDs).Scan(&count); err != nil { + return err + } + if count < len(insertedIDs) { + return fmt.Errorf("expected %d rows on n1, got %d", len(insertedIDs), count) + } + return nil + }) + + // Allow replication to catch on n2 + assertEventually(t, 30*time.Second, func() error { + var count int + if err := pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = ANY($1)", qualified), insertedIDs).Scan(&count); err != nil { + return err + } + if count < len(insertedIDs) { + return fmt.Errorf("expected %d rows on n2, got %d", len(insertedIDs), count) + } + return nil + }) + + // Simulate failed replication to n2 by deleting one row locally (repair_mode prevents replication) + tx, err := pgCluster.Node2Pool.Begin(ctx) + require.NoError(t, err, "begin tx on n2") + _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") + require.NoError(t, err, "enable repair_mode on n2") + _, err = tx.Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE id=$1", qualified), insertedIDs[0]) + require.NoError(t, err, "delete row on n2") + _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") + require.NoError(t, err, "disable repair_mode on n2") + require.NoError(t, tx.Commit(ctx), "commit delete on n2") + + // Stop n3 to mimic failure + if err := stopService(ctx, serviceN3); err != nil { + t.Fatalf("failed to stop %s: %v", serviceN3, err) + } + t.Cleanup(func() { + // best-effort restart to not break subsequent tests + if err := startService(ctx, serviceN3); err != nil { + t.Logf("cleanup: failed to restart %s: %v", serviceN3, err) + } + }) + + // Run table-diff on survivors focusing on origin n3 up to last commit ts + task := diff.NewTableDiffTask() + task.ClusterName = pgCluster.ClusterName + task.QualifiedTableName = qualified + task.DBName = dbName + task.Nodes = strings.Join([]string{serviceN1, serviceN2}, ",") + task.Output = "json" + task.BlockSize = 1000 + task.CompareUnitSize = 100 + task.ConcurrencyFactor = 1 + task.MaxDiffRows = 100 + task.OnlyOrigin = "n3" + fence := time.Now().Add(5 * time.Minute) + task.Until = fence.Format(time.RFC3339) + + require.NoError(t, task.Validate()) + require.NoError(t, task.RunChecks(false)) + err = task.ExecuteTask() + require.NoError(t, err) + + pairKey := serviceN1 + "/" + serviceN2 + if strings.Compare(serviceN1, serviceN2) > 0 { + pairKey = serviceN2 + "/" + serviceN1 + } + + nodeDiffs, ok := task.DiffResult.NodeDiffs[pairKey] + if !ok { + t.Fatalf("expected diffs for %s", pairKey) + } + if got := len(nodeDiffs.Rows[serviceN1]); got != 1 { + t.Fatalf("expected exactly 1 row present on %s (missing on %s) for origin n3, got %d", serviceN1, serviceN2, got) + } + if got := len(nodeDiffs.Rows[serviceN2]); got != 0 { + t.Fatalf("expected no rows present only on %s for origin n3 (they should be missing), got %d", serviceN2, got) + } + if val, _ := nodeDiffs.Rows[serviceN1][0].Get("id"); val != int32(insertedIDs[0]) { + t.Fatalf("expected missing id %d on %s, got %v", insertedIDs[0], serviceN2, val) + } + if metaVal, ok := nodeDiffs.Rows[serviceN1][0].Get("_spock_metadata_"); ok { + if m, ok := metaVal.(map[string]any); ok { + require.Contains(t, m, "node_origin", "spock metadata should include node_origin") + } + } + if task.DiffResult.Summary.OnlyOrigin == "" { + t.Fatalf("diff summary should record only_origin") + } + if got := task.DiffResult.Summary.Until; got == "" { + t.Fatalf("diff summary should record until cutoff") + } + log.Printf("only-origin diff detected %d row missing on %s (origin n3)", len(nodeDiffs.Rows[serviceN1]), serviceN2) + + // Run recovery-mode repair using the diff file and ensure n2 is healed + repairTask := repair.NewTableRepairTask() + repairTask.ClusterName = pgCluster.ClusterName + repairTask.QualifiedTableName = qualified + repairTask.DBName = dbName + repairTask.Nodes = strings.Join([]string{serviceN1, serviceN2}, ",") + repairTask.DiffFilePath = task.DiffFilePath + repairTask.RecoveryMode = true + repairTask.SourceOfTruth = serviceN1 // explicit SoT to avoid relying on LSN availability in test container setup + repairTask.Ctx = context.Background() + + require.NoError(t, repairTask.ValidateAndPrepare()) + require.NoError(t, repairTask.Run(true)) + + var countAfter int + require.NoError(t, pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id=$1", qualified), insertedIDs[0]).Scan(&countAfter)) + require.Equal(t, 1, countAfter, "recovery repair should restore missing row on n2") +} + +func addNetemDelay(ctx context.Context, service, delay string) error { + container, err := pgCluster.Cluster.ServiceContainer(ctx, service) + if err != nil { + return err + } + _, _, err = container.Exec(ctx, []string{"tc", "qdisc", "add", "dev", "eth0", "root", "netem", "delay", delay}) + return err +} + +func removeNetemDelay(ctx context.Context, service string) { + container, err := pgCluster.Cluster.ServiceContainer(ctx, service) + if err != nil { + return + } + container.Exec(ctx, []string{"tc", "qdisc", "del", "dev", "eth0", "root"}) +} + +func stopService(ctx context.Context, service string) error { + container, err := pgCluster.Cluster.ServiceContainer(ctx, service) + if err != nil { + return err + } + timeout := 5 * time.Second + return container.Stop(ctx, &timeout) +} + +func startService(ctx context.Context, service string) error { + container, err := pgCluster.Cluster.ServiceContainer(ctx, service) + if err != nil { + return err + } + return container.Start(ctx) +} + +func assertEventually(t *testing.T, timeout time.Duration, fn func() error) { + t.Helper() + deadline := time.Now().Add(timeout) + for { + if err := fn(); err == nil { + return + } + if time.Now().After(deadline) { + t.Fatalf("condition not met within %s", timeout) + } + time.Sleep(1 * time.Second) + } +} diff --git a/tests/integration/docker-compose.yaml b/tests/integration/docker-compose.yaml index 62a8247..94b88fe 100644 --- a/tests/integration/docker-compose.yaml +++ b/tests/integration/docker-compose.yaml @@ -38,6 +38,10 @@ configs: { "name": "n2", "hostname": "postgres-n2" + }, + { + "name": "n3", + "hostname": "postgres-n3" } ] } @@ -66,3 +70,15 @@ services: ports: - target: 5432 published: 6433 + postgres-n3: + image: pgedge/pgedge:pg17-latest + environment: + - NODE_NAME=n3 + configs: + - source: db-spec + target: /home/pgedge/db.json + gid: "1020" + uid: "1020" + ports: + - target: 5432 + published: 6434 diff --git a/tests/integration/helpers_test.go b/tests/integration/helpers_test.go index 114ffd8..45302c4 100644 --- a/tests/integration/helpers_test.go +++ b/tests/integration/helpers_test.go @@ -26,7 +26,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - "github.com/pgedge/ace/internal/core" + "github.com/pgedge/ace/internal/consistency/repair" "github.com/stretchr/testify/require" ) @@ -284,8 +284,8 @@ func loadDataFromCSV( return nil } -func newTestTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath string) *core.TableRepairTask { - task := core.NewTableRepairTask() +func newTestTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath string) *repair.TableRepairTask { + task := repair.NewTableRepairTask() task.ClusterName = "test_cluster" task.DBName = dbName task.SourceOfTruth = sourceOfTruthNode diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index b3b3736..d22f97b 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -15,6 +15,7 @@ import ( "context" "encoding/json" "fmt" + "io" "log" "os" "path/filepath" @@ -26,6 +27,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/types" + tcLog "github.com/testcontainers/testcontainers-go/log" "github.com/testcontainers/testcontainers-go/modules/compose" "github.com/testcontainers/testcontainers-go/wait" ) @@ -38,8 +40,10 @@ const ( dbName = "example_db" serviceN1 = "postgres-n1" serviceN2 = "postgres-n2" + serviceN3 = "postgres-n3" hostPortN1 = "6432" hostPortN2 = "6433" + hostPortN3 = "6434" containerPort = "5432/tcp" composeFilePath = "docker-compose.yaml" startupTimeout = 3 * time.Minute @@ -56,6 +60,9 @@ var pgCluster struct { Node2Host string Node2Port string Node2Pool *pgxpool.Pool + Node3Host string + Node3Port string + Node3Pool *pgxpool.Pool ClusterName string ClusterNodes []map[string]any } @@ -172,6 +179,44 @@ func setupPostgresCluster(t *testing.T) error { pgCluster.Node2Pool = poolN2 log.Printf("Successfully connected to %s", serviceN2) + n3Container, err := composeStack.ServiceContainer(context.Background(), serviceN3) + if err != nil { + return fmt.Errorf("failed to get container for service %s: %w", serviceN3, err) + } + hostN3, err := n3Container.Host(context.Background()) + if err != nil { + return fmt.Errorf("failed to get host for %s: %w", serviceN3, err) + } + cPortN3, err := nat.NewPort("tcp", strings.Split(containerPort, "/")[0]) + if err != nil { + return fmt.Errorf("failed to create nat.Port for %s: %w", serviceN3, err) + } + portN3Mapped, err := n3Container.MappedPort(context.Background(), cPortN3) + if err != nil { + return fmt.Errorf("failed to get mapped port for %s: %w", serviceN3, err) + } + pgCluster.Node3Host = hostN3 + pgCluster.Node3Port = portN3Mapped.Port() + log.Printf( + "Node 3 (%s) accessible at %s:%s", + serviceN3, + pgCluster.Node3Host, + pgCluster.Node3Port, + ) + + poolN3, err := connectToNode( + pgCluster.Node3Host, + pgCluster.Node3Port, + pgEdgeUser, + pgEdgePassword, + dbName, + ) + if err != nil { + return fmt.Errorf("failed to connect to %s: %w", serviceN3, err) + } + pgCluster.Node3Pool = poolN3 + log.Printf("Successfully connected to %s", serviceN3) + pgCluster.ClusterName = "test_cluster" pgCluster.ClusterNodes = []map[string]any{ { @@ -190,6 +235,14 @@ func setupPostgresCluster(t *testing.T) error { "DBPassword": pgEdgePassword, "DBName": dbName, }, + { + "Name": serviceN3, + "PublicIP": pgCluster.Node3Host, + "Port": float64(6434), + "DBUser": pgEdgeUser, + "DBPassword": pgEdgePassword, + "DBName": dbName, + }, } // Need this for using pg's 'digest' function @@ -200,6 +253,7 @@ func setupPostgresCluster(t *testing.T) error { }{ {serviceN1, pgCluster.Node1Pool}, {serviceN2, pgCluster.Node2Pool}, + {serviceN3, pgCluster.Node3Pool}, } for _, node := range poolsToConfigure { @@ -320,6 +374,7 @@ func TestMain(m *testing.M) { "TESTCONTAINERS_RYUK_DISABLED", "true", ) + tcLog.SetDefault(log.New(io.Discard, "", 0)) if err := config.Init("../../ace.yaml"); err != nil { log.Fatalf("Failed to load config: %v", err) @@ -388,7 +443,7 @@ func setupSharedCustomersTable(tableName string) error { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool, pgCluster.Node3Pool} { if err := createTestTable(ctx, pool, testSchema, tableName); err != nil { return fmt.Errorf("failed to create shared table %s: %w", qualifiedTableName, err) } diff --git a/tests/integration/merkle_tree_test.go b/tests/integration/merkle_tree_test.go index 7d869a9..18163a4 100644 --- a/tests/integration/merkle_tree_test.go +++ b/tests/integration/merkle_tree_test.go @@ -27,8 +27,8 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" - "github.com/pgedge/ace/internal/cdc" - "github.com/pgedge/ace/internal/core" + "github.com/pgedge/ace/internal/consistency/mtree" + "github.com/pgedge/ace/internal/infra/cdc" "github.com/pgedge/ace/pkg/config" "github.com/stretchr/testify/require" ) @@ -98,9 +98,9 @@ func runMerkleTreeTests(t *testing.T, tableName string) { }) } -func newTestMerkleTreeTask(t *testing.T, qualifiedTableName string, nodes []string) *core.MerkleTreeTask { +func newTestMerkleTreeTask(t *testing.T, qualifiedTableName string, nodes []string) *mtree.MerkleTreeTask { t.Helper() - task := core.NewMerkleTreeTask() + task := mtree.NewMerkleTreeTask() task.ClusterName = "test_cluster" task.DBName = dbName task.QualifiedTableName = qualifiedTableName diff --git a/tests/integration/table_diff_test.go b/tests/integration/table_diff_test.go index 39093c5..74adfa7 100644 --- a/tests/integration/table_diff_test.go +++ b/tests/integration/table_diff_test.go @@ -23,7 +23,7 @@ import ( "time" "github.com/jackc/pgx/v5/pgxpool" - "github.com/pgedge/ace/internal/core" + "github.com/pgedge/ace/internal/consistency/diff" "github.com/pgedge/ace/pkg/types" "github.com/stretchr/testify/require" ) @@ -32,8 +32,8 @@ func newTestTableDiffTask( t *testing.T, qualifiedTableName string, nodes []string, -) *core.TableDiffTask { - task := core.NewTableDiffTask() +) *diff.TableDiffTask { + task := diff.NewTableDiffTask() task.ClusterName = "test_cluster" task.DBName = dbName task.QualifiedTableName = qualifiedTableName @@ -112,6 +112,7 @@ func runCustomerTableDiffTests(t *testing.T) { t.Run("DataOnlyOnNode2", testTableDiff_DataOnlyOnNode2) t.Run("ModifiedRows", testTableDiff_ModifiedRows) t.Run("TableFiltering", testTableDiff_TableFiltering) + t.Run("TableFilterNoRows", testTableDiff_TableFilterNoRows) t.Run("MaxDiffRowsLimit", testTableDiff_MaxDiffRowsLimit) } @@ -1068,14 +1069,26 @@ func testTableDiff_TableFiltering(t *testing.T) { t.Errorf("Expected modified row index 3 not found in filtered diffs") } - viewName := tdTask.FilteredViewName - require.NotEmpty(t, viewName, "filtered view name should be recorded") - require.False(t, materializedViewExists(t, ctx, pgCluster.Node1Pool, testSchema, viewName), "Filtered materialized view should be dropped on %s", serviceN1) - require.False(t, materializedViewExists(t, ctx, pgCluster.Node2Pool, testSchema, viewName), "Filtered materialized view should be dropped on %s", serviceN2) + require.Empty(t, tdTask.FilteredViewName, "filtered view should no longer be created") + require.Equal(t, tdTask.TableFilter, tdTask.DiffResult.Summary.TableFilter, "diff summary should record the table filter used") log.Println("TestTableDiff_TableFiltering completed.") } +func testTableDiff_TableFilterNoRows(t *testing.T) { + tableName := "customers" + qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + nodesToCompare := []string{serviceN1, serviceN2} + + tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + tdTask.TableFilter = "index < 0" // no rows satisfy this + + require.NoError(t, tdTask.RunChecks(false)) + err := tdTask.ExecuteTask() + require.Error(t, err, "table-diff should fail when table filter produces no rows") + require.Contains(t, err.Error(), "table filter produced no rows") +} + func testTableDiff_MaxDiffRowsLimit(t *testing.T) { ctx := context.Background() tableName := "customers" @@ -1389,9 +1402,7 @@ func testTableDiff_WithSpockMetadata(t *testing.T, compositeKey bool) { t.Fatal("Expected '_spock_metadata_' key in the diff row for node1") } metaMapN1, _ := metaN1.(map[string]any) - if val, ok := metaMapN1["node_origin"]; ok && val != nil && fmt.Sprintf("%v", val) != "0" { - t.Errorf("Expected 'node_origin' to be 0 for local update on node1, but got %v", val) - } + require.Equal(t, "local", fmt.Sprintf("%v", metaMapN1["node_origin"]), "node_origin should translate to 'local' for node1") diffRowN2 := nodeDiffs.Rows[serviceN2][0] dataN2, _ := diffRowN2.Get("data") @@ -1403,8 +1414,10 @@ func testTableDiff_WithSpockMetadata(t *testing.T, compositeKey bool) { t.Fatal("Expected '_spock_metadata_' key in the diff row for node2") } metaMapN2, _ := metaN2.(map[string]any) - if val, ok := metaMapN2["node_origin"]; !ok || val == nil || val == "" { - t.Errorf("Expected 'node_origin' in spock metadata for node2 to have a valid value, but got %v", val) + require.NotEmpty(t, metaMapN2["node_origin"], "node_origin should be populated for replicated rows on node2") + translated := fmt.Sprintf("%v", metaMapN2["node_origin"]) + if translated != serviceN1 && translated != "n1" { + t.Fatalf("node_origin should map to the source node name (%s) or spock node name (n1); got %s", serviceN1, translated) } log.Println("TestTableDiff_WithSpockMetadata completed successfully.")