From 4eee7cebe784c8b890df4427d1c38d2b183bf094 Mon Sep 17 00:00:00 2001 From: carmysouschef Date: Mon, 19 May 2025 14:47:53 +0700 Subject: [PATCH 01/34] Dev/resharing (#1) * Add resharing functionality * Add resharing functionality with new session management and event handling * Refactor resharing session management to support old and new participants, enhance event handling, and update dependencies in go.mod * Enhance resharing session by introducing new resharing parameters, updating message handling for resharing, and improving logging for better traceability. * Refactor event consumer and session management to use specific resharing message handling methods, improving clarity and functionality in resharing processes. * Update wallet ID in resharing example and refactor session management by removing unused methods and improving unsubscribe handling for broadcast and direct subscriptions. * Refactor event consumer to improve resharing event handling by updating error channel references and enhancing logging for session completion. * Refactor event consumer and resharing session to improve context management and logging, enhancing clarity in session completion handling. * Enhance resharing functionality by introducing support for EDDSA key type, updating resharing session management, and improving message handling for resharing events. This includes the addition of new resharing session methods and error channel updates for better clarity and functionality. * Add IsReshared field to KeyInfo struct and update key generation and resharing sessions to utilize this field. Enhance session management by differentiating between key generation and resharing based on the IsReshared status. * Update wallet ID and key path in resharing and signing examples; remove unused session management method in event consumer. Enhance logging initialization in migration scripts. * Refactor key saving methods in resharing sessions to improve code clarity and error handling. Introduce SaveKeyData and SaveKeyInfo methods for streamlined key and key info storage, ensuring proper logging and error management. * Refactor topic and key formatting in ECDSA and EDDSA resharing sessions for improved clarity and consistency. Introduce constants for topic and key formats, enhancing maintainability and readability of the code. --------- Co-authored-by: carmy Co-authored-by: vietddude --- .gitignore | 2 + cmd/mpcium/main.go | 4 + docker-compose.yaml | 4 +- examples/generate/main.go | 8 +- examples/reshare/main.go | 49 ++++++ go.mod | 1 - go.sum | 2 - pkg/client/client.go | 72 ++++++-- pkg/eventconsumer/event_consumer.go | 191 ++++++++++++++++++--- pkg/eventconsumer/events.go | 2 +- pkg/identity/identity.go | 3 +- pkg/keyinfo/keyinfo.go | 1 + pkg/mpc/ecdsa_keygen_session.go | 1 + pkg/mpc/ecdsa_resharing_session.go | 203 +++++++++++++++++++++++ pkg/mpc/ecdsa_rounds.go | 73 ++++++-- pkg/mpc/eddsa_keygen_session.go | 1 + pkg/mpc/eddsa_resharing_session.go | 176 ++++++++++++++++++++ pkg/mpc/eddsa_rounds.go | 50 +++++- pkg/mpc/eddsa_signing_session.go | 3 +- pkg/mpc/node.go | 119 ++++++++++++- pkg/mpc/session.go | 161 ++++++++++++++++-- pkg/types/initiator_msg.go | 30 ++++ pkg/types/tss.go | 22 +++ scripts/migration/add-key-type/main.go | 2 +- scripts/migration/update-keyinfo/main.go | 2 +- setup_identities.sh | 59 +++++++ setup_initiator.sh | 41 +++++ 27 files changed, 1188 insertions(+), 94 deletions(-) create mode 100644 examples/reshare/main.go create mode 100644 pkg/mpc/ecdsa_resharing_session.go create mode 100644 pkg/mpc/eddsa_resharing_session.go create mode 100755 setup_identities.sh create mode 100755 setup_initiator.sh diff --git a/.gitignore b/.gitignore index 751dea4..6ee9767 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ identity/ event_initiator.identity.json event_initiator.key event_initiator.key.age +config.yaml +peers.json diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index cff7b70..33812ff 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -132,12 +132,15 @@ func runNode(ctx context.Context, c *cli.Command) error { mqManager := messaging.NewNATsMessageQueueManager("mpc", []string{ "mpc.mpc_keygen_success.*", event.SigningResultTopic, + "mpc.mpc_resharing_success.*", }, natsConn) genKeySuccessQueue := mqManager.NewMessageQueue("mpc_keygen_success") defer genKeySuccessQueue.Close() singingResultQueue := mqManager.NewMessageQueue("signing_result") defer singingResultQueue.Close() + resharingResultQueue := mqManager.NewMessageQueue("mpc_resharing_success") + defer resharingResultQueue.Close() logger.Info("Node is running", "peerID", nodeID, "name", nodeName) @@ -161,6 +164,7 @@ func runNode(ctx context.Context, c *cli.Command) error { pubsub, genKeySuccessQueue, singingResultQueue, + resharingResultQueue, identityStore, ) eventConsumer.Run() diff --git a/docker-compose.yaml b/docker-compose.yaml index 7aebc38..bdced5f 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -3,7 +3,7 @@ version: "3" services: nats-server: image: nats:latest - container_name: nats-server + container_name: nats-server-mpcium command: -js --http_port 8222 ports: - "4222:4222" @@ -14,7 +14,7 @@ services: consul: image: consul:1.15.4 - container_name: consul + container_name: consul-mpcium ports: - "8500:8500" - "8601:8600/udp" diff --git a/examples/generate/main.go b/examples/generate/main.go index fb004ed..ae3e6e9 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -7,20 +7,18 @@ import ( "syscall" "github.com/fystack/mpcium/pkg/client" - "github.com/fystack/mpcium/pkg/config" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/mpc" "github.com/google/uuid" "github.com/nats-io/nats.go" - "github.com/spf13/viper" ) func main() { const environment = "development" - config.InitViperConfig() + // config.InitViperConfig() logger.Init(environment, false) - natsURL := viper.GetString("nats.url") + natsURL := "nats://localhost:4222" natsConn, err := nats.Connect(natsURL) if err != nil { logger.Fatal("Failed to connect to NATS", err) @@ -30,7 +28,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "./event_initiator.key", + KeyPath: "/home/viet/Documents/other/mpcium/event_initiator.key", }) err = mpcClient.OnWalletCreationResult(func(event mpc.KeygenSuccessEvent) { logger.Info("Received wallet creation result", "event", event) diff --git a/examples/reshare/main.go b/examples/reshare/main.go new file mode 100644 index 0000000..4f33ea8 --- /dev/null +++ b/examples/reshare/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/mpc" + "github.com/nats-io/nats.go" +) + +func main() { + const environment = "development" + // config.InitViperConfig() + logger.Init(environment, false) + + natsURL := "nats://localhost:4222" + natsConn, err := nats.Connect(natsURL) + if err != nil { + logger.Fatal("Failed to connect to NATS", err) + } + defer natsConn.Drain() // drain inflight msgs + defer natsConn.Close() + + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + KeyPath: "/home/viet/Documents/other/mpcium/event_initiator.key", + }) + err = mpcClient.OnResharingResult(func(event mpc.ResharingSuccessEvent) { + logger.Info("Received resharing result", "event", event) + }) + if err != nil { + logger.Fatal("Failed to subscribe to resharing results", err) + } + + walletID := "892122fd-f2f4-46dc-be25-6fd0b83dff60" + if err := mpcClient.Resharing(walletID, 2); err != nil { + logger.Fatal("Resharing failed", err) + } + logger.Info("Resharing sent, awaiting result...", "walletID", walletID) + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + <-stop + + fmt.Println("Shutting down.") +} diff --git a/go.mod b/go.mod index fb45e42..8a4b6ed 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/spf13/viper v1.18.0 github.com/stretchr/testify v1.10.0 github.com/urfave/cli/v3 v3.3.2 - go.uber.org/mock v0.3.0 golang.org/x/term v0.31.0 ) diff --git a/go.sum b/go.sum index 3c448e3..cdd2f60 100644 --- a/go.sum +++ b/go.sum @@ -373,8 +373,6 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= -go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/pkg/client/client.go b/pkg/client/client.go index 6314158..285479d 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -22,6 +22,7 @@ import ( const ( GenerateWalletSuccessTopic = "mpc.mpc_keygen_success.*" // wildcard to listen to all success events + ResharingResultTopic = "mpc.mpc_resharing_success.*" ) type MPCClient interface { @@ -30,14 +31,18 @@ type MPCClient interface { SignTransaction(msg *types.SignTxMessage) error OnSignResult(callback func(event event.SigningResultEvent)) error + + Resharing(walletID string, newThreshold int, keyType types.KeyType) error + OnResharingResult(callback func(event mpc.ResharingSuccessEvent)) error } type mpcClient struct { - signingStream messaging.StreamPubsub - pubsub messaging.PubSub - genKeySuccessQueue messaging.MessageQueue - signResultQueue messaging.MessageQueue - privKey ed25519.PrivateKey + signingStream messaging.StreamPubsub + pubsub messaging.PubSub + genKeySuccessQueue messaging.MessageQueue + signResultQueue messaging.MessageQueue + resharingResultQueue messaging.MessageQueue + privKey ed25519.PrivateKey } // Options defines configuration options for creating a new MPCClient @@ -123,17 +128,20 @@ func NewMPCClient(opts Options) MPCClient { manager := messaging.NewNATsMessageQueueManager("mpc", []string{ "mpc.mpc_keygen_success.*", "mpc.signing_result.*", + "mpc.mpc_resharing_success.*", }, opts.NatsConn) genKeySuccessQueue := manager.NewMessageQueue("mpc_keygen_success") signResultQueue := manager.NewMessageQueue("signing_result") + resharingResultQueue := manager.NewMessageQueue("mpc_resharing_success") return &mpcClient{ - signingStream: signingStream, - pubsub: pubsub, - genKeySuccessQueue: genKeySuccessQueue, - signResultQueue: signResultQueue, - privKey: priv, + signingStream: signingStream, + pubsub: pubsub, + genKeySuccessQueue: genKeySuccessQueue, + signResultQueue: signResultQueue, + resharingResultQueue: resharingResultQueue, + privKey: priv, } } @@ -242,3 +250,47 @@ func (c *mpcClient) OnSignResult(callback func(event event.SigningResultEvent)) return nil } + +func (c *mpcClient) Resharing(walletID string, newThreshold int, keyType types.KeyType) error { + msg := &types.ResharingMessage{ + WalletID: walletID, + NewThreshold: newThreshold, + KeyType: keyType, + } + + // compute the canonical raw bytes + raw, err := msg.Raw() + if err != nil { + return fmt.Errorf("Resharing: raw payload error: %w", err) + } + // sign + msg.Signature = ed25519.Sign(c.privKey, raw) + + bytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("Resharing: marshal error: %w", err) + } + + if err := c.pubsub.Publish(eventconsumer.MPCResharingEvent, bytes); err != nil { + return fmt.Errorf("Resharing: publish error: %w", err) + } + return nil +} + +func (c *mpcClient) OnResharingResult(callback func(event mpc.ResharingSuccessEvent)) error { + err := c.resharingResultQueue.Dequeue(ResharingResultTopic, func(msg []byte) error { + var event mpc.ResharingSuccessEvent + err := json.Unmarshal(msg, &event) + if err != nil { + return err + } + callback(event) + return nil + }) + + if err != nil { + return fmt.Errorf("OnResharingResult: subscribe error: %w", err) + } + + return nil +} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 2f98210..8cb650f 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -21,8 +21,9 @@ import ( ) const ( - MPCGenerateEvent = "mpc:generate" - MPCSignEvent = "mpc:sign" + MPCGenerateEvent = "mpc:generate" + MPCSignEvent = "mpc:sign" + MPCResharingEvent = "mpc:reshare" ) type EventConsumer interface { @@ -35,11 +36,13 @@ type eventConsumer struct { pubsub messaging.PubSub mpcThreshold int - genKeySucecssQueue messaging.MessageQueue - signingResultQueue messaging.MessageQueue + genKeySucecssQueue messaging.MessageQueue + signingResultQueue messaging.MessageQueue + resharingResultQueue messaging.MessageQueue keyGenerationSub messaging.Subscription signingSub messaging.Subscription + resharingSub messaging.Subscription identityStore identity.Store // Track active sessions with timestamps for cleanup @@ -55,19 +58,21 @@ func NewEventConsumer( pubsub messaging.PubSub, genKeySucecssQueue messaging.MessageQueue, signingResultQueue messaging.MessageQueue, + resharingResultQueue messaging.MessageQueue, identityStore identity.Store, ) EventConsumer { ec := &eventConsumer{ - node: node, - pubsub: pubsub, - genKeySucecssQueue: genKeySucecssQueue, - signingResultQueue: signingResultQueue, - activeSessions: make(map[string]time.Time), - cleanupInterval: 5 * time.Minute, // Run cleanup every 5 minutes - sessionTimeout: 30 * time.Minute, // Consider sessions older than 30 minutes stale - cleanupStopChan: make(chan struct{}), - mpcThreshold: viper.GetInt("mpc_threshold"), - identityStore: identityStore, + node: node, + pubsub: pubsub, + genKeySucecssQueue: genKeySucecssQueue, + signingResultQueue: signingResultQueue, + resharingResultQueue: resharingResultQueue, + activeSessions: make(map[string]time.Time), + cleanupInterval: 5 * time.Minute, // Run cleanup every 5 minutes + sessionTimeout: 30 * time.Minute, // Consider sessions older than 30 minutes stale + cleanupStopChan: make(chan struct{}), + mpcThreshold: viper.GetInt("mpc_threshold"), + identityStore: identityStore, } // Start background cleanup goroutine @@ -87,6 +92,11 @@ func (ec *eventConsumer) Run() { log.Fatal("Failed to consume tx signing event", err) } + err = ec.consumeResharingEvent() + if err != nil { + log.Fatal("Failed to consume resharing event", err) + } + logger.Info("MPC Event consumer started...!") } @@ -362,6 +372,151 @@ func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkIntern } } +func (ec *eventConsumer) consumeResharingEvent() error { + sub, err := ec.pubsub.Subscribe(MPCResharingEvent, func(natMsg *nats.Msg) { + raw := natMsg.Data + var msg types.ResharingMessage + err := json.Unmarshal(raw, &msg) + if err != nil { + logger.Error("Failed to unmarshal resharing message", err) + return + } + logger.Info("Received resharing event", "walletID", msg.WalletID, "newThreshold", msg.NewThreshold) + + err = ec.identityStore.VerifyInitiatorMessage(&msg) + if err != nil { + logger.Error("Failed to verify initiator message", err) + return + } + + walletID := msg.WalletID + newThreshold := msg.NewThreshold + + // Get new participants + readyPeerIDs := ec.node.GetReadyPeersIncludeSelf() + if len(readyPeerIDs) < newThreshold+1 { + logger.Error("Not enough peers for resharing", nil, "expected", newThreshold+1, "got", len(readyPeerIDs)) + return + } + + var oldPSession, newPSession mpc.IResharingSession + + switch msg.KeyType { + case types.KeyTypeSecp256k1: + // Create resharing oldPSession + oldPSession, err = ec.node.CreateECDSAResharingSession(walletID, true, readyPeerIDs, newThreshold, ec.resharingResultQueue) + if err != nil { + logger.Error("Failed to create resharing session", err) + return + } + newPSession, err = ec.node.CreateECDSAResharingSession(walletID, false, readyPeerIDs, newThreshold, ec.resharingResultQueue) + if err != nil { + logger.Error("Failed to create resharing session", err) + return + } + case types.KeyTypeEd25519: + // Create resharing oldPSession + oldPSession, err = ec.node.CreeateEDDSAResharingSession(walletID, true, readyPeerIDs, newThreshold, ec.resharingResultQueue) + if err != nil { + logger.Error("Failed to create resharing session", err) + return + } + newPSession, err = ec.node.CreeateEDDSAResharingSession(walletID, false, readyPeerIDs, newThreshold, ec.resharingResultQueue) + if err != nil { + logger.Error("Failed to create resharing session", err) + return + } + } + + oldPSession.Init() + newPSession.Init() + + oldPSessionCtx, oldPSessionDone := context.WithCancel(context.Background()) + newPSessionCtx, newPSessionDone := context.WithCancel(context.Background()) + + successEvent := &mpc.ResharingSuccessEvent{ + WalletID: walletID, + } + + var wg sync.WaitGroup + wg.Add(2) + + // For old party, we just need to wait for completion + go func() { + for { + select { + case <-oldPSessionCtx.Done(): + wg.Done() + logger.Info("oldPSession done") + return + case err := <-oldPSession.ErrChan(): + if err != nil { + logger.Error("Resharing session error", err) + } + } + } + }() + + // For new party, we need to get the public key + go func() { + for { + select { + case <-newPSessionCtx.Done(): + if msg.KeyType == types.KeyTypeSecp256k1 { + successEvent.ECDSAPubKey = newPSession.GetPubKeyResult() + } else { + successEvent.EDDSAPubKey = newPSession.GetPubKeyResult() + } + wg.Done() + logger.Info("newPSession done") + return + case err := <-newPSession.ErrChan(): + if err != nil { + logger.Error("Resharing session error", err) + } + } + } + }() + + // Start listening for messages + oldPSession.ListenToIncomingResharingMessageAsync() + newPSession.ListenToIncomingResharingMessageAsync() + time.Sleep(1 * time.Second) + + // Start resharing process + go oldPSession.Resharing(oldPSessionDone) + go newPSession.Resharing(newPSessionDone) + + // Wait for both sessions to complete + wg.Wait() + logger.Info("Closing session successfully!", + "event", successEvent) + + successEventBytes, err := json.Marshal(successEvent) + if err != nil { + logger.Error("Failed to marshal resharing success event", err) + return + } + + err = ec.resharingResultQueue.Enqueue(fmt.Sprintf(mpc.TypeResharingSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(mpc.TypeResharingSuccess, walletID), + }) + if err != nil { + logger.Error("Failed to publish resharing result event", err) + return + } + + logger.Info("[COMPLETED RESHARING] Resharing completed successfully", + "walletID", walletID) + }) + + ec.resharingSub = sub + if err != nil { + return err + } + return nil +} + // Add a cleanup routine that runs periodically func (ec *eventConsumer) sessionCleanupRoutine() { ticker := time.NewTicker(ec.cleanupInterval) @@ -399,14 +554,6 @@ func (ec *eventConsumer) addSession(walletID, txID string) { ec.sessionsLock.Unlock() } -// Remove a session from tracking -func (ec *eventConsumer) removeSession(walletID, txID string) { - sessionID := fmt.Sprintf("%s-%s", walletID, txID) - ec.sessionsLock.Lock() - delete(ec.activeSessions, sessionID) - ec.sessionsLock.Unlock() -} - // checkAndTrackSession checks if a session already exists and tracks it if new. // Returns true if the session is a duplicate. func (ec *eventConsumer) checkDuplicateSession(walletID, txID string) bool { diff --git a/pkg/eventconsumer/events.go b/pkg/eventconsumer/events.go index 5b9ca06..4d71714 100644 --- a/pkg/eventconsumer/events.go +++ b/pkg/eventconsumer/events.go @@ -6,7 +6,7 @@ type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" - KeyTypeEd25519 = "ed25519" + KeyTypeEd25519 KeyType = "ed25519" ) // InitiatorMessage is anything that carries a payload to verify and its signature. diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 863a696..4d281c2 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -8,6 +8,7 @@ import ( "io" "os" "path/filepath" + "strings" "sync" "syscall" @@ -264,5 +265,5 @@ func (s *fileStore) VerifyInitiatorMessage(msg types.InitiatorMessage) error { } func partyIDToNodeID(partyID *tss.PartyID) string { - return string(partyID.KeyInt().Bytes()) + return strings.Split(string(partyID.KeyInt().Bytes()), ":")[0] } diff --git a/pkg/keyinfo/keyinfo.go b/pkg/keyinfo/keyinfo.go index 6952e7d..a10529c 100644 --- a/pkg/keyinfo/keyinfo.go +++ b/pkg/keyinfo/keyinfo.go @@ -11,6 +11,7 @@ import ( type KeyInfo struct { ParticipantPeerIDs []string `json:"participant_peer_ids"` Threshold int `json:"threshold"` + IsReshared bool `json:"is_reshared"` } type store struct { diff --git a/pkg/mpc/ecdsa_keygen_session.go b/pkg/mpc/ecdsa_keygen_session.go index 98dee70..0bb7973 100644 --- a/pkg/mpc/ecdsa_keygen_session.go +++ b/pkg/mpc/ecdsa_keygen_session.go @@ -115,6 +115,7 @@ func (s *KeygenSession) GenerateKey(done func()) { keyInfo := keyinfo.KeyInfo{ ParticipantPeerIDs: s.participantPeerIDs, Threshold: s.threshold, + IsReshared: false, } err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) diff --git a/pkg/mpc/ecdsa_resharing_session.go b/pkg/mpc/ecdsa_resharing_session.go new file mode 100644 index 0000000..2dd2352 --- /dev/null +++ b/pkg/mpc/ecdsa_resharing_session.go @@ -0,0 +1,203 @@ +package mpc + +import ( + "crypto/ecdsa" + "encoding/json" + "fmt" + + "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" +) + +const ( + TypeResharingSuccess = "mpc.mpc_resharing_success.%s" +) + +type IResharingSession interface { + ErrChan() <-chan error + ListenToIncomingResharingMessageAsync() + GetPubKeyResult() []byte + Init() + Resharing(done func()) +} + +type ECDSAResharingSession struct { + Session + isOldParty bool + oldPartyIDs []*tss.PartyID + oldThreshold int + newThreshold int + endCh chan *keygen.LocalPartySaveData +} + +type ResharingSuccessEvent struct { + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` +} + +func ECDSANewResharingSession( + walletID string, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + participantPeerIDs []string, + selfID *tss.PartyID, + oldPartyIDs []*tss.PartyID, + newPartyIDs []*tss.PartyID, + threshold int, + newThreshold int, + preParams *keygen.LocalPreParams, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + resultQueue messaging.MessageQueue, + identityStore identity.Store, + isOldParty bool, +) *ECDSAResharingSession { + oldCtx := tss.NewPeerContext(oldPartyIDs) + newCtx := tss.NewPeerContext(newPartyIDs) + reshareParams := tss.NewReSharingParameters( + tss.S256(), + oldCtx, + newCtx, + selfID, + len(oldPartyIDs), + threshold, + len(newPartyIDs), + newThreshold, + ) + return &ECDSAResharingSession{ + Session: Session{ + walletID: walletID, + pubSub: pubSub, + direct: direct, + threshold: newThreshold, + participantPeerIDs: participantPeerIDs, + selfPartyID: selfID, + partyIDs: newPartyIDs, + outCh: make(chan tss.Message), + ErrCh: make(chan error), + preParams: preParams, + reshareParams: reshareParams, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf(TopicFormatResharingBroadcast, "ecdsa", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf(TopicFormatResharingDirect, "ecdsa", nodeID, walletID) + }, + }, + composeKey: func(walletID string) string { + return fmt.Sprintf(KeyFormatEcdsa, walletID) + }, + getRoundFunc: GetEcdsaMsgRound, + resultQueue: resultQueue, + sessionType: SessionTypeEcdsa, + identityStore: identityStore, + }, + isOldParty: isOldParty, + oldPartyIDs: oldPartyIDs, + oldThreshold: threshold, + newThreshold: newThreshold, + endCh: make(chan *keygen.LocalPartySaveData), + } +} + +func (s *ECDSAResharingSession) Init() { + logger.Infof("Initializing resharing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) + var share keygen.LocalPartySaveData + if s.isOldParty { + // Get existing key data for old party + keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to get wallet data from KVStore: %w", err) + return + } + err = json.Unmarshal(keyData, &share) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to unmarshal wallet data: %w", err) + return + } + } else { + // Initialize empty share data for new party + share = keygen.NewLocalPartySaveData(len(s.partyIDs)) + share.LocalPreParams = *s.preParams + } + + s.party = resharing.NewLocalParty(s.reshareParams, share, s.outCh, s.endCh) + logger.Infof("[INITIALIZED] Initialized resharing session successfully partyID: %s, peerIDs %s, walletID %s, oldThreshold = %d, newThreshold = %d", + s.selfPartyID, s.partyIDs, s.walletID, s.oldThreshold, s.newThreshold) +} + +func (s *ECDSAResharingSession) Resharing(done func()) { + logger.Info("Starting resharing", "walletID", s.walletID, "partyID", s.selfPartyID) + go func() { + if err := s.party.Start(); err != nil { + s.ErrCh <- err + } + }() + + for { + select { + case saveData := <-s.endCh: + keyBytes, err := json.Marshal(saveData) + if err != nil { + s.ErrCh <- err + return + } + + if err := s.SaveKeyData(keyBytes); err != nil { + s.ErrCh <- err + return + } + + // Save key info with resharing flag + if err := s.SaveKeyInfo(true); err != nil { + s.ErrCh <- err + return + } + + // skip for old committee + if saveData.ECDSAPub != nil { + // Get public key + publicKey := saveData.ECDSAPub + pubKey := &ecdsa.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + logger.Error("failed to encode public key", err) + s.ErrCh <- fmt.Errorf("failed to encode public key: %w", err) + return + } + + // Set the public key bytes + s.pubkeyBytes = pubKeyBytes + logger.Info("Generated public key bytes", + "walletID", s.walletID, + "pubKeyBytes", pubKeyBytes) + } + + done() + err = s.Close() + if err != nil { + logger.Error("Failed to close session", err) + } + return + case msg := <-s.outCh: + // Handle the message + s.handleResharingMessage(msg) + } + } +} diff --git a/pkg/mpc/ecdsa_rounds.go b/pkg/mpc/ecdsa_rounds.go index 8e70f7a..e32badb 100644 --- a/pkg/mpc/ecdsa_rounds.go +++ b/pkg/mpc/ecdsa_rounds.go @@ -2,26 +2,35 @@ package mpc import ( "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/common/errors" ) const ( - KEYGEN1 = "KGRound1Message" - KEYGEN2aUnicast = "KGRound2Message1" - KEYGEN2b = "KGRound2Message2" - KEYGEN3 = "KGRound3Message" - KEYSIGN1aUnicast = "SignRound1Message1" - KEYSIGN1b = "SignRound1Message2" - KEYSIGN2Unicast = "SignRound2Message" - KEYSIGN3 = "SignRound3Message" - KEYSIGN4 = "SignRound4Message" - KEYSIGN5 = "SignRound5Message" - KEYSIGN6 = "SignRound6Message" - KEYSIGN7 = "SignRound7Message" - KEYSIGN8 = "SignRound8Message" - KEYSIGN9 = "SignRound9Message" + KEYGEN1 = "KGRound1Message" + KEYGEN2aUnicast = "KGRound2Message1" + KEYGEN2b = "KGRound2Message2" + KEYGEN3 = "KGRound3Message" + KEYSIGN1aUnicast = "SignRound1Message1" + KEYSIGN1b = "SignRound1Message2" + KEYSIGN2Unicast = "SignRound2Message" + KEYSIGN3 = "SignRound3Message" + KEYSIGN4 = "SignRound4Message" + KEYSIGN5 = "SignRound5Message" + KEYSIGN6 = "SignRound6Message" + KEYSIGN7 = "SignRound7Message" + KEYSIGN8 = "SignRound8Message" + KEYSIGN9 = "SignRound9Message" + KEYRESHARING1Unicast = "DGRound1Message" + KEYRESHARING2aUnicast = "DGRound2Message1" + KEYRESHARING2bUnicast = "DGRound2Message2" + KEYRESHARING3aUnicast = "DGRound3Message1" + KEYRESHARING3b = "DGRound3Message2" + KEYRESHARING4a = "DGRound4Message1" + KEYRESHARING4bUnicast = "DGRound4Message2" + TSSKEYGENROUNDS = 4 TSSKEYSIGNROUNDS = 10 ) @@ -113,7 +122,41 @@ func GetEcdsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (Round Index: 9, RoundMsg: KEYSIGN9, }, nil - + case *resharing.DGRound1Message: + return RoundInfo{ + Index: 0, + RoundMsg: KEYRESHARING1Unicast, + }, nil + case *resharing.DGRound2Message1: + return RoundInfo{ + Index: 1, + RoundMsg: KEYRESHARING2aUnicast, + }, nil + case *resharing.DGRound2Message2: + return RoundInfo{ + Index: 2, + RoundMsg: KEYRESHARING2bUnicast, + }, nil + case *resharing.DGRound3Message1: + return RoundInfo{ + Index: 3, + RoundMsg: KEYRESHARING3aUnicast, + }, nil + case *resharing.DGRound3Message2: + return RoundInfo{ + Index: 4, + RoundMsg: KEYRESHARING3b, + }, nil + case *resharing.DGRound4Message1: + return RoundInfo{ + Index: 5, + RoundMsg: KEYRESHARING4a, + }, nil + case *resharing.DGRound4Message2: + return RoundInfo{ + Index: 6, + RoundMsg: KEYRESHARING4bUnicast, + }, nil default: return RoundInfo{}, errors.New("unknown round") } diff --git a/pkg/mpc/eddsa_keygen_session.go b/pkg/mpc/eddsa_keygen_session.go index 7a1b325..944b22b 100644 --- a/pkg/mpc/eddsa_keygen_session.go +++ b/pkg/mpc/eddsa_keygen_session.go @@ -106,6 +106,7 @@ func (s *EDDSAKeygenSession) GenerateKey(done func()) { keyInfo := keyinfo.KeyInfo{ ParticipantPeerIDs: s.participantPeerIDs, Threshold: s.threshold, + IsReshared: false, } err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) diff --git a/pkg/mpc/eddsa_resharing_session.go b/pkg/mpc/eddsa_resharing_session.go new file mode 100644 index 0000000..06e8dba --- /dev/null +++ b/pkg/mpc/eddsa_resharing_session.go @@ -0,0 +1,176 @@ +package mpc + +import ( + "encoding/json" + "fmt" + + "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" + "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/decred/dcrd/dcrec/edwards/v2" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" +) + +type EDDSAResharingSession struct { + Session + isOldParty bool + oldPartyIDs []*tss.PartyID + oldThreshold int + newThreshold int + endCh chan *keygen.LocalPartySaveData +} + +func EDDSANewResharingSession( + walletID string, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + participantPeerIDs []string, + selfID *tss.PartyID, + oldPartyIDs []*tss.PartyID, + newPartyIDs []*tss.PartyID, + threshold int, + newThreshold int, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + resultQueue messaging.MessageQueue, + identityStore identity.Store, + isOldParty bool, +) *EDDSAResharingSession { + oldCtx := tss.NewPeerContext(oldPartyIDs) + newCtx := tss.NewPeerContext(newPartyIDs) + reshareParams := tss.NewReSharingParameters( + tss.Edwards(), + oldCtx, + newCtx, + selfID, + len(oldPartyIDs), + threshold, + len(newPartyIDs), + newThreshold, + ) + return &EDDSAResharingSession{ + Session: Session{ + walletID: walletID, + pubSub: pubSub, + direct: direct, + threshold: newThreshold, + participantPeerIDs: participantPeerIDs, + selfPartyID: selfID, + partyIDs: newPartyIDs, + outCh: make(chan tss.Message), + ErrCh: make(chan error), + reshareParams: reshareParams, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf(TopicFormatResharingBroadcast, "eddsa", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf(TopicFormatResharingDirect, "eddsa", nodeID, walletID) + }, + }, + composeKey: func(walletID string) string { + return fmt.Sprintf(KeyFormatEddsa, walletID) + }, + getRoundFunc: GetEddsaMsgRound, + resultQueue: resultQueue, + sessionType: SessionTypeEddsa, + identityStore: identityStore, + }, + isOldParty: isOldParty, + oldPartyIDs: oldPartyIDs, + oldThreshold: threshold, + newThreshold: newThreshold, + endCh: make(chan *keygen.LocalPartySaveData), + } +} + +func (s *EDDSAResharingSession) Init() { + logger.Infof("Initializing EDDSA resharing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) + var share keygen.LocalPartySaveData + if s.isOldParty { + // Get existing key data for old party + keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) + if err != nil { + fmt.Println("err", err) + s.ErrCh <- fmt.Errorf("failed to get wallet data from KVStore: %w", err) + return + } + err = json.Unmarshal(keyData, &share) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to unmarshal wallet data: %w", err) + return + } + } else { + // Initialize empty share data for new party + share = keygen.NewLocalPartySaveData(len(s.partyIDs)) + } + s.party = resharing.NewLocalParty(s.reshareParams, share, s.outCh, s.endCh) + logger.Infof("[INITIALIZED] Initialized EDDSA resharing session successfully partyID: %s, peerIDs %s, walletID %s, oldThreshold = %d, newThreshold = %d", + s.selfPartyID, s.partyIDs, s.walletID, s.oldThreshold, s.newThreshold) +} + +func (s *EDDSAResharingSession) Resharing(done func()) { + logger.Info("Starting EDDSA resharing", "walletID", s.walletID, "partyID", s.selfPartyID) + go func() { + if err := s.party.Start(); err != nil { + s.ErrCh <- err + } + }() + + for { + select { + case saveData := <-s.endCh: + // skip for old committee + if saveData.EDDSAPub != nil { + keyBytes, err := json.Marshal(saveData) + if err != nil { + s.ErrCh <- err + return + } + + if err := s.SaveKeyData(keyBytes); err != nil { + s.ErrCh <- err + return + } + + // Save key info with resharing flag + if err := s.SaveKeyInfo(true); err != nil { + s.ErrCh <- err + return + } + + // Get public key + publicKey := saveData.EDDSAPub + pkX, pkY := publicKey.X(), publicKey.Y() + pk := edwards.PublicKey{ + Curve: tss.Edwards(), + X: pkX, + Y: pkY, + } + + pubKeyBytes := pk.SerializeCompressed() + s.pubkeyBytes = pubKeyBytes + + logger.Info("Generated public key bytes", + "walletID", s.walletID, + "pubKeyBytes", pubKeyBytes) + } + + done() + err := s.Close() + if err != nil { + logger.Error("Failed to close session", err) + } + return + case msg := <-s.outCh: + // Handle the message + s.handleResharingMessage(msg) + } + } +} diff --git a/pkg/mpc/eddsa_rounds.go b/pkg/mpc/eddsa_rounds.go index 01519d0..88864f3 100644 --- a/pkg/mpc/eddsa_rounds.go +++ b/pkg/mpc/eddsa_rounds.go @@ -2,6 +2,7 @@ package mpc import ( "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" + "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" "github.com/bnb-chain/tss-lib/v2/eddsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/common/errors" @@ -16,14 +17,21 @@ type RoundInfo struct { } const ( - EDDSA_KEYGEN1 = "KGRound1Message" - EDDSA_KEYGEN2aUnicast = "KGRound2Message1" - EDDSA_KEYGEN2b = "KGRound2Message2" - EDDSA_KEYSIGN1 = "SignRound1Message" - EDDSA_KEYSIGN2 = "SignRound2Message" - EDDSA_KEYSIGN3 = "SignRound3Message" + EDDSA_KEYGEN1 = "KGRound1Message" + EDDSA_KEYGEN2aUnicast = "KGRound2Message1" + EDDSA_KEYGEN2b = "KGRound2Message2" + EDDSA_KEYSIGN1 = "SignRound1Message" + EDDSA_KEYSIGN2 = "SignRound2Message" + EDDSA_KEYSIGN3 = "SignRound3Message" + EDDSA_RESHARING1 = "DGRound1Message" + EDDSA_RESHARING2 = "DGRound2Message" + EDDSA_RESHARING3aUnicast = "DGRound3Message1" + EDDSA_RESHARING3bUnicast = "DGRound3Message2" + EDDSA_RESHARING4 = "DGRound4Message" + EDDSA_TSSKEYGENROUNDS = 3 EDDSA_TSSKEYSIGNROUNDS = 3 + EDDSA_RESHARINGROUNDS = 4 ) func GetEddsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (RoundInfo, error) { @@ -68,6 +76,36 @@ func GetEddsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (Round RoundMsg: EDDSA_KEYSIGN3, }, nil + case *resharing.DGRound1Message: + return RoundInfo{ + Index: 0, + RoundMsg: EDDSA_RESHARING1, + }, nil + + case *resharing.DGRound2Message: + return RoundInfo{ + Index: 1, + RoundMsg: EDDSA_RESHARING2, + }, nil + + case *resharing.DGRound3Message1: + return RoundInfo{ + Index: 2, + RoundMsg: EDDSA_RESHARING3aUnicast, + }, nil + + case *resharing.DGRound3Message2: + return RoundInfo{ + Index: 3, + RoundMsg: EDDSA_RESHARING3bUnicast, + }, nil + + case *resharing.DGRound4Message: + return RoundInfo{ + Index: 4, + RoundMsg: EDDSA_RESHARING4, + }, nil + default: return RoundInfo{}, errors.New("unknown round") } diff --git a/pkg/mpc/eddsa_signing_session.go b/pkg/mpc/eddsa_signing_session.go index c421839..ea5103d 100644 --- a/pkg/mpc/eddsa_signing_session.go +++ b/pkg/mpc/eddsa_signing_session.go @@ -9,6 +9,7 @@ import ( "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" "github.com/bnb-chain/tss-lib/v2/eddsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/decred/dcrd/dcrec/edwards/v2" "github.com/fystack/mpcium/pkg/common/errors" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/identity" @@ -16,7 +17,6 @@ import ( "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/decred/dcrd/dcrec/edwards/v2" "github.com/samber/lo" ) @@ -116,7 +116,6 @@ func (s *EDDSASigningSession) Init(tx *big.Int) error { if err != nil { return errors.Wrap(err, "Failed to unmarshal wallet data") } - s.party = signing.NewLocalParty(tx, params, data, s.outCh, s.endCh) s.data = &data s.tx = tx diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index 6c105f4..f61b9c3 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -17,8 +17,9 @@ import ( ) const ( - PurposeKeygen string = "keygen" - PurposeSign string = "sign" + PurposeKeygen string = "keygen" + PurposeSign string = "sign" + PurposeResharing string = "resharing" ) type ID string @@ -39,7 +40,7 @@ type Node struct { func CreatePartyID(nodeID string, label string) *tss.PartyID { partyID := uuid.NewString() - key := big.NewInt(0).SetBytes([]byte(nodeID)) + key := big.NewInt(0).SetBytes([]byte(nodeID + ":" + label)) return tss.NewPartyID(partyID, label, key) } @@ -90,10 +91,6 @@ func (p *Node) ID() string { return p.nodeID } -func composeReadyTopic(nodeID string) string { - return fmt.Sprintf("%s-%s", nodeID, "ready") -} - func (p *Node) CreateKeyGenSession(walletID string, threshold int, successQueue messaging.MessageQueue) (*KeygenSession, error) { if p.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { return nil, fmt.Errorf("Not enough peers to create gen session! Expected %d, got %d", threshold+1, p.peerRegistry.GetReadyPeersCount()) @@ -149,7 +146,17 @@ func (p *Node) CreateSigningSession( resultQueue messaging.MessageQueue, ) (*SigningSession, error) { readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) + keyInfo, err := p.keyinfoStore.Get(fmt.Sprintf("eddsa:%s", walletID)) + if err != nil { + return nil, fmt.Errorf("failed to get key info: %w", err) + } + var selfPartyID *tss.PartyID + var allPartyIDs []*tss.PartyID + if keyInfo.IsReshared { + selfPartyID, allPartyIDs = p.generatePartyIDs(PurposeResharing, readyPeerIDs) + } else { + selfPartyID, allPartyIDs = p.generatePartyIDs(PurposeKeygen, readyPeerIDs) + } session := NewSigningSession( walletID, txID, @@ -177,7 +184,17 @@ func (p *Node) CreateEDDSASigningSession( resultQueue messaging.MessageQueue, ) (*EDDSASigningSession, error) { readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) + keyInfo, err := p.keyinfoStore.Get(fmt.Sprintf("eddsa:%s", walletID)) + if err != nil { + return nil, fmt.Errorf("failed to get key info: %w", err) + } + var selfPartyID *tss.PartyID + var allPartyIDs []*tss.PartyID + if keyInfo.IsReshared { + selfPartyID, allPartyIDs = p.generatePartyIDs(PurposeResharing, readyPeerIDs) + } else { + selfPartyID, allPartyIDs = p.generatePartyIDs(PurposeKeygen, readyPeerIDs) + } session := NewEDDSASigningSession( walletID, txID, @@ -196,6 +213,78 @@ func (p *Node) CreateEDDSASigningSession( return session, nil } +func (p *Node) CreateECDSAResharingSession(walletID string, isOldParticipant bool, readyPeerIDs []string, newThreshold int, resultQueue messaging.MessageQueue) (*ECDSAResharingSession, error) { + // Get existing key info to determine old participants + keyInfo, err := p.keyinfoStore.Get(fmt.Sprintf("ecdsa:%s", walletID)) + if err != nil { + return nil, fmt.Errorf("failed to get key info: %w", err) + } + + oldSelfPartyID, oldPartyIDs := p.generatePartyIDs(PurposeKeygen, keyInfo.ParticipantPeerIDs) + newSelfPartyID, newPartyIDs := p.generatePartyIDs(PurposeResharing, readyPeerIDs) + + var selfPartyID *tss.PartyID + if isOldParticipant { + selfPartyID = oldSelfPartyID + } else { + selfPartyID = newSelfPartyID + } + + session := ECDSANewResharingSession( + walletID, + p.pubSub, + p.direct, + readyPeerIDs, + selfPartyID, + oldPartyIDs, + newPartyIDs, + keyInfo.Threshold, + newThreshold, + p.ecdsaPreParams, + p.kvstore, + p.keyinfoStore, + resultQueue, + p.identityStore, + isOldParticipant, + ) + return session, nil +} + +func (p *Node) CreeateEDDSAResharingSession(walletID string, isOldParticipant bool, readyPeerIDs []string, newThreshold int, resultQueue messaging.MessageQueue) (*EDDSAResharingSession, error) { + keyInfo, err := p.keyinfoStore.Get(fmt.Sprintf("eddsa:%s", walletID)) + if err != nil { + return nil, fmt.Errorf("failed to get key info: %w", err) + } + + oldSelfPartyID, oldPartyIDs := p.generatePartyIDs(PurposeKeygen, keyInfo.ParticipantPeerIDs) + newSelfPartyID, newPartyIDs := p.generatePartyIDs(PurposeResharing, readyPeerIDs) + + var selfPartyID *tss.PartyID + if isOldParticipant { + selfPartyID = oldSelfPartyID + } else { + selfPartyID = newSelfPartyID + } + + session := EDDSANewResharingSession( + walletID, + p.pubSub, + p.direct, + readyPeerIDs, + selfPartyID, + oldPartyIDs, + newPartyIDs, + keyInfo.Threshold, + newThreshold, + p.kvstore, + p.keyinfoStore, + resultQueue, + p.identityStore, + isOldParticipant, + ) + return session, nil +} + func (p *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { var selfPartyID *tss.PartyID partyIDs := make([]*tss.PartyID, len(readyPeerIDs)) @@ -217,3 +306,15 @@ func (p *Node) Close() { logger.Error("Resign failed", err) } } + +func (p *Node) GetKeyInfo(key string) (*keyinfo.KeyInfo, error) { + return p.keyinfoStore.Get(key) +} + +func (p *Node) GetReadyPeersIncludeSelf() []string { + return p.peerRegistry.GetReadyPeersIncludeSelf() +} + +func (p *Node) GetKVStore() kvstore.KVStore { + return p.kvstore +} diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go index f8204f6..76994c2 100644 --- a/pkg/mpc/session.go +++ b/pkg/mpc/session.go @@ -2,6 +2,7 @@ package mpc import ( "fmt" + "slices" "strings" "sync" @@ -21,13 +22,7 @@ var ( ErrNotEnoughParticipants = errors.New("Not enough participants to sign") ) -type TopicComposer struct { - ComposeBroadcastTopic func() string - ComposeDirectTopic func(nodeID string) string -} - -type KeyComposerFn func(id string) string - +// SessionType constants type SessionType string const ( @@ -35,6 +30,25 @@ const ( SessionTypeEddsa SessionType = "session_eddsa" ) +// Topic format constants +const ( + TopicFormatResharingBroadcast = "resharing:broadcast:%s:%s" + TopicFormatResharingDirect = "resharing:direct:%s:%s:%s" +) + +// Key format constants +const ( + KeyFormatEcdsa = "ecdsa:%s" + KeyFormatEddsa = "eddsa:%s" +) + +type TopicComposer struct { + ComposeBroadcastTopic func() string + ComposeDirectTopic func(nodeID string) string +} + +type KeyComposerFn func(id string) string + type Session struct { walletID string pubSub messaging.PubSub @@ -49,7 +63,9 @@ type Session struct { party tss.Party // preParams is nil for EDDSA session - preParams *keygen.LocalPreParams + preParams *keygen.LocalPreParams + // reshareParams is nil for non resharing session + reshareParams *tss.ReSharingParameters kvstore kvstore.KVStore keyinfoStore keyinfo.Store broadcastSub messaging.Subscription @@ -84,7 +100,6 @@ func (s *Session) handleTssMessage(keyshare tss.Message) { s.ErrCh <- err return } - tssMsg := types.NewTssMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To) signature, err := s.identityStore.SignMessage(&tssMsg) if err != nil { @@ -118,6 +133,34 @@ func (s *Session) handleTssMessage(keyshare tss.Message) { } } +func (s *Session) handleResharingMessage(msg tss.Message) { + data, routing, err := msg.WireBytes() + if err != nil { + s.ErrCh <- err + return + } + + tssMsg := types.NewTssResharingMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To, routing.IsToOldCommittee, routing.IsToOldAndNewCommittees) + signature, err := s.identityStore.SignMessage(&tssMsg) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to sign message: %w", err) + return + } + tssMsg.Signature = signature + msgBytes, err := types.MarshalTssMessage(&tssMsg) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to marshal tss message: %w", err) + return + } + + // Just send to all intended recipients except self + for _, to := range routing.To { + if to.Id != s.selfPartyID.Id { + s.direct.Send(s.topicComposer.ComposeDirectTopic(PartyIDToNodeID(to)), msgBytes) + } + } +} + func (s *Session) receiveTssMessage(rawMsg []byte) { msg, err := types.UnmarshalTssMessage(rawMsg) if err != nil { @@ -141,7 +184,12 @@ func (s *Session) receiveTssMessage(rawMsg []byte) { return } - logger.Debug(fmt.Sprintf("%s Received message", s.sessionType), "from", msg.From.String(), "to", strings.Join(toIDs, ","), "isBroadcast", msg.IsBroadcast, "round", round.RoundMsg) + logger.Info(fmt.Sprintf("%s Received message", s.sessionType), + "from", msg.From.String(), + "to", strings.Join(toIDs, ","), + "isBroadcast", msg.IsBroadcast, + "round", round.RoundMsg) + isBroadcast := msg.IsBroadcast && len(msg.To) == 0 isToSelf := len(msg.To) == 1 && ComparePartyIDs(msg.To[0], s.selfPartyID) @@ -153,7 +201,46 @@ func (s *Session) receiveTssMessage(rawMsg []byte) { logger.Error("Failed to update party", err, "walletID", s.walletID) return } + } +} +func (s *Session) receiveTssResharingMessage(rawMsg []byte) { + msg, err := types.UnmarshalTssMessage(rawMsg) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to unmarshal message: %w", err) + return + } + err = s.identityStore.VerifyMessage(msg) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to verify message: %w, tampered message", err) + return + } + + toIDs := make([]string, len(msg.To)) + for i, id := range msg.To { + toIDs[i] = id.String() + } + round, err := s.getRoundFunc(msg.MsgBytes, s.selfPartyID, msg.IsBroadcast) + if err != nil { + s.ErrCh <- errors.Wrap(err, "Broken TSS Share") + return + } + + logger.Info(fmt.Sprintf("%s Received resharing message", s.sessionType), + "from", msg.From.String(), + "to", strings.Join(toIDs, ","), + "isBroadcast", msg.IsBroadcast, + "round", round.RoundMsg) + + isToSelf := slices.Contains(toIDs, s.selfPartyID.String()) + if isToSelf { + s.mu.Lock() + defer s.mu.Unlock() + ok, err := s.party.UpdateFromBytes(msg.MsgBytes, msg.From, msg.IsBroadcast) + if !ok || err != nil { + logger.Error("Failed to update party", err, "walletID", s.walletID) + return + } } } @@ -197,14 +284,30 @@ func (s *Session) ListenToIncomingMessageAsync() { } -func (s *Session) Close() error { - err := s.broadcastSub.Unsubscribe() +func (s *Session) ListenToIncomingResharingMessageAsync() { + nodeID := PartyIDToNodeID(s.selfPartyID) + targetID := s.topicComposer.ComposeDirectTopic(nodeID) + sub, err := s.direct.Listen(targetID, func(msg []byte) { + go s.receiveTssResharingMessage(msg) // async for avoid timeout + }) if err != nil { - return err + s.ErrCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", targetID, err) } - err = s.directSub.Unsubscribe() - if err != nil { - return err + s.directSub = sub +} + +func (s *Session) Close() error { + if s.broadcastSub != nil { + err := s.broadcastSub.Unsubscribe() + if err != nil { + return err + } + } + if s.directSub != nil { + err := s.directSub.Unsubscribe() + if err != nil { + return err + } } return nil } @@ -216,3 +319,29 @@ func (s *Session) GetPubKeyResult() []byte { func (s *Session) ErrChan() <-chan error { return s.ErrCh } + +// SaveKeyInfo saves the key info with resharing information +func (s *Session) SaveKeyInfo(isReshared bool) error { + keyInfo := &keyinfo.KeyInfo{ + ParticipantPeerIDs: s.participantPeerIDs, + Threshold: s.threshold, + IsReshared: isReshared, + } + + err := s.keyinfoStore.Save(s.composeKey(s.walletID), keyInfo) + if err != nil { + logger.Error("Failed to save keyinfo", err, "walletID", s.walletID) + return err + } + return nil +} + +// SaveKeyData saves the key data to the kvstore +func (s *Session) SaveKeyData(keyBytes []byte) error { + err := s.kvstore.Put(s.composeKey(s.walletID), keyBytes) + if err != nil { + logger.Error("Failed to save key", err, "walletID", s.walletID) + return err + } + return nil +} diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index edd0bf4..b49b768 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -33,6 +33,36 @@ type SignTxMessage struct { Signature []byte `json:"signature"` } +type ResharingMessage struct { + WalletID string `json:"wallet_id"` + NewThreshold int `json:"new_threshold"` + Signature []byte `json:"signature"` + KeyType KeyType `json:"key_type"` +} + +// InitiatorID implements InitiatorMessage. +func (r *ResharingMessage) InitiatorID() string { + return r.WalletID +} + +// Raw implements InitiatorMessage. +func (r *ResharingMessage) Raw() ([]byte, error) { + // Create a struct with only the fields that should be signed + payload := struct { + WalletID string `json:"wallet_id"` + NewThreshold int `json:"new_threshold"` + }{ + WalletID: r.WalletID, + NewThreshold: r.NewThreshold, + } + return json.Marshal(payload) +} + +// Sig implements InitiatorMessage. +func (r *ResharingMessage) Sig() []byte { + return r.Signature +} + func (m *SignTxMessage) Raw() ([]byte, error) { // omit the Signature field itself when computing the signed‐over data payload := struct { diff --git a/pkg/types/tss.go b/pkg/types/tss.go index 6d61e9c..94559b5 100644 --- a/pkg/types/tss.go +++ b/pkg/types/tss.go @@ -41,6 +41,28 @@ func NewTssMessage( return tssMsg } +func NewTssResharingMessage( + walletID string, + msgBytes []byte, + isBroadcast bool, + from *tss.PartyID, + to []*tss.PartyID, + isToOldCommittee bool, + isToOldAndNewCommittees bool, +) TssMessage { + tssMsg := TssMessage{ + WalletID: walletID, + IsBroadcast: isBroadcast, + MsgBytes: msgBytes, + From: from, + To: to, + IsToOldCommittee: isToOldCommittee, + IsToOldAndNewCommittees: isToOldAndNewCommittees, + } + + return tssMsg +} + func MarshalTssMessage(tssMsg *TssMessage) ([]byte, error) { msgBytes, err := json.Marshal(tssMsg) if err != nil { diff --git a/scripts/migration/add-key-type/main.go b/scripts/migration/add-key-type/main.go index 9891243..a2da004 100644 --- a/scripts/migration/add-key-type/main.go +++ b/scripts/migration/add-key-type/main.go @@ -11,7 +11,7 @@ import ( ) func main() { - logger.Init("production") + logger.Init("production", true) nodeName := flag.String("name", "", "Provide node name") flag.Parse() if *nodeName == "" { diff --git a/scripts/migration/update-keyinfo/main.go b/scripts/migration/update-keyinfo/main.go index 704ff5a..e68134b 100644 --- a/scripts/migration/update-keyinfo/main.go +++ b/scripts/migration/update-keyinfo/main.go @@ -13,7 +13,7 @@ import ( // script to add key type prefix ecdsa for existing keys func main() { config.InitViperConfig() - logger.Init("production") + logger.Init("production", true) appConfig := config.LoadConfig() logger.Info("App config", "config", appConfig) diff --git a/setup_identities.sh b/setup_identities.sh new file mode 100755 index 0000000..53ed0a3 --- /dev/null +++ b/setup_identities.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Number of nodes to create (default is 3) +NUM_NODES=3 + +echo "πŸš€ Setting up Node Identities..." + +# Create node directories and copy config files +echo "πŸ“ Creating node directories..." +for i in $(seq 0 $((NUM_NODES-1))); do + mkdir -p "node$i/identity" + if [ ! -f "node$i/config.yaml" ]; then + cp config.yaml "node$i/" + fi + if [ ! -f "node$i/peers.json" ]; then + cp peers.json "node$i/" + fi +done + +# Generate identity for each node +echo "πŸ”‘ Generating identities for each node..." +for i in $(seq 0 $((NUM_NODES-1))); do + echo "πŸ“ Generating identity for node$i..." + cd "node$i" + mpcium-cli generate-identity --node "node$i" + cd .. +done + +# Distribute identity files to all nodes +echo "πŸ”„ Distributing identity files across nodes..." +for i in $(seq 0 $((NUM_NODES-1))); do + for j in $(seq 0 $((NUM_NODES-1))); do + if [ $i != $j ]; then + echo "πŸ“‹ Copying node${i}_identity.json to node$j..." + cp "node$i/identity/node${i}_identity.json" "node$j/identity/" + fi + done +done + +echo "✨ Node identities setup complete!" +echo +echo "πŸ“‚ Created folder structure:" +echo "β”œβ”€β”€ node0" +echo "β”‚ β”œβ”€β”€ config.yaml" +echo "β”‚ β”œβ”€β”€ identity/" +echo "β”‚ └── peers.json" +echo "β”œβ”€β”€ node1" +echo "β”‚ β”œβ”€β”€ config.yaml" +echo "β”‚ β”œβ”€β”€ identity/" +echo "β”‚ └── peers.json" +echo "└── node2" +echo " β”œβ”€β”€ config.yaml" +echo " β”œβ”€β”€ identity/" +echo " └── peers.json" +echo +echo "βœ… You can now start your nodes with:" +echo "cd node0 && mpcium start -n node0" +echo "cd node1 && mpcium start -n node1" +echo "cd node2 && mpcium start -n node2" \ No newline at end of file diff --git a/setup_initiator.sh b/setup_initiator.sh new file mode 100755 index 0000000..de37e07 --- /dev/null +++ b/setup_initiator.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +echo "πŸš€ Setting up Event Initiator..." + +# Generate the event initiator +echo "πŸ“ Generating event initiator..." +mpcium-cli generate-initiator + +# Extract the public key from the generated file +if [ -f "event_initiator.identity.json" ]; then + PUBLIC_KEY=$(grep -o '"public_key": *"[^"]*"' event_initiator.identity.json | cut -d'"' -f4) + + if [ -n "$PUBLIC_KEY" ]; then + echo "πŸ”‘ Found public key: $PUBLIC_KEY" + + # Update config.yaml + if [ -f "config.yaml" ]; then + echo "πŸ“ Updating config.yaml..." + # Check if event_initiator_pubkey already exists + if grep -q "event_initiator_pubkey:" config.yaml; then + # Replace existing line + sed -i "s/event_initiator_pubkey: .*/event_initiator_pubkey: \"$PUBLIC_KEY\"/" config.yaml + else + # Add new line + echo "event_initiator_pubkey: \"$PUBLIC_KEY\"" >> config.yaml + fi + echo "βœ… Successfully updated config.yaml" + else + echo "❌ Error: config.yaml not found. Please create it first." + exit 1 + fi + else + echo "❌ Error: Could not extract public key from event_initiator.identity.json" + exit 1 + fi +else + echo "❌ Error: event_initiator.identity.json not found" + exit 1 +fi + +echo "✨ Event Initiator setup complete!" \ No newline at end of file From d731d2a91042138a2d42680dba3b31ab93da5fdd Mon Sep 17 00:00:00 2001 From: vietddude Date: Sat, 7 Jun 2025 01:18:59 +0700 Subject: [PATCH 02/34] Adds MPC party package --- pkg/mpc/party/base.go | 53 +++++++++++++++++++++++++ pkg/mpc/party/ecdsa.go | 88 ++++++++++++++++++++++++++++++++++++++++++ pkg/mpc/party/eddsa.go | 86 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 227 insertions(+) create mode 100644 pkg/mpc/party/base.go create mode 100644 pkg/mpc/party/ecdsa.go create mode 100644 pkg/mpc/party/eddsa.go diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go new file mode 100644 index 0000000..a5a4c00 --- /dev/null +++ b/pkg/mpc/party/base.go @@ -0,0 +1,53 @@ +package party + +import ( + "context" + "encoding/json" + + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/logger" +) + +type party struct { + walletID string + localParty tss.Party + partyID *tss.PartyID + partyIDs []*tss.PartyID + threshold int +} + +type PartyInterface interface { + PartyID() *tss.PartyID + GetOutCh() chan tss.Message + UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) +} + +func NewParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int) *party { + return &party{walletID, nil, partyID, partyIDs, threshold} +} + +// runParty handles the common party execution loop +func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, send func(tss.Message), endCh <-chan T, finish func([]byte)) { + go func() { + if err := party.Start(); err != nil { + logger.Error("Failed to start party", err) + } + }() + + for { + select { + case <-ctx.Done(): + return + case msg := <-s.GetOutCh(): + send(msg) + case result := <-endCh: + bz, err := json.Marshal(result) + if err != nil { + logger.Error("Failed to marshal result", err) + return + } + finish(bz) + return + } + } +} diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go new file mode 100644 index 0000000..cb481ad --- /dev/null +++ b/pkg/mpc/party/ecdsa.go @@ -0,0 +1,88 @@ +package party + +import ( + "context" + "errors" + "math/big" + + "github.com/bnb-chain/tss-lib/v2/common" + "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" + "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/logger" +) + +type ECDSASession struct { + party + prepareParams keygen.LocalPreParams + reshareParams *tss.ReSharingParameters + saveData *keygen.LocalPartySaveData + outCh chan tss.Message +} + +func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, + prepareParams keygen.LocalPreParams, reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData) *ECDSASession { + return &ECDSASession{ + party: *NewParty(walletID, partyID, partyIDs, threshold), + prepareParams: prepareParams, + reshareParams: reshareParams, + saveData: saveData, + outCh: make(chan tss.Message, 1000), + } +} + +func (s *ECDSASession) PartyID() *tss.PartyID { + return s.partyID +} + +func (s *ECDSASession) GetOutCh() chan tss.Message { + return s.outCh +} + +func (s *ECDSASession) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { + ok, err := s.localParty.UpdateFromBytes(msgBytes, from, isBroadcast) + if err != nil { + return false, err + } + return ok, nil +} + +func (s *ECDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { + end := make(chan *keygen.LocalPartySaveData) + params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + party := keygen.NewLocalParty(params, s.outCh, end, s.prepareParams) + runParty(s, ctx, party, send, end, finish) +} + +func (s *ECDSASession) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { + if s.saveData == nil { + logger.Error("Save data is nil", errors.New("save data is nil")) + return + } + end := make(chan *common.SignatureData) + params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + party := signing.NewLocalParty(msg, params, *s.saveData, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} + +func (s *ECDSASession) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, + oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { + if s.saveData == nil { + logger.Error("Save data is nil", errors.New("save data is nil")) + return + } + end := make(chan *keygen.LocalPartySaveData) + params := tss.NewReSharingParameters( + tss.S256(), + tss.NewPeerContext(oldPartyIDs), + tss.NewPeerContext(newPartyIDs), + s.partyID, + len(oldPartyIDs), + len(newPartyIDs), + oldThreshold, + newThreshold, + ) + party := resharing.NewLocalParty(params, *s.saveData, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go new file mode 100644 index 0000000..e5c3331 --- /dev/null +++ b/pkg/mpc/party/eddsa.go @@ -0,0 +1,86 @@ +package party + +import ( + "context" + "errors" + "math/big" + + "github.com/bnb-chain/tss-lib/v2/common" + "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" + "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" + "github.com/bnb-chain/tss-lib/v2/eddsa/signing" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/logger" +) + +type EDDSASession struct { + party + reshareParams *tss.ReSharingParameters + saveData *keygen.LocalPartySaveData + outCh chan tss.Message +} + +func NewEDDASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, + reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData) *EDDSASession { + return &EDDSASession{ + party: *NewParty(walletID, partyID, partyIDs, threshold), + reshareParams: reshareParams, + saveData: saveData, + outCh: make(chan tss.Message, 1000), + } +} + +func (s *EDDSASession) PartyID() *tss.PartyID { + return s.partyID +} + +func (s *EDDSASession) GetOutCh() chan tss.Message { + return s.outCh +} + +func (s *EDDSASession) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { + ok, err := s.localParty.UpdateFromBytes(msgBytes, from, isBroadcast) + if err != nil { + return false, err + } + return ok, nil +} + +func (s *EDDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { + end := make(chan *keygen.LocalPartySaveData) + params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + party := keygen.NewLocalParty(params, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} + +func (s *EDDSASession) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { + if s.saveData == nil { + logger.Error("Save data is nil", errors.New("save data is nil")) + return + } + end := make(chan *common.SignatureData) + params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + party := signing.NewLocalParty(msg, params, *s.saveData, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} + +func (s *EDDSASession) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, + oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { + if s.saveData == nil { + logger.Error("Save data is nil", errors.New("save data is nil")) + return + } + end := make(chan *keygen.LocalPartySaveData) + params := tss.NewReSharingParameters( + tss.S256(), + tss.NewPeerContext(oldPartyIDs), + tss.NewPeerContext(newPartyIDs), + s.partyID, + len(oldPartyIDs), + len(newPartyIDs), + oldThreshold, + newThreshold, + ) + party := resharing.NewLocalParty(params, *s.saveData, s.outCh, end) + runParty(s, ctx, party, send, end, finish) +} From fd70149c6ba7ae0a97acf91f9907886ea580705a Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 9 Jun 2025 23:58:38 +0700 Subject: [PATCH 03/34] Refactors MPC party handling into sessions Introduces a session-based architecture for managing MPC party interactions. This change removes the direct logging from the party implementations and introduces an error channel for handling errors during the MPC process, allowing errors to be propagated to the session layer. This change also introduces base session structs and implements EcdsaSession. --- pkg/mpc/party/base.go | 15 ++- pkg/mpc/party/ecdsa.go | 28 ++--- pkg/mpc/party/eddsa.go | 25 ++-- pkg/mpc/session/base.go | 213 +++++++++++++++++++++++++++++++++++ pkg/mpc/session/constants.go | 6 + pkg/mpc/session/ecdsa.go | 22 ++++ pkg/mpc/session/eddsa.go | 0 pkg/mpc/session/utils.go | 7 ++ 8 files changed, 284 insertions(+), 32 deletions(-) create mode 100644 pkg/mpc/session/base.go create mode 100644 pkg/mpc/session/constants.go create mode 100644 pkg/mpc/session/ecdsa.go create mode 100644 pkg/mpc/session/eddsa.go create mode 100644 pkg/mpc/session/utils.go diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index a5a4c00..00b305b 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -5,7 +5,6 @@ import ( "encoding/json" "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/logger" ) type party struct { @@ -14,23 +13,29 @@ type party struct { partyID *tss.PartyID partyIDs []*tss.PartyID threshold int + errCh chan error } type PartyInterface interface { PartyID() *tss.PartyID GetOutCh() chan tss.Message UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) + GetErrCh() chan error } -func NewParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int) *party { - return &party{walletID, nil, partyID, partyIDs, threshold} +func NewParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, errCh chan error) *party { + return &party{walletID, nil, partyID, partyIDs, threshold, errCh} +} + +func (p *party) GetErrCh() chan error { + return p.errCh } // runParty handles the common party execution loop func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, send func(tss.Message), endCh <-chan T, finish func([]byte)) { go func() { if err := party.Start(); err != nil { - logger.Error("Failed to start party", err) + s.GetErrCh() <- err } }() @@ -43,7 +48,7 @@ func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, sen case result := <-endCh: bz, err := json.Marshal(result) if err != nil { - logger.Error("Failed to marshal result", err) + s.GetErrCh() <- err return } finish(bz) diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index cb481ad..e45b0b7 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -10,10 +10,9 @@ import ( "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/logger" ) -type ECDSASession struct { +type ECDSAParty struct { party prepareParams keygen.LocalPreParams reshareParams *tss.ReSharingParameters @@ -21,10 +20,10 @@ type ECDSASession struct { outCh chan tss.Message } -func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, - prepareParams keygen.LocalPreParams, reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData) *ECDSASession { - return &ECDSASession{ - party: *NewParty(walletID, partyID, partyIDs, threshold), +func NewECDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, + prepareParams keygen.LocalPreParams, reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData, errCh chan error) *ECDSAParty { + return &ECDSAParty{ + party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), prepareParams: prepareParams, reshareParams: reshareParams, saveData: saveData, @@ -32,32 +31,33 @@ func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.Part } } -func (s *ECDSASession) PartyID() *tss.PartyID { +func (s *ECDSAParty) PartyID() *tss.PartyID { return s.partyID } -func (s *ECDSASession) GetOutCh() chan tss.Message { +func (s *ECDSAParty) GetOutCh() chan tss.Message { return s.outCh } -func (s *ECDSASession) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { +func (s *ECDSAParty) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { ok, err := s.localParty.UpdateFromBytes(msgBytes, from, isBroadcast) if err != nil { + s.GetErrCh() <- err return false, err } return ok, nil } -func (s *ECDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { +func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { end := make(chan *keygen.LocalPartySaveData) params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := keygen.NewLocalParty(params, s.outCh, end, s.prepareParams) runParty(s, ctx, party, send, end, finish) } -func (s *ECDSASession) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { +func (s *ECDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { - logger.Error("Save data is nil", errors.New("save data is nil")) + s.GetErrCh() <- errors.New("save data is nil") return } end := make(chan *common.SignatureData) @@ -66,10 +66,10 @@ func (s *ECDSASession) StartSigning(ctx context.Context, msg *big.Int, send func runParty(s, ctx, party, send, end, finish) } -func (s *ECDSASession) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, +func (s *ECDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { - logger.Error("Save data is nil", errors.New("save data is nil")) + s.GetErrCh() <- errors.New("save data is nil") return } end := make(chan *keygen.LocalPartySaveData) diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index e5c3331..743071d 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -10,10 +10,9 @@ import ( "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" "github.com/bnb-chain/tss-lib/v2/eddsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/logger" ) -type EDDSASession struct { +type EDDSAParty struct { party reshareParams *tss.ReSharingParameters saveData *keygen.LocalPartySaveData @@ -21,24 +20,24 @@ type EDDSASession struct { } func NewEDDASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, - reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData) *EDDSASession { - return &EDDSASession{ - party: *NewParty(walletID, partyID, partyIDs, threshold), + reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData, errCh chan error) *EDDSAParty { + return &EDDSAParty{ + party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), reshareParams: reshareParams, saveData: saveData, outCh: make(chan tss.Message, 1000), } } -func (s *EDDSASession) PartyID() *tss.PartyID { +func (s *EDDSAParty) PartyID() *tss.PartyID { return s.partyID } -func (s *EDDSASession) GetOutCh() chan tss.Message { +func (s *EDDSAParty) GetOutCh() chan tss.Message { return s.outCh } -func (s *EDDSASession) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { +func (s *EDDSAParty) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { ok, err := s.localParty.UpdateFromBytes(msgBytes, from, isBroadcast) if err != nil { return false, err @@ -46,16 +45,16 @@ func (s *EDDSASession) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBro return ok, nil } -func (s *EDDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { +func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { end := make(chan *keygen.LocalPartySaveData) params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := keygen.NewLocalParty(params, s.outCh, end) runParty(s, ctx, party, send, end, finish) } -func (s *EDDSASession) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { +func (s *EDDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { - logger.Error("Save data is nil", errors.New("save data is nil")) + s.GetErrCh() <- errors.New("save data is nil") return } end := make(chan *common.SignatureData) @@ -64,10 +63,10 @@ func (s *EDDSASession) StartSigning(ctx context.Context, msg *big.Int, send func runParty(s, ctx, party, send, end, finish) } -func (s *EDDSASession) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, +func (s *EDDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { - logger.Error("Save data is nil", errors.New("save data is nil")) + s.GetErrCh() <- errors.New("save data is nil") return } end := make(chan *keygen.LocalPartySaveData) diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go new file mode 100644 index 0000000..f218acd --- /dev/null +++ b/pkg/mpc/session/base.go @@ -0,0 +1,213 @@ +package session + +import ( + "fmt" + "slices" + "sync" + + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/party" + "github.com/fystack/mpcium/pkg/types" +) + +type Curve string + +type Purpose string + +const ( + CurveSecp256k1 Curve = "secp256k1" + CurveEd25519 Curve = "ed25519" + + PurposeKeygen Purpose = "keygen" + PurposeSign Purpose = "sign" + PurposeReshare Purpose = "reshare" +) + +type TopicComposer struct { + ComposeBroadcastTopic func() string + ComposeDirectTopic func(nodeID string) string +} + +type KeyComposerFn func(id string) string + +type session struct { + walletID string + party party.PartyInterface + + broadcastSub messaging.Subscription + directSub messaging.Subscription + pubSub messaging.PubSub + direct messaging.DirectMessaging + + identityStore identity.Store + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + + topicComposer *TopicComposer + composeKey KeyComposerFn + mu sync.Mutex +} + +func NewSession( + curve Curve, + purpose Purpose, + walletID string, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + identityStore identity.Store, + kvstore kvstore.KVStore, +) *session { + return &session{ + walletID: walletID, + pubSub: pubSub, + direct: direct, + identityStore: identityStore, + kvstore: kvstore, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf(KeygenBroadcastTopic, walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf(KeygenDirectTopic, nodeID, walletID) + }, + }, + composeKey: func(id string) string { + return fmt.Sprintf("%s/%s", purpose, id) + }, + } +} + +func (s *session) SetParty(party party.PartyInterface) { + s.party = party +} + +func (s *session) Send(msg tss.Message) { + data, routing, err := msg.WireBytes() + if err != nil { + logger.Error("Failed to wire bytes", err) + return + } + + tssMsg := types.NewTssMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To) + signature, err := s.identityStore.SignMessage(&tssMsg) + if err != nil { + logger.Error("Failed to sign message", err) + return + } + tssMsg.Signature = signature + msgBytes, err := types.MarshalTssMessage(&tssMsg) + if err != nil { + logger.Error("Failed to marshal message", err) + return + } + + if routing.IsBroadcast && len(routing.To) == 0 { + err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msgBytes) + if err != nil { + logger.Error("Failed to publish message", err) + return + } + } else { + for _, to := range routing.To { + nodeID := partyIDToNodeID(to) + topic := s.topicComposer.ComposeDirectTopic(nodeID) + err := s.direct.Send(topic, msgBytes) + if err != nil { + logger.Error("Failed to send message", err) + return + } + } + } +} + +func (s *session) Receive(rawMsg []byte) { + msg, err := types.UnmarshalTssMessage(rawMsg) + if err != nil { + logger.Error("Failed to unmarshal message", err) + return + } + + err = s.identityStore.VerifyMessage(msg) + if err != nil { + logger.Error("Failed to verify message", err) + return + } + + toIDs := make([]string, len(msg.To)) + for i, id := range msg.To { + toIDs[i] = id.String() + } + + isBroadcast := msg.IsBroadcast && len(msg.To) == 0 + isToSelf := slices.Contains(toIDs, s.party.PartyID().String()) + + if isBroadcast || isToSelf { + s.mu.Lock() + defer s.mu.Unlock() + ok, err := s.party.UpdateFromBytes(msg.MsgBytes, msg.From, msg.IsBroadcast) + if !ok || err != nil { + logger.Error("Failed to update party", err) + return + } + } +} + +// func (s *session) Listen() { +// broadcast := func() { +// sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { +// msg := natMsg.Data +// s.receiveTssMessage(msg) +// }) + +// if err != nil { +// s.ErrCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) +// return +// } + +// s.broadcastSub = sub +// } + +// direct := func() { +// sub, err := s.direct.Listen(s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), func(msg []byte) { +// s.receiveTssMessage(msg) +// }) + +// if err != nil { +// s.ErrCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), err) +// return +// } + +// s.directSub = sub +// } + +// go broadcast() +// go direct() +// } + +func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared bool, data []byte) (err error) { + + keyInfo := keyinfo.KeyInfo{ + ParticipantPeerIDs: participantPeerIDs, + Threshold: threshold, + IsReshared: isReshared, + } + + err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) + if err != nil { + logger.Error("Failed to save keyinfo", err, "walletID", s.walletID) + return + } + + err = s.kvstore.Put(s.composeKey(s.walletID), data) + if err != nil { + logger.Error("Failed to save key", err, "walletID", s.walletID) + return + } + + return nil +} diff --git a/pkg/mpc/session/constants.go b/pkg/mpc/session/constants.go new file mode 100644 index 0000000..745cd04 --- /dev/null +++ b/pkg/mpc/session/constants.go @@ -0,0 +1,6 @@ +package session + +const ( + KeygenBroadcastTopic = "keygen:broadcast:%s" + KeygenDirectTopic = "keygen:direct:%s:%s" +) \ No newline at end of file diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go new file mode 100644 index 0000000..88117ad --- /dev/null +++ b/pkg/mpc/session/ecdsa.go @@ -0,0 +1,22 @@ +package session + +import ( + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/party" +) + +type EcdsaSession struct { + *session +} + +func NewECDSASession(walletID string, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore) *ECDSASession { + s := NewSession(CurveSecp256k1, PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore) + party := party.NewECDSAParty(walletID, s.PartyID(), s.PartyIDs(), s.threshold, s.prepareParams, s.reshareParams, s.saveData) + s.SetParty(party) + return &ECDSASession{ + session: s, + party: party, + } +} diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go new file mode 100644 index 0000000..e69de29 diff --git a/pkg/mpc/session/utils.go b/pkg/mpc/session/utils.go new file mode 100644 index 0000000..ad17c57 --- /dev/null +++ b/pkg/mpc/session/utils.go @@ -0,0 +1,7 @@ +package session + +import "github.com/bnb-chain/tss-lib/v2/tss" + +func partyIDToNodeID(partyID *tss.PartyID) string { + return string(partyID.KeyInt().Bytes()) +} From d53a2bdab1a8e66a5a4b54fac2debcefa0a06e42 Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 10 Jun 2025 17:58:18 +0700 Subject: [PATCH 04/34] Refactor MPC package to use node-based architecture This commit refactors the MPC package to introduce a node-based architecture, enhancing the organization and clarity of the code. The changes include: - Replacing direct references to the `mpc` package with a new `node` package for better modularity. - Implementing a `Node` struct that encapsulates peer management and session handling. - Updating event consumer and session management to utilize the new node structure, improving error handling and message processing. - Removing unused methods and enhancing logging for better traceability. These changes aim to streamline the MPC process and improve maintainability. --- cmd/mpcium/main.go | 9 +- pkg/eventconsumer/event_consumer.go | 720 +++++++++++++--------------- pkg/mpc/node/node.go | 130 +++++ pkg/mpc/node/registry.go | 207 ++++++++ pkg/mpc/party/base.go | 51 +- pkg/mpc/party/ecdsa.go | 24 +- pkg/mpc/party/eddsa.go | 8 +- pkg/mpc/session/base.go | 99 ++-- pkg/mpc/session/ecdsa.go | 16 +- pkg/mpc/session/eddsa.go | 1 + pkg/types/initiator_msg.go | 2 +- 11 files changed, 776 insertions(+), 491 deletions(-) create mode 100644 pkg/mpc/node/node.go create mode 100644 pkg/mpc/node/registry.go diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 33812ff..563fa12 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -18,7 +18,7 @@ import ( "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/mpc/node" "github.com/hashicorp/consul/api" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -145,19 +145,18 @@ func runNode(ctx context.Context, c *cli.Command) error { logger.Info("Node is running", "peerID", nodeID, "name", nodeName) peerNodeIDs := GetPeerIDs(peers) - peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV()) + peerRegistry := node.NewRegistry(nodeID, peerNodeIDs, consulClient.KV()) - mpcNode := mpc.NewNode( + mpcNode := node.NewNode( nodeID, peerNodeIDs, pubsub, directMessaging, badgerKV, keyinfoStore, - peerRegistry, identityStore, + peerRegistry, ) - defer mpcNode.Close() eventConsumer := eventconsumer.NewEventConsumer( mpcNode, diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 8cb650f..f3e3352 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -3,18 +3,15 @@ package eventconsumer import ( "context" "encoding/json" - "errors" "fmt" "log" - "math/big" "sync" "time" - "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/mpc/node" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -32,7 +29,7 @@ type EventConsumer interface { } type eventConsumer struct { - node *mpc.Node + node *node.Node pubsub messaging.PubSub mpcThreshold int @@ -54,7 +51,7 @@ type eventConsumer struct { } func NewEventConsumer( - node *mpc.Node, + node *node.Node, pubsub messaging.PubSub, genKeySucecssQueue messaging.MessageQueue, signingResultQueue messaging.MessageQueue, @@ -87,15 +84,15 @@ func (ec *eventConsumer) Run() { log.Fatal("Failed to consume key reconstruction event", err) } - err = ec.consumeTxSigningEvent() - if err != nil { - log.Fatal("Failed to consume tx signing event", err) - } + // err = ec.consumeTxSigningEvent() + // if err != nil { + // log.Fatal("Failed to consume tx signing event", err) + // } - err = ec.consumeResharingEvent() - if err != nil { - log.Fatal("Failed to consume resharing event", err) - } + // err = ec.consumeResharingEvent() + // if err != nil { + // log.Fatal("Failed to consume resharing event", err) + // } logger.Info("MPC Event consumer started...!") } @@ -118,404 +115,349 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { } walletID := msg.WalletID - session, err := ec.node.CreateKeyGenSession(walletID, ec.mpcThreshold, ec.genKeySucecssQueue) - if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) - return - } - eddsaSession, err := ec.node.CreateEDDSAKeyGenSession(walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + session, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) if err != nil { logger.Error("Failed to create key generation session", err, "walletID", walletID) return } - session.Init() - eddsaSession.Init() - - ctx, done := context.WithCancel(context.Background()) - ctxEddsa, doneEddsa := context.WithCancel(context.Background()) - - successEvent := &mpc.KeygenSuccessEvent{ - WalletID: walletID, - } - - var wg sync.WaitGroup - wg.Add(2) - go func() { - for { - select { - case <-ctx.Done(): - successEvent.ECDSAPubKey = session.GetPubKeyResult() - wg.Done() - return - case err := <-session.ErrCh: - logger.Error("Keygen session error", err) - } - } - }() - - go func() { - for { - select { - case <-ctxEddsa.Done(): - successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() - wg.Done() - return - case err := <-eddsaSession.ErrCh: - logger.Error("Keygen session error", err) - } - } - }() - - session.ListenToIncomingMessageAsync() - eddsaSession.ListenToIncomingMessageAsync() - // TODO: replace sleep with distributed lock - time.Sleep(1 * time.Second) - - go session.GenerateKey(done) - go eddsaSession.GenerateKey(doneEddsa) - - wg.Wait() - logger.Info("Closing session successfully!", "event", successEvent) - - successEventBytes, err := json.Marshal(successEvent) - if err != nil { - logger.Error("Failed to marshal keygen success event", err) - return - } - - err = ec.genKeySucecssQueue.Enqueue(fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), + go session.StartKeygen(context.Background(), session.Send, func(data []byte) { + logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) }) - if err != nil { - logger.Error("Failed to publish key generation success message", err) - return - } - + go session.Listen() logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) - - }) - - ec.keyGenerationSub = sub - if err != nil { - return err - } - return nil -} - -func (ec *eventConsumer) consumeTxSigningEvent() error { - sub, err := ec.pubsub.Subscribe(MPCSignEvent, func(natMsg *nats.Msg) { - raw := natMsg.Data - var msg types.SignTxMessage - err := json.Unmarshal(raw, &msg) - if err != nil { - logger.Error("Failed to unmarshal signing message", err) - return - } - - err = ec.identityStore.VerifyInitiatorMessage(&msg) - if err != nil { - logger.Error("Failed to verify initiator message", err) - return - } - - logger.Info( - "Received signing event", - "waleltID", - msg.WalletID, - "type", - msg.KeyType, - "tx", - msg.TxID, - "Id", - ec.node.ID(), - ) - - // Check for duplicate session and track if new - if ec.checkDuplicateSession(msg.WalletID, msg.TxID) { - natMsg.Term() - return - } - - var session mpc.ISigningSession - switch msg.KeyType { - case types.KeyTypeSecp256k1: - session, err = ec.node.CreateSigningSession( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - ec.mpcThreshold, - ec.signingResultQueue, - ) - case types.KeyTypeEd25519: - session, err = ec.node.CreateEDDSASigningSession( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - ec.mpcThreshold, - ec.signingResultQueue, - ) - - } - - if err != nil { - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - "Failed to create signing session", - natMsg, - ) - return - } - - txBigInt := new(big.Int).SetBytes(msg.Tx) - err = session.Init(txBigInt) - if err != nil { - if errors.Is(err, mpc.ErrNotEnoughParticipants) { - logger.Info("RETRY LATER: Not enough participants to sign") - //Return for retry later - return - } - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - "Failed to init signing session", - natMsg, - ) - return - } - - // Mark session as already processed - ec.addSession(msg.WalletID, msg.TxID) - - ctx, done := context.WithCancel(context.Background()) go func() { for { select { - case <-ctx.Done(): - return - case err := <-session.ErrChan(): + case err := <-session.ErrCh(): if err != nil { - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - "Failed to sign tx", - natMsg, - ) - return + logger.Error("Key generation session error", err) } } } }() - - session.ListenToIncomingMessageAsync() - // TODO: use consul distributed lock here, only sign after all nodes has already completed listing to incoming message async - // The purpose of the sleep is to be ensuring that the node has properly set up its message listeners - // before it starts the signing process. If the signing process starts sending messages before other nodes - // have set up their listeners, those messages might be missed, potentially causing the signing process to fail. - // One solution: - // The messaging includes mechanisms for direct point-to-point communication (in point2point.go). - // The nodes could explicitly coordinate through request-response patterns before starting signing - time.Sleep(1 * time.Second) - - onSuccess := func(data []byte) { - done() - if natMsg.Reply != "" { - err = ec.pubsub.Publish(natMsg.Reply, data) - if err != nil { - logger.Error("Failed to publish reply", err) - } else { - logger.Info("Reply to the original message", "reply", natMsg.Reply) - } - } - } - go session.Sign(onSuccess) }) - ec.signingSub = sub + ec.keyGenerationSub = sub if err != nil { return err } - return nil } -func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkInternalCode string, err error, errMsg string, natMsg *nats.Msg) { - logger.Error("Signing session error", err, "walletID", walletID, "txID", txID, "error", errMsg) - signingResult := event.SigningResultEvent{ - ResultType: event.SigningResultTypeError, - NetworkInternalCode: NetworkInternalCode, - WalletID: walletID, - TxID: txID, - ErrorReason: errMsg, - } - - signingResultBytes, err := json.Marshal(signingResult) - if err != nil { - logger.Error("Failed to marshal signing result event", err) - return - } - - natMsg.Ack() - err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: txID, - }) - if err != nil { - logger.Error("Failed to publish signing result event", err) - return - } -} - -func (ec *eventConsumer) consumeResharingEvent() error { - sub, err := ec.pubsub.Subscribe(MPCResharingEvent, func(natMsg *nats.Msg) { - raw := natMsg.Data - var msg types.ResharingMessage - err := json.Unmarshal(raw, &msg) - if err != nil { - logger.Error("Failed to unmarshal resharing message", err) - return - } - logger.Info("Received resharing event", "walletID", msg.WalletID, "newThreshold", msg.NewThreshold) - - err = ec.identityStore.VerifyInitiatorMessage(&msg) - if err != nil { - logger.Error("Failed to verify initiator message", err) - return - } - - walletID := msg.WalletID - newThreshold := msg.NewThreshold - - // Get new participants - readyPeerIDs := ec.node.GetReadyPeersIncludeSelf() - if len(readyPeerIDs) < newThreshold+1 { - logger.Error("Not enough peers for resharing", nil, "expected", newThreshold+1, "got", len(readyPeerIDs)) - return - } - - var oldPSession, newPSession mpc.IResharingSession - - switch msg.KeyType { - case types.KeyTypeSecp256k1: - // Create resharing oldPSession - oldPSession, err = ec.node.CreateECDSAResharingSession(walletID, true, readyPeerIDs, newThreshold, ec.resharingResultQueue) - if err != nil { - logger.Error("Failed to create resharing session", err) - return - } - newPSession, err = ec.node.CreateECDSAResharingSession(walletID, false, readyPeerIDs, newThreshold, ec.resharingResultQueue) - if err != nil { - logger.Error("Failed to create resharing session", err) - return - } - case types.KeyTypeEd25519: - // Create resharing oldPSession - oldPSession, err = ec.node.CreeateEDDSAResharingSession(walletID, true, readyPeerIDs, newThreshold, ec.resharingResultQueue) - if err != nil { - logger.Error("Failed to create resharing session", err) - return - } - newPSession, err = ec.node.CreeateEDDSAResharingSession(walletID, false, readyPeerIDs, newThreshold, ec.resharingResultQueue) - if err != nil { - logger.Error("Failed to create resharing session", err) - return - } - } - - oldPSession.Init() - newPSession.Init() - - oldPSessionCtx, oldPSessionDone := context.WithCancel(context.Background()) - newPSessionCtx, newPSessionDone := context.WithCancel(context.Background()) - - successEvent := &mpc.ResharingSuccessEvent{ - WalletID: walletID, - } - - var wg sync.WaitGroup - wg.Add(2) - - // For old party, we just need to wait for completion - go func() { - for { - select { - case <-oldPSessionCtx.Done(): - wg.Done() - logger.Info("oldPSession done") - return - case err := <-oldPSession.ErrChan(): - if err != nil { - logger.Error("Resharing session error", err) - } - } - } - }() - - // For new party, we need to get the public key - go func() { - for { - select { - case <-newPSessionCtx.Done(): - if msg.KeyType == types.KeyTypeSecp256k1 { - successEvent.ECDSAPubKey = newPSession.GetPubKeyResult() - } else { - successEvent.EDDSAPubKey = newPSession.GetPubKeyResult() - } - wg.Done() - logger.Info("newPSession done") - return - case err := <-newPSession.ErrChan(): - if err != nil { - logger.Error("Resharing session error", err) - } - } - } - }() - - // Start listening for messages - oldPSession.ListenToIncomingResharingMessageAsync() - newPSession.ListenToIncomingResharingMessageAsync() - time.Sleep(1 * time.Second) - - // Start resharing process - go oldPSession.Resharing(oldPSessionDone) - go newPSession.Resharing(newPSessionDone) - - // Wait for both sessions to complete - wg.Wait() - logger.Info("Closing session successfully!", - "event", successEvent) - - successEventBytes, err := json.Marshal(successEvent) - if err != nil { - logger.Error("Failed to marshal resharing success event", err) - return - } - - err = ec.resharingResultQueue.Enqueue(fmt.Sprintf(mpc.TypeResharingSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(mpc.TypeResharingSuccess, walletID), - }) - if err != nil { - logger.Error("Failed to publish resharing result event", err) - return - } - - logger.Info("[COMPLETED RESHARING] Resharing completed successfully", - "walletID", walletID) - }) - - ec.resharingSub = sub - if err != nil { - return err - } - return nil -} +// func (ec *eventConsumer) consumeTxSigningEvent() error { +// sub, err := ec.pubsub.Subscribe(MPCSignEvent, func(natMsg *nats.Msg) { +// raw := natMsg.Data +// var msg types.SignTxMessage +// err := json.Unmarshal(raw, &msg) +// if err != nil { +// logger.Error("Failed to unmarshal signing message", err) +// return +// } + +// err = ec.identityStore.VerifyInitiatorMessage(&msg) +// if err != nil { +// logger.Error("Failed to verify initiator message", err) +// return +// } + +// logger.Info( +// "Received signing event", +// "waleltID", +// msg.WalletID, +// "type", +// msg.KeyType, +// "tx", +// msg.TxID, +// "Id", +// ec.node.ID(), +// ) + +// // Check for duplicate session and track if new +// if ec.checkDuplicateSession(msg.WalletID, msg.TxID) { +// natMsg.Term() +// return +// } + +// var session mpc.ISigningSession +// switch msg.KeyType { +// case types.KeyTypeSecp256k1: +// session, err = ec.node.CreateSigningSession( +// msg.WalletID, +// msg.TxID, +// msg.NetworkInternalCode, +// ec.mpcThreshold, +// ec.signingResultQueue, +// ) +// case types.KeyTypeEd25519: +// session, err = ec.node.CreateEDDSASigningSession( +// msg.WalletID, +// msg.TxID, +// msg.NetworkInternalCode, +// ec.mpcThreshold, +// ec.signingResultQueue, +// ) + +// } + +// if err != nil { +// ec.handleSigningSessionError( +// msg.WalletID, +// msg.TxID, +// msg.NetworkInternalCode, +// err, +// "Failed to create signing session", +// natMsg, +// ) +// return +// } + +// txBigInt := new(big.Int).SetBytes(msg.Tx) +// err = session.Init(txBigInt) +// if err != nil { +// if errors.Is(err, mpc.ErrNotEnoughParticipants) { +// logger.Info("RETRY LATER: Not enough participants to sign") +// //Return for retry later +// return +// } +// ec.handleSigningSessionError( +// msg.WalletID, +// msg.TxID, +// msg.NetworkInternalCode, +// err, +// "Failed to init signing session", +// natMsg, +// ) +// return +// } + +// // Mark session as already processed +// ec.addSession(msg.WalletID, msg.TxID) + +// ctx, done := context.WithCancel(context.Background()) +// go func() { +// for { +// select { +// case <-ctx.Done(): +// return +// case err := <-session.ErrChan(): +// if err != nil { +// ec.handleSigningSessionError( +// msg.WalletID, +// msg.TxID, +// msg.NetworkInternalCode, +// err, +// "Failed to sign tx", +// natMsg, +// ) +// return +// } +// } +// } +// }() + +// session.ListenToIncomingMessageAsync() +// // TODO: use consul distributed lock here, only sign after all nodes has already completed listing to incoming message async +// // The purpose of the sleep is to be ensuring that the node has properly set up its message listeners +// // before it starts the signing process. If the signing process starts sending messages before other nodes +// // have set up their listeners, those messages might be missed, potentially causing the signing process to fail. +// // One solution: +// // The messaging includes mechanisms for direct point-to-point communication (in point2point.go). +// // The nodes could explicitly coordinate through request-response patterns before starting signing +// time.Sleep(1 * time.Second) + +// onSuccess := func(data []byte) { +// done() +// if natMsg.Reply != "" { +// err = ec.pubsub.Publish(natMsg.Reply, data) +// if err != nil { +// logger.Error("Failed to publish reply", err) +// } else { +// logger.Info("Reply to the original message", "reply", natMsg.Reply) +// } +// } +// } +// go session.Sign(onSuccess) +// }) + +// ec.signingSub = sub +// if err != nil { +// return err +// } + +// return nil +// } + +// func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkInternalCode string, err error, errMsg string, natMsg *nats.Msg) { +// logger.Error("Signing session error", err, "walletID", walletID, "txID", txID, "error", errMsg) +// signingResult := event.SigningResultEvent{ +// ResultType: event.SigningResultTypeError, +// NetworkInternalCode: NetworkInternalCode, +// WalletID: walletID, +// TxID: txID, +// ErrorReason: errMsg, +// } + +// signingResultBytes, err := json.Marshal(signingResult) +// if err != nil { +// logger.Error("Failed to marshal signing result event", err) +// return +// } + +// natMsg.Ack() +// err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ +// IdempotententKey: txID, +// }) +// if err != nil { +// logger.Error("Failed to publish signing result event", err) +// return +// } +// } + +// func (ec *eventConsumer) consumeResharingEvent() error { +// sub, err := ec.pubsub.Subscribe(MPCResharingEvent, func(natMsg *nats.Msg) { +// raw := natMsg.Data +// var msg types.ResharingMessage +// err := json.Unmarshal(raw, &msg) +// if err != nil { +// logger.Error("Failed to unmarshal resharing message", err) +// return +// } +// logger.Info("Received resharing event", "walletID", msg.WalletID, "newThreshold", msg.NewThreshold) + +// err = ec.identityStore.VerifyInitiatorMessage(&msg) +// if err != nil { +// logger.Error("Failed to verify initiator message", err) +// return +// } + +// walletID := msg.WalletID +// newThreshold := msg.NewThreshold + +// // Get new participants +// readyPeerIDs := ec.node.GetReadyPeersIncludeSelf() +// if len(readyPeerIDs) < newThreshold+1 { +// logger.Error("Not enough peers for resharing", nil, "expected", newThreshold+1, "got", len(readyPeerIDs)) +// return +// } + +// var oldPSession, newPSession mpc.IResharingSession + +// switch msg.KeyType { +// case types.KeyTypeSecp256k1: +// // Create resharing oldPSession +// oldPSession, err = ec.node.CreateECDSAResharingSession(walletID, true, readyPeerIDs, newThreshold, ec.resharingResultQueue) +// if err != nil { +// logger.Error("Failed to create resharing session", err) +// return +// } +// newPSession, err = ec.node.CreateECDSAResharingSession(walletID, false, readyPeerIDs, newThreshold, ec.resharingResultQueue) +// if err != nil { +// logger.Error("Failed to create resharing session", err) +// return +// } +// case types.KeyTypeEd25519: +// // Create resharing oldPSession +// oldPSession, err = ec.node.CreeateEDDSAResharingSession(walletID, true, readyPeerIDs, newThreshold, ec.resharingResultQueue) +// if err != nil { +// logger.Error("Failed to create resharing session", err) +// return +// } +// newPSession, err = ec.node.CreeateEDDSAResharingSession(walletID, false, readyPeerIDs, newThreshold, ec.resharingResultQueue) +// if err != nil { +// logger.Error("Failed to create resharing session", err) +// return +// } +// } + +// oldPSession.Init() +// newPSession.Init() + +// oldPSessionCtx, oldPSessionDone := context.WithCancel(context.Background()) +// newPSessionCtx, newPSessionDone := context.WithCancel(context.Background()) + +// successEvent := &mpc.ResharingSuccessEvent{ +// WalletID: walletID, +// } + +// var wg sync.WaitGroup +// wg.Add(2) + +// // For old party, we just need to wait for completion +// go func() { +// for { +// select { +// case <-oldPSessionCtx.Done(): +// wg.Done() +// logger.Info("oldPSession done") +// return +// case err := <-oldPSession.ErrChan(): +// if err != nil { +// logger.Error("Resharing session error", err) +// } +// } +// } +// }() + +// // For new party, we need to get the public key +// go func() { +// for { +// select { +// case <-newPSessionCtx.Done(): +// if msg.KeyType == types.KeyTypeSecp256k1 { +// successEvent.ECDSAPubKey = newPSession.GetPubKeyResult() +// } else { +// successEvent.EDDSAPubKey = newPSession.GetPubKeyResult() +// } +// wg.Done() +// logger.Info("newPSession done") +// return +// case err := <-newPSession.ErrChan(): +// if err != nil { +// logger.Error("Resharing session error", err) +// } +// } +// } +// }() + +// // Start listening for messages +// oldPSession.ListenToIncomingResharingMessageAsync() +// newPSession.ListenToIncomingResharingMessageAsync() +// time.Sleep(1 * time.Second) + +// // Start resharing process +// go oldPSession.Resharing(oldPSessionDone) +// go newPSession.Resharing(newPSessionDone) + +// // Wait for both sessions to complete +// wg.Wait() +// logger.Info("Closing session successfully!", +// "event", successEvent) + +// successEventBytes, err := json.Marshal(successEvent) +// if err != nil { +// logger.Error("Failed to marshal resharing success event", err) +// return +// } + +// err = ec.resharingResultQueue.Enqueue(fmt.Sprintf(mpc.TypeResharingSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ +// IdempotententKey: fmt.Sprintf(mpc.TypeResharingSuccess, walletID), +// }) +// if err != nil { +// logger.Error("Failed to publish resharing result event", err) +// return +// } + +// logger.Info("[COMPLETED RESHARING] Resharing completed successfully", +// "walletID", walletID) +// }) + +// ec.resharingSub = sub +// if err != nil { +// return err +// } +// return nil +// } // Add a cleanup routine that runs periodically func (ec *eventConsumer) sessionCleanupRoutine() { diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go new file mode 100644 index 0000000..c5467f5 --- /dev/null +++ b/pkg/mpc/node/node.go @@ -0,0 +1,130 @@ +package node + +import ( + "encoding/json" + "fmt" + "math/big" + "time" + + "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/session" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" +) + +type Node struct { + nodeID string + peerIDs []string + + pubSub messaging.PubSub + direct messaging.DirectMessaging + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + identityStore identity.Store + + peerRegistry *registry +} + +func NewNode(nodeID string, peerIDs []string, pubSub messaging.PubSub, direct messaging.DirectMessaging, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, identityStore identity.Store, peerRegistry *registry) *Node { + go peerRegistry.WatchPeersReady() + + return &Node{ + nodeID: nodeID, + peerIDs: peerIDs, + pubSub: pubSub, + direct: direct, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + identityStore: identityStore, + peerRegistry: peerRegistry, + } +} + +func (n *Node) CreateKeygenSession(_ types.KeyType, walletID string, threshold int, successQueue messaging.MessageQueue) (*session.ECDSASession, error) { + if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { + return nil, fmt.Errorf("not enough peers to create gen session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) + } + + readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() + selfPartyID, allPartyIDs := n.generatePartyIDs("keygen", readyPeerIDs) + + preparams, err := n.getECDSAPreParams(false) + if err != nil { + return nil, fmt.Errorf("failed to get preparams: %w", err) + } + logger.Info("Preparams loaded") + + ecdsaSession := session.NewECDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + preparams, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + ) + + return ecdsaSession, nil +} + +func (n *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { + var selfPartyID *tss.PartyID + partyIDs := make([]*tss.PartyID, len(readyPeerIDs)) + for i, peerID := range readyPeerIDs { + if peerID == n.nodeID { + selfPartyID = createPartyID(peerID, purpose) + partyIDs[i] = selfPartyID + } else { + partyIDs[i] = createPartyID(peerID, purpose) + } + } + allPartyIDs := tss.SortPartyIDs(partyIDs, 0) + return selfPartyID, allPartyIDs +} + +func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error) { + var path string + if isOldParty { + path = fmt.Sprintf("preparams.old.%s", n.nodeID) + } else { + path = fmt.Sprintf("preparams.%s", n.nodeID) + } + + preparamsBytes, _ := n.kvstore.Get(path) + // if err != nil { + // return nil, err + // } + + if preparamsBytes == nil { + preparams, err := keygen.GeneratePreParams(5 * time.Minute) + if err != nil { + return nil, err + } + preparamsBytes, err = json.Marshal(preparams) + if err != nil { + return nil, err + } + n.kvstore.Put(path, preparamsBytes) + return preparams, nil + } + + var preparams keygen.LocalPreParams + if err := json.Unmarshal(preparamsBytes, &preparams); err != nil { + return nil, err + } + return &preparams, nil +} + +func createPartyID(nodeID string, label string) *tss.PartyID { + partyID := uuid.NewString() + key := big.NewInt(0).SetBytes([]byte(nodeID + ":" + label)) + return tss.NewPartyID(partyID, label, key) +} diff --git a/pkg/mpc/node/registry.go b/pkg/mpc/node/registry.go new file mode 100644 index 0000000..98c65fa --- /dev/null +++ b/pkg/mpc/node/registry.go @@ -0,0 +1,207 @@ +package node + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/fystack/mpcium/pkg/infra" + "github.com/fystack/mpcium/pkg/logger" + "github.com/hashicorp/consul/api" + "github.com/samber/lo" +) + +const ( + ReadinessCheckPeriod = 1 * time.Second +) + +type PeerRegistry interface { + Ready() error + ArePeersReady() bool + WatchPeersReady() + // Resign is called by the node when it is going to shutdown + Resign() error + GetReadyPeersCount() int64 + GetReadyPeersIncludeSelf() []string // get ready peers include self +} + +type registry struct { + nodeID string + peerNodeIDs []string + readyMap map[string]bool + readyCount int64 + mu sync.RWMutex + ready bool // ready is true when all peers are ready + + consulKV infra.ConsulKV +} + +func NewRegistry( + nodeID string, + peerNodeIDs []string, + consulKV infra.ConsulKV, +) *registry { + return ®istry{ + consulKV: consulKV, + nodeID: nodeID, + peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), + readyMap: make(map[string]bool), + readyCount: 1, // self + } +} + +func getPeerIDsExceptSelf(nodeID string, peerNodeIDs []string) []string { + peerIDs := make([]string, 0, len(peerNodeIDs)) + for _, peerID := range peerNodeIDs { + if peerID != nodeID { + peerIDs = append(peerIDs, peerID) + } + } + return peerIDs +} + +func (r *registry) readyKey(nodeID string) string { + return fmt.Sprintf("ready/%s", nodeID) +} + +func (r *registry) registerReadyPairs(peerIDs []string) { + for _, peerID := range peerIDs { + ready, exist := r.readyMap[peerID] + if !exist { + atomic.AddInt64(&r.readyCount, 1) + logger.Info("Register", "peerID", peerID) + } else if !ready { + atomic.AddInt64(&r.readyCount, 1) + logger.Info("Reconnecting...", "peerID", peerID) + } + + r.readyMap[peerID] = true + } + + if len(peerIDs) == len(r.peerNodeIDs) && !r.ready { + r.mu.Lock() + r.ready = true + r.mu.Unlock() + logger.Info("ALL PEERS ARE READY! Starting to accept MPC requests") + } + +} + +// Ready is called by the node when it complete generate preparams and starting to accept +// incoming requests +func (r *registry) Ready() error { + k := r.readyKey(r.nodeID) + + kv := &api.KVPair{ + Key: k, + Value: []byte("true"), + } + + _, err := r.consulKV.Put(kv, nil) + if err != nil { + return fmt.Errorf("Put ready key failed: %w", err) + } + + return nil +} + +func (r *registry) WatchPeersReady() { + ticker := time.NewTicker(ReadinessCheckPeriod) + go r.logReadyStatus() + // first tick is executed immediately + for ; true; <-ticker.C { + pairs, _, err := r.consulKV.List("ready/", nil) + if err != nil { + logger.Error("List ready keys failed", err) + } + + newReadyPeerIDs := r.getReadyPeersFromKVStore(pairs) + if len(newReadyPeerIDs) != len(r.peerNodeIDs) { + r.mu.Lock() + r.ready = false + r.mu.Unlock() + + var readyPeerIDs []string + for peerID, isReady := range r.readyMap { + if isReady { + readyPeerIDs = append(readyPeerIDs, peerID) + } + } + + disconnecteds, _ := lo.Difference(readyPeerIDs, newReadyPeerIDs) + if len(disconnecteds) > 0 { + for _, peerID := range disconnecteds { + logger.Warn("Peer disconnected!", "peerID", peerID) + r.readyMap[peerID] = false + atomic.AddInt64(&r.readyCount, -1) + } + + } + + } + r.registerReadyPairs(newReadyPeerIDs) + } + +} + +func (r *registry) logReadyStatus() { + for { + time.Sleep(5 * time.Second) + if !r.ArePeersReady() { + logger.Info("Peers are not ready yet", "ready", r.GetReadyPeersCount(), "expected", len(r.peerNodeIDs)+1) + } + } +} + +func (r *registry) GetReadyPeersCount() int64 { + return atomic.LoadInt64(&r.readyCount) +} + +func (r *registry) GetReadyPeersIncludeSelf() []string { + var peerIDs []string + for peerID, isReady := range r.readyMap { + if isReady { + peerIDs = append(peerIDs, peerID) + } + } + + peerIDs = append(peerIDs, r.nodeID) // append self + return peerIDs +} + +func (r *registry) getReadyPeersFromKVStore(kvPairs api.KVPairs) []string { + var peers []string + for _, k := range kvPairs { + var peerNodeID string + _, err := fmt.Sscanf(k.Key, "ready/%s", &peerNodeID) + if err != nil { + logger.Error("Parse ready key failed", err) + } + if peerNodeID == r.nodeID { + continue + } + + peers = append(peers, peerNodeID) + } + + return peers +} + +func (r *registry) ArePeersReady() bool { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.ready +} + +func (r *registry) Resign() error { + k := r.readyKey(r.nodeID) + + _, err := r.consulKV.Delete(k, nil) + if err != nil { + return fmt.Errorf("Delete ready key failed: %w", err) + } + + return nil +} diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 00b305b..bd1b4b8 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -3,31 +3,58 @@ package party import ( "context" "encoding/json" + "math/big" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/types" ) type party struct { walletID string + threshold int localParty tss.Party partyID *tss.PartyID partyIDs []*tss.PartyID - threshold int + inCh chan types.TssMessage + outCh chan tss.Message errCh chan error } type PartyInterface interface { + StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) + StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) + StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) + PartyID() *tss.PartyID - GetOutCh() chan tss.Message - UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) - GetErrCh() chan error + Party() tss.Party + InCh() chan types.TssMessage + OutCh() chan tss.Message + ErrCh() chan error } func NewParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, errCh chan error) *party { - return &party{walletID, nil, partyID, partyIDs, threshold, errCh} + inCh := make(chan types.TssMessage, 1000) + outCh := make(chan tss.Message, 1000) + return &party{walletID, threshold, nil, partyID, partyIDs, inCh, outCh, errCh} +} + +func (p *party) PartyID() *tss.PartyID { + return p.partyID +} + +func (p *party) Party() tss.Party { + return p.localParty } -func (p *party) GetErrCh() chan error { +func (p *party) InCh() chan types.TssMessage { + return p.inCh +} + +func (p *party) OutCh() chan tss.Message { + return p.outCh +} + +func (p *party) ErrCh() chan error { return p.errCh } @@ -35,7 +62,7 @@ func (p *party) GetErrCh() chan error { func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, send func(tss.Message), endCh <-chan T, finish func([]byte)) { go func() { if err := party.Start(); err != nil { - s.GetErrCh() <- err + s.ErrCh() <- err } }() @@ -43,12 +70,18 @@ func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, sen select { case <-ctx.Done(): return - case msg := <-s.GetOutCh(): + case in := <-s.InCh(): + ok, err := party.UpdateFromBytes(in.MsgBytes, in.From, in.IsBroadcast) + if !ok || err != nil { + s.ErrCh() <- err + return + } + case msg := <-s.OutCh(): send(msg) case result := <-endCh: bz, err := json.Marshal(result) if err != nil { - s.GetErrCh() <- err + s.ErrCh() <- err return } finish(bz) diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index e45b0b7..acce81e 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -21,33 +21,15 @@ type ECDSAParty struct { } func NewECDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, - prepareParams keygen.LocalPreParams, reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData, errCh chan error) *ECDSAParty { + prepareParams keygen.LocalPreParams, reshareParams *tss.ReSharingParameters, errCh chan error) *ECDSAParty { return &ECDSAParty{ party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), prepareParams: prepareParams, reshareParams: reshareParams, - saveData: saveData, outCh: make(chan tss.Message, 1000), } } -func (s *ECDSAParty) PartyID() *tss.PartyID { - return s.partyID -} - -func (s *ECDSAParty) GetOutCh() chan tss.Message { - return s.outCh -} - -func (s *ECDSAParty) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { - ok, err := s.localParty.UpdateFromBytes(msgBytes, from, isBroadcast) - if err != nil { - s.GetErrCh() <- err - return false, err - } - return ok, nil -} - func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { end := make(chan *keygen.LocalPartySaveData) params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) @@ -57,7 +39,7 @@ func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), fi func (s *ECDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { - s.GetErrCh() <- errors.New("save data is nil") + s.ErrCh() <- errors.New("save data is nil") return } end := make(chan *common.SignatureData) @@ -69,7 +51,7 @@ func (s *ECDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(t func (s *ECDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { - s.GetErrCh() <- errors.New("save data is nil") + s.ErrCh() <- errors.New("save data is nil") return } end := make(chan *keygen.LocalPartySaveData) diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index 743071d..07e71fc 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -33,10 +33,6 @@ func (s *EDDSAParty) PartyID() *tss.PartyID { return s.partyID } -func (s *EDDSAParty) GetOutCh() chan tss.Message { - return s.outCh -} - func (s *EDDSAParty) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { ok, err := s.localParty.UpdateFromBytes(msgBytes, from, isBroadcast) if err != nil { @@ -54,7 +50,7 @@ func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), fi func (s *EDDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { - s.GetErrCh() <- errors.New("save data is nil") + s.ErrCh() <- errors.New("save data is nil") return } end := make(chan *common.SignatureData) @@ -66,7 +62,7 @@ func (s *EDDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(t func (s *EDDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { - s.GetErrCh() <- errors.New("save data is nil") + s.ErrCh() <- errors.New("save data is nil") return } end := make(chan *keygen.LocalPartySaveData) diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index f218acd..50e0168 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -9,10 +9,10 @@ import ( "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/party" "github.com/fystack/mpcium/pkg/types" + "github.com/nats-io/nats.go" ) type Curve string @@ -50,7 +50,9 @@ type session struct { topicComposer *TopicComposer composeKey KeyComposerFn - mu sync.Mutex + + mu sync.Mutex + errCh chan error } func NewSession( @@ -62,54 +64,45 @@ func NewSession( identityStore identity.Store, kvstore kvstore.KVStore, ) *session { + errCh := make(chan error, 1000) return &session{ walletID: walletID, pubSub: pubSub, direct: direct, identityStore: identityStore, kvstore: kvstore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf(KeygenBroadcastTopic, walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf(KeygenDirectTopic, nodeID, walletID) - }, - }, - composeKey: func(id string) string { - return fmt.Sprintf("%s/%s", purpose, id) - }, + errCh: errCh, } } -func (s *session) SetParty(party party.PartyInterface) { - s.party = party +func (s *session) ErrCh() chan error { + return s.errCh } func (s *session) Send(msg tss.Message) { data, routing, err := msg.WireBytes() if err != nil { - logger.Error("Failed to wire bytes", err) + s.errCh <- fmt.Errorf("Failed to wire bytes: %w", err) return } tssMsg := types.NewTssMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To) signature, err := s.identityStore.SignMessage(&tssMsg) if err != nil { - logger.Error("Failed to sign message", err) + s.errCh <- fmt.Errorf("Failed to sign message: %w", err) return } tssMsg.Signature = signature msgBytes, err := types.MarshalTssMessage(&tssMsg) if err != nil { - logger.Error("Failed to marshal message", err) + s.errCh <- fmt.Errorf("Failed to marshal message: %w", err) return } if routing.IsBroadcast && len(routing.To) == 0 { err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msgBytes) if err != nil { - logger.Error("Failed to publish message", err) + s.errCh <- fmt.Errorf("Failed to publish message: %w", err) return } } else { @@ -118,23 +111,23 @@ func (s *session) Send(msg tss.Message) { topic := s.topicComposer.ComposeDirectTopic(nodeID) err := s.direct.Send(topic, msgBytes) if err != nil { - logger.Error("Failed to send message", err) + s.errCh <- fmt.Errorf("Failed to send message: %w", err) return } } } } -func (s *session) Receive(rawMsg []byte) { +func (s *session) receive(rawMsg []byte) { msg, err := types.UnmarshalTssMessage(rawMsg) if err != nil { - logger.Error("Failed to unmarshal message", err) + s.errCh <- fmt.Errorf("Failed to unmarshal message: %w", err) return } err = s.identityStore.VerifyMessage(msg) if err != nil { - logger.Error("Failed to verify message", err) + s.errCh <- fmt.Errorf("Failed to verify message: %w", err) return } @@ -149,45 +142,41 @@ func (s *session) Receive(rawMsg []byte) { if isBroadcast || isToSelf { s.mu.Lock() defer s.mu.Unlock() - ok, err := s.party.UpdateFromBytes(msg.MsgBytes, msg.From, msg.IsBroadcast) - if !ok || err != nil { - logger.Error("Failed to update party", err) - return - } + s.party.InCh() <- *msg } } -// func (s *session) Listen() { -// broadcast := func() { -// sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { -// msg := natMsg.Data -// s.receiveTssMessage(msg) -// }) +func (s *session) Listen() { + broadcast := func() { + sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { + msg := natMsg.Data + s.receive(msg) + }) -// if err != nil { -// s.ErrCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) -// return -// } + if err != nil { + s.errCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) + return + } -// s.broadcastSub = sub -// } + s.broadcastSub = sub + } -// direct := func() { -// sub, err := s.direct.Listen(s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), func(msg []byte) { -// s.receiveTssMessage(msg) -// }) + direct := func() { + sub, err := s.direct.Listen(s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), func(msg []byte) { + s.receive(msg) + }) -// if err != nil { -// s.ErrCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), err) -// return -// } + if err != nil { + s.errCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), err) + return + } -// s.directSub = sub -// } + s.directSub = sub + } -// go broadcast() -// go direct() -// } + go broadcast() + go direct() +} func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared bool, data []byte) (err error) { @@ -199,13 +188,13 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) if err != nil { - logger.Error("Failed to save keyinfo", err, "walletID", s.walletID) + s.errCh <- fmt.Errorf("Failed to save keyinfo: %w", err) return } err = s.kvstore.Put(s.composeKey(s.walletID), data) if err != nil { - logger.Error("Failed to save key", err, "walletID", s.walletID) + s.errCh <- fmt.Errorf("Failed to save key: %w", err) return } diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index 88117ad..963fa56 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -1,22 +1,28 @@ package session import ( + "context" + + "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/party" ) -type EcdsaSession struct { +type ECDSASession struct { *session } -func NewECDSASession(walletID string, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore) *ECDSASession { +func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, prepareParams *keygen.LocalPreParams, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore) *ECDSASession { s := NewSession(CurveSecp256k1, PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore) - party := party.NewECDSAParty(walletID, s.PartyID(), s.PartyIDs(), s.threshold, s.prepareParams, s.reshareParams, s.saveData) - s.SetParty(party) + s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, *prepareParams, nil, s.errCh) return &ECDSASession{ session: s, - party: party, } } + +func (s *ECDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { + s.party.StartKeygen(ctx, send, finish) +} diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go index e69de29..ab87616 100644 --- a/pkg/mpc/session/eddsa.go +++ b/pkg/mpc/session/eddsa.go @@ -0,0 +1 @@ +package session diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index b49b768..b014b1a 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -6,7 +6,7 @@ type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" - KeyTypeEd25519 = "ed25519" + KeyTypeEd25519 KeyType = "ed25519" ) // InitiatorMessage is anything that carries a payload to verify and its signature. From 78fb104a792005986921a10826222acc30286efc Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 11 Jun 2025 16:23:02 +0700 Subject: [PATCH 05/34] Refactor signing and event consumer logic for improved clarity This commit refines the signing process and event consumer logic within the MPC package. Key changes include: - Updated the main signing example to use a local NATS URL and modified the wallet ID and network code for consistency. - Enhanced the event consumer to properly handle signing events, including improved error handling and logging. - Introduced a session interface for better management of signing and key generation sessions, allowing for more modular and maintainable code. - Refactored session methods to streamline the process of saving keys and handling errors, improving overall clarity and functionality. These changes aim to enhance the robustness and maintainability of the MPC signing process. --- examples/sign/main.go | 15 +- pkg/eventconsumer/event_consumer.go | 306 +++++++++++++--------------- pkg/mpc/node/node.go | 43 +++- pkg/mpc/party/base.go | 34 ++-- pkg/mpc/party/ecdsa.go | 32 ++- pkg/mpc/party/eddsa.go | 14 +- pkg/mpc/session/base.go | 68 +++++-- pkg/mpc/session/ecdsa.go | 51 ++++- 8 files changed, 324 insertions(+), 239 deletions(-) diff --git a/examples/sign/main.go b/examples/sign/main.go index d5e2410..a9ef4f9 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -7,21 +7,20 @@ import ( "syscall" "github.com/fystack/mpcium/pkg/client" - "github.com/fystack/mpcium/pkg/config" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" "github.com/google/uuid" "github.com/nats-io/nats.go" - "github.com/spf13/viper" ) func main() { const environment = "dev" - config.InitViperConfig() + // config.InitViperConfig() logger.Init(environment, true) - natsURL := viper.GetString("nats.url") + // natsURL := viper.GetString("nats.url") + natsURL := "nats://localhost:4222" natsConn, err := nats.Connect(natsURL) if err != nil { logger.Fatal("Failed to connect to NATS", err) @@ -31,7 +30,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "./event_initiator.key", + KeyPath: "/home/viet/Documents/other/mpcium/event_initiator.key", }) // 2) Once wallet exists, immediately fire a SignTransaction @@ -39,9 +38,9 @@ func main() { dummyTx := []byte("deadbeef") // replace with real transaction bytes txMsg := &types.SignTxMessage{ - KeyType: types.KeyTypeEd25519, - WalletID: "77dd7e23-9d5c-4ff1-8759-f119d1b19b36", - NetworkInternalCode: "solana-devnet", + KeyType: types.KeyTypeSecp256k1, + WalletID: "0bf609ad-63ed-4713-a673-e09d43f316d3", + NetworkInternalCode: "sepolia-devnet", TxID: txID, Tx: dummyTx, } diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index f3e3352..e89e3b2 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -5,12 +5,15 @@ import ( "encoding/json" "fmt" "log" + "math/big" "sync" "time" + "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc" "github.com/fystack/mpcium/pkg/mpc/node" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" @@ -84,10 +87,10 @@ func (ec *eventConsumer) Run() { log.Fatal("Failed to consume key reconstruction event", err) } - // err = ec.consumeTxSigningEvent() - // if err != nil { - // log.Fatal("Failed to consume tx signing event", err) - // } + err = ec.consumeTxSigningEvent() + if err != nil { + log.Fatal("Failed to consume tx signing event", err) + } // err = ec.consumeResharingEvent() // if err != nil { @@ -121,19 +124,43 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { return } - go session.StartKeygen(context.Background(), session.Send, func(data []byte) { - logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) - }) - go session.Listen() - logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) + // Start listening for messages first + go session.Listen(ec.node.ID()) + + // Start the key generation process go func() { - for { - select { - case err := <-session.ErrCh(): - if err != nil { - logger.Error("Key generation session error", err) - } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + session.StartKeygen(ctx, session.Send, func(data []byte) { + cancel() + session.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) + + successEvent := &mpc.KeygenSuccessEvent{ + WalletID: walletID, + ECDSAPubKey: session.GetPublicKey(data), + } + + successEventBytes, err := json.Marshal(successEvent) + if err != nil { + logger.Error("Failed to marshal keygen success event", err) + return } + + err = ec.genKeySucecssQueue.Enqueue(fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), + }) + if err != nil { + logger.Error("Failed to publish key generation success message", err) + return + } + logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID, "data", len(data)) + }) + }() + + // Handle errors from the session + go func() { + for err := range session.ErrCh() { + logger.Error("Error from session", err) + return } }() }) @@ -145,174 +172,115 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { return nil } -// func (ec *eventConsumer) consumeTxSigningEvent() error { -// sub, err := ec.pubsub.Subscribe(MPCSignEvent, func(natMsg *nats.Msg) { -// raw := natMsg.Data -// var msg types.SignTxMessage -// err := json.Unmarshal(raw, &msg) -// if err != nil { -// logger.Error("Failed to unmarshal signing message", err) -// return -// } - -// err = ec.identityStore.VerifyInitiatorMessage(&msg) -// if err != nil { -// logger.Error("Failed to verify initiator message", err) -// return -// } +func (ec *eventConsumer) consumeTxSigningEvent() error { + sub, err := ec.pubsub.Subscribe(MPCSignEvent, func(natMsg *nats.Msg) { + raw := natMsg.Data + var msg types.SignTxMessage + err := json.Unmarshal(raw, &msg) + if err != nil { + logger.Error("Failed to unmarshal signing message", err) + return + } -// logger.Info( -// "Received signing event", -// "waleltID", -// msg.WalletID, -// "type", -// msg.KeyType, -// "tx", -// msg.TxID, -// "Id", -// ec.node.ID(), -// ) - -// // Check for duplicate session and track if new -// if ec.checkDuplicateSession(msg.WalletID, msg.TxID) { -// natMsg.Term() -// return -// } + err = ec.identityStore.VerifyInitiatorMessage(&msg) + if err != nil { + logger.Error("Failed to verify initiator message", err) + return + } -// var session mpc.ISigningSession -// switch msg.KeyType { -// case types.KeyTypeSecp256k1: -// session, err = ec.node.CreateSigningSession( -// msg.WalletID, -// msg.TxID, -// msg.NetworkInternalCode, -// ec.mpcThreshold, -// ec.signingResultQueue, -// ) -// case types.KeyTypeEd25519: -// session, err = ec.node.CreateEDDSASigningSession( -// msg.WalletID, -// msg.TxID, -// msg.NetworkInternalCode, -// ec.mpcThreshold, -// ec.signingResultQueue, -// ) + logger.Info("Received signing event", "msg", msg) -// } + // Check for duplicate session and track if new + if ec.checkDuplicateSession(msg.WalletID, msg.TxID) { + natMsg.Term() + return + } -// if err != nil { -// ec.handleSigningSessionError( -// msg.WalletID, -// msg.TxID, -// msg.NetworkInternalCode, -// err, -// "Failed to create signing session", -// natMsg, -// ) -// return -// } + signingSession, err := ec.node.CreateSigningSession( + msg.KeyType, + msg.WalletID, + msg.TxID, + ec.mpcThreshold, + ec.signingResultQueue, + ) -// txBigInt := new(big.Int).SetBytes(msg.Tx) -// err = session.Init(txBigInt) -// if err != nil { -// if errors.Is(err, mpc.ErrNotEnoughParticipants) { -// logger.Info("RETRY LATER: Not enough participants to sign") -// //Return for retry later -// return -// } -// ec.handleSigningSessionError( -// msg.WalletID, -// msg.TxID, -// msg.NetworkInternalCode, -// err, -// "Failed to init signing session", -// natMsg, -// ) -// return -// } + if err != nil { + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to create signing session", + natMsg, + ) + return + } -// // Mark session as already processed -// ec.addSession(msg.WalletID, msg.TxID) + go signingSession.Listen(ec.node.ID()) -// ctx, done := context.WithCancel(context.Background()) -// go func() { -// for { -// select { -// case <-ctx.Done(): -// return -// case err := <-session.ErrChan(): -// if err != nil { -// ec.handleSigningSessionError( -// msg.WalletID, -// msg.TxID, -// msg.NetworkInternalCode, -// err, -// "Failed to sign tx", -// natMsg, -// ) -// return -// } -// } -// } -// }() + txBigInt := new(big.Int).SetBytes(msg.Tx) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + signingSession.StartSigning(ctx, txBigInt, signingSession.Send, func(data []byte) { + cancel() + logger.Info("Signing completed", "walletID", msg.WalletID, "txID", msg.TxID, "data", len(data)) + }) + }() -// session.ListenToIncomingMessageAsync() -// // TODO: use consul distributed lock here, only sign after all nodes has already completed listing to incoming message async -// // The purpose of the sleep is to be ensuring that the node has properly set up its message listeners -// // before it starts the signing process. If the signing process starts sending messages before other nodes -// // have set up their listeners, those messages might be missed, potentially causing the signing process to fail. -// // One solution: -// // The messaging includes mechanisms for direct point-to-point communication (in point2point.go). -// // The nodes could explicitly coordinate through request-response patterns before starting signing -// time.Sleep(1 * time.Second) + // Mark session as already processed + ec.addSession(msg.WalletID, msg.TxID) -// onSuccess := func(data []byte) { -// done() -// if natMsg.Reply != "" { -// err = ec.pubsub.Publish(natMsg.Reply, data) -// if err != nil { -// logger.Error("Failed to publish reply", err) -// } else { -// logger.Info("Reply to the original message", "reply", natMsg.Reply) -// } -// } -// } -// go session.Sign(onSuccess) -// }) + go func() { + for err := range signingSession.ErrCh() { + logger.Error("Error from session", err) + if err != nil { + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to sign tx", + natMsg, + ) + return + } + } + }() + }) -// ec.signingSub = sub -// if err != nil { -// return err -// } + ec.signingSub = sub + if err != nil { + return err + } -// return nil -// } + return nil +} -// func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkInternalCode string, err error, errMsg string, natMsg *nats.Msg) { -// logger.Error("Signing session error", err, "walletID", walletID, "txID", txID, "error", errMsg) -// signingResult := event.SigningResultEvent{ -// ResultType: event.SigningResultTypeError, -// NetworkInternalCode: NetworkInternalCode, -// WalletID: walletID, -// TxID: txID, -// ErrorReason: errMsg, -// } +func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkInternalCode string, err error, errMsg string, natMsg *nats.Msg) { + logger.Error("signing session error", err, "walletID", walletID, "txID", txID, "error", errMsg) + signingResult := event.SigningResultEvent{ + ResultType: event.SigningResultTypeError, + NetworkInternalCode: NetworkInternalCode, + WalletID: walletID, + TxID: txID, + ErrorReason: errMsg, + } -// signingResultBytes, err := json.Marshal(signingResult) -// if err != nil { -// logger.Error("Failed to marshal signing result event", err) -// return -// } + signingResultBytes, err := json.Marshal(signingResult) + if err != nil { + logger.Error("failed to marshal signing result event", err) + return + } -// natMsg.Ack() -// err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ -// IdempotententKey: txID, -// }) -// if err != nil { -// logger.Error("Failed to publish signing result event", err) -// return -// } -// } + natMsg.Ack() + err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: txID, + }) + if err != nil { + logger.Error("Failed to publish signing result event", err) + return + } +} // func (ec *eventConsumer) consumeResharingEvent() error { // sub, err := ec.pubsub.Subscribe(MPCResharingEvent, func(natMsg *nats.Msg) { diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index c5467f5..f468276 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -46,14 +46,17 @@ func NewNode(nodeID string, peerIDs []string, pubSub messaging.PubSub, direct me } } -func (n *Node) CreateKeygenSession(_ types.KeyType, walletID string, threshold int, successQueue messaging.MessageQueue) (*session.ECDSASession, error) { +func (n *Node) ID() string { + return n.nodeID +} + +func (n *Node) CreateKeygenSession(_ types.KeyType, walletID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { return nil, fmt.Errorf("not enough peers to create gen session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) } readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() selfPartyID, allPartyIDs := n.generatePartyIDs("keygen", readyPeerIDs) - preparams, err := n.getECDSAPreParams(false) if err != nil { return nil, fmt.Errorf("failed to get preparams: %w", err) @@ -65,16 +68,50 @@ func (n *Node) CreateKeygenSession(_ types.KeyType, walletID string, threshold i selfPartyID, allPartyIDs, threshold, - preparams, + *preparams, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + ) + + return ecdsaSession, nil +} + +func (n *Node) CreateSigningSession(_ types.KeyType, walletID string, txID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { + if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { + return nil, fmt.Errorf("not enough peers to create gen session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) + } + + readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() + selfPartyID, allPartyIDs := n.generatePartyIDs("keygen", readyPeerIDs) + ecdsaSession := session.NewECDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + keygen.LocalPreParams{}, n.pubSub, n.direct, n.identityStore, n.kvstore, + n.keyinfoStore, ) + saveData, err := ecdsaSession.GetSaveData() + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + + ecdsaSession.SetSaveData(saveData) return ecdsaSession, nil } +func (n *Node) GetReadyPeersIncludeSelf() []string { + return n.peerRegistry.GetReadyPeersIncludeSelf() +} + func (n *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { var selfPartyID *tss.PartyID partyIDs := make([]*tss.PartyID, len(readyPeerIDs)) diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index bd1b4b8..9da9d3f 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -6,18 +6,18 @@ import ( "math/big" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" ) type party struct { - walletID string - threshold int - localParty tss.Party - partyID *tss.PartyID - partyIDs []*tss.PartyID - inCh chan types.TssMessage - outCh chan tss.Message - errCh chan error + walletID string + threshold int + partyID *tss.PartyID + partyIDs []*tss.PartyID + inCh chan types.TssMessage + outCh chan tss.Message + errCh chan error } type PartyInterface interface { @@ -26,7 +26,7 @@ type PartyInterface interface { StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) PartyID() *tss.PartyID - Party() tss.Party + SetSaveData(saveData []byte) InCh() chan types.TssMessage OutCh() chan tss.Message ErrCh() chan error @@ -35,17 +35,13 @@ type PartyInterface interface { func NewParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, errCh chan error) *party { inCh := make(chan types.TssMessage, 1000) outCh := make(chan tss.Message, 1000) - return &party{walletID, threshold, nil, partyID, partyIDs, inCh, outCh, errCh} + return &party{walletID, threshold, partyID, partyIDs, inCh, outCh, errCh} } func (p *party) PartyID() *tss.PartyID { return p.partyID } -func (p *party) Party() tss.Party { - return p.localParty -} - func (p *party) InCh() chan types.TssMessage { return p.inCh } @@ -59,13 +55,17 @@ func (p *party) ErrCh() chan error { } // runParty handles the common party execution loop -func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, send func(tss.Message), endCh <-chan T, finish func([]byte)) { +func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, send func(tss.Message), endCh chan T, finish func([]byte)) { + // Start the party in a goroutine go func() { + logger.Info("Starting party", "partyID", s.PartyID().String()) if err := party.Start(); err != nil { s.ErrCh() <- err + return } }() + // Main message handling loop for { select { case <-ctx.Done(): @@ -76,8 +76,8 @@ func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, sen s.ErrCh() <- err return } - case msg := <-s.OutCh(): - send(msg) + case out := <-s.OutCh(): + send(out) case result := <-endCh: bz, err := json.Marshal(result) if err != nil { diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index acce81e..4729e05 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -2,7 +2,9 @@ package party import ( "context" + "encoding/json" "errors" + "fmt" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -14,26 +16,38 @@ import ( type ECDSAParty struct { party - prepareParams keygen.LocalPreParams + preParams keygen.LocalPreParams reshareParams *tss.ReSharingParameters saveData *keygen.LocalPartySaveData - outCh chan tss.Message } func NewECDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, - prepareParams keygen.LocalPreParams, reshareParams *tss.ReSharingParameters, errCh chan error) *ECDSAParty { + preParams keygen.LocalPreParams, reshareParams *tss.ReSharingParameters, errCh chan error) *ECDSAParty { return &ECDSAParty{ party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), - prepareParams: prepareParams, + preParams: preParams, reshareParams: reshareParams, - outCh: make(chan tss.Message, 1000), } } +func (s *ECDSAParty) SetSaveData(saveData []byte) { + localSaveData := &keygen.LocalPartySaveData{} + err := json.Unmarshal(saveData, localSaveData) + if err != nil { + s.ErrCh() <- fmt.Errorf("failed deserializing shares: %w", err) + return + } + localSaveData.ECDSAPub.SetCurve(tss.S256()) + for _, xj := range localSaveData.BigXj { + xj.SetCurve(tss.S256()) + } + s.saveData = localSaveData +} + func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { - end := make(chan *keygen.LocalPartySaveData) + end := make(chan *keygen.LocalPartySaveData, 1) params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) - party := keygen.NewLocalParty(params, s.outCh, end, s.prepareParams) + party := keygen.NewLocalParty(params, s.outCh, end, s.preParams) runParty(s, ctx, party, send, end, finish) } @@ -42,7 +56,7 @@ func (s *ECDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(t s.ErrCh() <- errors.New("save data is nil") return } - end := make(chan *common.SignatureData) + end := make(chan *common.SignatureData, 1) params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := signing.NewLocalParty(msg, params, *s.saveData, s.outCh, end) runParty(s, ctx, party, send, end, finish) @@ -54,7 +68,7 @@ func (s *ECDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs s.ErrCh() <- errors.New("save data is nil") return } - end := make(chan *keygen.LocalPartySaveData) + end := make(chan *keygen.LocalPartySaveData, 1) params := tss.NewReSharingParameters( tss.S256(), tss.NewPeerContext(oldPartyIDs), diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index 07e71fc..8d06f51 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -16,7 +16,6 @@ type EDDSAParty struct { party reshareParams *tss.ReSharingParameters saveData *keygen.LocalPartySaveData - outCh chan tss.Message } func NewEDDASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, @@ -25,20 +24,11 @@ func NewEDDASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.Party party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), reshareParams: reshareParams, saveData: saveData, - outCh: make(chan tss.Message, 1000), } } -func (s *EDDSAParty) PartyID() *tss.PartyID { - return s.partyID -} - -func (s *EDDSAParty) UpdateFromBytes(msgBytes []byte, from *tss.PartyID, isBroadcast bool) (bool, error) { - ok, err := s.localParty.UpdateFromBytes(msgBytes, from, isBroadcast) - if err != nil { - return false, err - } - return ok, nil +func (s *EDDSAParty) SetSaveData(saveData []byte) { + // s.saveData = saveData.(*keygen.LocalPartySaveData) } func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 50e0168..b88e9ea 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -1,7 +1,9 @@ package session import ( + "context" "fmt" + "math/big" "slices" "sync" @@ -9,6 +11,7 @@ import ( "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/party" "github.com/fystack/mpcium/pkg/types" @@ -35,6 +38,17 @@ type TopicComposer struct { type KeyComposerFn func(id string) string +type Session interface { + StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) + StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) + GetSaveData() ([]byte, error) + GetPublicKey(data []byte) []byte + Send(msg tss.Message) + Listen(nodeID string) + SaveKey(participantPeerIDs []string, threshold int, isReshared bool, data []byte) (err error) + ErrCh() chan error +} + type session struct { walletID string party party.PartyInterface @@ -63,6 +77,7 @@ func NewSession( direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, ) *session { errCh := make(chan error, 1000) return &session{ @@ -71,6 +86,7 @@ func NewSession( direct: direct, identityStore: identityStore, kvstore: kvstore, + keyinfoStore: keyinfoStore, errCh: errCh, } } @@ -82,27 +98,29 @@ func (s *session) ErrCh() chan error { func (s *session) Send(msg tss.Message) { data, routing, err := msg.WireBytes() if err != nil { - s.errCh <- fmt.Errorf("Failed to wire bytes: %w", err) + s.errCh <- fmt.Errorf("failed to wire bytes: %w", err) return } tssMsg := types.NewTssMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To) signature, err := s.identityStore.SignMessage(&tssMsg) if err != nil { - s.errCh <- fmt.Errorf("Failed to sign message: %w", err) + s.errCh <- fmt.Errorf("failed to sign message: %w", err) return } tssMsg.Signature = signature msgBytes, err := types.MarshalTssMessage(&tssMsg) if err != nil { - s.errCh <- fmt.Errorf("Failed to marshal message: %w", err) + s.errCh <- fmt.Errorf("failed to marshal message: %w", err) return } + logger.Debug("Sending message", "from", routing.From, "to", routing.To, "isBroadcast", routing.IsBroadcast) + if routing.IsBroadcast && len(routing.To) == 0 { err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msgBytes) if err != nil { - s.errCh <- fmt.Errorf("Failed to publish message: %w", err) + s.errCh <- fmt.Errorf("failed to publish message: %w", err) return } } else { @@ -111,7 +129,7 @@ func (s *session) Send(msg tss.Message) { topic := s.topicComposer.ComposeDirectTopic(nodeID) err := s.direct.Send(topic, msgBytes) if err != nil { - s.errCh <- fmt.Errorf("Failed to send message: %w", err) + s.errCh <- fmt.Errorf("failed to send message: %w", err) return } } @@ -121,13 +139,18 @@ func (s *session) Send(msg tss.Message) { func (s *session) receive(rawMsg []byte) { msg, err := types.UnmarshalTssMessage(rawMsg) if err != nil { - s.errCh <- fmt.Errorf("Failed to unmarshal message: %w", err) + s.errCh <- fmt.Errorf("failed to unmarshal message: %w", err) return } err = s.identityStore.VerifyMessage(msg) if err != nil { - s.errCh <- fmt.Errorf("Failed to verify message: %w", err) + s.errCh <- fmt.Errorf("failed to verify message: %w", err) + return + } + + // Skip messages from self + if msg.From.String() == s.party.PartyID().String() { return } @@ -140,13 +163,14 @@ func (s *session) receive(rawMsg []byte) { isToSelf := slices.Contains(toIDs, s.party.PartyID().String()) if isBroadcast || isToSelf { + logger.Debug("Received message", "from", msg.From, "to", msg.To, "isBroadcast", msg.IsBroadcast, "isToSelf", isToSelf) s.mu.Lock() defer s.mu.Unlock() s.party.InCh() <- *msg } } -func (s *session) Listen() { +func (s *session) Listen(nodeID string) { broadcast := func() { sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { msg := natMsg.Data @@ -154,7 +178,7 @@ func (s *session) Listen() { }) if err != nil { - s.errCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) + s.errCh <- fmt.Errorf("failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) return } @@ -162,12 +186,12 @@ func (s *session) Listen() { } direct := func() { - sub, err := s.direct.Listen(s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), func(msg []byte) { + sub, err := s.direct.Listen(s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, "keygen")), func(msg []byte) { s.receive(msg) }) if err != nil { - s.errCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), err) + s.errCh <- fmt.Errorf("failed to subscribe to direct topic %s: %w", s.topicComposer.ComposeDirectTopic(s.party.PartyID().String()), err) return } @@ -179,24 +203,32 @@ func (s *session) Listen() { } func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared bool, data []byte) (err error) { - keyInfo := keyinfo.KeyInfo{ ParticipantPeerIDs: participantPeerIDs, Threshold: threshold, IsReshared: isReshared, } - - err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) + composeKey := s.composeKey(s.walletID) + err = s.keyinfoStore.Save(composeKey, &keyInfo) if err != nil { - s.errCh <- fmt.Errorf("Failed to save keyinfo: %w", err) + s.errCh <- fmt.Errorf("failed to save keyinfo: %w", err) return } - err = s.kvstore.Put(s.composeKey(s.walletID), data) + err = s.kvstore.Put(composeKey, data) if err != nil { - s.errCh <- fmt.Errorf("Failed to save key: %w", err) + s.errCh <- fmt.Errorf("failed to save key: %w", err) return } - + logger.Info("Saved key", "walletID", s.walletID, "threshold", threshold, "isReshared", isReshared, "data", len(data)) return nil } + +func (s *session) GetSaveData() ([]byte, error) { + composeKey := s.composeKey(s.walletID) + data, err := s.kvstore.Get(composeKey) + if err != nil { + return nil, fmt.Errorf("failed to get key: %w", err) + } + return data, nil +} diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index 963fa56..e8668d2 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -2,10 +2,16 @@ package session import ( "context" + "crypto/ecdsa" + "encoding/json" + "fmt" + "math/big" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/encoding" "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/party" @@ -15,14 +21,53 @@ type ECDSASession struct { *session } -func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, prepareParams *keygen.LocalPreParams, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore) *ECDSASession { - s := NewSession(CurveSecp256k1, PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore) - s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, *prepareParams, nil, s.errCh) +func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, preParams keygen.LocalPreParams, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store) *ECDSASession { + s := NewSession(CurveSecp256k1, PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) + s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, preParams, nil, s.errCh) + s.topicComposer = &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("keygen:broadcast:ecdsa:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("keygen:direct:ecdsa:%s:%s", nodeID, walletID) + }, + } + s.composeKey = func(walletID string) string { + return fmt.Sprintf("ecdsa:%s", walletID) + } return &ECDSASession{ session: s, } } +func (s *ECDSASession) SetSaveData(saveBytes []byte) { + s.party.SetSaveData(saveBytes) +} + func (s *ECDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { s.party.StartKeygen(ctx, send, finish) } + +func (s *ECDSASession) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { + s.party.StartSigning(ctx, msg, send, finish) +} + +func (s *ECDSASession) GetPublicKey(data []byte) []byte { + saveData := &keygen.LocalPartySaveData{} + err := json.Unmarshal(data, saveData) + if err != nil { + return nil + } + + publicKey := saveData.ECDSAPub + pubKey := &ecdsa.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + return nil + } + return pubKeyBytes +} From 7837f90f9ea752b1375221388d52379f0d214c8e Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 11 Jun 2025 16:44:11 +0700 Subject: [PATCH 06/34] Refactor event handling and session management in MPC package This commit introduces significant changes to the event handling and session management within the MPC package. Key updates include: - Replaced references to `mpc` events with a new `event` package, enhancing modularity and clarity. - Updated the `OnWalletCreationResult` and `OnResharingResult` methods to utilize the new event types, improving type safety and consistency. - Enhanced the resharing process by introducing a `KeyType` parameter in the `Resharing` method, allowing for better flexibility in key management. - Removed deprecated ECDSA and EDDSA session files, streamlining the codebase and focusing on the new architecture. These changes aim to improve the maintainability and clarity of the MPC package, facilitating future enhancements and reducing complexity. --- examples/generate/main.go | 4 +- examples/reshare/main.go | 7 +- examples/sign/main.go | 2 +- pkg/client/client.go | 290 +++++++++++------------ pkg/event/event.go | 53 +++++ pkg/event/sign.go | 35 --- pkg/eventconsumer/event_consumer.go | 7 +- pkg/mpc/ecdsa_keygen_session.go | 152 ------------ pkg/mpc/ecdsa_resharing_session.go | 203 ---------------- pkg/mpc/ecdsa_rounds.go | 163 ------------- pkg/mpc/ecdsa_signing_session.go | 204 ---------------- pkg/mpc/eddsa_keygen_session.go | 138 ----------- pkg/mpc/eddsa_resharing_session.go | 176 -------------- pkg/mpc/eddsa_rounds.go | 112 --------- pkg/mpc/eddsa_signing_session.go | 187 --------------- pkg/mpc/key_type.go | 8 - pkg/mpc/node.go | 320 ------------------------- pkg/mpc/node_test.go | 41 ---- pkg/mpc/registry.go | 207 ----------------- pkg/mpc/session.go | 347 ---------------------------- 20 files changed, 192 insertions(+), 2464 deletions(-) create mode 100644 pkg/event/event.go delete mode 100644 pkg/mpc/ecdsa_keygen_session.go delete mode 100644 pkg/mpc/ecdsa_resharing_session.go delete mode 100644 pkg/mpc/ecdsa_rounds.go delete mode 100644 pkg/mpc/ecdsa_signing_session.go delete mode 100644 pkg/mpc/eddsa_keygen_session.go delete mode 100644 pkg/mpc/eddsa_resharing_session.go delete mode 100644 pkg/mpc/eddsa_rounds.go delete mode 100644 pkg/mpc/eddsa_signing_session.go delete mode 100644 pkg/mpc/key_type.go delete mode 100644 pkg/mpc/node.go delete mode 100644 pkg/mpc/node_test.go delete mode 100644 pkg/mpc/registry.go delete mode 100644 pkg/mpc/session.go diff --git a/examples/generate/main.go b/examples/generate/main.go index ae3e6e9..f4c936c 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -7,8 +7,8 @@ import ( "syscall" "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/mpc" "github.com/google/uuid" "github.com/nats-io/nats.go" ) @@ -30,7 +30,7 @@ func main() { NatsConn: natsConn, KeyPath: "/home/viet/Documents/other/mpcium/event_initiator.key", }) - err = mpcClient.OnWalletCreationResult(func(event mpc.KeygenSuccessEvent) { + err = mpcClient.OnWalletCreationResult(func(event event.KeygenSuccessEvent) { logger.Info("Received wallet creation result", "event", event) }) if err != nil { diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 4f33ea8..9a51baf 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -7,8 +7,9 @@ import ( "syscall" "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" ) @@ -29,7 +30,7 @@ func main() { NatsConn: natsConn, KeyPath: "/home/viet/Documents/other/mpcium/event_initiator.key", }) - err = mpcClient.OnResharingResult(func(event mpc.ResharingSuccessEvent) { + err = mpcClient.OnResharingResult(func(event event.ResharingSuccessEvent) { logger.Info("Received resharing result", "event", event) }) if err != nil { @@ -37,7 +38,7 @@ func main() { } walletID := "892122fd-f2f4-46dc-be25-6fd0b83dff60" - if err := mpcClient.Resharing(walletID, 2); err != nil { + if err := mpcClient.Resharing(walletID, 2, types.KeyTypeSecp256k1); err != nil { logger.Fatal("Resharing failed", err) } logger.Info("Resharing sent, awaiting result...", "walletID", walletID) diff --git a/examples/sign/main.go b/examples/sign/main.go index a9ef4f9..3cf67b5 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -40,7 +40,7 @@ func main() { txMsg := &types.SignTxMessage{ KeyType: types.KeyTypeSecp256k1, WalletID: "0bf609ad-63ed-4713-a673-e09d43f316d3", - NetworkInternalCode: "sepolia-devnet", + NetworkInternalCode: "ethereum-sepolia", TxID: txID, Tx: dummyTx, } diff --git a/pkg/client/client.go b/pkg/client/client.go index 285479d..3d5986f 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "os" - "path/filepath" "strings" "filippo.io/age" @@ -15,25 +14,38 @@ import ( "github.com/fystack/mpcium/pkg/eventconsumer" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/mpc" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" ) const ( - GenerateWalletSuccessTopic = "mpc.mpc_keygen_success.*" // wildcard to listen to all success events - ResharingResultTopic = "mpc.mpc_resharing_success.*" + defaultKeyPath = "./event_initiator.key" + keyFileExt = ".age" + + // NATS stream names + mpcSigningStream = "mpc-signing" + + // NATS queue names + mpcKeygenSuccessQueue = "mpc_keygen_success" + mpcSigningResultQueue = "signing_result" + mpcResharingSuccessQueue = "mpc_resharing_success" + + // NATS subjects + mpcSigningRequestSubject = "mpc.signing_request.*" + mpcKeygenSuccessSubject = "mpc.mpc_keygen_success.*" + mpcSigningResultSubject = "mpc.signing_result.*" + mpcResharingSuccessSubject = "mpc.mpc_resharing_success.*" ) type MPCClient interface { CreateWallet(walletID string) error - OnWalletCreationResult(callback func(event mpc.KeygenSuccessEvent)) error + OnWalletCreationResult(callback func(event.KeygenSuccessEvent)) error SignTransaction(msg *types.SignTxMessage) error - OnSignResult(callback func(event event.SigningResultEvent)) error + OnSignResult(callback func(event.SigningResultEvent)) error Resharing(walletID string, newThreshold int, keyType types.KeyType) error - OnResharingResult(callback func(event mpc.ResharingSuccessEvent)) error + OnResharingResult(callback func(event.ResharingSuccessEvent)) error } type mpcClient struct { @@ -47,141 +59,60 @@ type mpcClient struct { // Options defines configuration options for creating a new MPCClient type Options struct { - // NATS connection - NatsConn *nats.Conn - - // Key path options - KeyPath string // Path to unencrypted key (default: "./event_initiator.key") - - // Encryption options + NatsConn *nats.Conn + KeyPath string // Path to unencrypted key (default: "./event_initiator.key") Encrypted bool // Whether the key is encrypted Password string // Password for encrypted key } // NewMPCClient creates a new MPC client using the provided options. -// It reads the Ed25519 private key from disk and sets up messaging connections. -// If the key is encrypted (.age file), decryption options must be provided in the config. func NewMPCClient(opts Options) MPCClient { - // Set default paths if not provided + // Set default key path if not provided if opts.KeyPath == "" { - opts.KeyPath = filepath.Join(".", "event_initiator.key") + opts.KeyPath = defaultKeyPath } - if strings.HasSuffix(opts.KeyPath, ".age") { + // Auto-detect encryption based on file extension + if strings.HasSuffix(opts.KeyPath, keyFileExt) { opts.Encrypted = true } - var privHexBytes []byte - var err error - - // Check if key file exists - if _, err := os.Stat(opts.KeyPath); err == nil { - if opts.Encrypted { - // Encrypted key exists, try to decrypt it - if opts.Password == "" { - logger.Fatal("Encrypted key found but no decryption option provided", nil) - } - - // Read encrypted file - encryptedBytes, err := os.ReadFile(opts.KeyPath) - if err != nil { - logger.Fatal("Failed to read encrypted private key file", err) - } - - // Decrypt the key using the provided password - privHexBytes, err = decryptPrivateKey(encryptedBytes, opts.Password) - if err != nil { - logger.Fatal("Failed to decrypt private key", err) - } - } else { - // Unencrypted key exists, read it normally - privHexBytes, err = os.ReadFile(opts.KeyPath) - if err != nil { - logger.Fatal("Failed to read private key file", err) - } - } - } else { - logger.Fatal("No private key file found", nil) - } - - privHex := string(privHexBytes) - // Decode private key from hex - privSeed, err := hex.DecodeString(privHex) - if err != nil { - fmt.Println("Failed to decode private key hex:", err) - os.Exit(1) - } - - // Reconstruct full Ed25519 private key from seed - priv := ed25519.NewKeyFromSeed(privSeed) - - // 2) Create the PubSub for both publish & subscribe - signingStream, err := messaging.NewJetStreamPubSub(opts.NatsConn, "mpc-signing", []string{ - "mpc.signing_request.*", - }) - if err != nil { - logger.Fatal("Failed to create JetStream PubSub", err) - } + // Load private key + privKey := loadPrivateKey(opts) + // Initialize messaging components + signingStream := initSigningStream(opts.NatsConn) pubsub := messaging.NewNATSPubSub(opts.NatsConn) - - manager := messaging.NewNATsMessageQueueManager("mpc", []string{ - "mpc.mpc_keygen_success.*", - "mpc.signing_result.*", - "mpc.mpc_resharing_success.*", - }, opts.NatsConn) - - genKeySuccessQueue := manager.NewMessageQueue("mpc_keygen_success") - signResultQueue := manager.NewMessageQueue("signing_result") - resharingResultQueue := manager.NewMessageQueue("mpc_resharing_success") + manager := initMessageQueueManager(opts.NatsConn) return &mpcClient{ signingStream: signingStream, pubsub: pubsub, - genKeySuccessQueue: genKeySuccessQueue, - signResultQueue: signResultQueue, - resharingResultQueue: resharingResultQueue, - privKey: priv, + genKeySuccessQueue: manager.NewMessageQueue(mpcKeygenSuccessQueue), + signResultQueue: manager.NewMessageQueue(mpcSigningResultQueue), + resharingResultQueue: manager.NewMessageQueue(mpcResharingSuccessQueue), + privKey: privKey, } } -// decryptPrivateKey decrypts the encrypted private key using the provided password -func decryptPrivateKey(encryptedData []byte, password string) ([]byte, error) { - // Create an age identity (decryption key) from the password - identity, err := age.NewScryptIdentity(password) - if err != nil { - return nil, fmt.Errorf("failed to create identity from password: %w", err) - } - - // Create a reader from the encrypted data - decrypter, err := age.Decrypt(strings.NewReader(string(encryptedData)), identity) - if err != nil { - return nil, fmt.Errorf("failed to create decrypter: %w", err) - } - - // Read the decrypted data - decryptedData, err := io.ReadAll(decrypter) - if err != nil { - return nil, fmt.Errorf("failed to read decrypted data: %w", err) - } - - return decryptedData, nil +func initMessageQueueManager(natsConn *nats.Conn) *messaging.NATsMessageQueueManager { + return messaging.NewNATsMessageQueueManager("mpc", []string{ + mpcKeygenSuccessSubject, + mpcSigningResultSubject, + mpcResharingSuccessSubject, + }, natsConn) } // CreateWallet generates a GenerateKeyMessage, signs it, and publishes it. func (c *mpcClient) CreateWallet(walletID string) error { - // build the message - msg := &types.GenerateKeyMessage{ - WalletID: walletID, - } - // compute the canonical raw bytes + msg := &types.GenerateKeyMessage{WalletID: walletID} + raw, err := msg.Raw() if err != nil { return fmt.Errorf("CreateWallet: raw payload error: %w", err) } - // sign - msg.Signature = ed25519.Sign(c.privKey, raw) + msg.Signature = ed25519.Sign(c.privKey, raw) bytes, err := json.Marshal(msg) if err != nil { return fmt.Errorf("CreateWallet: marshal error: %w", err) @@ -193,35 +124,17 @@ func (c *mpcClient) CreateWallet(walletID string) error { return nil } -// The callback will be invoked whenever a wallet creation result is received. -func (c *mpcClient) OnWalletCreationResult(callback func(event mpc.KeygenSuccessEvent)) error { - err := c.genKeySuccessQueue.Dequeue(GenerateWalletSuccessTopic, func(msg []byte) error { - var event mpc.KeygenSuccessEvent - err := json.Unmarshal(msg, &event) - if err != nil { - return err - } - callback(event) - return nil - }) - - if err != nil { - return fmt.Errorf("OnWalletCreationResult: subscribe error: %w", err) - } - - return nil +func (c *mpcClient) OnWalletCreationResult(callback func(event.KeygenSuccessEvent)) error { + return c.handleQueueEvent(c.genKeySuccessQueue, event.KeygenSuccessEventTopic, callback) } -// SignTransaction builds a SignTxMessage, signs it, and publishes it. func (c *mpcClient) SignTransaction(msg *types.SignTxMessage) error { - // compute the canonical raw bytes (omitting Signature field) raw, err := msg.Raw() if err != nil { return fmt.Errorf("SignTransaction: raw payload error: %w", err) } - // sign - msg.Signature = ed25519.Sign(c.privKey, raw) + msg.Signature = ed25519.Sign(c.privKey, raw) bytes, err := json.Marshal(msg) if err != nil { return fmt.Errorf("SignTransaction: marshal error: %w", err) @@ -233,22 +146,8 @@ func (c *mpcClient) SignTransaction(msg *types.SignTxMessage) error { return nil } -func (c *mpcClient) OnSignResult(callback func(event event.SigningResultEvent)) error { - err := c.signResultQueue.Dequeue(event.SigningResultCompleteTopic, func(msg []byte) error { - var event event.SigningResultEvent - err := json.Unmarshal(msg, &event) - if err != nil { - return err - } - callback(event) - return nil - }) - - if err != nil { - return fmt.Errorf("OnSignResult: subscribe error: %w", err) - } - - return nil +func (c *mpcClient) OnSignResult(callback func(event.SigningResultEvent)) error { + return c.handleQueueEvent(c.signResultQueue, event.SigningResultCompleteTopic, callback) } func (c *mpcClient) Resharing(walletID string, newThreshold int, keyType types.KeyType) error { @@ -258,14 +157,12 @@ func (c *mpcClient) Resharing(walletID string, newThreshold int, keyType types.K KeyType: keyType, } - // compute the canonical raw bytes raw, err := msg.Raw() if err != nil { return fmt.Errorf("Resharing: raw payload error: %w", err) } - // sign - msg.Signature = ed25519.Sign(c.privKey, raw) + msg.Signature = ed25519.Sign(c.privKey, raw) bytes, err := json.Marshal(msg) if err != nil { return fmt.Errorf("Resharing: marshal error: %w", err) @@ -277,20 +174,91 @@ func (c *mpcClient) Resharing(walletID string, newThreshold int, keyType types.K return nil } -func (c *mpcClient) OnResharingResult(callback func(event mpc.ResharingSuccessEvent)) error { - err := c.resharingResultQueue.Dequeue(ResharingResultTopic, func(msg []byte) error { - var event mpc.ResharingSuccessEvent - err := json.Unmarshal(msg, &event) - if err != nil { - return err +func (c *mpcClient) OnResharingResult(callback func(event.ResharingSuccessEvent)) error { + return c.handleQueueEvent(c.resharingResultQueue, event.ResharingSuccessEventTopic, callback) +} + +// Generic handler for queue events +func (c *mpcClient) handleQueueEvent(queue messaging.MessageQueue, topic string, callback interface{}) error { + return queue.Dequeue(topic, func(msg []byte) error { + switch cb := callback.(type) { + case func(event.KeygenSuccessEvent): + var event event.KeygenSuccessEvent + if err := json.Unmarshal(msg, &event); err != nil { + return err + } + cb(event) + case func(event.SigningResultEvent): + var event event.SigningResultEvent + if err := json.Unmarshal(msg, &event); err != nil { + return err + } + cb(event) + case func(event.ResharingSuccessEvent): + var event event.ResharingSuccessEvent + if err := json.Unmarshal(msg, &event); err != nil { + return err + } + cb(event) + default: + return fmt.Errorf("unsupported callback type") } - callback(event) return nil }) +} + +func loadPrivateKey(opts Options) ed25519.PrivateKey { + if _, err := os.Stat(opts.KeyPath); os.IsNotExist(err) { + logger.Fatal("No private key file found", nil) + } + + var privHexBytes []byte + var err error + + if opts.Encrypted { + if opts.Password == "" { + logger.Fatal("Encrypted key found but no decryption option provided", nil) + } + privHexBytes, err = loadEncryptedKey(opts.KeyPath, opts.Password) + } else { + privHexBytes, err = os.ReadFile(opts.KeyPath) + } if err != nil { - return fmt.Errorf("OnResharingResult: subscribe error: %w", err) + logger.Fatal("Failed to read private key file", err) } - return nil + privSeed, err := hex.DecodeString(string(privHexBytes)) + if err != nil { + logger.Fatal("Failed to decode private key hex", err) + } + + return ed25519.NewKeyFromSeed(privSeed) +} + +func loadEncryptedKey(keyPath, password string) ([]byte, error) { + encryptedBytes, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted key file: %w", err) + } + + identity, err := age.NewScryptIdentity(password) + if err != nil { + return nil, fmt.Errorf("failed to create identity from password: %w", err) + } + + decrypter, err := age.Decrypt(strings.NewReader(string(encryptedBytes)), identity) + if err != nil { + return nil, fmt.Errorf("failed to create decrypter: %w", err) + } + + return io.ReadAll(decrypter) +} + +func initSigningStream(natsConn *nats.Conn) messaging.StreamPubsub { + stream, err := messaging.NewJetStreamPubSub(natsConn, mpcSigningStream, []string{mpcSigningRequestSubject}) + if err != nil { + logger.Fatal("Failed to create JetStream PubSub", err) + } + return stream } diff --git a/pkg/event/event.go b/pkg/event/event.go new file mode 100644 index 0000000..4cb4763 --- /dev/null +++ b/pkg/event/event.go @@ -0,0 +1,53 @@ +package event + +const ( + KeygenSuccessEventTopic = "mpc.keygen.success.*" + ResharingSuccessEventTopic = "mpc.resharing.success.*" +) + +type KeygenSuccessEvent struct { + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` +} + +type SigningResultEvent struct { + ResultType SigningResultType `json:"result_type"` + ErrorReason string `json:"error_reason"` + IsTimeout bool `json:"is_timeout"` + NetworkInternalCode string `json:"network_internal_code"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` + R []byte `json:"r"` + S []byte `json:"s"` + SignatureRecovery []byte `json:"signature_recovery"` + + // TODO: define two separate events for eddsa and ecdsa + Signature []byte `json:"signature"` +} + +type SigningResultSuccessEvent struct { + NetworkInternalCode string `json:"network_internal_code"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` + R []byte `json:"r"` + S []byte `json:"s"` + SignatureRecovery []byte `json:"signature_recovery"` + + // TODO: define two separate events for eddsa and ecdsa + Signature []byte `json:"signature"` +} + +type SigningResultErrorEvent struct { + NetworkInternalCode string `json:"network_internal_code"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` + ErrorReason string `json:"error_reason"` + IsTimeout bool `json:"is_timeout"` +} + +type ResharingSuccessEvent struct { + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` +} diff --git a/pkg/event/sign.go b/pkg/event/sign.go index cb8d53d..bece990 100644 --- a/pkg/event/sign.go +++ b/pkg/event/sign.go @@ -17,38 +17,3 @@ const ( SigningResultTypeSuccess SigningResultTypeError ) - -type SigningResultEvent struct { - ResultType SigningResultType `json:"result_type"` - ErrorReason string `json:"error_reason"` - IsTimeout bool `json:"is_timeout"` - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - R []byte `json:"r"` - S []byte `json:"s"` - SignatureRecovery []byte `json:"signature_recovery"` - - // TODO: define two separate events for eddsa and ecdsa - Signature []byte `json:"signature"` -} - -type SigningResultSuccessEvent struct { - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - R []byte `json:"r"` - S []byte `json:"s"` - SignatureRecovery []byte `json:"signature_recovery"` - - // TODO: define two separate events for eddsa and ecdsa - Signature []byte `json:"signature"` -} - -type SigningResultErrorEvent struct { - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - ErrorReason string `json:"error_reason"` - IsTimeout bool `json:"is_timeout"` -} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index e89e3b2..d7cf80a 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -13,7 +13,6 @@ import ( "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/mpc" "github.com/fystack/mpcium/pkg/mpc/node" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" @@ -134,7 +133,7 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { cancel() session.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) - successEvent := &mpc.KeygenSuccessEvent{ + successEvent := &event.KeygenSuccessEvent{ WalletID: walletID, ECDSAPubKey: session.GetPublicKey(data), } @@ -145,8 +144,8 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { return } - err = ec.genKeySucecssQueue.Enqueue(fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), + err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ + IdempotententKey: event.KeygenSuccessEventTopic, }) if err != nil { logger.Error("Failed to publish key generation success message", err) diff --git a/pkg/mpc/ecdsa_keygen_session.go b/pkg/mpc/ecdsa_keygen_session.go deleted file mode 100644 index 0bb7973..0000000 --- a/pkg/mpc/ecdsa_keygen_session.go +++ /dev/null @@ -1,152 +0,0 @@ -package mpc - -import ( - "crypto/ecdsa" - "encoding/json" - "fmt" - - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/encoding" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" -) - -const ( - TypeGenerateWalletSuccess = "mpc.mpc_keygen_success.%s" -) - -type KeygenSession struct { - Session - endCh chan *keygen.LocalPartySaveData -} - -type KeygenSuccessEvent struct { - WalletID string `json:"wallet_id"` - ECDSAPubKey []byte `json:"ecdsa_pub_key"` - EDDSAPubKey []byte `json:"eddsa_pub_key"` -} - -func NewKeygenSession( - walletID string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - partyIDs []*tss.PartyID, - threshold int, - preParams *keygen.LocalPreParams, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *KeygenSession { - return &KeygenSession{ - Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: threshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: partyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - preParams: preParams, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("keygen:broadcast:ecdsa:%s", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:ecdsa:%s:%s", nodeID, walletID) - }, - }, - composeKey: func(walletID string) string { - return fmt.Sprintf("ecdsa:%s", walletID) - }, - getRoundFunc: GetEcdsaMsgRound, - resultQueue: resultQueue, - sessionType: SessionTypeEcdsa, - identityStore: identityStore, - }, - endCh: make(chan *keygen.LocalPartySaveData), - } -} - -func (s *KeygenSession) Init() { - logger.Infof("Initializing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - ctx := tss.NewPeerContext(s.partyIDs) - params := tss.NewParameters(tss.S256(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - s.party = keygen.NewLocalParty(params, s.outCh, s.endCh, *s.preParams) - logger.Infof("[INITIALIZED] Initialized session successfully partyID: %s, peerIDs %s, walletID %s, threshold = %d", s.selfPartyID, s.partyIDs, s.walletID, s.threshold) -} - -func (s *KeygenSession) GenerateKey(done func()) { - logger.Info("Starting to generate key", "walletID", s.walletID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - select { - case msg := <-s.outCh: - s.handleTssMessage(msg) - case saveData := <-s.endCh: - keyBytes, err := json.Marshal(saveData) - if err != nil { - s.ErrCh <- err - return - } - - err = s.kvstore.Put(s.composeKey(s.walletID), keyBytes) - if err != nil { - logger.Error("Failed to save key", err, "walletID", s.walletID) - s.ErrCh <- err - return - } - - keyInfo := keyinfo.KeyInfo{ - ParticipantPeerIDs: s.participantPeerIDs, - Threshold: s.threshold, - IsReshared: false, - } - - err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) - if err != nil { - logger.Error("Failed to save keyinfo", err, "walletID", s.walletID) - s.ErrCh <- err - return - } - - publicKey := saveData.ECDSAPub - - pubKey := &ecdsa.PublicKey{ - Curve: publicKey.Curve(), - X: publicKey.X(), - Y: publicKey.Y(), - } - - pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) - if err != nil { - logger.Error("failed to encode public key", err) - s.ErrCh <- fmt.Errorf("failed to encode public key: %w", err) - return - } - s.pubkeyBytes = pubKeyBytes - done() - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - // done() - return - } - } -} diff --git a/pkg/mpc/ecdsa_resharing_session.go b/pkg/mpc/ecdsa_resharing_session.go deleted file mode 100644 index 2dd2352..0000000 --- a/pkg/mpc/ecdsa_resharing_session.go +++ /dev/null @@ -1,203 +0,0 @@ -package mpc - -import ( - "crypto/ecdsa" - "encoding/json" - "fmt" - - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/encoding" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" -) - -const ( - TypeResharingSuccess = "mpc.mpc_resharing_success.%s" -) - -type IResharingSession interface { - ErrChan() <-chan error - ListenToIncomingResharingMessageAsync() - GetPubKeyResult() []byte - Init() - Resharing(done func()) -} - -type ECDSAResharingSession struct { - Session - isOldParty bool - oldPartyIDs []*tss.PartyID - oldThreshold int - newThreshold int - endCh chan *keygen.LocalPartySaveData -} - -type ResharingSuccessEvent struct { - WalletID string `json:"wallet_id"` - ECDSAPubKey []byte `json:"ecdsa_pub_key"` - EDDSAPubKey []byte `json:"eddsa_pub_key"` -} - -func ECDSANewResharingSession( - walletID string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - oldPartyIDs []*tss.PartyID, - newPartyIDs []*tss.PartyID, - threshold int, - newThreshold int, - preParams *keygen.LocalPreParams, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, - isOldParty bool, -) *ECDSAResharingSession { - oldCtx := tss.NewPeerContext(oldPartyIDs) - newCtx := tss.NewPeerContext(newPartyIDs) - reshareParams := tss.NewReSharingParameters( - tss.S256(), - oldCtx, - newCtx, - selfID, - len(oldPartyIDs), - threshold, - len(newPartyIDs), - newThreshold, - ) - return &ECDSAResharingSession{ - Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: newThreshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: newPartyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - preParams: preParams, - reshareParams: reshareParams, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf(TopicFormatResharingBroadcast, "ecdsa", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf(TopicFormatResharingDirect, "ecdsa", nodeID, walletID) - }, - }, - composeKey: func(walletID string) string { - return fmt.Sprintf(KeyFormatEcdsa, walletID) - }, - getRoundFunc: GetEcdsaMsgRound, - resultQueue: resultQueue, - sessionType: SessionTypeEcdsa, - identityStore: identityStore, - }, - isOldParty: isOldParty, - oldPartyIDs: oldPartyIDs, - oldThreshold: threshold, - newThreshold: newThreshold, - endCh: make(chan *keygen.LocalPartySaveData), - } -} - -func (s *ECDSAResharingSession) Init() { - logger.Infof("Initializing resharing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - var share keygen.LocalPartySaveData - if s.isOldParty { - // Get existing key data for old party - keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to get wallet data from KVStore: %w", err) - return - } - err = json.Unmarshal(keyData, &share) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to unmarshal wallet data: %w", err) - return - } - } else { - // Initialize empty share data for new party - share = keygen.NewLocalPartySaveData(len(s.partyIDs)) - share.LocalPreParams = *s.preParams - } - - s.party = resharing.NewLocalParty(s.reshareParams, share, s.outCh, s.endCh) - logger.Infof("[INITIALIZED] Initialized resharing session successfully partyID: %s, peerIDs %s, walletID %s, oldThreshold = %d, newThreshold = %d", - s.selfPartyID, s.partyIDs, s.walletID, s.oldThreshold, s.newThreshold) -} - -func (s *ECDSAResharingSession) Resharing(done func()) { - logger.Info("Starting resharing", "walletID", s.walletID, "partyID", s.selfPartyID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - select { - case saveData := <-s.endCh: - keyBytes, err := json.Marshal(saveData) - if err != nil { - s.ErrCh <- err - return - } - - if err := s.SaveKeyData(keyBytes); err != nil { - s.ErrCh <- err - return - } - - // Save key info with resharing flag - if err := s.SaveKeyInfo(true); err != nil { - s.ErrCh <- err - return - } - - // skip for old committee - if saveData.ECDSAPub != nil { - // Get public key - publicKey := saveData.ECDSAPub - pubKey := &ecdsa.PublicKey{ - Curve: publicKey.Curve(), - X: publicKey.X(), - Y: publicKey.Y(), - } - - pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) - if err != nil { - logger.Error("failed to encode public key", err) - s.ErrCh <- fmt.Errorf("failed to encode public key: %w", err) - return - } - - // Set the public key bytes - s.pubkeyBytes = pubKeyBytes - logger.Info("Generated public key bytes", - "walletID", s.walletID, - "pubKeyBytes", pubKeyBytes) - } - - done() - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - return - case msg := <-s.outCh: - // Handle the message - s.handleResharingMessage(msg) - } - } -} diff --git a/pkg/mpc/ecdsa_rounds.go b/pkg/mpc/ecdsa_rounds.go deleted file mode 100644 index e32badb..0000000 --- a/pkg/mpc/ecdsa_rounds.go +++ /dev/null @@ -1,163 +0,0 @@ -package mpc - -import ( - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" - "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" -) - -const ( - KEYGEN1 = "KGRound1Message" - KEYGEN2aUnicast = "KGRound2Message1" - KEYGEN2b = "KGRound2Message2" - KEYGEN3 = "KGRound3Message" - KEYSIGN1aUnicast = "SignRound1Message1" - KEYSIGN1b = "SignRound1Message2" - KEYSIGN2Unicast = "SignRound2Message" - KEYSIGN3 = "SignRound3Message" - KEYSIGN4 = "SignRound4Message" - KEYSIGN5 = "SignRound5Message" - KEYSIGN6 = "SignRound6Message" - KEYSIGN7 = "SignRound7Message" - KEYSIGN8 = "SignRound8Message" - KEYSIGN9 = "SignRound9Message" - KEYRESHARING1Unicast = "DGRound1Message" - KEYRESHARING2aUnicast = "DGRound2Message1" - KEYRESHARING2bUnicast = "DGRound2Message2" - KEYRESHARING3aUnicast = "DGRound3Message1" - KEYRESHARING3b = "DGRound3Message2" - KEYRESHARING4a = "DGRound4Message1" - KEYRESHARING4bUnicast = "DGRound4Message2" - - TSSKEYGENROUNDS = 4 - TSSKEYSIGNROUNDS = 10 -) - -func GetEcdsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (RoundInfo, error) { - parsedMsg, err := tss.ParseWireMessage(msg, partyID, isBroadcast) - if err != nil { - return RoundInfo{}, err - } - switch parsedMsg.Content().(type) { - case *keygen.KGRound1Message: - return RoundInfo{ - Index: 0, - RoundMsg: KEYGEN1, - }, nil - - case *keygen.KGRound2Message1: - return RoundInfo{ - Index: 1, - RoundMsg: KEYGEN2aUnicast, - }, nil - - case *keygen.KGRound2Message2: - return RoundInfo{ - Index: 2, - RoundMsg: KEYGEN2b, - }, nil - - case *keygen.KGRound3Message: - return RoundInfo{ - Index: 3, - RoundMsg: KEYGEN3, - }, nil - - case *signing.SignRound1Message1: - return RoundInfo{ - Index: 0, - RoundMsg: KEYSIGN1aUnicast, - }, nil - - case *signing.SignRound1Message2: - return RoundInfo{ - Index: 1, - RoundMsg: KEYSIGN1b, - }, nil - - case *signing.SignRound2Message: - return RoundInfo{ - Index: 2, - RoundMsg: KEYSIGN2Unicast, - }, nil - - case *signing.SignRound3Message: - return RoundInfo{ - Index: 3, - RoundMsg: KEYSIGN3, - }, nil - - case *signing.SignRound4Message: - return RoundInfo{ - Index: 4, - RoundMsg: KEYSIGN4, - }, nil - - case *signing.SignRound5Message: - return RoundInfo{ - Index: 5, - RoundMsg: KEYSIGN5, - }, nil - - case *signing.SignRound6Message: - return RoundInfo{ - Index: 6, - RoundMsg: KEYSIGN6, - }, nil - - case *signing.SignRound7Message: - return RoundInfo{ - Index: 7, - RoundMsg: KEYSIGN7, - }, nil - case *signing.SignRound8Message: - return RoundInfo{ - Index: 8, - RoundMsg: KEYSIGN8, - }, nil - case *signing.SignRound9Message: - return RoundInfo{ - Index: 9, - RoundMsg: KEYSIGN9, - }, nil - case *resharing.DGRound1Message: - return RoundInfo{ - Index: 0, - RoundMsg: KEYRESHARING1Unicast, - }, nil - case *resharing.DGRound2Message1: - return RoundInfo{ - Index: 1, - RoundMsg: KEYRESHARING2aUnicast, - }, nil - case *resharing.DGRound2Message2: - return RoundInfo{ - Index: 2, - RoundMsg: KEYRESHARING2bUnicast, - }, nil - case *resharing.DGRound3Message1: - return RoundInfo{ - Index: 3, - RoundMsg: KEYRESHARING3aUnicast, - }, nil - case *resharing.DGRound3Message2: - return RoundInfo{ - Index: 4, - RoundMsg: KEYRESHARING3b, - }, nil - case *resharing.DGRound4Message1: - return RoundInfo{ - Index: 5, - RoundMsg: KEYRESHARING4a, - }, nil - case *resharing.DGRound4Message2: - return RoundInfo{ - Index: 6, - RoundMsg: KEYRESHARING4bUnicast, - }, nil - default: - return RoundInfo{}, errors.New("unknown round") - } -} diff --git a/pkg/mpc/ecdsa_signing_session.go b/pkg/mpc/ecdsa_signing_session.go deleted file mode 100644 index c0dcbaf..0000000 --- a/pkg/mpc/ecdsa_signing_session.go +++ /dev/null @@ -1,204 +0,0 @@ -package mpc - -import ( - "crypto/ecdsa" - "encoding/json" - "fmt" - "math/big" - - "github.com/bnb-chain/tss-lib/v2/common" - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" - "github.com/fystack/mpcium/pkg/event" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/samber/lo" -) - -// Ecdsa signing session -type SigningSession struct { - Session - endCh chan *common.SignatureData - data *keygen.LocalPartySaveData - tx *big.Int - txID string - networkInternalCode string -} - -type ISession interface { - ErrChan() <-chan error - ListenToIncomingMessageAsync() -} - -type ISigningSession interface { - ISession - - Init(tx *big.Int) error - Sign(onSuccess func(data []byte)) -} - -func NewSigningSession( - walletID string, - txID string, - networkInternalCode string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - partyIDs []*tss.PartyID, - threshold int, - preParams *keygen.LocalPreParams, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *SigningSession { - return &SigningSession{ - Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: threshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: partyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - preParams: preParams, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("sign:ecdsa:broadcast:%s:%s", walletID, txID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("sign:ecdsa:direct:%s:%s", nodeID, txID) - }, - }, - composeKey: func(waleltID string) string { - return fmt.Sprintf("ecdsa:%s", waleltID) - }, - getRoundFunc: GetEcdsaMsgRound, - resultQueue: resultQueue, - identityStore: identityStore, - }, - endCh: make(chan *common.SignatureData), - txID: txID, - networkInternalCode: networkInternalCode, - } -} - -func (s *SigningSession) Init(tx *big.Int) error { - logger.Infof("Initializing signing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - ctx := tss.NewPeerContext(s.partyIDs) - params := tss.NewParameters(tss.S256(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - - keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get wallet data from KVStore") - } - - keyInfo, err := s.keyinfoStore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get key info data") - } - - if len(s.participantPeerIDs) < keyInfo.Threshold+1 { - logger.Warn("Not enough participants to sign", "participants", s.participantPeerIDs, "expected", keyInfo.Threshold+1) - return ErrNotEnoughParticipants - } - - // check if t+1 participants are present - result := lo.Intersect(s.participantPeerIDs, keyInfo.ParticipantPeerIDs) - if len(result) < keyInfo.Threshold+1 { - return fmt.Errorf( - "Incompatible peerIDs to participate in signing. Current participants: %v, expected participants: %v", - s.participantPeerIDs, - keyInfo.ParticipantPeerIDs, - ) - } - - logger.Info("Have enough participants to sign", "participants", s.participantPeerIDs) - // Check if all the participants of the key are present - var data keygen.LocalPartySaveData - err = json.Unmarshal(keyData, &data) - if err != nil { - return errors.Wrap(err, "Failed to unmarshal wallet data") - } - - s.party = signing.NewLocalParty(tx, params, data, s.outCh, s.endCh) - s.data = &data - s.tx = tx - logger.Info("Initialized sigining session successfully!") - return nil -} - -func (s *SigningSession) Sign(onSuccess func(data []byte)) { - logger.Info("Starting signing", "walletID", s.walletID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - - select { - case msg := <-s.outCh: - s.handleTssMessage(msg) - case sig := <-s.endCh: - publicKey := *s.data.ECDSAPub - pk := ecdsa.PublicKey{ - Curve: publicKey.Curve(), - X: publicKey.X(), - Y: publicKey.Y(), - } - - ok := ecdsa.Verify(&pk, s.tx.Bytes(), new(big.Int).SetBytes(sig.R), new(big.Int).SetBytes(sig.S)) - if !ok { - s.ErrCh <- errors.New("Failed to verify signature") - return - } - - r := event.SigningResultEvent{ - ResultType: event.SigningResultTypeSuccess, - NetworkInternalCode: s.networkInternalCode, - WalletID: s.walletID, - TxID: s.txID, - R: sig.R, - S: sig.S, - SignatureRecovery: sig.SignatureRecovery, - } - - bytes, err := json.Marshal(r) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Failed to marshal raw signature") - return - } - - err = s.resultQueue.Enqueue(event.SigningResultCompleteTopic, bytes, &messaging.EnqueueOptions{ - IdempotententKey: s.txID, - }) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Failed to publish sign success message") - - return - } - - logger.Info("[SIGN] Sign successfully", "walletID", s.walletID) - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - - onSuccess(bytes) - return - } - - } -} diff --git a/pkg/mpc/eddsa_keygen_session.go b/pkg/mpc/eddsa_keygen_session.go deleted file mode 100644 index 944b22b..0000000 --- a/pkg/mpc/eddsa_keygen_session.go +++ /dev/null @@ -1,138 +0,0 @@ -package mpc - -import ( - "encoding/json" - "fmt" - - "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/decred/dcrd/dcrec/edwards/v2" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" -) - -type EDDSAKeygenSession struct { - Session - endCh chan *keygen.LocalPartySaveData -} - -type EDDSAKeygenSuccessEvent struct { - WalletID string `json:"wallet_id"` - PubKey []byte `json:"pub_key"` -} - -func NewEDDSAKeygenSession( - walletID string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - partyIDs []*tss.PartyID, - threshold int, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *EDDSAKeygenSession { - return &EDDSAKeygenSession{Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: threshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: partyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("keygen:broadcast:eddsa:%s", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:eddsa:%s:%s", nodeID, walletID) - }, - }, - composeKey: func(waleltID string) string { - return fmt.Sprintf("eddsa:%s", waleltID) - }, - getRoundFunc: GetEddsaMsgRound, - resultQueue: resultQueue, - sessionType: SessionTypeEddsa, - identityStore: identityStore, - }, - endCh: make(chan *keygen.LocalPartySaveData), - } -} - -func (s *EDDSAKeygenSession) Init() { - logger.Infof("Initializing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - ctx := tss.NewPeerContext(s.partyIDs) - params := tss.NewParameters(tss.Edwards(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - s.party = keygen.NewLocalParty(params, s.outCh, s.endCh) - logger.Infof("[INITIALIZED] Initialized session successfully partyID: %s, peerIDs %s, walletID %s, threshold = %d", s.selfPartyID, s.partyIDs, s.walletID, s.threshold) -} - -func (s *EDDSAKeygenSession) GenerateKey(done func()) { - logger.Info("Starting to generate key", "walletID", s.walletID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - select { - case msg := <-s.outCh: - s.handleTssMessage(msg) - case saveData := <-s.endCh: - keyBytes, err := json.Marshal(saveData) - if err != nil { - s.ErrCh <- err - return - } - - err = s.kvstore.Put(s.composeKey(s.walletID), keyBytes) - if err != nil { - logger.Error("Failed to save key", err, "walletID", s.walletID) - s.ErrCh <- err - return - } - - keyInfo := keyinfo.KeyInfo{ - ParticipantPeerIDs: s.participantPeerIDs, - Threshold: s.threshold, - IsReshared: false, - } - - err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) - if err != nil { - logger.Error("Failed to save keyinfo", err, "walletID", s.walletID) - s.ErrCh <- err - return - } - - publicKey := saveData.EDDSAPub - pkX, pkY := publicKey.X(), publicKey.Y() - pk := edwards.PublicKey{ - Curve: tss.Edwards(), - X: pkX, - Y: pkY, - } - - pubKeyBytes := pk.SerializeCompressed() - s.pubkeyBytes = pubKeyBytes - - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - done() - return - } - } -} diff --git a/pkg/mpc/eddsa_resharing_session.go b/pkg/mpc/eddsa_resharing_session.go deleted file mode 100644 index 06e8dba..0000000 --- a/pkg/mpc/eddsa_resharing_session.go +++ /dev/null @@ -1,176 +0,0 @@ -package mpc - -import ( - "encoding/json" - "fmt" - - "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" - "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/decred/dcrd/dcrec/edwards/v2" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" -) - -type EDDSAResharingSession struct { - Session - isOldParty bool - oldPartyIDs []*tss.PartyID - oldThreshold int - newThreshold int - endCh chan *keygen.LocalPartySaveData -} - -func EDDSANewResharingSession( - walletID string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - oldPartyIDs []*tss.PartyID, - newPartyIDs []*tss.PartyID, - threshold int, - newThreshold int, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, - isOldParty bool, -) *EDDSAResharingSession { - oldCtx := tss.NewPeerContext(oldPartyIDs) - newCtx := tss.NewPeerContext(newPartyIDs) - reshareParams := tss.NewReSharingParameters( - tss.Edwards(), - oldCtx, - newCtx, - selfID, - len(oldPartyIDs), - threshold, - len(newPartyIDs), - newThreshold, - ) - return &EDDSAResharingSession{ - Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: newThreshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: newPartyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - reshareParams: reshareParams, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf(TopicFormatResharingBroadcast, "eddsa", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf(TopicFormatResharingDirect, "eddsa", nodeID, walletID) - }, - }, - composeKey: func(walletID string) string { - return fmt.Sprintf(KeyFormatEddsa, walletID) - }, - getRoundFunc: GetEddsaMsgRound, - resultQueue: resultQueue, - sessionType: SessionTypeEddsa, - identityStore: identityStore, - }, - isOldParty: isOldParty, - oldPartyIDs: oldPartyIDs, - oldThreshold: threshold, - newThreshold: newThreshold, - endCh: make(chan *keygen.LocalPartySaveData), - } -} - -func (s *EDDSAResharingSession) Init() { - logger.Infof("Initializing EDDSA resharing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - var share keygen.LocalPartySaveData - if s.isOldParty { - // Get existing key data for old party - keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) - if err != nil { - fmt.Println("err", err) - s.ErrCh <- fmt.Errorf("failed to get wallet data from KVStore: %w", err) - return - } - err = json.Unmarshal(keyData, &share) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to unmarshal wallet data: %w", err) - return - } - } else { - // Initialize empty share data for new party - share = keygen.NewLocalPartySaveData(len(s.partyIDs)) - } - s.party = resharing.NewLocalParty(s.reshareParams, share, s.outCh, s.endCh) - logger.Infof("[INITIALIZED] Initialized EDDSA resharing session successfully partyID: %s, peerIDs %s, walletID %s, oldThreshold = %d, newThreshold = %d", - s.selfPartyID, s.partyIDs, s.walletID, s.oldThreshold, s.newThreshold) -} - -func (s *EDDSAResharingSession) Resharing(done func()) { - logger.Info("Starting EDDSA resharing", "walletID", s.walletID, "partyID", s.selfPartyID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - select { - case saveData := <-s.endCh: - // skip for old committee - if saveData.EDDSAPub != nil { - keyBytes, err := json.Marshal(saveData) - if err != nil { - s.ErrCh <- err - return - } - - if err := s.SaveKeyData(keyBytes); err != nil { - s.ErrCh <- err - return - } - - // Save key info with resharing flag - if err := s.SaveKeyInfo(true); err != nil { - s.ErrCh <- err - return - } - - // Get public key - publicKey := saveData.EDDSAPub - pkX, pkY := publicKey.X(), publicKey.Y() - pk := edwards.PublicKey{ - Curve: tss.Edwards(), - X: pkX, - Y: pkY, - } - - pubKeyBytes := pk.SerializeCompressed() - s.pubkeyBytes = pubKeyBytes - - logger.Info("Generated public key bytes", - "walletID", s.walletID, - "pubKeyBytes", pubKeyBytes) - } - - done() - err := s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - return - case msg := <-s.outCh: - // Handle the message - s.handleResharingMessage(msg) - } - } -} diff --git a/pkg/mpc/eddsa_rounds.go b/pkg/mpc/eddsa_rounds.go deleted file mode 100644 index 88864f3..0000000 --- a/pkg/mpc/eddsa_rounds.go +++ /dev/null @@ -1,112 +0,0 @@ -package mpc - -import ( - "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" - "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" - "github.com/bnb-chain/tss-lib/v2/eddsa/signing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" -) - -type GetRoundFunc func(msg []byte, partyID *tss.PartyID, isBroadcast bool) (RoundInfo, error) - -type RoundInfo struct { - Index int - RoundMsg string - MsgIdentifier string -} - -const ( - EDDSA_KEYGEN1 = "KGRound1Message" - EDDSA_KEYGEN2aUnicast = "KGRound2Message1" - EDDSA_KEYGEN2b = "KGRound2Message2" - EDDSA_KEYSIGN1 = "SignRound1Message" - EDDSA_KEYSIGN2 = "SignRound2Message" - EDDSA_KEYSIGN3 = "SignRound3Message" - EDDSA_RESHARING1 = "DGRound1Message" - EDDSA_RESHARING2 = "DGRound2Message" - EDDSA_RESHARING3aUnicast = "DGRound3Message1" - EDDSA_RESHARING3bUnicast = "DGRound3Message2" - EDDSA_RESHARING4 = "DGRound4Message" - - EDDSA_TSSKEYGENROUNDS = 3 - EDDSA_TSSKEYSIGNROUNDS = 3 - EDDSA_RESHARINGROUNDS = 4 -) - -func GetEddsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (RoundInfo, error) { - parsedMsg, err := tss.ParseWireMessage(msg, partyID, isBroadcast) - if err != nil { - return RoundInfo{}, err - } - switch parsedMsg.Content().(type) { - case *keygen.KGRound1Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_KEYGEN1, - }, nil - - case *keygen.KGRound2Message1: - return RoundInfo{ - Index: 1, - RoundMsg: EDDSA_KEYGEN2aUnicast, - }, nil - - case *keygen.KGRound2Message2: - return RoundInfo{ - Index: 2, - RoundMsg: EDDSA_KEYGEN2b, - }, nil - - case *signing.SignRound1Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_KEYSIGN1, - }, nil - - case *signing.SignRound2Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_KEYSIGN2, - }, nil - - case *signing.SignRound3Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_KEYSIGN3, - }, nil - - case *resharing.DGRound1Message: - return RoundInfo{ - Index: 0, - RoundMsg: EDDSA_RESHARING1, - }, nil - - case *resharing.DGRound2Message: - return RoundInfo{ - Index: 1, - RoundMsg: EDDSA_RESHARING2, - }, nil - - case *resharing.DGRound3Message1: - return RoundInfo{ - Index: 2, - RoundMsg: EDDSA_RESHARING3aUnicast, - }, nil - - case *resharing.DGRound3Message2: - return RoundInfo{ - Index: 3, - RoundMsg: EDDSA_RESHARING3bUnicast, - }, nil - - case *resharing.DGRound4Message: - return RoundInfo{ - Index: 4, - RoundMsg: EDDSA_RESHARING4, - }, nil - - default: - return RoundInfo{}, errors.New("unknown round") - } -} diff --git a/pkg/mpc/eddsa_signing_session.go b/pkg/mpc/eddsa_signing_session.go deleted file mode 100644 index ea5103d..0000000 --- a/pkg/mpc/eddsa_signing_session.go +++ /dev/null @@ -1,187 +0,0 @@ -package mpc - -import ( - "encoding/json" - "fmt" - "math/big" - - "github.com/bnb-chain/tss-lib/v2/common" - "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" - "github.com/bnb-chain/tss-lib/v2/eddsa/signing" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/decred/dcrd/dcrec/edwards/v2" - "github.com/fystack/mpcium/pkg/common/errors" - "github.com/fystack/mpcium/pkg/event" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/samber/lo" -) - -type EDDSASigningSession struct { - Session - endCh chan *common.SignatureData - data *keygen.LocalPartySaveData - tx *big.Int - txID string - networkInternalCode string -} - -func NewEDDSASigningSession( - walletID string, - txID string, - networkInternalCode string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - participantPeerIDs []string, - selfID *tss.PartyID, - partyIDs []*tss.PartyID, - threshold int, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *EDDSASigningSession { - return &EDDSASigningSession{ - Session: Session{ - walletID: walletID, - pubSub: pubSub, - direct: direct, - threshold: threshold, - participantPeerIDs: participantPeerIDs, - selfPartyID: selfID, - partyIDs: partyIDs, - outCh: make(chan tss.Message), - ErrCh: make(chan error), - // preParams: preParams, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("sign:eddsa:broadcast:%s:%s", walletID, txID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("sign:eddsa:direct:%s:%s", nodeID, txID) - }, - }, - composeKey: func(waleltID string) string { - return fmt.Sprintf("eddsa:%s", waleltID) - }, - getRoundFunc: GetEddsaMsgRound, - resultQueue: resultQueue, - identityStore: identityStore, - }, - endCh: make(chan *common.SignatureData), - txID: txID, - networkInternalCode: networkInternalCode, - } -} - -func (s *EDDSASigningSession) Init(tx *big.Int) error { - logger.Infof("Initializing signing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) - ctx := tss.NewPeerContext(s.partyIDs) - params := tss.NewParameters(tss.Edwards(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - - keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get wallet data from KVStore") - } - - keyInfo, err := s.keyinfoStore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get key info data") - } - - if len(s.participantPeerIDs) < keyInfo.Threshold+1 { - logger.Warn("Not enough participants to sign, expected %d, got %d", keyInfo.Threshold+1, len(s.participantPeerIDs)) - return ErrNotEnoughParticipants - } - - // check if t+1 participants are present - result := lo.Intersect(s.participantPeerIDs, keyInfo.ParticipantPeerIDs) - if len(result) < keyInfo.Threshold+1 { - return fmt.Errorf( - "Incompatible peerIDs to participate in signing. Current participants: %v, expected participants: %v", - s.participantPeerIDs, - keyInfo.ParticipantPeerIDs, - ) - } - - logger.Info("Have enough participants to sign", "participants", s.participantPeerIDs) - // Check if all the participants of the key are present - var data keygen.LocalPartySaveData - err = json.Unmarshal(keyData, &data) - if err != nil { - return errors.Wrap(err, "Failed to unmarshal wallet data") - } - s.party = signing.NewLocalParty(tx, params, data, s.outCh, s.endCh) - s.data = &data - s.tx = tx - logger.Info("Initialized sigining session successfully!") - return nil -} - -func (s *EDDSASigningSession) Sign(onSuccess func(data []byte)) { - logger.Info("Starting signing", "walletID", s.walletID) - go func() { - if err := s.party.Start(); err != nil { - s.ErrCh <- err - } - }() - - for { - - select { - case msg := <-s.outCh: - s.handleTssMessage(msg) - case sig := <-s.endCh: - publicKey := *s.data.EDDSAPub - pk := edwards.PublicKey{ - Curve: tss.Edwards(), - X: publicKey.X(), - Y: publicKey.Y(), - } - - ok := edwards.Verify(&pk, s.tx.Bytes(), new(big.Int).SetBytes(sig.R), new(big.Int).SetBytes(sig.S)) - if !ok { - s.ErrCh <- errors.New("Failed to verify signature") - return - } - - r := event.SigningResultEvent{ - ResultType: event.SigningResultTypeSuccess, - NetworkInternalCode: s.networkInternalCode, - WalletID: s.walletID, - TxID: s.txID, - Signature: sig.Signature, - } - - bytes, err := json.Marshal(r) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Failed to marshal raw signature") - return - } - - err = s.resultQueue.Enqueue(event.SigningResultCompleteTopic, bytes, &messaging.EnqueueOptions{ - IdempotententKey: s.txID, - }) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Failed to publish sign success message") - return - } - - logger.Info("[SIGN] Sign successfully", "walletID", s.walletID) - - err = s.Close() - if err != nil { - logger.Error("Failed to close session", err) - } - - onSuccess(bytes) - return - } - - } -} diff --git a/pkg/mpc/key_type.go b/pkg/mpc/key_type.go deleted file mode 100644 index 756efa8..0000000 --- a/pkg/mpc/key_type.go +++ /dev/null @@ -1,8 +0,0 @@ -package mpc - -type KeyType string - -const ( - KeyTypeSecp256k1 KeyType = "secp256k1" - KeyTypeEd25519 = "ed25519" -) diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go deleted file mode 100644 index f61b9c3..0000000 --- a/pkg/mpc/node.go +++ /dev/null @@ -1,320 +0,0 @@ -package mpc - -import ( - "bytes" - "fmt" - "math/big" - "time" - - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/google/uuid" -) - -const ( - PurposeKeygen string = "keygen" - PurposeSign string = "sign" - PurposeResharing string = "resharing" -) - -type ID string - -type Node struct { - nodeID string - peerIDs []string - - pubSub messaging.PubSub - direct messaging.DirectMessaging - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - ecdsaPreParams *keygen.LocalPreParams - identityStore identity.Store - - peerRegistry PeerRegistry -} - -func CreatePartyID(nodeID string, label string) *tss.PartyID { - partyID := uuid.NewString() - key := big.NewInt(0).SetBytes([]byte(nodeID + ":" + label)) - return tss.NewPartyID(partyID, label, key) -} - -func PartyIDToNodeID(partyID *tss.PartyID) string { - return string(partyID.KeyInt().Bytes()) -} - -func ComparePartyIDs(x, y *tss.PartyID) bool { - return bytes.Equal(x.KeyInt().Bytes(), y.KeyInt().Bytes()) -} - -func ComposeReadyKey(nodeID string) string { - return fmt.Sprintf("ready/%s", nodeID) -} - -func NewNode( - nodeID string, - peerIDs []string, - pubSub messaging.PubSub, - direct messaging.DirectMessaging, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - peerRegistry PeerRegistry, - identityStore identity.Store, -) *Node { - preParams, err := keygen.GeneratePreParams(5 * time.Minute) - if err != nil { - logger.Fatal("Generate pre params failed", err) - } - logger.Info("Starting new node, preparams is generated successfully!") - - go peerRegistry.WatchPeersReady() - - return &Node{ - nodeID: nodeID, - peerIDs: peerIDs, - pubSub: pubSub, - direct: direct, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - ecdsaPreParams: preParams, - peerRegistry: peerRegistry, - identityStore: identityStore, - } -} - -func (p *Node) ID() string { - return p.nodeID -} - -func (p *Node) CreateKeyGenSession(walletID string, threshold int, successQueue messaging.MessageQueue) (*KeygenSession, error) { - if p.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { - return nil, fmt.Errorf("Not enough peers to create gen session! Expected %d, got %d", threshold+1, p.peerRegistry.GetReadyPeersCount()) - } - - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) - session := NewKeygenSession( - walletID, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - allPartyIDs, - threshold, - p.ecdsaPreParams, - p.kvstore, - p.keyinfoStore, - successQueue, - p.identityStore, - ) - return session, nil -} - -func (p *Node) CreateEDDSAKeyGenSession(walletID string, threshold int, successQueue messaging.MessageQueue) (*EDDSAKeygenSession, error) { - if p.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { - return nil, fmt.Errorf("Not enough peers to create gen session! Expected %d, got %d", threshold+1, p.peerRegistry.GetReadyPeersCount()) - } - - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) - session := NewEDDSAKeygenSession( - walletID, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - allPartyIDs, - threshold, - p.kvstore, - p.keyinfoStore, - successQueue, - p.identityStore, - ) - return session, nil -} - -func (p *Node) CreateSigningSession( - walletID string, - txID string, - networkInternalCode string, - threshold int, - resultQueue messaging.MessageQueue, -) (*SigningSession, error) { - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - keyInfo, err := p.keyinfoStore.Get(fmt.Sprintf("eddsa:%s", walletID)) - if err != nil { - return nil, fmt.Errorf("failed to get key info: %w", err) - } - var selfPartyID *tss.PartyID - var allPartyIDs []*tss.PartyID - if keyInfo.IsReshared { - selfPartyID, allPartyIDs = p.generatePartyIDs(PurposeResharing, readyPeerIDs) - } else { - selfPartyID, allPartyIDs = p.generatePartyIDs(PurposeKeygen, readyPeerIDs) - } - session := NewSigningSession( - walletID, - txID, - networkInternalCode, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - allPartyIDs, - threshold, - p.ecdsaPreParams, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.identityStore, - ) - return session, nil -} - -func (p *Node) CreateEDDSASigningSession( - walletID string, - txID string, - networkInternalCode string, - threshold int, - resultQueue messaging.MessageQueue, -) (*EDDSASigningSession, error) { - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - keyInfo, err := p.keyinfoStore.Get(fmt.Sprintf("eddsa:%s", walletID)) - if err != nil { - return nil, fmt.Errorf("failed to get key info: %w", err) - } - var selfPartyID *tss.PartyID - var allPartyIDs []*tss.PartyID - if keyInfo.IsReshared { - selfPartyID, allPartyIDs = p.generatePartyIDs(PurposeResharing, readyPeerIDs) - } else { - selfPartyID, allPartyIDs = p.generatePartyIDs(PurposeKeygen, readyPeerIDs) - } - session := NewEDDSASigningSession( - walletID, - txID, - networkInternalCode, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - allPartyIDs, - threshold, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.identityStore, - ) - return session, nil -} - -func (p *Node) CreateECDSAResharingSession(walletID string, isOldParticipant bool, readyPeerIDs []string, newThreshold int, resultQueue messaging.MessageQueue) (*ECDSAResharingSession, error) { - // Get existing key info to determine old participants - keyInfo, err := p.keyinfoStore.Get(fmt.Sprintf("ecdsa:%s", walletID)) - if err != nil { - return nil, fmt.Errorf("failed to get key info: %w", err) - } - - oldSelfPartyID, oldPartyIDs := p.generatePartyIDs(PurposeKeygen, keyInfo.ParticipantPeerIDs) - newSelfPartyID, newPartyIDs := p.generatePartyIDs(PurposeResharing, readyPeerIDs) - - var selfPartyID *tss.PartyID - if isOldParticipant { - selfPartyID = oldSelfPartyID - } else { - selfPartyID = newSelfPartyID - } - - session := ECDSANewResharingSession( - walletID, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - oldPartyIDs, - newPartyIDs, - keyInfo.Threshold, - newThreshold, - p.ecdsaPreParams, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.identityStore, - isOldParticipant, - ) - return session, nil -} - -func (p *Node) CreeateEDDSAResharingSession(walletID string, isOldParticipant bool, readyPeerIDs []string, newThreshold int, resultQueue messaging.MessageQueue) (*EDDSAResharingSession, error) { - keyInfo, err := p.keyinfoStore.Get(fmt.Sprintf("eddsa:%s", walletID)) - if err != nil { - return nil, fmt.Errorf("failed to get key info: %w", err) - } - - oldSelfPartyID, oldPartyIDs := p.generatePartyIDs(PurposeKeygen, keyInfo.ParticipantPeerIDs) - newSelfPartyID, newPartyIDs := p.generatePartyIDs(PurposeResharing, readyPeerIDs) - - var selfPartyID *tss.PartyID - if isOldParticipant { - selfPartyID = oldSelfPartyID - } else { - selfPartyID = newSelfPartyID - } - - session := EDDSANewResharingSession( - walletID, - p.pubSub, - p.direct, - readyPeerIDs, - selfPartyID, - oldPartyIDs, - newPartyIDs, - keyInfo.Threshold, - newThreshold, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.identityStore, - isOldParticipant, - ) - return session, nil -} - -func (p *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { - var selfPartyID *tss.PartyID - partyIDs := make([]*tss.PartyID, len(readyPeerIDs)) - for i, peerID := range readyPeerIDs { - if peerID == p.nodeID { - selfPartyID = CreatePartyID(peerID, purpose) - partyIDs[i] = selfPartyID - } else { - partyIDs[i] = CreatePartyID(peerID, purpose) - } - } - allPartyIDs := tss.SortPartyIDs(partyIDs, 0) - return selfPartyID, allPartyIDs -} - -func (p *Node) Close() { - err := p.peerRegistry.Resign() - if err != nil { - logger.Error("Resign failed", err) - } -} - -func (p *Node) GetKeyInfo(key string) (*keyinfo.KeyInfo, error) { - return p.keyinfoStore.Get(key) -} - -func (p *Node) GetReadyPeersIncludeSelf() []string { - return p.peerRegistry.GetReadyPeersIncludeSelf() -} - -func (p *Node) GetKVStore() kvstore.KVStore { - return p.kvstore -} diff --git a/pkg/mpc/node_test.go b/pkg/mpc/node_test.go deleted file mode 100644 index dafacda..0000000 --- a/pkg/mpc/node_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package mpc - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -// func TestCreateKeyGenSession(t *testing.T) { -// nodeID := uuid.NewString() - -// peerIDs := []string{ -// nodeID, -// uuid.NewString(), -// uuid.NewString(), -// } -// ctrl := gomock.NewController(t) -// defer ctrl.Finish() -// pubsub := mock.NewMockPubSub(ctrl) -// direct := mock.NewMockDirectMessaging(ctrl) - -// node := NewNode(nodeID, peerIDs, pubsub, direct) - -// session, err := node.CreateKeyGenSession() - -// assert.NoError(t, err) -// assert.Len(t, session.PartyIDs(), 3, "Length of partyIDs should be equal") -// assert.NotNil(t, session.PartyID()) - -// for i, partyID := range session.PartyIDs() { -// // check sortedID -// assert.Equal(t, partyID.Index, i, "Index should be equal") -// } - -// } - -func TestPartyIDToNodeID(t *testing.T) { - partyID := CreatePartyID("4d8cb873-dc86-4776-b6f6-cf5c668f6468", "keygen") - nodeID := PartyIDToNodeID(partyID) - assert.Equal(t, nodeID, "4d8cb873-dc86-4776-b6f6-cf5c668f6468", "NodeID should be equal") -} diff --git a/pkg/mpc/registry.go b/pkg/mpc/registry.go deleted file mode 100644 index dba9b55..0000000 --- a/pkg/mpc/registry.go +++ /dev/null @@ -1,207 +0,0 @@ -package mpc - -import ( - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/fystack/mpcium/pkg/infra" - "github.com/fystack/mpcium/pkg/logger" - "github.com/hashicorp/consul/api" - "github.com/samber/lo" -) - -const ( - ReadinessCheckPeriod = 1 * time.Second -) - -type PeerRegistry interface { - Ready() error - ArePeersReady() bool - WatchPeersReady() - // Resign is called by the node when it is going to shutdown - Resign() error - GetReadyPeersCount() int64 - GetReadyPeersIncludeSelf() []string // get ready peers include self -} - -type registry struct { - nodeID string - peerNodeIDs []string - readyMap map[string]bool - readyCount int64 - mu sync.RWMutex - ready bool // ready is true when all peers are ready - - consulKV infra.ConsulKV -} - -func NewRegistry( - nodeID string, - peerNodeIDs []string, - consulKV infra.ConsulKV, -) *registry { - return ®istry{ - consulKV: consulKV, - nodeID: nodeID, - peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), - readyMap: make(map[string]bool), - readyCount: 1, // self - } -} - -func getPeerIDsExceptSelf(nodeID string, peerNodeIDs []string) []string { - peerIDs := make([]string, 0, len(peerNodeIDs)) - for _, peerID := range peerNodeIDs { - if peerID != nodeID { - peerIDs = append(peerIDs, peerID) - } - } - return peerIDs -} - -func (r *registry) readyKey(nodeID string) string { - return fmt.Sprintf("ready/%s", nodeID) -} - -func (r *registry) registerReadyPairs(peerIDs []string) { - for _, peerID := range peerIDs { - ready, exist := r.readyMap[peerID] - if !exist { - atomic.AddInt64(&r.readyCount, 1) - logger.Info("Register", "peerID", peerID) - } else if !ready { - atomic.AddInt64(&r.readyCount, 1) - logger.Info("Reconnecting...", "peerID", peerID) - } - - r.readyMap[peerID] = true - } - - if len(peerIDs) == len(r.peerNodeIDs) && !r.ready { - r.mu.Lock() - r.ready = true - r.mu.Unlock() - logger.Info("ALL PEERS ARE READY! Starting to accept MPC requests") - } - -} - -// Ready is called by the node when it complete generate preparams and starting to accept -// incoming requests -func (r *registry) Ready() error { - k := r.readyKey(r.nodeID) - - kv := &api.KVPair{ - Key: k, - Value: []byte("true"), - } - - _, err := r.consulKV.Put(kv, nil) - if err != nil { - return fmt.Errorf("Put ready key failed: %w", err) - } - - return nil -} - -func (r *registry) WatchPeersReady() { - ticker := time.NewTicker(ReadinessCheckPeriod) - go r.logReadyStatus() - // first tick is executed immediately - for ; true; <-ticker.C { - pairs, _, err := r.consulKV.List("ready/", nil) - if err != nil { - logger.Error("List ready keys failed", err) - } - - newReadyPeerIDs := r.getReadyPeersFromKVStore(pairs) - if len(newReadyPeerIDs) != len(r.peerNodeIDs) { - r.mu.Lock() - r.ready = false - r.mu.Unlock() - - var readyPeerIDs []string - for peerID, isReady := range r.readyMap { - if isReady { - readyPeerIDs = append(readyPeerIDs, peerID) - } - } - - disconnecteds, _ := lo.Difference(readyPeerIDs, newReadyPeerIDs) - if len(disconnecteds) > 0 { - for _, peerID := range disconnecteds { - logger.Warn("Peer disconnected!", "peerID", peerID) - r.readyMap[peerID] = false - atomic.AddInt64(&r.readyCount, -1) - } - - } - - } - r.registerReadyPairs(newReadyPeerIDs) - } - -} - -func (r *registry) logReadyStatus() { - for { - time.Sleep(5 * time.Second) - if !r.ArePeersReady() { - logger.Info("Peers are not ready yet", "ready", r.GetReadyPeersCount(), "expected", len(r.peerNodeIDs)+1) - } - } -} - -func (r *registry) GetReadyPeersCount() int64 { - return atomic.LoadInt64(&r.readyCount) -} - -func (r *registry) GetReadyPeersIncludeSelf() []string { - var peerIDs []string - for peerID, isReady := range r.readyMap { - if isReady { - peerIDs = append(peerIDs, peerID) - } - } - - peerIDs = append(peerIDs, r.nodeID) // append self - return peerIDs -} - -func (r *registry) getReadyPeersFromKVStore(kvPairs api.KVPairs) []string { - var peers []string - for _, k := range kvPairs { - var peerNodeID string - _, err := fmt.Sscanf(k.Key, "ready/%s", &peerNodeID) - if err != nil { - logger.Error("Parse ready key failed", err) - } - if peerNodeID == r.nodeID { - continue - } - - peers = append(peers, peerNodeID) - } - - return peers -} - -func (r *registry) ArePeersReady() bool { - r.mu.RLock() - defer r.mu.RUnlock() - - return r.ready -} - -func (r *registry) Resign() error { - k := r.readyKey(r.nodeID) - - _, err := r.consulKV.Delete(k, nil) - if err != nil { - return fmt.Errorf("Delete ready key failed: %w", err) - } - - return nil -} diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go deleted file mode 100644 index 76994c2..0000000 --- a/pkg/mpc/session.go +++ /dev/null @@ -1,347 +0,0 @@ -package mpc - -import ( - "fmt" - "slices" - "strings" - "sync" - - "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" - "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/common/errors" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/types" - "github.com/nats-io/nats.go" -) - -var ( - ErrNotEnoughParticipants = errors.New("Not enough participants to sign") -) - -// SessionType constants -type SessionType string - -const ( - SessionTypeEcdsa SessionType = "session_ecdsa" - SessionTypeEddsa SessionType = "session_eddsa" -) - -// Topic format constants -const ( - TopicFormatResharingBroadcast = "resharing:broadcast:%s:%s" - TopicFormatResharingDirect = "resharing:direct:%s:%s:%s" -) - -// Key format constants -const ( - KeyFormatEcdsa = "ecdsa:%s" - KeyFormatEddsa = "eddsa:%s" -) - -type TopicComposer struct { - ComposeBroadcastTopic func() string - ComposeDirectTopic func(nodeID string) string -} - -type KeyComposerFn func(id string) string - -type Session struct { - walletID string - pubSub messaging.PubSub - direct messaging.DirectMessaging - threshold int - participantPeerIDs []string - selfPartyID *tss.PartyID - // IDs of all parties in the session including self - partyIDs []*tss.PartyID - outCh chan tss.Message - ErrCh chan error - party tss.Party - - // preParams is nil for EDDSA session - preParams *keygen.LocalPreParams - // reshareParams is nil for non resharing session - reshareParams *tss.ReSharingParameters - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - broadcastSub messaging.Subscription - directSub messaging.Subscription - resultQueue messaging.MessageQueue - identityStore identity.Store - - topicComposer *TopicComposer - composeKey KeyComposerFn - getRoundFunc GetRoundFunc - mu sync.Mutex - // After the session is done, the key will be stored pubkeyBytes - pubkeyBytes []byte - sessionType SessionType -} - -func (s *Session) PartyID() *tss.PartyID { - return s.selfPartyID -} - -func (s *Session) PartyIDs() []*tss.PartyID { - return s.partyIDs -} - -func (s *Session) PartyCount() int { - return len(s.partyIDs) -} - -func (s *Session) handleTssMessage(keyshare tss.Message) { - data, routing, err := keyshare.WireBytes() - if err != nil { - s.ErrCh <- err - return - } - tssMsg := types.NewTssMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To) - signature, err := s.identityStore.SignMessage(&tssMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to sign message: %w", err) - return - } - tssMsg.Signature = signature - msg, err := types.MarshalTssMessage(&tssMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to marshal tss message: %w", err) - return - } - - if routing.IsBroadcast && len(routing.To) == 0 { - err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msg) - if err != nil { - s.ErrCh <- err - return - } - } else { - for _, to := range routing.To { - nodeID := PartyIDToNodeID(to) - topic := s.topicComposer.ComposeDirectTopic(nodeID) - err := s.direct.Send(topic, msg) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to send direct message to %s: %w", topic, err) - } - - } - - } -} - -func (s *Session) handleResharingMessage(msg tss.Message) { - data, routing, err := msg.WireBytes() - if err != nil { - s.ErrCh <- err - return - } - - tssMsg := types.NewTssResharingMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To, routing.IsToOldCommittee, routing.IsToOldAndNewCommittees) - signature, err := s.identityStore.SignMessage(&tssMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to sign message: %w", err) - return - } - tssMsg.Signature = signature - msgBytes, err := types.MarshalTssMessage(&tssMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to marshal tss message: %w", err) - return - } - - // Just send to all intended recipients except self - for _, to := range routing.To { - if to.Id != s.selfPartyID.Id { - s.direct.Send(s.topicComposer.ComposeDirectTopic(PartyIDToNodeID(to)), msgBytes) - } - } -} - -func (s *Session) receiveTssMessage(rawMsg []byte) { - msg, err := types.UnmarshalTssMessage(rawMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to unmarshal message: %w", err) - return - } - err = s.identityStore.VerifyMessage(msg) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to verify message: %w, tampered message", err) - return - } - - toIDs := make([]string, len(msg.To)) - for i, id := range msg.To { - toIDs[i] = id.String() - } - - round, err := s.getRoundFunc(msg.MsgBytes, s.selfPartyID, msg.IsBroadcast) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Broken TSS Share") - return - } - - logger.Info(fmt.Sprintf("%s Received message", s.sessionType), - "from", msg.From.String(), - "to", strings.Join(toIDs, ","), - "isBroadcast", msg.IsBroadcast, - "round", round.RoundMsg) - - isBroadcast := msg.IsBroadcast && len(msg.To) == 0 - isToSelf := len(msg.To) == 1 && ComparePartyIDs(msg.To[0], s.selfPartyID) - - if isBroadcast || isToSelf { - s.mu.Lock() - defer s.mu.Unlock() - ok, err := s.party.UpdateFromBytes(msg.MsgBytes, msg.From, msg.IsBroadcast) - if !ok || err != nil { - logger.Error("Failed to update party", err, "walletID", s.walletID) - return - } - } -} - -func (s *Session) receiveTssResharingMessage(rawMsg []byte) { - msg, err := types.UnmarshalTssMessage(rawMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to unmarshal message: %w", err) - return - } - err = s.identityStore.VerifyMessage(msg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to verify message: %w, tampered message", err) - return - } - - toIDs := make([]string, len(msg.To)) - for i, id := range msg.To { - toIDs[i] = id.String() - } - round, err := s.getRoundFunc(msg.MsgBytes, s.selfPartyID, msg.IsBroadcast) - if err != nil { - s.ErrCh <- errors.Wrap(err, "Broken TSS Share") - return - } - - logger.Info(fmt.Sprintf("%s Received resharing message", s.sessionType), - "from", msg.From.String(), - "to", strings.Join(toIDs, ","), - "isBroadcast", msg.IsBroadcast, - "round", round.RoundMsg) - - isToSelf := slices.Contains(toIDs, s.selfPartyID.String()) - if isToSelf { - s.mu.Lock() - defer s.mu.Unlock() - ok, err := s.party.UpdateFromBytes(msg.MsgBytes, msg.From, msg.IsBroadcast) - if !ok || err != nil { - logger.Error("Failed to update party", err, "walletID", s.walletID) - return - } - } -} - -func (s *Session) SendReplySignSuccess(natMsg *nats.Msg) { - msg := natMsg.Data - s.mu.Lock() - defer s.mu.Unlock() - - err := s.pubSub.Publish(natMsg.Reply, msg) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to reply sign sucess message: %w", err) - return - } - logger.Info("Sent reply sign sucess message", "reply", natMsg.Reply) -} - -func (s *Session) ListenToIncomingMessageAsync() { - go func() { - sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { - msg := natMsg.Data - s.receiveTssMessage(msg) - }) - - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) - return - } - - s.broadcastSub = sub - }() - - nodeID := PartyIDToNodeID(s.selfPartyID) - targetID := s.topicComposer.ComposeDirectTopic(nodeID) - sub, err := s.direct.Listen(targetID, func(msg []byte) { - go s.receiveTssMessage(msg) // async for avoid timeout - }) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", targetID, err) - } - s.directSub = sub - -} - -func (s *Session) ListenToIncomingResharingMessageAsync() { - nodeID := PartyIDToNodeID(s.selfPartyID) - targetID := s.topicComposer.ComposeDirectTopic(nodeID) - sub, err := s.direct.Listen(targetID, func(msg []byte) { - go s.receiveTssResharingMessage(msg) // async for avoid timeout - }) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", targetID, err) - } - s.directSub = sub -} - -func (s *Session) Close() error { - if s.broadcastSub != nil { - err := s.broadcastSub.Unsubscribe() - if err != nil { - return err - } - } - if s.directSub != nil { - err := s.directSub.Unsubscribe() - if err != nil { - return err - } - } - return nil -} - -func (s *Session) GetPubKeyResult() []byte { - return s.pubkeyBytes -} - -func (s *Session) ErrChan() <-chan error { - return s.ErrCh -} - -// SaveKeyInfo saves the key info with resharing information -func (s *Session) SaveKeyInfo(isReshared bool) error { - keyInfo := &keyinfo.KeyInfo{ - ParticipantPeerIDs: s.participantPeerIDs, - Threshold: s.threshold, - IsReshared: isReshared, - } - - err := s.keyinfoStore.Save(s.composeKey(s.walletID), keyInfo) - if err != nil { - logger.Error("Failed to save keyinfo", err, "walletID", s.walletID) - return err - } - return nil -} - -// SaveKeyData saves the key data to the kvstore -func (s *Session) SaveKeyData(keyBytes []byte) error { - err := s.kvstore.Put(s.composeKey(s.walletID), keyBytes) - if err != nil { - logger.Error("Failed to save key", err, "walletID", s.walletID) - return err - } - return nil -} From 89eefc62974e908b0a01290d8c952821f6683c24 Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 11 Jun 2025 17:48:26 +0700 Subject: [PATCH 07/34] Implement signing result event handling and add save data methods This commit introduces the handling of signing result events within the event consumer, enhancing the signature verification process. Key updates include: - Added `SigningResultEvent` struct to encapsulate signing results, including wallet ID, transaction ID, and signature details. - Implemented signature verification logic in the event consumer, improving error handling and logging for signing operations. - Introduced `GetSaveData` methods in both ECDSAParty and EDDSAParty to facilitate serialization of party state data. - Added `VerifySignature` method in the ECDSASession to validate signatures against stored public keys, enhancing security and reliability. These changes aim to improve the robustness and maintainability of the signing process within the MPC package. --- pkg/event/event.go | 20 ------------- pkg/eventconsumer/event_consumer.go | 32 +++++++++++++++++++++ pkg/mpc/party/base.go | 1 + pkg/mpc/party/ecdsa.go | 9 ++++++ pkg/mpc/party/eddsa.go | 11 ++++++++ pkg/mpc/session/base.go | 3 ++ pkg/mpc/session/constants.go | 6 ---- pkg/mpc/session/ecdsa.go | 44 +++++++++++++++++++++++++++++ 8 files changed, 100 insertions(+), 26 deletions(-) delete mode 100644 pkg/mpc/session/constants.go diff --git a/pkg/event/event.go b/pkg/event/event.go index 4cb4763..c430c2d 100644 --- a/pkg/event/event.go +++ b/pkg/event/event.go @@ -26,26 +26,6 @@ type SigningResultEvent struct { Signature []byte `json:"signature"` } -type SigningResultSuccessEvent struct { - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - R []byte `json:"r"` - S []byte `json:"s"` - SignatureRecovery []byte `json:"signature_recovery"` - - // TODO: define two separate events for eddsa and ecdsa - Signature []byte `json:"signature"` -} - -type SigningResultErrorEvent struct { - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - ErrorReason string `json:"error_reason"` - IsTimeout bool `json:"is_timeout"` -} - type ResharingSuccessEvent struct { WalletID string `json:"wallet_id"` ECDSAPubKey []byte `json:"ecdsa_pub_key"` diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index d7cf80a..8f9a2bb 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -222,6 +222,38 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) signingSession.StartSigning(ctx, txBigInt, signingSession.Send, func(data []byte) { cancel() + fmt.Println("data", data) + ok, r, s, signatureRecovery, err := signingSession.VerifySignature(msg.Tx, data) + if err != nil || !ok { + logger.Error("Failed to verify signature", err) + return + } + + signingResult := event.SigningResultEvent{ + WalletID: msg.WalletID, + TxID: msg.TxID, + NetworkInternalCode: msg.NetworkInternalCode, + ResultType: event.SigningResultTypeSuccess, + Signature: data, + R: r, + S: s, + SignatureRecovery: signatureRecovery, + } + + signingResultBytes, err := json.Marshal(signingResult) + if err != nil { + logger.Error("Failed to marshal signing result event", err) + return + } + + err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: event.SigningResultCompleteTopic, + }) + if err != nil { + logger.Error("Failed to publish signing result event", err) + return + } + logger.Info("Signing completed", "walletID", msg.WalletID, "txID", msg.TxID, "data", len(data)) }) }() diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 9da9d3f..896b80c 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -26,6 +26,7 @@ type PartyInterface interface { StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) PartyID() *tss.PartyID + GetSaveData() []byte SetSaveData(saveData []byte) InCh() chan types.TssMessage OutCh() chan tss.Message diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index 4729e05..b661596 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -30,6 +30,15 @@ func NewECDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyI } } +func (s *ECDSAParty) GetSaveData() []byte { + saveData, err := json.Marshal(s.saveData) + if err != nil { + s.ErrCh() <- fmt.Errorf("failed serializing shares: %w", err) + return nil + } + return saveData +} + func (s *ECDSAParty) SetSaveData(saveData []byte) { localSaveData := &keygen.LocalPartySaveData{} err := json.Unmarshal(saveData, localSaveData) diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index 8d06f51..946319c 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -2,7 +2,9 @@ package party import ( "context" + "encoding/json" "errors" + "fmt" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -27,6 +29,15 @@ func NewEDDASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.Party } } +func (s *EDDSAParty) GetSaveData() []byte { + saveData, err := json.Marshal(s.saveData) + if err != nil { + s.ErrCh() <- fmt.Errorf("failed serializing shares: %w", err) + return nil + } + return saveData +} + func (s *EDDSAParty) SetSaveData(saveData []byte) { // s.saveData = saveData.(*keygen.LocalPartySaveData) } diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index b88e9ea..63c489f 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -41,8 +41,11 @@ type KeyComposerFn func(id string) string type Session interface { StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) + GetSaveData() ([]byte, error) GetPublicKey(data []byte) []byte + VerifySignature(msg []byte, signature []byte) (bool, []byte, []byte, []byte, error) + Send(msg tss.Message) Listen(nodeID string) SaveKey(participantPeerIDs []string, threshold int, isReshared bool, data []byte) (err error) diff --git a/pkg/mpc/session/constants.go b/pkg/mpc/session/constants.go deleted file mode 100644 index 745cd04..0000000 --- a/pkg/mpc/session/constants.go +++ /dev/null @@ -1,6 +0,0 @@ -package session - -const ( - KeygenBroadcastTopic = "keygen:broadcast:%s" - KeygenDirectTopic = "keygen:direct:%s:%s" -) \ No newline at end of file diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index e8668d2..fc03f4a 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -4,9 +4,11 @@ import ( "context" "crypto/ecdsa" "encoding/json" + "errors" "fmt" "math/big" + "github.com/bnb-chain/tss-lib/v2/common" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/encoding" @@ -71,3 +73,45 @@ func (s *ECDSASession) GetPublicKey(data []byte) []byte { } return pubKeyBytes } + +func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (bool, []byte, []byte, []byte, error) { + signatureData := &common.SignatureData{} + err := json.Unmarshal(signature, signatureData) + if err != nil { + return false, nil, nil, nil, fmt.Errorf("failed to unmarshal signature data: %w", err) + } + + data := s.party.GetSaveData() + if data == nil { + return false, nil, nil, nil, errors.New("save data is nil") + } + + saveData := &keygen.LocalPartySaveData{} + err = json.Unmarshal(data, saveData) + if err != nil { + return false, nil, nil, nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + + if saveData.ECDSAPub == nil { + return false, nil, nil, nil, errors.New("ECDSA public key is nil") + } + + publicKey := saveData.ECDSAPub + pk := &ecdsa.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + + // Convert signature components to big integers + r := new(big.Int).SetBytes(signatureData.R) + sigS := new(big.Int).SetBytes(signatureData.S) + + // Verify the signature + ok := ecdsa.Verify(pk, msg, r, sigS) + if !ok { + return false, nil, nil, nil, errors.New("signature verification failed") + } + + return true, signatureData.R, signatureData.S, signatureData.SignatureRecovery, nil +} From aa4f3d36d8869c87678a7127ad6c3855138a3486 Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 11 Jun 2025 17:51:54 +0700 Subject: [PATCH 08/34] Refactor signature verification in event consumer and session management This commit updates the signature verification process within the event consumer and session management. Key changes include: - Modified the `VerifySignature` method in the `ECDSASession` to return a structured `common.SignatureData` instead of multiple return values, enhancing clarity and usability. - Updated the event consumer to utilize the new signature verification method, improving error handling and logging. - Removed unnecessary variables in the event consumer, streamlining the signature handling process. These changes aim to improve the maintainability and robustness of the signature verification workflow in the MPC package. --- pkg/eventconsumer/event_consumer.go | 11 +++++------ pkg/mpc/session/base.go | 3 ++- pkg/mpc/session/ecdsa.go | 14 +++++++------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 8f9a2bb..9e6c963 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -222,9 +222,8 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) signingSession.StartSigning(ctx, txBigInt, signingSession.Send, func(data []byte) { cancel() - fmt.Println("data", data) - ok, r, s, signatureRecovery, err := signingSession.VerifySignature(msg.Tx, data) - if err != nil || !ok { + signatureData, err := signingSession.VerifySignature(msg.Tx, data) + if err != nil { logger.Error("Failed to verify signature", err) return } @@ -235,9 +234,9 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { NetworkInternalCode: msg.NetworkInternalCode, ResultType: event.SigningResultTypeSuccess, Signature: data, - R: r, - S: s, - SignatureRecovery: signatureRecovery, + R: signatureData.R, + S: signatureData.S, + SignatureRecovery: signatureData.SignatureRecovery, } signingResultBytes, err := json.Marshal(signingResult) diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 63c489f..c146309 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -7,6 +7,7 @@ import ( "slices" "sync" + "github.com/bnb-chain/tss-lib/v2/common" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/keyinfo" @@ -44,7 +45,7 @@ type Session interface { GetSaveData() ([]byte, error) GetPublicKey(data []byte) []byte - VerifySignature(msg []byte, signature []byte) (bool, []byte, []byte, []byte, error) + VerifySignature(msg []byte, signature []byte) (common.SignatureData, error) Send(msg tss.Message) Listen(nodeID string) diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index fc03f4a..4021c69 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -74,26 +74,26 @@ func (s *ECDSASession) GetPublicKey(data []byte) []byte { return pubKeyBytes } -func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (bool, []byte, []byte, []byte, error) { +func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (common.SignatureData, error) { signatureData := &common.SignatureData{} err := json.Unmarshal(signature, signatureData) if err != nil { - return false, nil, nil, nil, fmt.Errorf("failed to unmarshal signature data: %w", err) + return common.SignatureData{}, fmt.Errorf("failed to unmarshal signature data: %w", err) } data := s.party.GetSaveData() if data == nil { - return false, nil, nil, nil, errors.New("save data is nil") + return common.SignatureData{}, errors.New("save data is nil") } saveData := &keygen.LocalPartySaveData{} err = json.Unmarshal(data, saveData) if err != nil { - return false, nil, nil, nil, fmt.Errorf("failed to unmarshal save data: %w", err) + return common.SignatureData{}, fmt.Errorf("failed to unmarshal save data: %w", err) } if saveData.ECDSAPub == nil { - return false, nil, nil, nil, errors.New("ECDSA public key is nil") + return common.SignatureData{}, errors.New("ECDSA public key is nil") } publicKey := saveData.ECDSAPub @@ -110,8 +110,8 @@ func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (bool, []by // Verify the signature ok := ecdsa.Verify(pk, msg, r, sigS) if !ok { - return false, nil, nil, nil, errors.New("signature verification failed") + return common.SignatureData{}, errors.New("signature verification failed") } - return true, signatureData.R, signatureData.S, signatureData.SignatureRecovery, nil + return *signatureData, nil } From 0e45e35bf763a5bfa3e91a171e354918a91e11d2 Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 11 Jun 2025 23:40:06 +0700 Subject: [PATCH 09/34] Update docker-compose and refactor event topics and key path This commit includes the following changes: - Removed the version declaration from the `docker-compose.yaml` file for simplification. - Updated the key path in `main.go` to a relative path for better portability. - Refactored event topic constants in `event.go` to include the `mpc_` prefix for consistency. - Modified the idempotent key in `event_consumer.go` to use a formatted string based on the wallet ID, enhancing clarity in message handling. These changes aim to improve the organization and maintainability of the codebase. --- docker-compose.yaml | 2 -- examples/generate/main.go | 2 +- pkg/event/event.go | 6 ++++-- pkg/eventconsumer/event_consumer.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index bdced5f..7a80bf0 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,5 +1,3 @@ -version: "3" - services: nats-server: image: nats:latest diff --git a/examples/generate/main.go b/examples/generate/main.go index f4c936c..46f47e3 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -28,7 +28,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "/home/viet/Documents/other/mpcium/event_initiator.key", + KeyPath: "./../../event_initiator.key", }) err = mpcClient.OnWalletCreationResult(func(event event.KeygenSuccessEvent) { logger.Info("Received wallet creation result", "event", event) diff --git a/pkg/event/event.go b/pkg/event/event.go index c430c2d..d74b7a5 100644 --- a/pkg/event/event.go +++ b/pkg/event/event.go @@ -1,8 +1,10 @@ package event const ( - KeygenSuccessEventTopic = "mpc.keygen.success.*" - ResharingSuccessEventTopic = "mpc.resharing.success.*" + KeygenSuccessEventTopic = "mpc.mpc_keygen_success.*" + ResharingSuccessEventTopic = "mpc.mpc_resharing_success.*" + + TypeGenerateWalletSuccess = "mpc.mpc_keygen_success.%s" ) type KeygenSuccessEvent struct { diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 9e6c963..e03c2df 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -145,7 +145,7 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { } err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: event.KeygenSuccessEventTopic, + IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), }) if err != nil { logger.Error("Failed to publish key generation success message", err) From 7d9fd5e99a462d7cafd6f3ba4c29500f97ceb7c2 Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 12 Jun 2025 00:29:20 +0700 Subject: [PATCH 10/34] Enhance MPC session management and event handling This commit introduces several improvements to the MPC package, including: - Updated the `CreateKeygenSession` method to handle both ECDSA and EDDSA key types, improving flexibility in session creation. - Added a new `EDDSASession` struct to manage EDDSA key generation and signing processes, enhancing modularity. - Refactored the event consumer to support the new EDDSA session, including improved error handling and logging for key generation and signing events. - Modified the `VerifySignature` method in both ECDSASession and EDDSASession to return structured data, improving clarity and usability. These changes aim to enhance the maintainability and robustness of the MPC package, facilitating future enhancements. --- examples/sign/main.go | 4 +- pkg/event/event.go | 1 + pkg/eventconsumer/event_consumer.go | 91 ++++++++++++++-------- pkg/mpc/node/node.go | 116 ++++++++++++++++++---------- pkg/mpc/party/eddsa.go | 14 ++-- pkg/mpc/session/base.go | 2 +- pkg/mpc/session/ecdsa.go | 14 ++-- pkg/mpc/session/eddsa.go | 113 +++++++++++++++++++++++++++ 8 files changed, 265 insertions(+), 90 deletions(-) diff --git a/examples/sign/main.go b/examples/sign/main.go index 3cf67b5..b69337d 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -30,7 +30,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "/home/viet/Documents/other/mpcium/event_initiator.key", + KeyPath: "./../../event_initiator.key", }) // 2) Once wallet exists, immediately fire a SignTransaction @@ -39,7 +39,7 @@ func main() { txMsg := &types.SignTxMessage{ KeyType: types.KeyTypeSecp256k1, - WalletID: "0bf609ad-63ed-4713-a673-e09d43f316d3", + WalletID: "9af13a60-9aa9-4069-ba3f-bd6d821c8905", NetworkInternalCode: "ethereum-sepolia", TxID: txID, Tx: dummyTx, diff --git a/pkg/event/event.go b/pkg/event/event.go index d74b7a5..335a7ab 100644 --- a/pkg/event/event.go +++ b/pkg/event/event.go @@ -5,6 +5,7 @@ const ( ResharingSuccessEventTopic = "mpc.mpc_resharing_success.*" TypeGenerateWalletSuccess = "mpc.mpc_keygen_success.%s" + TypeSigningResultComplete = "mpc.mpc_signing_result_complete.%s.%s" ) type KeygenSuccessEvent struct { diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index e03c2df..4a09d5e 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -117,51 +117,76 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { } walletID := msg.WalletID - session, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + ecdsaSession, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) if err != nil { logger.Error("Failed to create key generation session", err, "walletID", walletID) return } - // Start listening for messages first - go session.Listen(ec.node.ID()) + eddsaSession, err := ec.node.CreateKeygenSession(types.KeyTypeEd25519, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + if err != nil { + logger.Error("Failed to create key generation session", err, "walletID", walletID) + return + } - // Start the key generation process - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) - session.StartKeygen(ctx, session.Send, func(data []byte) { - cancel() - session.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) + // Start listening for messages first + go ecdsaSession.Listen(ec.node.ID()) + go eddsaSession.Listen(ec.node.ID()) + successEvent := &event.KeygenSuccessEvent{ + WalletID: walletID, + } - successEvent := &event.KeygenSuccessEvent{ - WalletID: walletID, - ECDSAPubKey: session.GetPublicKey(data), - } + var wg sync.WaitGroup + wg.Add(2) - successEventBytes, err := json.Marshal(successEvent) - if err != nil { - logger.Error("Failed to marshal keygen success event", err) + // Handle errors from the session + go func() { + for { + select { + case err := <-ecdsaSession.ErrCh(): + logger.Error("Error from ECDSA session", err) return - } - - err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), - }) - if err != nil { - logger.Error("Failed to publish key generation success message", err) + case err := <-eddsaSession.ErrCh(): + logger.Error("Error from EDDSA session", err) return } - logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID, "data", len(data)) - }) - }() - - // Handle errors from the session - go func() { - for err := range session.ErrCh() { - logger.Error("Error from session", err) - return } }() + + // Start the key generation process + ecdsaCtx, ecdsaCancel := context.WithTimeout(context.Background(), 30*time.Second) + go ecdsaSession.StartKeygen(ecdsaCtx, ecdsaSession.Send, func(data []byte) { + ecdsaCancel() + ecdsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) + successEvent.ECDSAPubKey = ecdsaSession.GetPublicKey(data) + wg.Done() + }) + + eddsaCtx, eddsaCancel := context.WithTimeout(context.Background(), 30*time.Second) + go eddsaSession.StartKeygen(eddsaCtx, eddsaSession.Send, func(data []byte) { + eddsaCancel() + eddsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) + successEvent.EDDSAPubKey = eddsaSession.GetPublicKey(data) + wg.Done() + }) + + wg.Wait() + + // Marshal the success event + successEventBytes, err := json.Marshal(successEvent) + if err != nil { + logger.Error("Failed to marshal keygen success event", err) + return + } + + err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ + IdempotententKey: event.KeygenSuccessEventTopic, + }) + if err != nil { + logger.Error("Failed to publish key generation success message", err) + return + } + logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) }) ec.keyGenerationSub = sub @@ -246,7 +271,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { } err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: event.SigningResultCompleteTopic, + IdempotententKey: fmt.Sprintf(event.TypeSigningResultComplete, msg.WalletID, msg.TxID), }) if err != nil { logger.Error("Failed to publish signing result event", err) diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index f468276..60ca832 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -50,62 +50,98 @@ func (n *Node) ID() string { return n.nodeID } -func (n *Node) CreateKeygenSession(_ types.KeyType, walletID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { +func (n *Node) CreateKeygenSession(keyType types.KeyType, walletID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { return nil, fmt.Errorf("not enough peers to create gen session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) } readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() selfPartyID, allPartyIDs := n.generatePartyIDs("keygen", readyPeerIDs) - preparams, err := n.getECDSAPreParams(false) - if err != nil { - return nil, fmt.Errorf("failed to get preparams: %w", err) + switch keyType { + case types.KeyTypeSecp256k1: + preparams, err := n.getECDSAPreParams(false) + if err != nil { + return nil, fmt.Errorf("failed to get preparams: %w", err) + } + logger.Info("Preparams loaded") + + ecdsaSession := session.NewECDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + *preparams, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + ) + + return ecdsaSession, nil + case types.KeyTypeEd25519: + eddsaSession := session.NewEDDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + ) + return eddsaSession, nil + default: + return nil, fmt.Errorf("invalid key type: %s", keyType) } - logger.Info("Preparams loaded") - - ecdsaSession := session.NewECDSASession( - walletID, - selfPartyID, - allPartyIDs, - threshold, - *preparams, - n.pubSub, - n.direct, - n.identityStore, - n.kvstore, - n.keyinfoStore, - ) - - return ecdsaSession, nil } -func (n *Node) CreateSigningSession(_ types.KeyType, walletID string, txID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { +func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { return nil, fmt.Errorf("not enough peers to create gen session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) } readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() selfPartyID, allPartyIDs := n.generatePartyIDs("keygen", readyPeerIDs) - ecdsaSession := session.NewECDSASession( - walletID, - selfPartyID, - allPartyIDs, - threshold, - keygen.LocalPreParams{}, - n.pubSub, - n.direct, - n.identityStore, - n.kvstore, - n.keyinfoStore, - ) - saveData, err := ecdsaSession.GetSaveData() - if err != nil { - return nil, fmt.Errorf("failed to get save data: %w", err) - } - - ecdsaSession.SetSaveData(saveData) + switch keyType { + case types.KeyTypeSecp256k1: + ecdsaSession := session.NewECDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + keygen.LocalPreParams{}, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + ) + saveData, err := ecdsaSession.GetSaveData() + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } - return ecdsaSession, nil + ecdsaSession.SetSaveData(saveData) + + return ecdsaSession, nil + case types.KeyTypeEd25519: + eddsaSession := session.NewEDDSASession( + walletID, + selfPartyID, + allPartyIDs, + threshold, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + ) + return eddsaSession, nil + default: + return nil, fmt.Errorf("invalid key type: %s", keyType) + } } func (n *Node) GetReadyPeersIncludeSelf() []string { diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index 946319c..8ede17e 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -20,7 +20,7 @@ type EDDSAParty struct { saveData *keygen.LocalPartySaveData } -func NewEDDASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, +func NewEDDAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData, errCh chan error) *EDDSAParty { return &EDDSAParty{ party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), @@ -43,8 +43,8 @@ func (s *EDDSAParty) SetSaveData(saveData []byte) { } func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { - end := make(chan *keygen.LocalPartySaveData) - params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + end := make(chan *keygen.LocalPartySaveData, 1) + params := tss.NewParameters(tss.Edwards(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := keygen.NewLocalParty(params, s.outCh, end) runParty(s, ctx, party, send, end, finish) } @@ -54,8 +54,8 @@ func (s *EDDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(t s.ErrCh() <- errors.New("save data is nil") return } - end := make(chan *common.SignatureData) - params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) + end := make(chan *common.SignatureData, 1) + params := tss.NewParameters(tss.Edwards(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := signing.NewLocalParty(msg, params, *s.saveData, s.outCh, end) runParty(s, ctx, party, send, end, finish) } @@ -66,9 +66,9 @@ func (s *EDDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs s.ErrCh() <- errors.New("save data is nil") return } - end := make(chan *keygen.LocalPartySaveData) + end := make(chan *keygen.LocalPartySaveData, 1) params := tss.NewReSharingParameters( - tss.S256(), + tss.Edwards(), tss.NewPeerContext(oldPartyIDs), tss.NewPeerContext(newPartyIDs), s.partyID, diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index c146309..d5d62b3 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -45,7 +45,7 @@ type Session interface { GetSaveData() ([]byte, error) GetPublicKey(data []byte) []byte - VerifySignature(msg []byte, signature []byte) (common.SignatureData, error) + VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) Send(msg tss.Message) Listen(nodeID string) diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index 4021c69..8546259 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -74,26 +74,26 @@ func (s *ECDSASession) GetPublicKey(data []byte) []byte { return pubKeyBytes } -func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (common.SignatureData, error) { +func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) { signatureData := &common.SignatureData{} err := json.Unmarshal(signature, signatureData) if err != nil { - return common.SignatureData{}, fmt.Errorf("failed to unmarshal signature data: %w", err) + return nil, fmt.Errorf("failed to unmarshal signature data: %w", err) } data := s.party.GetSaveData() if data == nil { - return common.SignatureData{}, errors.New("save data is nil") + return nil, errors.New("save data is nil") } saveData := &keygen.LocalPartySaveData{} err = json.Unmarshal(data, saveData) if err != nil { - return common.SignatureData{}, fmt.Errorf("failed to unmarshal save data: %w", err) + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) } if saveData.ECDSAPub == nil { - return common.SignatureData{}, errors.New("ECDSA public key is nil") + return nil, errors.New("ECDSA public key is nil") } publicKey := saveData.ECDSAPub @@ -110,8 +110,8 @@ func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (common.Sig // Verify the signature ok := ecdsa.Verify(pk, msg, r, sigS) if !ok { - return common.SignatureData{}, errors.New("signature verification failed") + return nil, errors.New("signature verification failed") } - return *signatureData, nil + return signatureData, nil } diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go index ab87616..a4956a3 100644 --- a/pkg/mpc/session/eddsa.go +++ b/pkg/mpc/session/eddsa.go @@ -1 +1,114 @@ package session + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/big" + + "github.com/bnb-chain/tss-lib/v2/common" + "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/decred/dcrd/dcrec/edwards/v2" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/party" +) + +type EDDSASession struct { + *session +} + +func NewEDDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store) *EDDSASession { + s := NewSession(CurveSecp256k1, PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) + s.party = party.NewEDDAParty(walletID, partyID, partyIDs, threshold, nil, nil, s.errCh) + s.topicComposer = &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("keygen:broadcast:eddsa:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("keygen:direct:eddsa:%s:%s", nodeID, walletID) + }, + } + s.composeKey = func(walletID string) string { + return fmt.Sprintf("eddsa:%s", walletID) + } + return &EDDSASession{ + session: s, + } +} + +func (s *EDDSASession) SetSaveData(saveBytes []byte) { + s.party.SetSaveData(saveBytes) +} + +func (s *EDDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { + s.party.StartKeygen(ctx, send, finish) +} + +func (s *EDDSASession) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { + s.party.StartSigning(ctx, msg, send, finish) +} + +func (s *EDDSASession) GetPublicKey(data []byte) []byte { + saveData := &keygen.LocalPartySaveData{} + err := json.Unmarshal(data, saveData) + if err != nil { + return nil + } + + publicKey := saveData.EDDSAPub + pubKey := &edwards.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + + pubKeyBytes := pubKey.SerializeCompressed() + return pubKeyBytes +} + +func (s *EDDSASession) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) { + signatureData := &common.SignatureData{} + err := json.Unmarshal(signature, signatureData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal signature data: %w", err) + } + + data := s.party.GetSaveData() + if data == nil { + return nil, errors.New("save data is nil") + } + + saveData := &keygen.LocalPartySaveData{} + err = json.Unmarshal(data, saveData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + + if saveData.EDDSAPub == nil { + return nil, errors.New("EDDSA public key is nil") + } + + publicKey := saveData.EDDSAPub + pk := &edwards.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + + // Convert signature components to big integers + r := new(big.Int).SetBytes(signatureData.R) + sigS := new(big.Int).SetBytes(signatureData.S) + + // Verify the signature + ok := edwards.Verify(pk, msg, r, sigS) + if !ok { + return nil, errors.New("signature verification failed") + } + + return signatureData, nil +} From fc0877fb66361bc11f25836d7b84745f86b6b2fa Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 12 Jun 2025 10:21:21 +0700 Subject: [PATCH 11/34] Enhance save data handling in ECDSA and EDDSA parties This commit introduces improvements to the save data management in the MPC package, including: - Updated the `SetSaveData` method in both `ECDSAParty` and `EDDSAParty` to correctly handle deserialization of save data, ensuring proper assignment of local save data structures. - Added a call to retrieve and set save data in the `CreateSigningSession` method of the `Node` struct, enhancing session management. These changes aim to improve the reliability and clarity of the save data handling process within the MPC package. --- pkg/mpc/node/node.go | 8 ++++++++ pkg/mpc/party/ecdsa.go | 6 +++--- pkg/mpc/party/eddsa.go | 14 ++++++++++++-- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index 60ca832..e10bb2f 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -138,6 +138,14 @@ func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID n.kvstore, n.keyinfoStore, ) + + saveData, err := eddsaSession.GetSaveData() + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + + eddsaSession.SetSaveData(saveData) + return eddsaSession, nil default: return nil, fmt.Errorf("invalid key type: %s", keyType) diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index b661596..00a62c2 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -40,8 +40,8 @@ func (s *ECDSAParty) GetSaveData() []byte { } func (s *ECDSAParty) SetSaveData(saveData []byte) { - localSaveData := &keygen.LocalPartySaveData{} - err := json.Unmarshal(saveData, localSaveData) + var localSaveData keygen.LocalPartySaveData + err := json.Unmarshal(saveData, &localSaveData) if err != nil { s.ErrCh() <- fmt.Errorf("failed deserializing shares: %w", err) return @@ -50,7 +50,7 @@ func (s *ECDSAParty) SetSaveData(saveData []byte) { for _, xj := range localSaveData.BigXj { xj.SetCurve(tss.S256()) } - s.saveData = localSaveData + s.saveData = &localSaveData } func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index 8ede17e..b2a6405 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -38,8 +38,18 @@ func (s *EDDSAParty) GetSaveData() []byte { return saveData } -func (s *EDDSAParty) SetSaveData(saveData []byte) { - // s.saveData = saveData.(*keygen.LocalPartySaveData) +func (s *EDDSAParty) SetSaveData(shareData []byte) { + var localSaveData keygen.LocalPartySaveData + err := json.Unmarshal(shareData, &localSaveData) + if err != nil { + s.ErrCh() <- fmt.Errorf("failed deserializing shares: %w", err) + return + } + localSaveData.EDDSAPub.SetCurve(tss.Edwards()) + for _, xj := range localSaveData.BigXj { + xj.SetCurve(tss.Edwards()) + } + s.saveData = &localSaveData } func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { From 7c650053c34134ef19cd17208b0977bd4375283d Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 12 Jun 2025 14:11:44 +0700 Subject: [PATCH 12/34] Implement resharing event handling and enhance session management This commit introduces the following changes to the MPC package: - Added a new `consumeResharingEvent` method in the `eventConsumer` to handle resharing events, improving event processing and logging. - Updated the `Node` struct to include a `CreateResharingSession` method, facilitating the creation of resharing sessions for both ECDSA and EDDSA key types. - Refactored session methods to support resharing, including updates to the `Listen` method to differentiate between key generation and resharing sessions. - Enhanced error handling and logging throughout the resharing process, ensuring better traceability and reliability. These changes aim to improve the robustness and maintainability of the MPC package, particularly in managing resharing events and sessions. --- cmd/mpcium/main.go | 4 +- go.mod | 4 +- pkg/event/event.go | 1 + pkg/eventconsumer/event_consumer.go | 292 +++++++++++++--------------- pkg/messaging/point2point.go | 2 +- pkg/mpc/node/node.go | 68 +++++-- pkg/mpc/party/base.go | 7 +- pkg/mpc/party/ecdsa.go | 16 +- pkg/mpc/party/eddsa.go | 2 +- pkg/mpc/session/base.go | 21 +- pkg/mpc/session/ecdsa.go | 15 +- pkg/mpc/session/eddsa.go | 14 +- 12 files changed, 243 insertions(+), 203 deletions(-) diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 563fa12..5d1c93a 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -87,11 +87,11 @@ func runNode(ctx context.Context, c *cli.Command) error { nodeName := c.String("name") decryptPrivateKey := c.Bool("decrypt-private-key") usePrompts := c.Bool("prompt-credentials") - debug := c.Bool("debug") + // debug := c.Bool("debug") config.InitViperConfig() environment := viper.GetString("environment") - logger.Init(environment, debug) + logger.Init(environment, true) // Handle configuration based on prompt flag if usePrompts { diff --git a/go.mod b/go.mod index 8a4b6ed..04268d9 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( github.com/rs/zerolog v1.31.0 github.com/samber/lo v1.39.0 github.com/spf13/viper v1.18.0 - github.com/stretchr/testify v1.10.0 github.com/urfave/cli/v3 v3.3.2 golang.org/x/term v0.31.0 ) @@ -30,7 +29,6 @@ require ( github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect github.com/btcsuite/btcutil v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dustin/go-humanize v1.0.0 // indirect @@ -66,7 +64,6 @@ require ( github.com/otiai10/primes v0.0.0-20210501021515-f1b2be525a11 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect @@ -74,6 +71,7 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.opencensus.io v0.24.0 // indirect go.uber.org/atomic v1.9.0 // indirect diff --git a/pkg/event/event.go b/pkg/event/event.go index 335a7ab..0d333dd 100644 --- a/pkg/event/event.go +++ b/pkg/event/event.go @@ -6,6 +6,7 @@ const ( TypeGenerateWalletSuccess = "mpc.mpc_keygen_success.%s" TypeSigningResultComplete = "mpc.mpc_signing_result_complete.%s.%s" + TypeResharingSuccess = "mpc.mpc_resharing_success.%s" ) type KeygenSuccessEvent struct { diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 4a09d5e..026a631 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -91,10 +91,10 @@ func (ec *eventConsumer) Run() { log.Fatal("Failed to consume tx signing event", err) } - // err = ec.consumeResharingEvent() - // if err != nil { - // log.Fatal("Failed to consume resharing event", err) - // } + err = ec.consumeResharingEvent() + if err != nil { + log.Fatal("Failed to consume resharing event", err) + } logger.Info("MPC Event consumer started...!") } @@ -130,8 +130,8 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { } // Start listening for messages first - go ecdsaSession.Listen(ec.node.ID()) - go eddsaSession.Listen(ec.node.ID()) + go ecdsaSession.Listen(ec.node.ID(), false) + go eddsaSession.Listen(ec.node.ID(), false) successEvent := &event.KeygenSuccessEvent{ WalletID: walletID, } @@ -158,7 +158,12 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { go ecdsaSession.StartKeygen(ecdsaCtx, ecdsaSession.Send, func(data []byte) { ecdsaCancel() ecdsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) - successEvent.ECDSAPubKey = ecdsaSession.GetPublicKey(data) + ecdsaPubKey, err := ecdsaSession.GetPublicKey(data) + if err != nil { + logger.Error("Failed to get ECDSA public key", err) + return + } + successEvent.ECDSAPubKey = ecdsaPubKey wg.Done() }) @@ -166,7 +171,12 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { go eddsaSession.StartKeygen(eddsaCtx, eddsaSession.Send, func(data []byte) { eddsaCancel() eddsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) - successEvent.EDDSAPubKey = eddsaSession.GetPublicKey(data) + eddsaPubKey, err := eddsaSession.GetPublicKey(data) + if err != nil { + logger.Error("Failed to get EDDSA public key", err) + return + } + successEvent.EDDSAPubKey = eddsaPubKey wg.Done() }) @@ -180,7 +190,7 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { } err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: event.KeygenSuccessEventTopic, + IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), }) if err != nil { logger.Error("Failed to publish key generation success message", err) @@ -240,11 +250,11 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return } - go signingSession.Listen(ec.node.ID()) + go signingSession.Listen(ec.node.ID(), false) txBigInt := new(big.Int).SetBytes(msg.Tx) go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) signingSession.StartSigning(ctx, txBigInt, signingSession.Send, func(data []byte) { cancel() signatureData, err := signingSession.VerifySignature(msg.Tx, data) @@ -311,6 +321,113 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return nil } +func (ec *eventConsumer) consumeResharingEvent() error { + sub, err := ec.pubsub.Subscribe(MPCResharingEvent, func(natMsg *nats.Msg) { + raw := natMsg.Data + var msg types.ResharingMessage + err := json.Unmarshal(raw, &msg) + if err != nil { + logger.Error("Failed to unmarshal resharing message", err) + return + } + logger.Info("Received resharing event", "walletID", msg.WalletID, "oldThreshold", ec.mpcThreshold, "newThreshold", msg.NewThreshold) + + err = ec.identityStore.VerifyInitiatorMessage(&msg) + if err != nil { + logger.Error("Failed to verify initiator message", err) + return + } + oldSession, err := ec.node.CreateResharingSession( + true, + msg.KeyType, + msg.WalletID, + ec.mpcThreshold, + ec.resharingResultQueue, + ) + if err != nil { + logger.Error("Failed to create resharing session", err) + return + } + + newSession, err := ec.node.CreateResharingSession( + false, + msg.KeyType, + msg.WalletID, + msg.NewThreshold, + ec.resharingResultQueue, + ) + if err != nil { + logger.Error("Failed to create resharing session", err) + return + } + + go oldSession.Listen(ec.node.ID(), false) + go newSession.Listen(ec.node.ID(), true) + + successEvent := &event.ResharingSuccessEvent{ + WalletID: msg.WalletID, + } + + var wg sync.WaitGroup + wg.Add(2) + + // Handle errors from the session + go func() { + for { + select { + case err := <-oldSession.ErrCh(): + logger.Error("Error from ECDSA session", err) + return + case err := <-newSession.ErrCh(): + logger.Error("Error from EDDSA session", err) + return + } + } + }() + + oldCtx, oldCancel := context.WithTimeout(context.Background(), 30*time.Second) + go oldSession.StartResharing(oldCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, oldSession.Send, func(data []byte) { + oldCancel() + wg.Done() + }) + + newCtx, newCancel := context.WithTimeout(context.Background(), 30*time.Second) + go newSession.StartResharing(newCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, newSession.Send, func(data []byte) { + newCancel() + newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, true, data) + ecdsaPubKey, err := newSession.GetPublicKey(data) + if err != nil { + logger.Error("Failed to get ECDSA public key", err) + return + } + successEvent.ECDSAPubKey = ecdsaPubKey + wg.Done() + }) + + wg.Wait() + + successEventBytes, err := json.Marshal(successEvent) + if err != nil { + logger.Error("Failed to marshal resharing success event", err) + return + } + + err = ec.resharingResultQueue.Enqueue(event.ResharingSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(event.TypeResharingSuccess, msg.WalletID), + }) + if err != nil { + logger.Error("Failed to publish resharing result event", err) + return + } + }) + + ec.resharingSub = sub + if err != nil { + return err + } + return nil +} + func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkInternalCode string, err error, errMsg string, natMsg *nats.Msg) { logger.Error("signing session error", err, "walletID", walletID, "txID", txID, "error", errMsg) signingResult := event.SigningResultEvent{ @@ -337,151 +454,6 @@ func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkIntern } } -// func (ec *eventConsumer) consumeResharingEvent() error { -// sub, err := ec.pubsub.Subscribe(MPCResharingEvent, func(natMsg *nats.Msg) { -// raw := natMsg.Data -// var msg types.ResharingMessage -// err := json.Unmarshal(raw, &msg) -// if err != nil { -// logger.Error("Failed to unmarshal resharing message", err) -// return -// } -// logger.Info("Received resharing event", "walletID", msg.WalletID, "newThreshold", msg.NewThreshold) - -// err = ec.identityStore.VerifyInitiatorMessage(&msg) -// if err != nil { -// logger.Error("Failed to verify initiator message", err) -// return -// } - -// walletID := msg.WalletID -// newThreshold := msg.NewThreshold - -// // Get new participants -// readyPeerIDs := ec.node.GetReadyPeersIncludeSelf() -// if len(readyPeerIDs) < newThreshold+1 { -// logger.Error("Not enough peers for resharing", nil, "expected", newThreshold+1, "got", len(readyPeerIDs)) -// return -// } - -// var oldPSession, newPSession mpc.IResharingSession - -// switch msg.KeyType { -// case types.KeyTypeSecp256k1: -// // Create resharing oldPSession -// oldPSession, err = ec.node.CreateECDSAResharingSession(walletID, true, readyPeerIDs, newThreshold, ec.resharingResultQueue) -// if err != nil { -// logger.Error("Failed to create resharing session", err) -// return -// } -// newPSession, err = ec.node.CreateECDSAResharingSession(walletID, false, readyPeerIDs, newThreshold, ec.resharingResultQueue) -// if err != nil { -// logger.Error("Failed to create resharing session", err) -// return -// } -// case types.KeyTypeEd25519: -// // Create resharing oldPSession -// oldPSession, err = ec.node.CreeateEDDSAResharingSession(walletID, true, readyPeerIDs, newThreshold, ec.resharingResultQueue) -// if err != nil { -// logger.Error("Failed to create resharing session", err) -// return -// } -// newPSession, err = ec.node.CreeateEDDSAResharingSession(walletID, false, readyPeerIDs, newThreshold, ec.resharingResultQueue) -// if err != nil { -// logger.Error("Failed to create resharing session", err) -// return -// } -// } - -// oldPSession.Init() -// newPSession.Init() - -// oldPSessionCtx, oldPSessionDone := context.WithCancel(context.Background()) -// newPSessionCtx, newPSessionDone := context.WithCancel(context.Background()) - -// successEvent := &mpc.ResharingSuccessEvent{ -// WalletID: walletID, -// } - -// var wg sync.WaitGroup -// wg.Add(2) - -// // For old party, we just need to wait for completion -// go func() { -// for { -// select { -// case <-oldPSessionCtx.Done(): -// wg.Done() -// logger.Info("oldPSession done") -// return -// case err := <-oldPSession.ErrChan(): -// if err != nil { -// logger.Error("Resharing session error", err) -// } -// } -// } -// }() - -// // For new party, we need to get the public key -// go func() { -// for { -// select { -// case <-newPSessionCtx.Done(): -// if msg.KeyType == types.KeyTypeSecp256k1 { -// successEvent.ECDSAPubKey = newPSession.GetPubKeyResult() -// } else { -// successEvent.EDDSAPubKey = newPSession.GetPubKeyResult() -// } -// wg.Done() -// logger.Info("newPSession done") -// return -// case err := <-newPSession.ErrChan(): -// if err != nil { -// logger.Error("Resharing session error", err) -// } -// } -// } -// }() - -// // Start listening for messages -// oldPSession.ListenToIncomingResharingMessageAsync() -// newPSession.ListenToIncomingResharingMessageAsync() -// time.Sleep(1 * time.Second) - -// // Start resharing process -// go oldPSession.Resharing(oldPSessionDone) -// go newPSession.Resharing(newPSessionDone) - -// // Wait for both sessions to complete -// wg.Wait() -// logger.Info("Closing session successfully!", -// "event", successEvent) - -// successEventBytes, err := json.Marshal(successEvent) -// if err != nil { -// logger.Error("Failed to marshal resharing success event", err) -// return -// } - -// err = ec.resharingResultQueue.Enqueue(fmt.Sprintf(mpc.TypeResharingSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ -// IdempotententKey: fmt.Sprintf(mpc.TypeResharingSuccess, walletID), -// }) -// if err != nil { -// logger.Error("Failed to publish resharing result event", err) -// return -// } - -// logger.Info("[COMPLETED RESHARING] Resharing completed successfully", -// "walletID", walletID) -// }) - -// ec.resharingSub = sub -// if err != nil { -// return err -// } -// return nil -// } - // Add a cleanup routine that runs periodically func (ec *eventConsumer) sessionCleanupRoutine() { ticker := time.NewTicker(ec.cleanupInterval) @@ -519,6 +491,14 @@ func (ec *eventConsumer) addSession(walletID, txID string) { ec.sessionsLock.Unlock() } +// Remove a session from tracking +func (ec *eventConsumer) removeSession(walletID, txID string) { + sessionID := fmt.Sprintf("%s-%s", walletID, txID) + ec.sessionsLock.Lock() + delete(ec.activeSessions, sessionID) + ec.sessionsLock.Unlock() +} + // checkAndTrackSession checks if a session already exists and tracks it if new. // Returns true if the session is a duplicate. func (ec *eventConsumer) checkDuplicateSession(walletID, txID string) bool { diff --git a/pkg/messaging/point2point.go b/pkg/messaging/point2point.go index a9af8f7..143b99c 100644 --- a/pkg/messaging/point2point.go +++ b/pkg/messaging/point2point.go @@ -37,7 +37,7 @@ func (d *natsDirectMessaging) Send(id string, message []byte) error { retry.Delay(50*time.Millisecond), retry.DelayType(retry.FixedDelay), retry.OnRetry(func(n uint, err error) { - logger.Error("Failed to send direct message message", err, "retryCount", retryCount) + logger.Error("Failed to send direct message message", err, "retryCount", retryCount, "target", id) }), ) diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index e10bb2f..f2e40f1 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -63,8 +63,6 @@ func (n *Node) CreateKeygenSession(keyType types.KeyType, walletID string, thres if err != nil { return nil, fmt.Errorf("failed to get preparams: %w", err) } - logger.Info("Preparams loaded") - ecdsaSession := session.NewECDSASession( walletID, selfPartyID, @@ -152,23 +150,65 @@ func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID } } +func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, walletID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { + if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { + return nil, fmt.Errorf("not enough peers to create resharing session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) + } + readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() + var selfPartyID *tss.PartyID + var partyIDs []*tss.PartyID + if isOldParty { + selfPartyID, partyIDs = n.generatePartyIDs("keygen", readyPeerIDs) + } else { + selfPartyID, partyIDs = n.generatePartyIDs("resharing", readyPeerIDs) + } + + switch keyType { + case types.KeyTypeSecp256k1: + preparams, err := n.getECDSAPreParams(isOldParty) + if err != nil { + return nil, fmt.Errorf("failed to get preparams: %w", err) + } + ecdsaSession := session.NewECDSASession(walletID, selfPartyID, partyIDs, threshold, *preparams, n.pubSub, n.direct, n.identityStore, n.kvstore, n.keyinfoStore) + saveData, err := ecdsaSession.GetSaveData() + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + ecdsaSession.SetSaveData(saveData) + return ecdsaSession, nil + case types.KeyTypeEd25519: + eddsaSession := session.NewEDDSASession(walletID, selfPartyID, partyIDs, threshold, n.pubSub, n.direct, n.identityStore, n.kvstore, n.keyinfoStore) + saveData, err := eddsaSession.GetSaveData() + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + eddsaSession.SetSaveData(saveData) + return eddsaSession, nil + default: + return nil, fmt.Errorf("invalid key type: %s", keyType) + } +} + func (n *Node) GetReadyPeersIncludeSelf() []string { return n.peerRegistry.GetReadyPeersIncludeSelf() } func (n *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { - var selfPartyID *tss.PartyID - partyIDs := make([]*tss.PartyID, len(readyPeerIDs)) - for i, peerID := range readyPeerIDs { + // Pre-allocate slice with exact size needed + partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) + + // Create all party IDs in one pass + for _, peerID := range readyPeerIDs { + partyID := createPartyID(peerID, purpose) if peerID == n.nodeID { - selfPartyID = createPartyID(peerID, purpose) - partyIDs[i] = selfPartyID - } else { - partyIDs[i] = createPartyID(peerID, purpose) + self = partyID } + partyIDs = append(partyIDs, partyID) } - allPartyIDs := tss.SortPartyIDs(partyIDs, 0) - return selfPartyID, allPartyIDs + + // Sort party IDs in place + all = tss.SortPartyIDs(partyIDs, 0) + return } func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error) { @@ -180,11 +220,8 @@ func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error } preparamsBytes, _ := n.kvstore.Get(path) - // if err != nil { - // return nil, err - // } - if preparamsBytes == nil { + logger.Info("Generating preparams", "isOldParty", isOldParty) preparams, err := keygen.GeneratePreParams(5 * time.Minute) if err != nil { return nil, err @@ -201,6 +238,7 @@ func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error if err := json.Unmarshal(preparamsBytes, &preparams); err != nil { return nil, err } + logger.Info("Preparams loaded", "isOldParty", isOldParty) return &preparams, nil } diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 896b80c..20026b3 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -23,9 +23,10 @@ type party struct { type PartyInterface interface { StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) - StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) + StartResharing(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) PartyID() *tss.PartyID + PartyIDs() []*tss.PartyID GetSaveData() []byte SetSaveData(saveData []byte) InCh() chan types.TssMessage @@ -43,6 +44,10 @@ func (p *party) PartyID() *tss.PartyID { return p.partyID } +func (p *party) PartyIDs() []*tss.PartyID { + return p.partyIDs +} + func (p *party) InCh() chan types.TssMessage { return p.inCh } diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index 00a62c2..e1a8f52 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -16,17 +16,15 @@ import ( type ECDSAParty struct { party - preParams keygen.LocalPreParams - reshareParams *tss.ReSharingParameters - saveData *keygen.LocalPartySaveData + preParams keygen.LocalPreParams + saveData *keygen.LocalPartySaveData } func NewECDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, - preParams keygen.LocalPreParams, reshareParams *tss.ReSharingParameters, errCh chan error) *ECDSAParty { + preParams keygen.LocalPreParams, errCh chan error) *ECDSAParty { return &ECDSAParty{ - party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), - preParams: preParams, - reshareParams: reshareParams, + party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), + preParams: preParams, } } @@ -71,7 +69,7 @@ func (s *ECDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(t runParty(s, ctx, party, send, end, finish) } -func (s *ECDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, +func (s *ECDSAParty) StartResharing(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { s.ErrCh() <- errors.New("save data is nil") @@ -84,8 +82,8 @@ func (s *ECDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs tss.NewPeerContext(newPartyIDs), s.partyID, len(oldPartyIDs), - len(newPartyIDs), oldThreshold, + len(newPartyIDs), newThreshold, ) party := resharing.NewLocalParty(params, *s.saveData, s.outCh, end) diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index b2a6405..47f2e16 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -70,7 +70,7 @@ func (s *EDDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(t runParty(s, ctx, party, send, end, finish) } -func (s *EDDSAParty) StartReshare(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, +func (s *EDDSAParty) StartResharing(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) { if s.saveData == nil { s.ErrCh() <- errors.New("save data is nil") diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index d5d62b3..caabf47 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -42,13 +42,14 @@ type KeyComposerFn func(id string) string type Session interface { StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) - + StartResharing(ctx context.Context, oldPartyIDs []*tss.PartyID, newPartyIDs []*tss.PartyID, oldThreshold int, newThreshold int, send func(tss.Message), finish func([]byte)) GetSaveData() ([]byte, error) - GetPublicKey(data []byte) []byte + GetPublicKey(data []byte) ([]byte, error) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) + PartyIDs() []*tss.PartyID Send(msg tss.Message) - Listen(nodeID string) + Listen(nodeID string, isResharing bool) SaveKey(participantPeerIDs []string, threshold int, isReshared bool, data []byte) (err error) ErrCh() chan error } @@ -94,7 +95,9 @@ func NewSession( errCh: errCh, } } - +func (s *session) PartyIDs() []*tss.PartyID { + return s.party.PartyIDs() +} func (s *session) ErrCh() chan error { return s.errCh } @@ -174,7 +177,13 @@ func (s *session) receive(rawMsg []byte) { } } -func (s *session) Listen(nodeID string) { +func (s *session) Listen(nodeID string, isResharingParty bool) { + var directTopic string + if isResharingParty { + directTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, "resharing")) + } else { + directTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, "keygen")) + } broadcast := func() { sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { msg := natMsg.Data @@ -190,7 +199,7 @@ func (s *session) Listen(nodeID string) { } direct := func() { - sub, err := s.direct.Listen(s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, "keygen")), func(msg []byte) { + sub, err := s.direct.Listen(directTopic, func(msg []byte) { s.receive(msg) }) diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index 8546259..b567519 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -25,7 +25,7 @@ type ECDSASession struct { func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, preParams keygen.LocalPreParams, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store) *ECDSASession { s := NewSession(CurveSecp256k1, PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) - s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, preParams, nil, s.errCh) + s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, preParams, s.errCh) s.topicComposer = &TopicComposer{ ComposeBroadcastTopic: func() string { return fmt.Sprintf("keygen:broadcast:ecdsa:%s", walletID) @@ -54,13 +54,16 @@ func (s *ECDSASession) StartSigning(ctx context.Context, msg *big.Int, send func s.party.StartSigning(ctx, msg, send, finish) } -func (s *ECDSASession) GetPublicKey(data []byte) []byte { +func (s *ECDSASession) StartResharing(ctx context.Context, oldPartyIDs []*tss.PartyID, newPartyIDs []*tss.PartyID, oldThreshold int, newThreshold int, send func(tss.Message), finish func([]byte)) { + s.party.StartResharing(ctx, oldPartyIDs, newPartyIDs, oldThreshold, newThreshold, send, finish) +} + +func (s *ECDSASession) GetPublicKey(data []byte) ([]byte, error) { saveData := &keygen.LocalPartySaveData{} err := json.Unmarshal(data, saveData) if err != nil { - return nil + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) } - publicKey := saveData.ECDSAPub pubKey := &ecdsa.PublicKey{ Curve: publicKey.Curve(), @@ -69,9 +72,9 @@ func (s *ECDSASession) GetPublicKey(data []byte) []byte { } pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) if err != nil { - return nil + return nil, fmt.Errorf("failed to encode public key: %w", err) } - return pubKeyBytes + return pubKeyBytes, nil } func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) { diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go index a4956a3..6b56ca7 100644 --- a/pkg/mpc/session/eddsa.go +++ b/pkg/mpc/session/eddsa.go @@ -53,11 +53,19 @@ func (s *EDDSASession) StartSigning(ctx context.Context, msg *big.Int, send func s.party.StartSigning(ctx, msg, send, finish) } -func (s *EDDSASession) GetPublicKey(data []byte) []byte { +func (s *EDDSASession) StartResharing(ctx context.Context, oldPartyIDs []*tss.PartyID, newPartyIDs []*tss.PartyID, oldThreshold int, newThreshold int, send func(tss.Message), finish func([]byte)) { + s.party.StartResharing(ctx, oldPartyIDs, newPartyIDs, oldThreshold, newThreshold, send, finish) +} + +func (s *EDDSASession) GetPublicKey(data []byte) ([]byte, error) { saveData := &keygen.LocalPartySaveData{} err := json.Unmarshal(data, saveData) if err != nil { - return nil + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + + if saveData.EDDSAPub == nil { + return nil, errors.New("EDDSA public key is nil") } publicKey := saveData.EDDSAPub @@ -68,7 +76,7 @@ func (s *EDDSASession) GetPublicKey(data []byte) []byte { } pubKeyBytes := pubKey.SerializeCompressed() - return pubKeyBytes + return pubKeyBytes, nil } func (s *EDDSASession) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) { From e6fadebeccd7659b0f17fdccecd0f50bed053c48 Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 12 Jun 2025 16:22:50 +0700 Subject: [PATCH 13/34] Enhance logging and save data management in resharing process This commit introduces several improvements to the MPC package, including: - Updated the `runNode` function to allow debug logging based on the command-line flag, enhancing the configurability of logging. - Improved the `consumeResharingEvent` method to build a local save data subset for resharing sessions, ensuring proper handling of session data. - Refactored the `CreateResharingSession` method to differentiate between old and new party sessions, enhancing session management. - Added logging for successful resharing completion, improving traceability of resharing events. These changes aim to enhance the robustness and maintainability of the MPC package, particularly in managing resharing events and session data. --- cmd/mpcium/main.go | 4 +- pkg/eventconsumer/event_consumer.go | 12 ++++- pkg/mpc/node/node.go | 19 +++++-- pkg/mpc/party/ecdsa.go | 4 -- pkg/mpc/party/eddsa.go | 4 -- pkg/mpc/session/base.go | 79 ++++++++++++++++------------- pkg/mpc/session/ecdsa.go | 34 +++++++++++-- pkg/mpc/session/eddsa.go | 13 +++-- 8 files changed, 111 insertions(+), 58 deletions(-) diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 5d1c93a..563fa12 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -87,11 +87,11 @@ func runNode(ctx context.Context, c *cli.Command) error { nodeName := c.String("name") decryptPrivateKey := c.Bool("decrypt-private-key") usePrompts := c.Bool("prompt-credentials") - // debug := c.Bool("debug") + debug := c.Bool("debug") config.InitViperConfig() environment := viper.GetString("environment") - logger.Init(environment, true) + logger.Init(environment, debug) // Handle configuration based on prompt flag if usePrompts { diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 026a631..0536bf7 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -337,6 +337,7 @@ func (ec *eventConsumer) consumeResharingEvent() error { logger.Error("Failed to verify initiator message", err) return } + oldSession, err := ec.node.CreateResharingSession( true, msg.KeyType, @@ -394,8 +395,14 @@ func (ec *eventConsumer) consumeResharingEvent() error { newCtx, newCancel := context.WithTimeout(context.Background(), 30*time.Second) go newSession.StartResharing(newCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, newSession.Send, func(data []byte) { newCancel() - newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, true, data) - ecdsaPubKey, err := newSession.GetPublicKey(data) + // Rebuild the save data for and attach to old session + subsetData, err := newSession.BuildLocalSaveDataSubset(data, oldSession.PartyIDs()) + if err != nil { + logger.Error("Failed to build local save data subset", err) + return + } + newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), msg.NewThreshold, true, subsetData) + ecdsaPubKey, err := newSession.GetPublicKey(subsetData) if err != nil { logger.Error("Failed to get ECDSA public key", err) return @@ -419,6 +426,7 @@ func (ec *eventConsumer) consumeResharingEvent() error { logger.Error("Failed to publish resharing result event", err) return } + logger.Info("[COMPLETED RESH] Resharing completed successfully", "walletID", msg.WalletID) }) ec.resharingSub = sub diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index f2e40f1..3ef9f2e 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -170,11 +170,22 @@ func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, wa return nil, fmt.Errorf("failed to get preparams: %w", err) } ecdsaSession := session.NewECDSASession(walletID, selfPartyID, partyIDs, threshold, *preparams, n.pubSub, n.direct, n.identityStore, n.kvstore, n.keyinfoStore) - saveData, err := ecdsaSession.GetSaveData() - if err != nil { - return nil, fmt.Errorf("failed to get save data: %w", err) + if isOldParty { + saveData, err := ecdsaSession.GetSaveData() + if err != nil { + return nil, fmt.Errorf("failed to get save data: %w", err) + } + ecdsaSession.SetSaveData(saveData) + } else { + // Initialize new save data for new parties + saveData := keygen.NewLocalPartySaveData(len(partyIDs)) + saveData.LocalPreParams = *preparams + saveDataBytes, err := json.Marshal(saveData) + if err != nil { + return nil, fmt.Errorf("failed to marshal save data: %w", err) + } + ecdsaSession.SetSaveData(saveDataBytes) } - ecdsaSession.SetSaveData(saveData) return ecdsaSession, nil case types.KeyTypeEd25519: eddsaSession := session.NewEDDSASession(walletID, selfPartyID, partyIDs, threshold, n.pubSub, n.direct, n.identityStore, n.kvstore, n.keyinfoStore) diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index e1a8f52..f520b01 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -44,10 +44,6 @@ func (s *ECDSAParty) SetSaveData(saveData []byte) { s.ErrCh() <- fmt.Errorf("failed deserializing shares: %w", err) return } - localSaveData.ECDSAPub.SetCurve(tss.S256()) - for _, xj := range localSaveData.BigXj { - xj.SetCurve(tss.S256()) - } s.saveData = &localSaveData } diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index 47f2e16..32a66e3 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -45,10 +45,6 @@ func (s *EDDSAParty) SetSaveData(shareData []byte) { s.ErrCh() <- fmt.Errorf("failed deserializing shares: %w", err) return } - localSaveData.EDDSAPub.SetCurve(tss.Edwards()) - for _, xj := range localSaveData.BigXj { - xj.SetCurve(tss.Edwards()) - } s.saveData = &localSaveData } diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index caabf47..4982e1e 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -43,10 +43,12 @@ type Session interface { StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) StartResharing(ctx context.Context, oldPartyIDs []*tss.PartyID, newPartyIDs []*tss.PartyID, oldThreshold int, newThreshold int, send func(tss.Message), finish func([]byte)) + GetSaveData() ([]byte, error) GetPublicKey(data []byte) ([]byte, error) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) + BuildLocalSaveDataSubset(sourceData []byte, sortedIDs tss.SortedPartyIDs) ([]byte, error) PartyIDs() []*tss.PartyID Send(msg tss.Message) Listen(nodeID string, isResharing bool) @@ -95,9 +97,11 @@ func NewSession( errCh: errCh, } } + func (s *session) PartyIDs() []*tss.PartyID { return s.party.PartyIDs() } + func (s *session) ErrCh() chan error { return s.errCh } @@ -143,40 +147,6 @@ func (s *session) Send(msg tss.Message) { } } -func (s *session) receive(rawMsg []byte) { - msg, err := types.UnmarshalTssMessage(rawMsg) - if err != nil { - s.errCh <- fmt.Errorf("failed to unmarshal message: %w", err) - return - } - - err = s.identityStore.VerifyMessage(msg) - if err != nil { - s.errCh <- fmt.Errorf("failed to verify message: %w", err) - return - } - - // Skip messages from self - if msg.From.String() == s.party.PartyID().String() { - return - } - - toIDs := make([]string, len(msg.To)) - for i, id := range msg.To { - toIDs[i] = id.String() - } - - isBroadcast := msg.IsBroadcast && len(msg.To) == 0 - isToSelf := slices.Contains(toIDs, s.party.PartyID().String()) - - if isBroadcast || isToSelf { - logger.Debug("Received message", "from", msg.From, "to", msg.To, "isBroadcast", msg.IsBroadcast, "isToSelf", isToSelf) - s.mu.Lock() - defer s.mu.Unlock() - s.party.InCh() <- *msg - } -} - func (s *session) Listen(nodeID string, isResharingParty bool) { var directTopic string if isResharingParty { @@ -228,6 +198,9 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared return } + fmt.Printf("key info: %+v\n", keyInfo) + fmt.Printf("compose key: %s\n", composeKey) + err = s.kvstore.Put(composeKey, data) if err != nil { s.errCh <- fmt.Errorf("failed to save key: %w", err) @@ -237,6 +210,10 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared return nil } +func (s *session) SetSaveData(saveBytes []byte) { + s.party.SetSaveData(saveBytes) +} + func (s *session) GetSaveData() ([]byte, error) { composeKey := s.composeKey(s.walletID) data, err := s.kvstore.Get(composeKey) @@ -245,3 +222,37 @@ func (s *session) GetSaveData() ([]byte, error) { } return data, nil } + +func (s *session) receive(rawMsg []byte) { + msg, err := types.UnmarshalTssMessage(rawMsg) + if err != nil { + s.errCh <- fmt.Errorf("failed to unmarshal message: %w", err) + return + } + + err = s.identityStore.VerifyMessage(msg) + if err != nil { + s.errCh <- fmt.Errorf("failed to verify message: %w", err) + return + } + + // Skip messages from self + if msg.From.String() == s.party.PartyID().String() { + return + } + + toIDs := make([]string, len(msg.To)) + for i, id := range msg.To { + toIDs[i] = id.String() + } + + isBroadcast := msg.IsBroadcast && len(msg.To) == 0 + isToSelf := slices.Contains(toIDs, s.party.PartyID().String()) + + if isBroadcast || isToSelf { + logger.Debug("Received message", "from", msg.From, "to", msg.To, "isBroadcast", msg.IsBroadcast, "isToSelf", isToSelf) + s.mu.Lock() + defer s.mu.Unlock() + s.party.InCh() <- *msg + } +} diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index b567519..ed1a882 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -42,10 +42,6 @@ func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.Part } } -func (s *ECDSASession) SetSaveData(saveBytes []byte) { - s.party.SetSaveData(saveBytes) -} - func (s *ECDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { s.party.StartKeygen(ctx, send, finish) } @@ -118,3 +114,33 @@ func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (*common.Si return signatureData, nil } + +func (s *ECDSASession) BuildLocalSaveDataSubset(sourceData []byte, sortedIDs tss.SortedPartyIDs) ([]byte, error) { + saveData := &keygen.LocalPartySaveData{} + err := json.Unmarshal(sourceData, saveData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + return json.Marshal(buildLocalSaveDataSubset(*saveData, sortedIDs)) +} + +func buildLocalSaveDataSubset(sourceData keygen.LocalPartySaveData, sortedIDs tss.SortedPartyIDs) keygen.LocalPartySaveData { + if len(sortedIDs) != len(sourceData.Ks) { + return keygen.LocalPartySaveData{} + } + newData := keygen.NewLocalPartySaveData(sortedIDs.Len()) + newData.LocalPreParams = sourceData.LocalPreParams + newData.LocalSecrets = sourceData.LocalSecrets + newData.ECDSAPub = sourceData.ECDSAPub + + // Map directly based on sorted ID order + for j := range sortedIDs { + newData.Ks[j] = sourceData.Ks[j] + newData.NTildej[j] = sourceData.NTildej[j] + newData.H1j[j] = sourceData.H1j[j] + newData.H2j[j] = sourceData.H2j[j] + newData.BigXj[j] = sourceData.BigXj[j] + newData.PaillierPKs[j] = sourceData.PaillierPKs[j] + } + return newData +} diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go index 6b56ca7..c0a57e1 100644 --- a/pkg/mpc/session/eddsa.go +++ b/pkg/mpc/session/eddsa.go @@ -41,10 +41,6 @@ func NewEDDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.Part } } -func (s *EDDSASession) SetSaveData(saveBytes []byte) { - s.party.SetSaveData(saveBytes) -} - func (s *EDDSASession) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { s.party.StartKeygen(ctx, send, finish) } @@ -120,3 +116,12 @@ func (s *EDDSASession) VerifySignature(msg []byte, signature []byte) (*common.Si return signatureData, nil } + +func (s *EDDSASession) BuildLocalSaveDataSubset(sourceData []byte, sortedIDs tss.SortedPartyIDs) ([]byte, error) { + saveData := &keygen.LocalPartySaveData{} + err := json.Unmarshal(sourceData, saveData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal save data: %w", err) + } + return json.Marshal(keygen.BuildLocalSaveDataSubset(*saveData, sortedIDs)) +} From 89b42464b9af9f885936f20653bf25802f299aa8 Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 12 Jun 2025 18:01:06 +0700 Subject: [PATCH 14/34] Refactor party ID handling and enhance session message logging This commit introduces several key changes to the MPC package: - Updated the `partyIDToNodeID` function to utilize the `Moniker` field of `tss.PartyID`, improving clarity in party ID representation. - Refactored the `generatePartyIDs` function to maintain its functionality while ensuring consistency in party ID creation. - Enhanced logging in the `Send` and `Listen` methods of the session to provide clearer output for message routing, including broadcasts and direct messages. These changes aim to improve the maintainability and clarity of the MPC package, particularly in party ID management and session messaging. --- pkg/identity/identity.go | 2 +- pkg/mpc/node/node.go | 41 ++++++++++++++++++++-------------------- pkg/mpc/session/base.go | 10 +++++----- pkg/mpc/session/utils.go | 2 +- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 4d281c2..0ee5800 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -265,5 +265,5 @@ func (s *fileStore) VerifyInitiatorMessage(msg types.InitiatorMessage) error { } func partyIDToNodeID(partyID *tss.PartyID) string { - return strings.Split(string(partyID.KeyInt().Bytes()), ":")[0] + return strings.Split(partyID.Moniker, ":")[0] } diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index 3ef9f2e..6e9cb04 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -204,24 +204,6 @@ func (n *Node) GetReadyPeersIncludeSelf() []string { return n.peerRegistry.GetReadyPeersIncludeSelf() } -func (n *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { - // Pre-allocate slice with exact size needed - partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) - - // Create all party IDs in one pass - for _, peerID := range readyPeerIDs { - partyID := createPartyID(peerID, purpose) - if peerID == n.nodeID { - self = partyID - } - partyIDs = append(partyIDs, partyID) - } - - // Sort party IDs in place - all = tss.SortPartyIDs(partyIDs, 0) - return -} - func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error) { var path string if isOldParty { @@ -253,8 +235,27 @@ func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error return &preparams, nil } +func (n *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { + // Pre-allocate slice with exact size needed + partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) + + // Create all party IDs in one pass + for _, peerID := range readyPeerIDs { + partyID := createPartyID(peerID, purpose) + if peerID == n.nodeID { + self = partyID + } + partyIDs = append(partyIDs, partyID) + } + + // Sort party IDs in place + all = tss.SortPartyIDs(partyIDs, 0) + return +} + func createPartyID(nodeID string, label string) *tss.PartyID { partyID := uuid.NewString() - key := big.NewInt(0).SetBytes([]byte(nodeID + ":" + label)) - return tss.NewPartyID(partyID, label, key) + key := big.NewInt(0).SetBytes([]byte(partyID)) + moniker := nodeID + ":" + label + return tss.NewPartyID(partyID, moniker, key) } diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 4982e1e..80a62f5 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -125,10 +125,10 @@ func (s *session) Send(msg tss.Message) { s.errCh <- fmt.Errorf("failed to marshal message: %w", err) return } - - logger.Debug("Sending message", "from", routing.From, "to", routing.To, "isBroadcast", routing.IsBroadcast) + // fmt.Printf("Sending message from %s to %s isBroadcast %v\n", routing.From, routing.To, routing.IsBroadcast) if routing.IsBroadcast && len(routing.To) == 0 { + fmt.Printf("sending broadcast message to %s\n", s.topicComposer.ComposeBroadcastTopic()) err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msgBytes) if err != nil { s.errCh <- fmt.Errorf("failed to publish message: %w", err) @@ -138,6 +138,7 @@ func (s *session) Send(msg tss.Message) { for _, to := range routing.To { nodeID := partyIDToNodeID(to) topic := s.topicComposer.ComposeDirectTopic(nodeID) + fmt.Printf("sending direct message to %s\n", topic) err := s.direct.Send(topic, msgBytes) if err != nil { s.errCh <- fmt.Errorf("failed to send message: %w", err) @@ -155,6 +156,7 @@ func (s *session) Listen(nodeID string, isResharingParty bool) { directTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, "keygen")) } broadcast := func() { + fmt.Printf("subscribing to broadcast topic %s\n", s.topicComposer.ComposeBroadcastTopic()) sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { msg := natMsg.Data s.receive(msg) @@ -169,6 +171,7 @@ func (s *session) Listen(nodeID string, isResharingParty bool) { } direct := func() { + fmt.Printf("subscribing to direct topic %s\n", directTopic) sub, err := s.direct.Listen(directTopic, func(msg []byte) { s.receive(msg) }) @@ -198,9 +201,6 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared return } - fmt.Printf("key info: %+v\n", keyInfo) - fmt.Printf("compose key: %s\n", composeKey) - err = s.kvstore.Put(composeKey, data) if err != nil { s.errCh <- fmt.Errorf("failed to save key: %w", err) diff --git a/pkg/mpc/session/utils.go b/pkg/mpc/session/utils.go index ad17c57..33740a5 100644 --- a/pkg/mpc/session/utils.go +++ b/pkg/mpc/session/utils.go @@ -3,5 +3,5 @@ package session import "github.com/bnb-chain/tss-lib/v2/tss" func partyIDToNodeID(partyID *tss.PartyID) string { - return string(partyID.KeyInt().Bytes()) + return partyID.Moniker } From 02205051bccaebe07e9cb0a15e48f2c6dc91d99d Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 13 Jun 2025 15:15:16 +0700 Subject: [PATCH 15/34] Enhance version management and session handling in MPC package This commit introduces several key updates to the MPC package: - Added a `DefaultVersion` constant to standardize version handling across key generation and signing sessions. - Updated the `SaveKey` method to accept a version parameter, improving clarity in key storage. - Refactored the `CreateSigningSession` and `CreateResharingSession` methods to incorporate party versioning, enhancing session management. - Improved error handling and logging in the event consumer for key generation and signing events, ensuring better traceability. These changes aim to enhance the maintainability and robustness of the MPC package, particularly in managing session versions and key storage. --- pkg/eventconsumer/event_consumer.go | 48 +++++++++++++++++++---------- pkg/keyinfo/keyinfo.go | 2 +- pkg/mpc/node/node.go | 40 ++++++++++++++++++------ pkg/mpc/party/base.go | 20 ++++++------ pkg/mpc/party/eddsa.go | 2 +- pkg/mpc/session/base.go | 23 ++++++-------- pkg/mpc/session/ecdsa.go | 34 ++------------------ pkg/mpc/session/eddsa.go | 13 ++------ 8 files changed, 87 insertions(+), 95 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 0536bf7..8778b83 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -23,6 +23,9 @@ const ( MPCGenerateEvent = "mpc:generate" MPCSignEvent = "mpc:sign" MPCResharingEvent = "mpc:reshare" + + // Default version for keygen + DefaultVersion int = 0 ) type EventConsumer interface { @@ -157,7 +160,7 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { ecdsaCtx, ecdsaCancel := context.WithTimeout(context.Background(), 30*time.Second) go ecdsaSession.StartKeygen(ecdsaCtx, ecdsaSession.Send, func(data []byte) { ecdsaCancel() - ecdsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) + ecdsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) ecdsaPubKey, err := ecdsaSession.GetPublicKey(data) if err != nil { logger.Error("Failed to get ECDSA public key", err) @@ -170,7 +173,7 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { eddsaCtx, eddsaCancel := context.WithTimeout(context.Background(), 30*time.Second) go eddsaSession.StartKeygen(eddsaCtx, eddsaSession.Send, func(data []byte) { eddsaCancel() - eddsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, false, data) + eddsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) eddsaPubKey, err := eddsaSession.GetPublicKey(data) if err != nil { logger.Error("Failed to get EDDSA public key", err) @@ -230,10 +233,21 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return } + // Add session to tracking before starting + ec.addSession(msg.WalletID, msg.TxID) + + partyVersion, err := ec.node.GetPartyVersion(msg.KeyType, msg.WalletID) + if err != nil { + logger.Error("Failed to get party version", err) + ec.removeSession(msg.WalletID, msg.TxID) + return + } + signingSession, err := ec.node.CreateSigningSession( msg.KeyType, msg.WalletID, msg.TxID, + partyVersion, ec.mpcThreshold, ec.signingResultQueue, ) @@ -247,6 +261,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { "Failed to create signing session", natMsg, ) + ec.removeSession(msg.WalletID, msg.TxID) return } @@ -260,6 +275,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { signatureData, err := signingSession.VerifySignature(msg.Tx, data) if err != nil { logger.Error("Failed to verify signature", err) + ec.removeSession(msg.WalletID, msg.TxID) return } @@ -277,6 +293,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { signingResultBytes, err := json.Marshal(signingResult) if err != nil { logger.Error("Failed to marshal signing result event", err) + ec.removeSession(msg.WalletID, msg.TxID) return } @@ -285,16 +302,15 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { }) if err != nil { logger.Error("Failed to publish signing result event", err) + ec.removeSession(msg.WalletID, msg.TxID) return } logger.Info("Signing completed", "walletID", msg.WalletID, "txID", msg.TxID, "data", len(data)) + ec.removeSession(msg.WalletID, msg.TxID) }) }() - // Mark session as already processed - ec.addSession(msg.WalletID, msg.TxID) - go func() { for err := range signingSession.ErrCh() { logger.Error("Error from session", err) @@ -307,6 +323,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { "Failed to sign tx", natMsg, ) + ec.removeSession(msg.WalletID, msg.TxID) return } } @@ -337,12 +354,18 @@ func (ec *eventConsumer) consumeResharingEvent() error { logger.Error("Failed to verify initiator message", err) return } + partyVersion, err := ec.node.GetPartyVersion(msg.KeyType, msg.WalletID) + if err != nil { + logger.Error("Failed to get party version", err) + return + } oldSession, err := ec.node.CreateResharingSession( true, msg.KeyType, msg.WalletID, ec.mpcThreshold, + partyVersion, ec.resharingResultQueue, ) if err != nil { @@ -355,6 +378,7 @@ func (ec *eventConsumer) consumeResharingEvent() error { msg.KeyType, msg.WalletID, msg.NewThreshold, + partyVersion, // Increment inside the session ec.resharingResultQueue, ) if err != nil { @@ -372,22 +396,20 @@ func (ec *eventConsumer) consumeResharingEvent() error { var wg sync.WaitGroup wg.Add(2) - // Handle errors from the session go func() { for { select { case err := <-oldSession.ErrCh(): logger.Error("Error from ECDSA session", err) - return case err := <-newSession.ErrCh(): logger.Error("Error from EDDSA session", err) - return } } }() oldCtx, oldCancel := context.WithTimeout(context.Background(), 30*time.Second) go oldSession.StartResharing(oldCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, oldSession.Send, func(data []byte) { + fmt.Printf("old session done\n") oldCancel() wg.Done() }) @@ -395,18 +417,12 @@ func (ec *eventConsumer) consumeResharingEvent() error { newCtx, newCancel := context.WithTimeout(context.Background(), 30*time.Second) go newSession.StartResharing(newCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, newSession.Send, func(data []byte) { newCancel() - // Rebuild the save data for and attach to old session - subsetData, err := newSession.BuildLocalSaveDataSubset(data, oldSession.PartyIDs()) - if err != nil { - logger.Error("Failed to build local save data subset", err) - return - } - newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), msg.NewThreshold, true, subsetData) - ecdsaPubKey, err := newSession.GetPublicKey(subsetData) + ecdsaPubKey, err := newSession.GetPublicKey(data) if err != nil { logger.Error("Failed to get ECDSA public key", err) return } + newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), msg.NewThreshold, partyVersion+1, data) successEvent.ECDSAPubKey = ecdsaPubKey wg.Done() }) diff --git a/pkg/keyinfo/keyinfo.go b/pkg/keyinfo/keyinfo.go index a10529c..b98e781 100644 --- a/pkg/keyinfo/keyinfo.go +++ b/pkg/keyinfo/keyinfo.go @@ -11,7 +11,7 @@ import ( type KeyInfo struct { ParticipantPeerIDs []string `json:"participant_peer_ids"` Threshold int `json:"threshold"` - IsReshared bool `json:"is_reshared"` + Version uint16 `json:"version"` } type store struct { diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index 6e9cb04..07e6fef 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "math/big" + "strconv" "time" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" @@ -18,6 +19,8 @@ import ( "github.com/google/uuid" ) +const DefaultVersion = 0 + type Node struct { nodeID string peerIDs []string @@ -56,7 +59,7 @@ func (n *Node) CreateKeygenSession(keyType types.KeyType, walletID string, thres } readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := n.generatePartyIDs("keygen", readyPeerIDs) + selfPartyID, allPartyIDs := n.generatePartyIDs(session.PurposeKeygen, readyPeerIDs, DefaultVersion) switch keyType { case types.KeyTypeSecp256k1: preparams, err := n.getECDSAPreParams(false) @@ -95,13 +98,13 @@ func (n *Node) CreateKeygenSession(keyType types.KeyType, walletID string, thres } } -func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { +func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID string, partyVersion int, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { return nil, fmt.Errorf("not enough peers to create gen session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) } readyPeerIDs := n.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := n.generatePartyIDs("keygen", readyPeerIDs) + selfPartyID, allPartyIDs := n.generatePartyIDs(session.PurposeSign, readyPeerIDs, partyVersion) switch keyType { case types.KeyTypeSecp256k1: ecdsaSession := session.NewECDSASession( @@ -150,7 +153,7 @@ func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID } } -func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, walletID string, threshold int, successQueue messaging.MessageQueue) (session.Session, error) { +func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, walletID string, threshold int, partyVersion int, successQueue messaging.MessageQueue) (session.Session, error) { if n.peerRegistry.GetReadyPeersCount() < int64(threshold+1) { return nil, fmt.Errorf("not enough peers to create resharing session! expected %d, got %d", threshold+1, n.peerRegistry.GetReadyPeersCount()) } @@ -158,9 +161,9 @@ func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, wa var selfPartyID *tss.PartyID var partyIDs []*tss.PartyID if isOldParty { - selfPartyID, partyIDs = n.generatePartyIDs("keygen", readyPeerIDs) + selfPartyID, partyIDs = n.generatePartyIDs(session.PurposeKeygen, readyPeerIDs, partyVersion) } else { - selfPartyID, partyIDs = n.generatePartyIDs("resharing", readyPeerIDs) + selfPartyID, partyIDs = n.generatePartyIDs(session.PurposeReshare, readyPeerIDs, partyVersion+1) // Increment version for new parties } switch keyType { @@ -204,6 +207,23 @@ func (n *Node) GetReadyPeersIncludeSelf() []string { return n.peerRegistry.GetReadyPeersIncludeSelf() } +func (n *Node) GetPartyVersion(keyType types.KeyType, walletID string) (int, error) { + var walletKey string + switch keyType { + case types.KeyTypeSecp256k1: + walletKey = fmt.Sprintf("ecdsa:%s", walletID) + case types.KeyTypeEd25519: + walletKey = fmt.Sprintf("eddsa:%s", walletID) + default: + return 0, fmt.Errorf("invalid key type: %s", keyType) + } + keyInfo, err := n.keyinfoStore.Get(walletKey) + if err != nil { + return 0, err + } + return int(keyInfo.Version), nil +} + func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error) { var path string if isOldParty { @@ -235,13 +255,13 @@ func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error return &preparams, nil } -func (n *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { +func (n *Node) generatePartyIDs(purpose session.Purpose, readyPeerIDs []string, version int) (self *tss.PartyID, all []*tss.PartyID) { // Pre-allocate slice with exact size needed partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) // Create all party IDs in one pass for _, peerID := range readyPeerIDs { - partyID := createPartyID(peerID, purpose) + partyID := createPartyID(peerID, string(purpose), version) if peerID == n.nodeID { self = partyID } @@ -253,9 +273,9 @@ func (n *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *ts return } -func createPartyID(nodeID string, label string) *tss.PartyID { +func createPartyID(nodeID string, label string, version int) *tss.PartyID { partyID := uuid.NewString() - key := big.NewInt(0).SetBytes([]byte(partyID)) moniker := nodeID + ":" + label + key := big.NewInt(0).SetBytes([]byte(nodeID + ":" + strconv.Itoa(version))) return tss.NewPartyID(partyID, moniker, key) } diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 20026b3..ea74980 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -10,16 +10,6 @@ import ( "github.com/fystack/mpcium/pkg/types" ) -type party struct { - walletID string - threshold int - partyID *tss.PartyID - partyIDs []*tss.PartyID - inCh chan types.TssMessage - outCh chan tss.Message - errCh chan error -} - type PartyInterface interface { StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) @@ -34,6 +24,16 @@ type PartyInterface interface { ErrCh() chan error } +type party struct { + walletID string + threshold int + partyID *tss.PartyID + partyIDs []*tss.PartyID + inCh chan types.TssMessage + outCh chan tss.Message + errCh chan error +} + func NewParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, errCh chan error) *party { inCh := make(chan types.TssMessage, 1000) outCh := make(chan tss.Message, 1000) diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index 32a66e3..c728d74 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -79,8 +79,8 @@ func (s *EDDSAParty) StartResharing(ctx context.Context, oldPartyIDs, newPartyID tss.NewPeerContext(newPartyIDs), s.partyID, len(oldPartyIDs), - len(newPartyIDs), oldThreshold, + len(newPartyIDs), newThreshold, ) party := resharing.NewLocalParty(params, *s.saveData, s.outCh, end) diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 80a62f5..93583bc 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -48,11 +48,10 @@ type Session interface { GetPublicKey(data []byte) ([]byte, error) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) - BuildLocalSaveDataSubset(sourceData []byte, sortedIDs tss.SortedPartyIDs) ([]byte, error) PartyIDs() []*tss.PartyID Send(msg tss.Message) Listen(nodeID string, isResharing bool) - SaveKey(participantPeerIDs []string, threshold int, isReshared bool, data []byte) (err error) + SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) ErrCh() chan error } @@ -125,10 +124,9 @@ func (s *session) Send(msg tss.Message) { s.errCh <- fmt.Errorf("failed to marshal message: %w", err) return } - // fmt.Printf("Sending message from %s to %s isBroadcast %v\n", routing.From, routing.To, routing.IsBroadcast) + logger.Debug("Sending message", "from", routing.From, "to", routing.To, "isBroadcast", routing.IsBroadcast) if routing.IsBroadcast && len(routing.To) == 0 { - fmt.Printf("sending broadcast message to %s\n", s.topicComposer.ComposeBroadcastTopic()) err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msgBytes) if err != nil { s.errCh <- fmt.Errorf("failed to publish message: %w", err) @@ -138,7 +136,6 @@ func (s *session) Send(msg tss.Message) { for _, to := range routing.To { nodeID := partyIDToNodeID(to) topic := s.topicComposer.ComposeDirectTopic(nodeID) - fmt.Printf("sending direct message to %s\n", topic) err := s.direct.Send(topic, msgBytes) if err != nil { s.errCh <- fmt.Errorf("failed to send message: %w", err) @@ -149,14 +146,13 @@ func (s *session) Send(msg tss.Message) { } func (s *session) Listen(nodeID string, isResharingParty bool) { - var directTopic string + var selfDirectTopic string if isResharingParty { - directTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, "resharing")) + selfDirectTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, PurposeReshare)) } else { - directTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, "keygen")) + selfDirectTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, PurposeKeygen)) } broadcast := func() { - fmt.Printf("subscribing to broadcast topic %s\n", s.topicComposer.ComposeBroadcastTopic()) sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { msg := natMsg.Data s.receive(msg) @@ -171,8 +167,7 @@ func (s *session) Listen(nodeID string, isResharingParty bool) { } direct := func() { - fmt.Printf("subscribing to direct topic %s\n", directTopic) - sub, err := s.direct.Listen(directTopic, func(msg []byte) { + sub, err := s.direct.Listen(selfDirectTopic, func(msg []byte) { s.receive(msg) }) @@ -188,11 +183,11 @@ func (s *session) Listen(nodeID string, isResharingParty bool) { go direct() } -func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared bool, data []byte) (err error) { +func (s *session) SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) { keyInfo := keyinfo.KeyInfo{ ParticipantPeerIDs: participantPeerIDs, Threshold: threshold, - IsReshared: isReshared, + Version: uint16(version), } composeKey := s.composeKey(s.walletID) err = s.keyinfoStore.Save(composeKey, &keyInfo) @@ -206,7 +201,7 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, isReshared s.errCh <- fmt.Errorf("failed to save key: %w", err) return } - logger.Info("Saved key", "walletID", s.walletID, "threshold", threshold, "isReshared", isReshared, "data", len(data)) + logger.Info("Saved key", "walletID", s.walletID, "threshold", threshold, "version", version, "data", len(data)) return nil } diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index ed1a882..0cf71f7 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -28,10 +28,10 @@ func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.Part s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, preParams, s.errCh) s.topicComposer = &TopicComposer{ ComposeBroadcastTopic: func() string { - return fmt.Sprintf("keygen:broadcast:ecdsa:%s", walletID) + return fmt.Sprintf("broadcast:ecdsa:%s", walletID) }, ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:ecdsa:%s:%s", nodeID, walletID) + return fmt.Sprintf("direct:ecdsa:%s:%s", nodeID, walletID) }, } s.composeKey = func(walletID string) string { @@ -114,33 +114,3 @@ func (s *ECDSASession) VerifySignature(msg []byte, signature []byte) (*common.Si return signatureData, nil } - -func (s *ECDSASession) BuildLocalSaveDataSubset(sourceData []byte, sortedIDs tss.SortedPartyIDs) ([]byte, error) { - saveData := &keygen.LocalPartySaveData{} - err := json.Unmarshal(sourceData, saveData) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal save data: %w", err) - } - return json.Marshal(buildLocalSaveDataSubset(*saveData, sortedIDs)) -} - -func buildLocalSaveDataSubset(sourceData keygen.LocalPartySaveData, sortedIDs tss.SortedPartyIDs) keygen.LocalPartySaveData { - if len(sortedIDs) != len(sourceData.Ks) { - return keygen.LocalPartySaveData{} - } - newData := keygen.NewLocalPartySaveData(sortedIDs.Len()) - newData.LocalPreParams = sourceData.LocalPreParams - newData.LocalSecrets = sourceData.LocalSecrets - newData.ECDSAPub = sourceData.ECDSAPub - - // Map directly based on sorted ID order - for j := range sortedIDs { - newData.Ks[j] = sourceData.Ks[j] - newData.NTildej[j] = sourceData.NTildej[j] - newData.H1j[j] = sourceData.H1j[j] - newData.H2j[j] = sourceData.H2j[j] - newData.BigXj[j] = sourceData.BigXj[j] - newData.PaillierPKs[j] = sourceData.PaillierPKs[j] - } - return newData -} diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go index c0a57e1..a77b8ce 100644 --- a/pkg/mpc/session/eddsa.go +++ b/pkg/mpc/session/eddsa.go @@ -27,10 +27,10 @@ func NewEDDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.Part s.party = party.NewEDDAParty(walletID, partyID, partyIDs, threshold, nil, nil, s.errCh) s.topicComposer = &TopicComposer{ ComposeBroadcastTopic: func() string { - return fmt.Sprintf("keygen:broadcast:eddsa:%s", walletID) + return fmt.Sprintf("broadcast:eddsa:%s", walletID) }, ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:eddsa:%s:%s", nodeID, walletID) + return fmt.Sprintf("direct:eddsa:%s:%s", nodeID, walletID) }, } s.composeKey = func(walletID string) string { @@ -116,12 +116,3 @@ func (s *EDDSASession) VerifySignature(msg []byte, signature []byte) (*common.Si return signatureData, nil } - -func (s *EDDSASession) BuildLocalSaveDataSubset(sourceData []byte, sortedIDs tss.SortedPartyIDs) ([]byte, error) { - saveData := &keygen.LocalPartySaveData{} - err := json.Unmarshal(sourceData, saveData) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal save data: %w", err) - } - return json.Marshal(keygen.BuildLocalSaveDataSubset(*saveData, sortedIDs)) -} From 3eb72890e34d3e0c7ba04d6a77b76a4756585ee1 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 13 Jun 2025 15:25:19 +0700 Subject: [PATCH 16/34] Refactor container names and improve event consumer error handling This commit includes the following changes: - Updated the container names in `docker-compose.yaml` for clarity and consistency. - Refactored the `eventConsumer` methods to replace `GetPartyVersion` with `GetKeyInfoVersion`, enhancing the accuracy of version management in transaction signing and resharing events. - Improved error handling and logging in the `consumeTxSigningEvent` and `consumeResharingEvent` methods, ensuring better traceability of issues. These changes aim to enhance the maintainability and clarity of the codebase, particularly in container management and event processing. --- docker-compose.yaml | 4 ++-- pkg/eventconsumer/event_consumer.go | 15 ++++++++------- pkg/mpc/node/node.go | 13 ++++++++++++- pkg/mpc/session/base.go | 7 +++++++ 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 7a80bf0..88b818d 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,7 +1,7 @@ services: nats-server: image: nats:latest - container_name: nats-server-mpcium + container_name: nats-server command: -js --http_port 8222 ports: - "4222:4222" @@ -12,7 +12,7 @@ services: consul: image: consul:1.15.4 - container_name: consul-mpcium + container_name: consul ports: - "8500:8500" - "8601:8600/udp" diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 8778b83..1513a6f 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -236,7 +236,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { // Add session to tracking before starting ec.addSession(msg.WalletID, msg.TxID) - partyVersion, err := ec.node.GetPartyVersion(msg.KeyType, msg.WalletID) + keyInfoVersion, err := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) if err != nil { logger.Error("Failed to get party version", err) ec.removeSession(msg.WalletID, msg.TxID) @@ -247,7 +247,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { msg.KeyType, msg.WalletID, msg.TxID, - partyVersion, + keyInfoVersion, ec.mpcThreshold, ec.signingResultQueue, ) @@ -354,7 +354,7 @@ func (ec *eventConsumer) consumeResharingEvent() error { logger.Error("Failed to verify initiator message", err) return } - partyVersion, err := ec.node.GetPartyVersion(msg.KeyType, msg.WalletID) + keyInfoVersion, err := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) if err != nil { logger.Error("Failed to get party version", err) return @@ -365,7 +365,7 @@ func (ec *eventConsumer) consumeResharingEvent() error { msg.KeyType, msg.WalletID, ec.mpcThreshold, - partyVersion, + keyInfoVersion, ec.resharingResultQueue, ) if err != nil { @@ -378,7 +378,7 @@ func (ec *eventConsumer) consumeResharingEvent() error { msg.KeyType, msg.WalletID, msg.NewThreshold, - partyVersion, // Increment inside the session + keyInfoVersion, // Increment inside the session ec.resharingResultQueue, ) if err != nil { @@ -409,7 +409,7 @@ func (ec *eventConsumer) consumeResharingEvent() error { oldCtx, oldCancel := context.WithTimeout(context.Background(), 30*time.Second) go oldSession.StartResharing(oldCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, oldSession.Send, func(data []byte) { - fmt.Printf("old session done\n") + // Old session is done, no need to save oldCancel() wg.Done() }) @@ -417,12 +417,13 @@ func (ec *eventConsumer) consumeResharingEvent() error { newCtx, newCancel := context.WithTimeout(context.Background(), 30*time.Second) go newSession.StartResharing(newCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, newSession.Send, func(data []byte) { newCancel() + // Only save for new parties ecdsaPubKey, err := newSession.GetPublicKey(data) if err != nil { logger.Error("Failed to get ECDSA public key", err) return } - newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), msg.NewThreshold, partyVersion+1, data) + newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), msg.NewThreshold, keyInfoVersion+1, data) successEvent.ECDSAPubKey = ecdsaPubKey wg.Done() }) diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index 07e6fef..bd0252c 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -19,6 +19,7 @@ import ( "github.com/google/uuid" ) +// DefaultVersion is the default version for keygen and resharing const DefaultVersion = 0 type Node struct { @@ -181,6 +182,7 @@ func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, wa ecdsaSession.SetSaveData(saveData) } else { // Initialize new save data for new parties + // Reduce the loading time by pre-allocating the save data saveData := keygen.NewLocalPartySaveData(len(partyIDs)) saveData.LocalPreParams = *preparams saveDataBytes, err := json.Marshal(saveData) @@ -207,7 +209,7 @@ func (n *Node) GetReadyPeersIncludeSelf() []string { return n.peerRegistry.GetReadyPeersIncludeSelf() } -func (n *Node) GetPartyVersion(keyType types.KeyType, walletID string) (int, error) { +func (n *Node) GetKeyInfoVersion(keyType types.KeyType, walletID string) (int, error) { var walletKey string switch keyType { case types.KeyTypeSecp256k1: @@ -224,6 +226,8 @@ func (n *Node) GetPartyVersion(keyType types.KeyType, walletID string) (int, err return int(keyInfo.Version), nil } +// For ecdsa, we need to generate preparams for each party +// Load preparams from kvstore if exists, otherwise generate and save to kvstore func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error) { var path string if isOldParty { @@ -255,6 +259,9 @@ func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error return &preparams, nil } +// generatePartyIDs generates the party IDs for the given purpose and version +// It returns the self party ID and all party IDs +// It also sorts the party IDs in place func (n *Node) generatePartyIDs(purpose session.Purpose, readyPeerIDs []string, version int) (self *tss.PartyID, all []*tss.PartyID) { // Pre-allocate slice with exact size needed partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) @@ -273,6 +280,10 @@ func (n *Node) generatePartyIDs(purpose session.Purpose, readyPeerIDs []string, return } +// createPartyID creates a new party ID for the given node ID, label and version +// It returns the party ID: random string +// Moniker: for routing messages +// Key: for mpc internal use (need persistent storage) func createPartyID(nodeID string, label string, version int) *tss.PartyID { partyID := uuid.NewString() moniker := nodeID + ":" + label diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 93583bc..1f4e1ec 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -105,6 +105,8 @@ func (s *session) ErrCh() chan error { return s.errCh } +// Send is a wrapper around the party's Send method +// It signs the message and sends it to the remote party func (s *session) Send(msg tss.Message) { data, routing, err := msg.WireBytes() if err != nil { @@ -145,6 +147,8 @@ func (s *session) Send(msg tss.Message) { } } +// Listen is a wrapper around the party's Listen method +// It subscribes to the broadcast and self direct topics func (s *session) Listen(nodeID string, isResharingParty bool) { var selfDirectTopic string if isResharingParty { @@ -183,6 +187,7 @@ func (s *session) Listen(nodeID string, isResharingParty bool) { go direct() } +// SaveKey saves the key to the keyinfo store and the kvstore func (s *session) SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) { keyInfo := keyinfo.KeyInfo{ ParticipantPeerIDs: participantPeerIDs, @@ -209,6 +214,7 @@ func (s *session) SetSaveData(saveBytes []byte) { s.party.SetSaveData(saveBytes) } +// GetSaveData gets the key from the kvstore func (s *session) GetSaveData() ([]byte, error) { composeKey := s.composeKey(s.walletID) data, err := s.kvstore.Get(composeKey) @@ -218,6 +224,7 @@ func (s *session) GetSaveData() ([]byte, error) { return data, nil } +// receive is a helper function that receives a message from the party func (s *session) receive(rawMsg []byte) { msg, err := types.UnmarshalTssMessage(rawMsg) if err != nil { From 0627e59d53e74db155d7b83d0297623fbf041ce9 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 13 Jun 2025 15:28:08 +0700 Subject: [PATCH 17/34] Integrate Viper configuration and update NATS URL handling This commit introduces the following changes across multiple example files: - Enabled Viper configuration initialization by uncommenting `config.InitViperConfig()`, allowing for dynamic configuration management. - Updated the NATS URL retrieval to use `viper.GetString("nats.url")` instead of a hardcoded value, enhancing flexibility in environment configuration. - Adjusted the key path in the `reshare` example to a relative path for improved portability. These changes aim to enhance the configurability and maintainability of the example applications. --- examples/generate/main.go | 6 ++++-- examples/reshare/main.go | 8 +++++--- examples/sign/main.go | 7 ++++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/generate/main.go b/examples/generate/main.go index 46f47e3..69cd9bd 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -7,18 +7,20 @@ import ( "syscall" "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/config" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" "github.com/google/uuid" "github.com/nats-io/nats.go" + "github.com/spf13/viper" ) func main() { const environment = "development" - // config.InitViperConfig() + config.InitViperConfig() logger.Init(environment, false) - natsURL := "nats://localhost:4222" + natsURL := viper.GetString("nats.url") natsConn, err := nats.Connect(natsURL) if err != nil { logger.Fatal("Failed to connect to NATS", err) diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 9a51baf..44ce7bb 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -7,18 +7,20 @@ import ( "syscall" "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/config" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" + "github.com/spf13/viper" ) func main() { const environment = "development" - // config.InitViperConfig() + config.InitViperConfig() logger.Init(environment, false) - natsURL := "nats://localhost:4222" + natsURL := viper.GetString("nats.url") natsConn, err := nats.Connect(natsURL) if err != nil { logger.Fatal("Failed to connect to NATS", err) @@ -28,7 +30,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "/home/viet/Documents/other/mpcium/event_initiator.key", + KeyPath: "./../../event_initiator.key", }) err = mpcClient.OnResharingResult(func(event event.ResharingSuccessEvent) { logger.Info("Received resharing result", "event", event) diff --git a/examples/sign/main.go b/examples/sign/main.go index b69337d..dd322b7 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -7,20 +7,21 @@ import ( "syscall" "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/config" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" "github.com/google/uuid" "github.com/nats-io/nats.go" + "github.com/spf13/viper" ) func main() { const environment = "dev" - // config.InitViperConfig() + config.InitViperConfig() logger.Init(environment, true) - // natsURL := viper.GetString("nats.url") - natsURL := "nats://localhost:4222" + natsURL := viper.GetString("nats.url") natsConn, err := nats.Connect(natsURL) if err != nil { logger.Fatal("Failed to connect to NATS", err) From e1666da34f8885d8aa81b13662cd20d3ff075e53 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 13 Jun 2025 17:08:13 +0700 Subject: [PATCH 18/34] Refactor session initialization and error handling in MPC package This commit includes the following changes: - Updated the `NewEDDAParty` and `NewECDSASession` functions to remove the redundant `curve` parameter, simplifying session initialization. - Refactored error handling in the `consumeTxSigningEvent` method to streamline the error processing flow, enhancing clarity and maintainability. These changes aim to improve the readability and efficiency of the MPC package, particularly in session management and error handling. --- examples/generate/main.go | 2 +- examples/reshare/main.go | 2 +- examples/sign/main.go | 2 +- pkg/eventconsumer/event_consumer.go | 21 +++++++++------------ pkg/mpc/party/eddsa.go | 2 +- pkg/mpc/session/base.go | 1 - pkg/mpc/session/ecdsa.go | 2 +- pkg/mpc/session/eddsa.go | 4 ++-- 8 files changed, 16 insertions(+), 20 deletions(-) diff --git a/examples/generate/main.go b/examples/generate/main.go index 69cd9bd..c8e64bf 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -30,7 +30,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "./../../event_initiator.key", + KeyPath: "./event_initiator.key", }) err = mpcClient.OnWalletCreationResult(func(event event.KeygenSuccessEvent) { logger.Info("Received wallet creation result", "event", event) diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 44ce7bb..b05b8c6 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -30,7 +30,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "./../../event_initiator.key", + KeyPath: "./event_initiator.key", }) err = mpcClient.OnResharingResult(func(event event.ResharingSuccessEvent) { logger.Info("Received resharing result", "event", event) diff --git a/examples/sign/main.go b/examples/sign/main.go index dd322b7..e0c0d8f 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -31,7 +31,7 @@ func main() { mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "./../../event_initiator.key", + KeyPath: "./event_initiator.key", }) // 2) Once wallet exists, immediately fire a SignTransaction diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 1513a6f..a4927b8 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -314,18 +314,15 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { go func() { for err := range signingSession.ErrCh() { logger.Error("Error from session", err) - if err != nil { - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - "Failed to sign tx", - natMsg, - ) - ec.removeSession(msg.WalletID, msg.TxID) - return - } + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to sign tx", + natMsg, + ) + ec.removeSession(msg.WalletID, msg.TxID) } }() }) diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index c728d74..cd6c981 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -20,7 +20,7 @@ type EDDSAParty struct { saveData *keygen.LocalPartySaveData } -func NewEDDAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, +func NewEDDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, reshareParams *tss.ReSharingParameters, saveData *keygen.LocalPartySaveData, errCh chan error) *EDDSAParty { return &EDDSAParty{ party: *NewParty(walletID, partyID, partyIDs, threshold, errCh), diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 1f4e1ec..4772185 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -76,7 +76,6 @@ type session struct { } func NewSession( - curve Curve, purpose Purpose, walletID string, pubSub messaging.PubSub, diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index 0cf71f7..f883bf7 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -24,7 +24,7 @@ type ECDSASession struct { } func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, preParams keygen.LocalPreParams, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store) *ECDSASession { - s := NewSession(CurveSecp256k1, PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) + s := NewSession(PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, preParams, s.errCh) s.topicComposer = &TopicComposer{ ComposeBroadcastTopic: func() string { diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go index a77b8ce..4cbf65c 100644 --- a/pkg/mpc/session/eddsa.go +++ b/pkg/mpc/session/eddsa.go @@ -23,8 +23,8 @@ type EDDSASession struct { } func NewEDDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store) *EDDSASession { - s := NewSession(CurveSecp256k1, PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) - s.party = party.NewEDDAParty(walletID, partyID, partyIDs, threshold, nil, nil, s.errCh) + s := NewSession(PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) + s.party = party.NewEDDSAParty(walletID, partyID, partyIDs, threshold, nil, nil, s.errCh) s.topicComposer = &TopicComposer{ ComposeBroadcastTopic: func() string { return fmt.Sprintf("broadcast:eddsa:%s", walletID) From 6c3de006cd9257988c46608b68fafdea4de7fd8d Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 13 Jun 2025 17:11:30 +0700 Subject: [PATCH 19/34] Add Close method to Node and defer its call in runNode This commit introduces a new `Close` method in the `Node` struct to handle peer registry resignation, enhancing resource management. Additionally, the `runNode` function is updated to defer the `Close` method call, ensuring proper cleanup of resources when the node operation completes. These changes aim to improve the robustness and maintainability of the MPC package by ensuring resources are released appropriately. --- cmd/mpcium/main.go | 1 + pkg/mpc/node/node.go | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 563fa12..ee02d2c 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -157,6 +157,7 @@ func runNode(ctx context.Context, c *cli.Command) error { identityStore, peerRegistry, ) + defer mpcNode.Close() eventConsumer := eventconsumer.NewEventConsumer( mpcNode, diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index bd0252c..91342e9 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -205,6 +205,13 @@ func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, wa } } +func (p *Node) Close() { + err := p.peerRegistry.Resign() + if err != nil { + logger.Error("Resign failed", err) + } +} + func (n *Node) GetReadyPeersIncludeSelf() []string { return n.peerRegistry.GetReadyPeersIncludeSelf() } From 057d99077a7304985f2a17ac9854ae01d64771ed Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 13 Jun 2025 17:32:43 +0700 Subject: [PATCH 20/34] Implement PreloadPreParams method and refactor party ID handling This commit introduces a new `PreloadPreParams` method in the `Node` struct to preload ECDSA preparameters, enhancing the initialization process. Additionally, the `partyIDToNodeID` function is renamed to `getRoutingFromPartyID` for improved clarity in routing party ID to node ID mapping. The `Send` and `Listen` methods in the session have been updated to utilize the new function, ensuring consistency in message routing. These changes aim to improve the maintainability and clarity of the MPC package, particularly in session initialization and party ID management. --- cmd/mpcium/main.go | 2 ++ pkg/mpc/node/node.go | 12 ++++++++++++ pkg/mpc/session/base.go | 6 +++--- pkg/mpc/session/utils.go | 3 ++- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index ee02d2c..ca533c4 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -157,6 +157,8 @@ func runNode(ctx context.Context, c *cli.Command) error { identityStore, peerRegistry, ) + // Preload preparams for the first time + mpcNode.PreloadPreParams() defer mpcNode.Close() eventConsumer := eventconsumer.NewEventConsumer( diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index 91342e9..a554054 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -233,6 +233,18 @@ func (n *Node) GetKeyInfoVersion(keyType types.KeyType, walletID string) (int, e return int(keyInfo.Version), nil } +// PreloadPreParams preloads the preparams for the first time +func (n *Node) PreloadPreParams() { + _, err := n.getECDSAPreParams(false) + if err != nil { + logger.Error("Failed to get preparams", err) + } + _, err = n.getECDSAPreParams(true) + if err != nil { + logger.Error("Failed to get preparams", err) + } +} + // For ecdsa, we need to generate preparams for each party // Load preparams from kvstore if exists, otherwise generate and save to kvstore func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error) { diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 4772185..675122b 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -135,7 +135,7 @@ func (s *session) Send(msg tss.Message) { } } else { for _, to := range routing.To { - nodeID := partyIDToNodeID(to) + nodeID := getRoutingFromPartyID(to) topic := s.topicComposer.ComposeDirectTopic(nodeID) err := s.direct.Send(topic, msgBytes) if err != nil { @@ -151,9 +151,9 @@ func (s *session) Send(msg tss.Message) { func (s *session) Listen(nodeID string, isResharingParty bool) { var selfDirectTopic string if isResharingParty { - selfDirectTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, PurposeReshare)) + selfDirectTopic = s.topicComposer.ComposeDirectTopic(getRoutingFromPartyID(s.party.PartyID())) } else { - selfDirectTopic = s.topicComposer.ComposeDirectTopic(fmt.Sprintf("%s:%s", nodeID, PurposeKeygen)) + selfDirectTopic = s.topicComposer.ComposeDirectTopic(getRoutingFromPartyID(s.party.PartyID())) } broadcast := func() { sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { diff --git a/pkg/mpc/session/utils.go b/pkg/mpc/session/utils.go index 33740a5..581abc8 100644 --- a/pkg/mpc/session/utils.go +++ b/pkg/mpc/session/utils.go @@ -2,6 +2,7 @@ package session import "github.com/bnb-chain/tss-lib/v2/tss" -func partyIDToNodeID(partyID *tss.PartyID) string { +// Moniker saves the routing partyID to nodeID mapping +func getRoutingFromPartyID(partyID *tss.PartyID) string { return partyID.Moniker } From 81840aa39b9b5c2e1a323e686d245c8a7d76dc66 Mon Sep 17 00:00:00 2001 From: vietddude Date: Sat, 14 Jun 2025 21:07:27 +0700 Subject: [PATCH 21/34] Refactor key info version handling in event consumer and node This commit updates the `GetKeyInfoVersion` method in the `Node` struct to return -1 for invalid key types and errors, enhancing error handling. In the `consumeResharingEvent` method of the `eventConsumer`, the retrieval of the key info version is simplified by ignoring errors, defaulting to -1 if no key is found. These changes aim to improve the robustness and clarity of version management in the MPC package. --- pkg/eventconsumer/event_consumer.go | 8 +++----- pkg/mpc/node/node.go | 11 ++++++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index a4927b8..d1ca4aa 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -351,11 +351,9 @@ func (ec *eventConsumer) consumeResharingEvent() error { logger.Error("Failed to verify initiator message", err) return } - keyInfoVersion, err := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) - if err != nil { - logger.Error("Failed to get party version", err) - return - } + + // Default is -1 if no key found + keyInfoVersion, _ := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) oldSession, err := ec.node.CreateResharingSession( true, diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index a554054..9feaba9 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -224,11 +224,11 @@ func (n *Node) GetKeyInfoVersion(keyType types.KeyType, walletID string) (int, e case types.KeyTypeEd25519: walletKey = fmt.Sprintf("eddsa:%s", walletID) default: - return 0, fmt.Errorf("invalid key type: %s", keyType) + return -1, fmt.Errorf("invalid key type: %s", keyType) } keyInfo, err := n.keyinfoStore.Get(walletKey) if err != nil { - return 0, err + return -1, err } return int(keyInfo.Version), nil } @@ -306,6 +306,11 @@ func (n *Node) generatePartyIDs(purpose session.Purpose, readyPeerIDs []string, func createPartyID(nodeID string, label string, version int) *tss.PartyID { partyID := uuid.NewString() moniker := nodeID + ":" + label - key := big.NewInt(0).SetBytes([]byte(nodeID + ":" + strconv.Itoa(version))) + var key *big.Int + if version == -1 { + key = big.NewInt(0).SetBytes([]byte(nodeID)) + } else { + key = big.NewInt(0).SetBytes([]byte(nodeID + ":" + strconv.Itoa(version))) + } return tss.NewPartyID(partyID, moniker, key) } From 2916a71aec99ea5c6726335d2916e3e3c03f3e1f Mon Sep 17 00:00:00 2001 From: vietddude Date: Sun, 15 Jun 2025 13:46:05 +0700 Subject: [PATCH 22/34] Update key version for backward compatibility This commit modifies the handling of key versions across multiple files in the MPC package. The `KeyInfo` struct now uses an `int` for the version field instead of `uint16`, and the `DefaultVersion` constant is updated from 0 to 1. Additionally, the `GetKeyInfoVersion` method in the `Node` struct now returns 0 for invalid key types and errors, improving error handling consistency. The `consumeResharingEvent` method in the `eventConsumer` is also updated to reflect these changes, ensuring better clarity and robustness in version management. --- pkg/eventconsumer/event_consumer.go | 4 ++-- pkg/keyinfo/keyinfo.go | 2 +- pkg/mpc/node/node.go | 8 ++++---- pkg/mpc/session/base.go | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index d1ca4aa..c8b168f 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -352,9 +352,9 @@ func (ec *eventConsumer) consumeResharingEvent() error { return } - // Default is -1 if no key found + // Default is 0 if no keyVersion found keyInfoVersion, _ := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) - + fmt.Println("keyInfoVersion", keyInfoVersion) oldSession, err := ec.node.CreateResharingSession( true, msg.KeyType, diff --git a/pkg/keyinfo/keyinfo.go b/pkg/keyinfo/keyinfo.go index b98e781..49d4c7f 100644 --- a/pkg/keyinfo/keyinfo.go +++ b/pkg/keyinfo/keyinfo.go @@ -11,7 +11,7 @@ import ( type KeyInfo struct { ParticipantPeerIDs []string `json:"participant_peer_ids"` Threshold int `json:"threshold"` - Version uint16 `json:"version"` + Version int `json:"version"` } type store struct { diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index 9feaba9..c359b73 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -20,7 +20,7 @@ import ( ) // DefaultVersion is the default version for keygen and resharing -const DefaultVersion = 0 +const DefaultVersion = 1 type Node struct { nodeID string @@ -224,11 +224,11 @@ func (n *Node) GetKeyInfoVersion(keyType types.KeyType, walletID string) (int, e case types.KeyTypeEd25519: walletKey = fmt.Sprintf("eddsa:%s", walletID) default: - return -1, fmt.Errorf("invalid key type: %s", keyType) + return 0, fmt.Errorf("invalid key type: %s", keyType) } keyInfo, err := n.keyinfoStore.Get(walletKey) if err != nil { - return -1, err + return 0, err } return int(keyInfo.Version), nil } @@ -307,7 +307,7 @@ func createPartyID(nodeID string, label string, version int) *tss.PartyID { partyID := uuid.NewString() moniker := nodeID + ":" + label var key *big.Int - if version == -1 { + if version == 0 { key = big.NewInt(0).SetBytes([]byte(nodeID)) } else { key = big.NewInt(0).SetBytes([]byte(nodeID + ":" + strconv.Itoa(version))) diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 675122b..57e33dd 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -191,7 +191,7 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, version in keyInfo := keyinfo.KeyInfo{ ParticipantPeerIDs: participantPeerIDs, Threshold: threshold, - Version: uint16(version), + Version: version, } composeKey := s.composeKey(s.walletID) err = s.keyinfoStore.Save(composeKey, &keyInfo) From 9c3161ff73050489176f6afe50a289e9f358f2e1 Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 16 Jun 2025 16:11:50 +0700 Subject: [PATCH 23/34] Refactor event handling and message classification in MPC package This commit introduces several key updates across the MPC package, including: - Updated the `TypeResharingSuccess` constant to include a version number for better version tracking. - Modified the `Listen` method in the session interface to remove the `isResharing` parameter, simplifying the method signature. - Implemented `ClassifyMsg` methods in both `ECDSAParty` and `EDDSAParty` to classify messages and determine their round and broadcast status. - Enhanced logging in the `Send` and `receive` methods to include message round information, improving traceability. These changes aim to improve the clarity and maintainability of the MPC package, particularly in event processing and message handling. --- pkg/event/event.go | 2 +- pkg/eventconsumer/event_consumer.go | 2 +- pkg/mpc/party/base.go | 3 +- pkg/mpc/party/ecdsa.go | 17 +++++++++ pkg/mpc/party/ecdsa_round.go | 56 +++++++++++++++++++++++++++++ pkg/mpc/party/eddsa.go | 17 +++++++++ pkg/mpc/party/eddsa_round.go | 38 ++++++++++++++++++++ pkg/mpc/session/base.go | 31 ++++++++++------ 8 files changed, 152 insertions(+), 14 deletions(-) create mode 100644 pkg/mpc/party/ecdsa_round.go create mode 100644 pkg/mpc/party/eddsa_round.go diff --git a/pkg/event/event.go b/pkg/event/event.go index 0d333dd..1def269 100644 --- a/pkg/event/event.go +++ b/pkg/event/event.go @@ -6,7 +6,7 @@ const ( TypeGenerateWalletSuccess = "mpc.mpc_keygen_success.%s" TypeSigningResultComplete = "mpc.mpc_signing_result_complete.%s.%s" - TypeResharingSuccess = "mpc.mpc_resharing_success.%s" + TypeResharingSuccess = "mpc.mpc_resharing_success.%s.%d" ) type KeygenSuccessEvent struct { diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index c8b168f..07d6cae 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -25,7 +25,7 @@ const ( MPCResharingEvent = "mpc:reshare" // Default version for keygen - DefaultVersion int = 0 + DefaultVersion int = 1 ) type EventConsumer interface { diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index ea74980..78d1283 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -19,6 +19,7 @@ type PartyInterface interface { PartyIDs() []*tss.PartyID GetSaveData() []byte SetSaveData(saveData []byte) + ClassifyMsg(msgBytes []byte) (uint8, bool, error) InCh() chan types.TssMessage OutCh() chan tss.Message ErrCh() chan error @@ -62,7 +63,7 @@ func (p *party) ErrCh() chan error { // runParty handles the common party execution loop func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, send func(tss.Message), endCh chan T, finish func([]byte)) { - // Start the party in a goroutine + // Start the party in a goroutine to handle errors go func() { logger.Info("Starting party", "partyID", s.PartyID().String()) if err := party.Start(); err != nil { diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index f520b01..b79f42f 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -12,6 +12,8 @@ import ( "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/golang/protobuf/ptypes/any" + "google.golang.org/protobuf/proto" ) type ECDSAParty struct { @@ -47,6 +49,21 @@ func (s *ECDSAParty) SetSaveData(saveData []byte) { s.saveData = &localSaveData } +func (s *ECDSAParty) ClassifyMsg(msgBytes []byte) (uint8, bool, error) { + msg := &any.Any{} + if err := proto.Unmarshal(msgBytes, msg); err != nil { + return 0, false, err + } + + _, isBroadcast := ecdsaBroadcastMessages[msg.TypeUrl] + + round := ecdsaMsgURL2Round[msg.TypeUrl] + if round > 4 { + round = round - 4 + } + return round, isBroadcast, nil +} + func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { end := make(chan *keygen.LocalPartySaveData, 1) params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) diff --git a/pkg/mpc/party/ecdsa_round.go b/pkg/mpc/party/ecdsa_round.go new file mode 100644 index 0000000..0f38c05 --- /dev/null +++ b/pkg/mpc/party/ecdsa_round.go @@ -0,0 +1,56 @@ +package party + +var ( + ecdsaMsgURL2Round = map[string]uint8{ + // DKG + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound1Message": 1, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound2Message1": 2, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound2Message2": 3, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound3Message": 4, + + // Signing + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound1Message1": 5, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound1Message2": 6, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound2Message": 7, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound3Message": 8, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound4Message": 9, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound5Message": 10, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound6Message": 11, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound7Message": 12, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound8Message": 13, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound9Message": 14, + + // Resharing + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound1Message": 15, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound2Message1": 16, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound2Message2": 17, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound3Message1": 18, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound3Message2": 19, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound4Message1": 20, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound4Message2": 21, + } + + ecdsaBroadcastMessages = map[string]struct{}{ + // DKG + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound1Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound2Message2": {}, + "type.googleapis.com/binance.tsslib.ecdsa.keygen.KGRound3Message": {}, + + // Signing + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound1Message2": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound3Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound4Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound5Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound6Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound7Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound8Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.signing.SignRound9Message": {}, + + // Resharing + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound1Message": {}, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound2Message1": {}, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound2Message2": {}, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound3Message1": {}, + "type.googleapis.com/binance.tsslib.ecdsa.resharing.DGRound4Message2": {}, + } +) diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index cd6c981..f90741c 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -12,6 +12,8 @@ import ( "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" "github.com/bnb-chain/tss-lib/v2/eddsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/golang/protobuf/ptypes/any" + "google.golang.org/protobuf/proto" ) type EDDSAParty struct { @@ -48,6 +50,21 @@ func (s *EDDSAParty) SetSaveData(shareData []byte) { s.saveData = &localSaveData } +func (s *EDDSAParty) ClassifyMsg(msgBytes []byte) (uint8, bool, error) { + msg := &any.Any{} + if err := proto.Unmarshal(msgBytes, msg); err != nil { + return 0, false, err + } + + _, isBroadcast := eddsaBroadcastMessages[msg.TypeUrl] + + round := eddsaMsgURL2Round[msg.TypeUrl] + if round > 4 { + round = round - 4 + } + return round, isBroadcast, nil +} + func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { end := make(chan *keygen.LocalPartySaveData, 1) params := tss.NewParameters(tss.Edwards(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) diff --git a/pkg/mpc/party/eddsa_round.go b/pkg/mpc/party/eddsa_round.go new file mode 100644 index 0000000..7f0c874 --- /dev/null +++ b/pkg/mpc/party/eddsa_round.go @@ -0,0 +1,38 @@ +package party + +var ( + eddsaMsgURL2Round = map[string]uint8{ + // DKG + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound1Message": 1, + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound2Message1": 2, + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound2Message2": 3, + + // Signing + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound1Message": 4, + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound2Message": 5, + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound3Message": 6, + + // Resharing + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound1Message": 7, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound2Message": 8, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound3Message1": 9, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound3Message2": 10, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound4Message": 11, + } + + eddsaBroadcastMessages = map[string]struct{}{ + // DKG + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound1Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.keygen.KGRound2Message2": {}, + + // Signing + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound1Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound2Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.signing.SignRound3Message": {}, + + // Resharing + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound1Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound2Message": {}, + "type.googleapis.com/binance.tsslib.eddsa.resharing.DGRound4Message": {}, + } +) diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 57e33dd..bb8ed0e 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -50,7 +50,7 @@ type Session interface { PartyIDs() []*tss.PartyID Send(msg tss.Message) - Listen(nodeID string, isResharing bool) + Listen(nodeID string) SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) ErrCh() chan error } @@ -125,7 +125,16 @@ func (s *session) Send(msg tss.Message) { s.errCh <- fmt.Errorf("failed to marshal message: %w", err) return } - logger.Debug("Sending message", "from", routing.From, "to", routing.To, "isBroadcast", routing.IsBroadcast) + round, _, err := s.party.ClassifyMsg(data) + if err != nil { + s.errCh <- fmt.Errorf("failed to classify message: %w", err) + return + } + toNodeIDs := make([]string, len(routing.To)) + for i, to := range routing.To { + toNodeIDs[i] = getRoutingFromPartyID(to) + } + logger.Debug("Sending message", "from", routing.From.Moniker, "to", toNodeIDs, "isBroadcast", routing.IsBroadcast, "round", round) if routing.IsBroadcast && len(routing.To) == 0 { err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msgBytes) @@ -148,13 +157,8 @@ func (s *session) Send(msg tss.Message) { // Listen is a wrapper around the party's Listen method // It subscribes to the broadcast and self direct topics -func (s *session) Listen(nodeID string, isResharingParty bool) { - var selfDirectTopic string - if isResharingParty { - selfDirectTopic = s.topicComposer.ComposeDirectTopic(getRoutingFromPartyID(s.party.PartyID())) - } else { - selfDirectTopic = s.topicComposer.ComposeDirectTopic(getRoutingFromPartyID(s.party.PartyID())) - } +func (s *session) Listen(nodeID string) { + selfDirectTopic := s.topicComposer.ComposeDirectTopic(getRoutingFromPartyID(s.party.PartyID())) broadcast := func() { sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { msg := natMsg.Data @@ -206,7 +210,7 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, version in return } logger.Info("Saved key", "walletID", s.walletID, "threshold", threshold, "version", version, "data", len(data)) - return nil + return } func (s *session) SetSaveData(saveBytes []byte) { @@ -251,7 +255,12 @@ func (s *session) receive(rawMsg []byte) { isToSelf := slices.Contains(toIDs, s.party.PartyID().String()) if isBroadcast || isToSelf { - logger.Debug("Received message", "from", msg.From, "to", msg.To, "isBroadcast", msg.IsBroadcast, "isToSelf", isToSelf) + round, _, err := s.party.ClassifyMsg(msg.MsgBytes) + if err != nil { + s.errCh <- fmt.Errorf("failed to classify message: %w", err) + return + } + logger.Debug("Received message", "from", msg.From.Moniker, "round", round, "isBroadcast", msg.IsBroadcast, "isToSelf", isToSelf) s.mu.Lock() defer s.mu.Unlock() s.party.InCh() <- *msg From 2a62b9e43a786d993e12d5ba4de2ad1774e8529d Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 16 Jun 2025 17:53:11 +0700 Subject: [PATCH 24/34] Refactor Listen method in session interface and event consumer This commit updates the `Listen` method in the session interface to remove the `nodeID` parameter, simplifying its signature. Corresponding changes are made in the `eventConsumer` to reflect this update, enhancing the clarity and maintainability of event handling across the MPC package. These modifications aim to streamline the listening process for various session types. --- pkg/eventconsumer/event_consumer.go | 11 +++++------ pkg/mpc/session/base.go | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 07d6cae..c41f5b7 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -133,8 +133,8 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { } // Start listening for messages first - go ecdsaSession.Listen(ec.node.ID(), false) - go eddsaSession.Listen(ec.node.ID(), false) + go ecdsaSession.Listen() + go eddsaSession.Listen() successEvent := &event.KeygenSuccessEvent{ WalletID: walletID, } @@ -201,7 +201,6 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { } logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) }) - ec.keyGenerationSub = sub if err != nil { return err @@ -265,7 +264,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return } - go signingSession.Listen(ec.node.ID(), false) + go signingSession.Listen() txBigInt := new(big.Int).SetBytes(msg.Tx) go func() { @@ -381,8 +380,8 @@ func (ec *eventConsumer) consumeResharingEvent() error { return } - go oldSession.Listen(ec.node.ID(), false) - go newSession.Listen(ec.node.ID(), true) + go oldSession.Listen() + go newSession.Listen() successEvent := &event.ResharingSuccessEvent{ WalletID: msg.WalletID, diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index bb8ed0e..25cb1ec 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -50,7 +50,7 @@ type Session interface { PartyIDs() []*tss.PartyID Send(msg tss.Message) - Listen(nodeID string) + Listen() SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) ErrCh() chan error } @@ -157,7 +157,7 @@ func (s *session) Send(msg tss.Message) { // Listen is a wrapper around the party's Listen method // It subscribes to the broadcast and self direct topics -func (s *session) Listen(nodeID string) { +func (s *session) Listen() { selfDirectTopic := s.topicComposer.ComposeDirectTopic(getRoutingFromPartyID(s.party.PartyID())) broadcast := func() { sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { From 34941508c04c360ede01ae637feeaa9b0afbf66f Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 16 Jun 2025 18:02:16 +0700 Subject: [PATCH 25/34] Refactor key generation event handling in event consumer This commit enhances the `consumeKeyGenerationEvent` method by introducing a context with a timeout for better control over key generation processes. The event handling logic is refactored to improve readability and maintainability, including the use of goroutines for concurrent processing of ECDSA and EDDSA sessions. Additionally, error handling and logging are improved to ensure better traceability of issues during key generation. These changes aim to streamline the key generation workflow and enhance the robustness of the event consumer. --- pkg/eventconsumer/event_consumer.go | 168 +++++++++++++++------------- 1 file changed, 89 insertions(+), 79 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index c41f5b7..e30ebcb 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -101,110 +101,120 @@ func (ec *eventConsumer) Run() { logger.Info("MPC Event consumer started...!") } - func (ec *eventConsumer) consumeKeyGenerationEvent() error { sub, err := ec.pubsub.Subscribe(MPCGenerateEvent, func(natMsg *nats.Msg) { - raw := natMsg.Data - var msg types.GenerateKeyMessage - err := json.Unmarshal(raw, &msg) - if err != nil { - logger.Error("Failed to unmarshal signing message", err) - return - } - logger.Info("Received key generation event", "msg", msg) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - err = ec.identityStore.VerifyInitiatorMessage(&msg) - if err != nil { - logger.Error("Failed to verify initiator message", err) - return - } + done := make(chan struct{}) - walletID := msg.WalletID - ecdsaSession, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) - if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) - return - } + go func() { + defer close(done) - eddsaSession, err := ec.node.CreateKeygenSession(types.KeyTypeEd25519, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) - if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) - return - } + raw := natMsg.Data + var msg types.GenerateKeyMessage + if err := json.Unmarshal(raw, &msg); err != nil { + logger.Error("Failed to unmarshal signing message", err) + return + } + logger.Info("Received key generation event", "msg", msg) - // Start listening for messages first - go ecdsaSession.Listen() - go eddsaSession.Listen() - successEvent := &event.KeygenSuccessEvent{ - WalletID: walletID, - } + if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { + logger.Error("Failed to verify initiator message", err) + return + } - var wg sync.WaitGroup - wg.Add(2) + walletID := msg.WalletID + ecdsaSession, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + if err != nil { + logger.Error("Failed to create key generation session", err, "walletID", walletID) + return + } + eddsaSession, err := ec.node.CreateKeygenSession(types.KeyTypeEd25519, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + if err != nil { + logger.Error("Failed to create key generation session", err, "walletID", walletID) + return + } - // Handle errors from the session - go func() { - for { - select { - case err := <-ecdsaSession.ErrCh(): - logger.Error("Error from ECDSA session", err) - return - case err := <-eddsaSession.ErrCh(): - logger.Error("Error from EDDSA session", err) - return + // Start listening for messages + go ecdsaSession.Listen() + go eddsaSession.Listen() + + successEvent := &event.KeygenSuccessEvent{WalletID: walletID} + var wg sync.WaitGroup + wg.Add(2) + + // session error monitoring + go func() { + for { + select { + case err := <-ecdsaSession.ErrCh(): + logger.Error("Error from ECDSA session", err) + return + case err := <-eddsaSession.ErrCh(): + logger.Error("Error from EDDSA session", err) + return + case <-ctx.Done(): + return + } + } + }() + + go ecdsaSession.StartKeygen(ctx, ecdsaSession.Send, func(data []byte) { + ecdsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) + if pubKey, err := ecdsaSession.GetPublicKey(data); err == nil { + successEvent.ECDSAPubKey = pubKey + } + wg.Done() + }) + + go eddsaSession.StartKeygen(ctx, eddsaSession.Send, func(data []byte) { + eddsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) + if pubKey, err := eddsaSession.GetPublicKey(data); err == nil { + successEvent.EDDSAPubKey = pubKey } + wg.Done() + }) + + wg.Wait() + + select { + case <-ctx.Done(): + logger.Warn("Keygen timed out", "walletID", walletID) + return + default: + // all done } - }() - // Start the key generation process - ecdsaCtx, ecdsaCancel := context.WithTimeout(context.Background(), 30*time.Second) - go ecdsaSession.StartKeygen(ecdsaCtx, ecdsaSession.Send, func(data []byte) { - ecdsaCancel() - ecdsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) - ecdsaPubKey, err := ecdsaSession.GetPublicKey(data) + successEventBytes, err := json.Marshal(successEvent) if err != nil { - logger.Error("Failed to get ECDSA public key", err) + logger.Error("Failed to marshal keygen success event", err) return } - successEvent.ECDSAPubKey = ecdsaPubKey - wg.Done() - }) - eddsaCtx, eddsaCancel := context.WithTimeout(context.Background(), 30*time.Second) - go eddsaSession.StartKeygen(eddsaCtx, eddsaSession.Send, func(data []byte) { - eddsaCancel() - eddsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) - eddsaPubKey, err := eddsaSession.GetPublicKey(data) + err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), + }) if err != nil { - logger.Error("Failed to get EDDSA public key", err) + logger.Error("Failed to publish key generation success message", err) return } - successEvent.EDDSAPubKey = eddsaPubKey - wg.Done() - }) - - wg.Wait() - // Marshal the success event - successEventBytes, err := json.Marshal(successEvent) - if err != nil { - logger.Error("Failed to marshal keygen success event", err) - return - } + logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) + }() - err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), - }) - if err != nil { - logger.Error("Failed to publish key generation success message", err) - return + select { + case <-ctx.Done(): + logger.Warn("Keygen handler exceeded timeout") + case <-done: + // done successfully } - logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) }) - ec.keyGenerationSub = sub + if err != nil { return err } + ec.keyGenerationSub = sub return nil } From 52f24384a3f17707d9bc58085376430e19602f0c Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 17 Jun 2025 11:45:21 +0700 Subject: [PATCH 26/34] Implement Close method for PartyInterface and Session to manage resource cleanup --- pkg/eventconsumer/event_consumer.go | 364 ++++++++++++++-------------- pkg/mpc/party/base.go | 6 + pkg/mpc/session/base.go | 25 +- 3 files changed, 210 insertions(+), 185 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index e30ebcb..15d7905 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -14,6 +14,7 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/node" + "github.com/fystack/mpcium/pkg/mpc/session" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -103,118 +104,101 @@ func (ec *eventConsumer) Run() { } func (ec *eventConsumer) consumeKeyGenerationEvent() error { sub, err := ec.pubsub.Subscribe(MPCGenerateEvent, func(natMsg *nats.Msg) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - done := make(chan struct{}) - - go func() { - defer close(done) + if err := ec.handleKeyGenerationEvent(ctx, natMsg.Data); err != nil { + logger.Error("Failed to handle key generation event", err) + } + }) - raw := natMsg.Data - var msg types.GenerateKeyMessage - if err := json.Unmarshal(raw, &msg); err != nil { - logger.Error("Failed to unmarshal signing message", err) - return - } - logger.Info("Received key generation event", "msg", msg) + if err != nil { + return err + } - if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { - logger.Error("Failed to verify initiator message", err) - return - } + ec.keyGenerationSub = sub + return nil +} - walletID := msg.WalletID - ecdsaSession, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) - if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) - return - } - eddsaSession, err := ec.node.CreateKeygenSession(types.KeyTypeEd25519, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) - if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) - return - } +func (ec *eventConsumer) handleKeyGenerationEvent(ctx context.Context, raw []byte) error { + var msg types.GenerateKeyMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return fmt.Errorf("unmarshal message: %w", err) + } + logger.Info("Received key generation event", "msg", msg) - // Start listening for messages - go ecdsaSession.Listen() - go eddsaSession.Listen() - - successEvent := &event.KeygenSuccessEvent{WalletID: walletID} - var wg sync.WaitGroup - wg.Add(2) - - // session error monitoring - go func() { - for { - select { - case err := <-ecdsaSession.ErrCh(): - logger.Error("Error from ECDSA session", err) - return - case err := <-eddsaSession.ErrCh(): - logger.Error("Error from EDDSA session", err) - return - case <-ctx.Done(): - return - } - } - }() + if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { + return fmt.Errorf("verify initiator: %w", err) + } - go ecdsaSession.StartKeygen(ctx, ecdsaSession.Send, func(data []byte) { - ecdsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) - if pubKey, err := ecdsaSession.GetPublicKey(data); err == nil { - successEvent.ECDSAPubKey = pubKey - } - wg.Done() - }) + walletID := msg.WalletID + successEvent := &event.KeygenSuccessEvent{WalletID: walletID} + var wg sync.WaitGroup - go eddsaSession.StartKeygen(ctx, eddsaSession.Send, func(data []byte) { - eddsaSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) - if pubKey, err := eddsaSession.GetPublicKey(data); err == nil { - successEvent.EDDSAPubKey = pubKey - } - wg.Done() - }) + // Start ECDSA and EDDSA sessions + for _, keyType := range []types.KeyType{types.KeyTypeSecp256k1, types.KeyTypeEd25519} { + kgSession, err := ec.node.CreateKeygenSession(keyType, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + if err != nil { + return fmt.Errorf("create %v session: %w", keyType, err) + } - wg.Wait() + defer kgSession.Close() - select { - case <-ctx.Done(): - logger.Warn("Keygen timed out", "walletID", walletID) - return - default: - // all done - } + go kgSession.Listen() + wg.Add(1) - successEventBytes, err := json.Marshal(successEvent) - if err != nil { - logger.Error("Failed to marshal keygen success event", err) - return - } + go func(s session.Session, kt types.KeyType) { + defer wg.Done() - err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), + s.StartKeygen(ctx, s.Send, func(data []byte) { + err := s.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) + if err != nil { + logger.Error("Failed to save key", err) + } + logger.Info("Saved key", "type", kt, "walletID", walletID, "threshold", ec.mpcThreshold, "version", DefaultVersion, "data", len(data)) + if pubKey, err := s.GetPublicKey(data); err == nil { + switch kt { + case types.KeyTypeSecp256k1: + successEvent.ECDSAPubKey = pubKey + case types.KeyTypeEd25519: + successEvent.EDDSAPubKey = pubKey + } + } }) + if err != nil { - logger.Error("Failed to publish key generation success message", err) - return + logger.Error("Keygen failed", err, "keyType", kt) } + }(kgSession, keyType) + } - logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) - }() + doneCh := make(chan struct{}) + go func() { + wg.Wait() + close(doneCh) + }() + + select { + case <-ctx.Done(): + logger.Warn("Keygen timed out", "walletID", walletID) + return ctx.Err() + case <-doneCh: + // All keygens done + } - select { - case <-ctx.Done(): - logger.Warn("Keygen handler exceeded timeout") - case <-done: - // done successfully - } - }) + successBytes, err := json.Marshal(successEvent) + if err != nil { + return fmt.Errorf("marshal keygen error: %w", err) + } + err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successBytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), + }) if err != nil { - return err + return fmt.Errorf("enqueue keygen error: %w", err) } - ec.keyGenerationSub = sub + + logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) return nil } @@ -346,114 +330,126 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { func (ec *eventConsumer) consumeResharingEvent() error { sub, err := ec.pubsub.Subscribe(MPCResharingEvent, func(natMsg *nats.Msg) { - raw := natMsg.Data - var msg types.ResharingMessage - err := json.Unmarshal(raw, &msg) - if err != nil { - logger.Error("Failed to unmarshal resharing message", err) - return - } - logger.Info("Received resharing event", "walletID", msg.WalletID, "oldThreshold", ec.mpcThreshold, "newThreshold", msg.NewThreshold) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() - err = ec.identityStore.VerifyInitiatorMessage(&msg) - if err != nil { - logger.Error("Failed to verify initiator message", err) - return + if err := ec.handleReshareEvent(ctx, natMsg.Data); err != nil { + logger.Error("Failed to handle resharing event", err) } + }) + if err != nil { + return err + } - // Default is 0 if no keyVersion found - keyInfoVersion, _ := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) - fmt.Println("keyInfoVersion", keyInfoVersion) - oldSession, err := ec.node.CreateResharingSession( - true, - msg.KeyType, - msg.WalletID, - ec.mpcThreshold, - keyInfoVersion, - ec.resharingResultQueue, - ) - if err != nil { - logger.Error("Failed to create resharing session", err) - return - } + ec.resharingSub = sub + return nil +} - newSession, err := ec.node.CreateResharingSession( - false, - msg.KeyType, - msg.WalletID, - msg.NewThreshold, - keyInfoVersion, // Increment inside the session - ec.resharingResultQueue, - ) - if err != nil { - logger.Error("Failed to create resharing session", err) - return - } +func (ec *eventConsumer) handleReshareEvent(ctx context.Context, raw []byte) error { + var msg types.ResharingMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return fmt.Errorf("unmarshal message: %w", err) + } + logger.Info("Received resharing event", + "walletID", msg.WalletID, + "oldThreshold", ec.mpcThreshold, + "newThreshold", msg.NewThreshold) - go oldSession.Listen() - go newSession.Listen() + if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { + return fmt.Errorf("verify initiator: %w", err) + } - successEvent := &event.ResharingSuccessEvent{ - WalletID: msg.WalletID, - } + keyInfoVersion, _ := ec.node.GetKeyInfoVersion(msg.KeyType, msg.WalletID) - var wg sync.WaitGroup - wg.Add(2) + oldSession, err := ec.node.CreateResharingSession(true, msg.KeyType, msg.WalletID, ec.mpcThreshold, keyInfoVersion, ec.resharingResultQueue) + if err != nil { + return fmt.Errorf("create old session: %w", err) + } - go func() { - for { - select { - case err := <-oldSession.ErrCh(): - logger.Error("Error from ECDSA session", err) - case err := <-newSession.ErrCh(): - logger.Error("Error from EDDSA session", err) - } - } - }() + newSession, err := ec.node.CreateResharingSession(false, msg.KeyType, msg.WalletID, msg.NewThreshold, keyInfoVersion, ec.resharingResultQueue) + if err != nil { + return fmt.Errorf("create new session: %w", err) + } - oldCtx, oldCancel := context.WithTimeout(context.Background(), 30*time.Second) - go oldSession.StartResharing(oldCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, oldSession.Send, func(data []byte) { - // Old session is done, no need to save - oldCancel() - wg.Done() - }) - - newCtx, newCancel := context.WithTimeout(context.Background(), 30*time.Second) - go newSession.StartResharing(newCtx, oldSession.PartyIDs(), newSession.PartyIDs(), ec.mpcThreshold, msg.NewThreshold, newSession.Send, func(data []byte) { - newCancel() - // Only save for new parties - ecdsaPubKey, err := newSession.GetPublicKey(data) - if err != nil { - logger.Error("Failed to get ECDSA public key", err) - return - } - newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), msg.NewThreshold, keyInfoVersion+1, data) - successEvent.ECDSAPubKey = ecdsaPubKey - wg.Done() - }) + go oldSession.Listen() + go newSession.Listen() - wg.Wait() + successEvent := &event.ResharingSuccessEvent{WalletID: msg.WalletID} - successEventBytes, err := json.Marshal(successEvent) - if err != nil { - logger.Error("Failed to marshal resharing success event", err) - return - } + var wg sync.WaitGroup + wg.Add(2) - err = ec.resharingResultQueue.Enqueue(event.ResharingSuccessEventTopic, successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(event.TypeResharingSuccess, msg.WalletID), - }) - if err != nil { - logger.Error("Failed to publish resharing result event", err) - return + // Error monitor + go func() { + for { + select { + case err := <-oldSession.ErrCh(): + logger.Error("Error from old session", err) + case err := <-newSession.ErrCh(): + logger.Error("Error from new session", err) + } } - logger.Info("[COMPLETED RESH] Resharing completed successfully", "walletID", msg.WalletID) - }) + }() + + // Start old session + go func() { + ctxOld, cancelOld := context.WithCancel(ctx) + defer cancelOld() + oldSession.StartResharing(ctxOld, + oldSession.PartyIDs(), + newSession.PartyIDs(), + ec.mpcThreshold, + msg.NewThreshold, + oldSession.Send, + func([]byte) { wg.Done() }, + ) + }() + + // Start new session + go func() { + ctxNew, cancelNew := context.WithCancel(ctx) + defer cancelNew() + newSession.StartResharing(ctxNew, + oldSession.PartyIDs(), + newSession.PartyIDs(), + ec.mpcThreshold, + msg.NewThreshold, + newSession.Send, + func(data []byte) { + if pubKey, err := newSession.GetPublicKey(data); err == nil { + newSession.SaveKey(ec.node.GetReadyPeersIncludeSelf(), msg.NewThreshold, keyInfoVersion+1, data) + if msg.KeyType == types.KeyTypeSecp256k1 { + successEvent.ECDSAPubKey = pubKey + } else { + successEvent.EDDSAPubKey = pubKey + } + } else { + logger.Error("Failed to get public key", err) + } + wg.Done() + }, + ) + }() - ec.resharingSub = sub + wg.Wait() + + eventBytes, err := json.Marshal(successEvent) if err != nil { - return err + return fmt.Errorf("marshal success event: %w", err) } + + err = ec.resharingResultQueue.Enqueue( + event.ResharingSuccessEventTopic, + eventBytes, + &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(event.TypeResharingSuccess, msg.WalletID, keyInfoVersion+1), + }, + ) + if err != nil { + return fmt.Errorf("enqueue resharing success: %w", err) + } + + logger.Info("[COMPLETED RESH] Resharing completed successfully", "walletID", msg.WalletID) return nil } diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 78d1283..98feaf8 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -23,6 +23,7 @@ type PartyInterface interface { InCh() chan types.TssMessage OutCh() chan tss.Message ErrCh() chan error + Close() } type party struct { @@ -61,6 +62,11 @@ func (p *party) ErrCh() chan error { return p.errCh } +func (p *party) Close() { + close(p.inCh) + close(p.outCh) +} + // runParty handles the common party execution loop func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, send func(tss.Message), endCh chan T, finish func([]byte)) { // Start the party in a goroutine to handle errors diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 25cb1ec..0662b13 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -53,6 +53,7 @@ type Session interface { Listen() SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) ErrCh() chan error + Close() } type session struct { @@ -209,7 +210,6 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, version in s.errCh <- fmt.Errorf("failed to save key: %w", err) return } - logger.Info("Saved key", "walletID", s.walletID, "threshold", threshold, "version", version, "data", len(data)) return } @@ -227,6 +227,29 @@ func (s *session) GetSaveData() ([]byte, error) { return data, nil } +func (s *session) Close() { + // Close subscriptions first + if s.broadcastSub != nil { + s.broadcastSub.Unsubscribe() + } + if s.directSub != nil { + s.directSub.Unsubscribe() + } + + // Close party + if s.party != nil { + s.party.Close() + } + + // Close error channel last + select { + case <-s.errCh: + // Channel already closed + default: + close(s.errCh) + } +} + // receive is a helper function that receives a message from the party func (s *session) receive(rawMsg []byte) { msg, err := types.UnmarshalTssMessage(rawMsg) From 8b340faf88c21feeb8364f0275932c5d8d6b71d4 Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 17 Jun 2025 13:27:54 +0700 Subject: [PATCH 27/34] Update GetSaveData method to accept party version for improved data management --- pkg/mpc/node/node.go | 8 ++++---- pkg/mpc/session/base.go | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index c359b73..e78583a 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -120,7 +120,7 @@ func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID n.kvstore, n.keyinfoStore, ) - saveData, err := ecdsaSession.GetSaveData() + saveData, err := ecdsaSession.GetSaveData(partyVersion) if err != nil { return nil, fmt.Errorf("failed to get save data: %w", err) } @@ -141,7 +141,7 @@ func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID n.keyinfoStore, ) - saveData, err := eddsaSession.GetSaveData() + saveData, err := eddsaSession.GetSaveData(partyVersion) if err != nil { return nil, fmt.Errorf("failed to get save data: %w", err) } @@ -175,7 +175,7 @@ func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, wa } ecdsaSession := session.NewECDSASession(walletID, selfPartyID, partyIDs, threshold, *preparams, n.pubSub, n.direct, n.identityStore, n.kvstore, n.keyinfoStore) if isOldParty { - saveData, err := ecdsaSession.GetSaveData() + saveData, err := ecdsaSession.GetSaveData(partyVersion) if err != nil { return nil, fmt.Errorf("failed to get save data: %w", err) } @@ -194,7 +194,7 @@ func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, wa return ecdsaSession, nil case types.KeyTypeEd25519: eddsaSession := session.NewEDDSASession(walletID, selfPartyID, partyIDs, threshold, n.pubSub, n.direct, n.identityStore, n.kvstore, n.keyinfoStore) - saveData, err := eddsaSession.GetSaveData() + saveData, err := eddsaSession.GetSaveData(partyVersion) if err != nil { return nil, fmt.Errorf("failed to get save data: %w", err) } diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 0662b13..0af8a55 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -44,7 +44,7 @@ type Session interface { StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) StartResharing(ctx context.Context, oldPartyIDs []*tss.PartyID, newPartyIDs []*tss.PartyID, oldThreshold int, newThreshold int, send func(tss.Message), finish func([]byte)) - GetSaveData() ([]byte, error) + GetSaveData(version int) ([]byte, error) GetPublicKey(data []byte) ([]byte, error) VerifySignature(msg []byte, signature []byte) (*common.SignatureData, error) @@ -205,7 +205,7 @@ func (s *session) SaveKey(participantPeerIDs []string, threshold int, version in return } - err = s.kvstore.Put(composeKey, data) + err = s.kvstore.Put(fmt.Sprintf("%s-%d", composeKey, version), data) if err != nil { s.errCh <- fmt.Errorf("failed to save key: %w", err) return @@ -218,9 +218,9 @@ func (s *session) SetSaveData(saveBytes []byte) { } // GetSaveData gets the key from the kvstore -func (s *session) GetSaveData() ([]byte, error) { +func (s *session) GetSaveData(version int) ([]byte, error) { composeKey := s.composeKey(s.walletID) - data, err := s.kvstore.Get(composeKey) + data, err := s.kvstore.Get(fmt.Sprintf("%s-%d", composeKey, version)) if err != nil { return nil, fmt.Errorf("failed to get key: %w", err) } From 5dd1082048bbbc490eb425dc6e0775c0728c1ec0 Mon Sep 17 00:00:00 2001 From: anhthii Date: Thu, 26 Jun 2025 20:50:27 +0700 Subject: [PATCH 28/34] Improve format, rename finish -> onComplete --- pkg/mpc/party/base.go | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 98feaf8..72a73b4 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -11,9 +11,17 @@ import ( ) type PartyInterface interface { - StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) - StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) - StartResharing(ctx context.Context, oldPartyIDs, newPartyIDs []*tss.PartyID, oldThreshold, newThreshold int, send func(tss.Message), finish func([]byte)) + StartKeygen(ctx context.Context, send func(tss.Message), onComplete func([]byte)) + StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), onComplete func([]byte)) + StartResharing( + ctx context.Context, + oldPartyIDs, + newPartyIDs []*tss.PartyID, + oldThreshold, + newThreshold int, + send func(tss.Message), + onComplete func([]byte), + ) PartyID() *tss.PartyID PartyIDs() []*tss.PartyID @@ -68,7 +76,14 @@ func (p *party) Close() { } // runParty handles the common party execution loop -func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, send func(tss.Message), endCh chan T, finish func([]byte)) { +func runParty[T any]( + s PartyInterface, + ctx context.Context, + party tss.Party, + send func(tss.Message), + endCh chan T, + onComplete func([]byte), +) { // Start the party in a goroutine to handle errors go func() { logger.Info("Starting party", "partyID", s.PartyID().String()) @@ -92,12 +107,12 @@ func runParty[T any](s PartyInterface, ctx context.Context, party tss.Party, sen case out := <-s.OutCh(): send(out) case result := <-endCh: - bz, err := json.Marshal(result) + bytes, err := json.Marshal(result) if err != nil { s.ErrCh() <- err return } - finish(bz) + onComplete(bytes) return } } From cda8da2e48af740ed7c665431eefe269b8460c31 Mon Sep 17 00:00:00 2001 From: anhthii Date: Thu, 26 Jun 2025 20:55:08 +0700 Subject: [PATCH 29/34] Rename PartyInterface -> Party --- pkg/mpc/party/base.go | 4 ++-- pkg/mpc/session/base.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 72a73b4..9bb4f2f 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -10,7 +10,7 @@ import ( "github.com/fystack/mpcium/pkg/types" ) -type PartyInterface interface { +type Party interface { StartKeygen(ctx context.Context, send func(tss.Message), onComplete func([]byte)) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), onComplete func([]byte)) StartResharing( @@ -77,7 +77,7 @@ func (p *party) Close() { // runParty handles the common party execution loop func runParty[T any]( - s PartyInterface, + s Party, ctx context.Context, party tss.Party, send func(tss.Message), diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 0af8a55..59d7b03 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -58,7 +58,7 @@ type Session interface { type session struct { walletID string - party party.PartyInterface + party party.Party broadcastSub messaging.Subscription directSub messaging.Subscription From 5945f8f1f76ab8d4159924b2c04b076923168b72 Mon Sep 17 00:00:00 2001 From: anhthii Date: Thu, 26 Jun 2025 20:37:50 +0700 Subject: [PATCH 30/34] Experimental refactoring --- examples/generate/main.go | 26 +++++++-- pkg/eventconsumer/event_consumer.go | 83 +++++++++++++++++++++-------- pkg/messaging/message_queue.go | 8 ++- pkg/mpc/node/node.go | 1 - pkg/mpc/party/base.go | 11 +++- pkg/mpc/party/ecdsa.go | 12 +++++ pkg/mpc/party/eddsa.go | 14 +++++ pkg/mpc/session/base.go | 28 +++++++--- 8 files changed, 146 insertions(+), 37 deletions(-) diff --git a/examples/generate/main.go b/examples/generate/main.go index c8e64bf..f7d3972 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -1,10 +1,12 @@ package main import ( + "flag" "fmt" "os" "os/signal" "syscall" + "time" "github.com/fystack/mpcium/pkg/client" "github.com/fystack/mpcium/pkg/config" @@ -17,6 +19,11 @@ import ( func main() { const environment = "development" + + // Parse the -n flag + numWallets := flag.Int("n", 1, "Number of wallets to generate") + flag.Parse() + config.InitViperConfig() logger.Init(environment, false) @@ -25,13 +32,14 @@ func main() { if err != nil { logger.Fatal("Failed to connect to NATS", err) } - defer natsConn.Drain() // drain inflight msgs + defer natsConn.Drain() defer natsConn.Close() mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, KeyPath: "./event_initiator.key", }) + err = mpcClient.OnWalletCreationResult(func(event event.KeygenSuccessEvent) { logger.Info("Received wallet creation result", "event", event) }) @@ -39,11 +47,19 @@ func main() { logger.Fatal("Failed to subscribe to wallet-creation results", err) } - walletID := uuid.New().String() - if err := mpcClient.CreateWallet(walletID); err != nil { - logger.Fatal("CreateWallet failed", err) + for i := 0; i < *numWallets; i++ { + walletID := uuid.New().String() + if err := mpcClient.CreateWallet(walletID); err != nil { + logger.Error("CreateWallet failed", err) + continue + } + time.Sleep(100 * time.Millisecond) + logger.Info("CreateWallet sent", "walletID", walletID) } - logger.Info("CreateWallet sent, awaiting result...", "walletID", walletID) + + logger.Info("All CreateWallet requests sent, awaiting results...") + + // Wait for shutdown signal stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) <-stop diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 15d7905..ef34362 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -26,7 +26,9 @@ const ( MPCResharingEvent = "mpc:reshare" // Default version for keygen - DefaultVersion int = 1 + DefaultVersion int = 1 + SessionTimeout = 3 * time.Minute + MaxConcurrentSessions = 5 ) type EventConsumer interface { @@ -34,6 +36,11 @@ type EventConsumer interface { Close() error } +func Elaps(start time.Time, text string) { + elapsed := time.Since(start) + fmt.Printf("%s, Elapsed time: %d ms\n", text, elapsed.Milliseconds()) +} + type eventConsumer struct { node *node.Node pubsub messaging.PubSub @@ -47,6 +54,7 @@ type eventConsumer struct { signingSub messaging.Subscription resharingSub messaging.Subscription identityStore identity.Store + sessionLimiter chan struct{} // acts as a pool of session tokens // Track active sessions with timestamps for cleanup activeSessions map[string]time.Time // Maps "walletID-txID" to creation time @@ -103,13 +111,25 @@ func (ec *eventConsumer) Run() { logger.Info("MPC Event consumer started...!") } func (ec *eventConsumer) consumeKeyGenerationEvent() error { + // Create session limiter channel with capacity 5 + ec.sessionLimiter = make(chan struct{}, MaxConcurrentSessions) sub, err := ec.pubsub.Subscribe(MPCGenerateEvent, func(natMsg *nats.Msg) { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() + logger.Info("Received key generation event", "subject", natMsg.Subject) + // This blocks if max sessions are already running + ec.sessionLimiter <- struct{}{} - if err := ec.handleKeyGenerationEvent(ctx, natMsg.Data); err != nil { - logger.Error("Failed to handle key generation event", err) - } + go func(data []byte) { + defer func() { + <-ec.sessionLimiter // release slot + }() + + ctx, cancel := context.WithTimeout(context.Background(), SessionTimeout) + defer cancel() + + if err := ec.handleKeyGenerationEvent(ctx, data); err != nil { + logger.Error("Failed to handle key generation event", err) + } + }(natMsg.Data) }) if err != nil { @@ -137,25 +157,42 @@ func (ec *eventConsumer) handleKeyGenerationEvent(ctx context.Context, raw []byt // Start ECDSA and EDDSA sessions for _, keyType := range []types.KeyType{types.KeyTypeSecp256k1, types.KeyTypeEd25519} { - kgSession, err := ec.node.CreateKeygenSession(keyType, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + s, err := ec.node.CreateKeygenSession(keyType, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) if err != nil { return fmt.Errorf("create %v session: %w", keyType, err) } - - defer kgSession.Close() - - go kgSession.Listen() + start := time.Now() + s.Listen() + Elaps(start, "Listen") wg.Add(1) - go func(s session.Session, kt types.KeyType) { defer wg.Done() - + defer s.Close() + start := time.Now() s.StartKeygen(ctx, s.Send, func(data []byte) { - err := s.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data) + err := s.SaveKey( + ec.node.GetReadyPeersIncludeSelf(), + ec.mpcThreshold, + DefaultVersion, + data, + ) if err != nil { logger.Error("Failed to save key", err) } - logger.Info("Saved key", "type", kt, "walletID", walletID, "threshold", ec.mpcThreshold, "version", DefaultVersion, "data", len(data)) + logger.Info( + "[KEY GEN]", + "type", + kt, + "walletID", + walletID, + "threshold", + ec.mpcThreshold, + "version", + DefaultVersion, + "data", + len(data), + ) + if pubKey, err := s.GetPublicKey(data); err == nil { switch kt { case types.KeyTypeSecp256k1: @@ -164,12 +201,11 @@ func (ec *eventConsumer) handleKeyGenerationEvent(ctx context.Context, raw []byt successEvent.EDDSAPubKey = pubKey } } + }) - if err != nil { - logger.Error("Keygen failed", err, "keyType", kt) - } - }(kgSession, keyType) + Elaps(start, string(kt)) + }(s, keyType) } doneCh := make(chan struct{}) @@ -182,8 +218,7 @@ func (ec *eventConsumer) handleKeyGenerationEvent(ctx context.Context, raw []byt case <-ctx.Done(): logger.Warn("Keygen timed out", "walletID", walletID) return ctx.Err() - case <-doneCh: - // All keygens done + case <-doneCh: // All keygens done } successBytes, err := json.Marshal(successEvent) @@ -191,7 +226,10 @@ func (ec *eventConsumer) handleKeyGenerationEvent(ctx context.Context, raw []byt return fmt.Errorf("marshal keygen error: %w", err) } - err = ec.genKeySucecssQueue.Enqueue(event.KeygenSuccessEventTopic, successBytes, &messaging.EnqueueOptions{ + err = ec.genKeySucecssQueue.Enqueue(fmt.Sprintf( + event.TypeGenerateWalletSuccess, + walletID, + ), successBytes, &messaging.EnqueueOptions{ IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), }) if err != nil { @@ -200,6 +238,7 @@ func (ec *eventConsumer) handleKeyGenerationEvent(ctx context.Context, raw []byt logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) return nil + } func (ec *eventConsumer) consumeTxSigningEvent() error { diff --git a/pkg/messaging/message_queue.go b/pkg/messaging/message_queue.go index d5a1c7f..60242df 100644 --- a/pkg/messaging/message_queue.go +++ b/pkg/messaging/message_queue.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "time" "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" @@ -57,7 +58,7 @@ func NewNATsMessageQueueManager(queueName string, subjectWildCards []string, nc Name: queueName, Description: "Stream for " + queueName, Subjects: subjectWildCards, - MaxBytes: 1024, + MaxBytes: 100_000_000, // Light Production (Low Traffic) 100_000_000 (100 MB) Storage: jetstream.FileStorage, Retention: jetstream.WorkQueuePolicy, }) @@ -81,7 +82,10 @@ func (m *NATsMessageQueueManager) NewMessageQueue(consumerName string) MessageQu cfg := jetstream.ConsumerConfig{ Name: consumerName, Durable: consumerName, - MaxAckPending: 4, + MaxAckPending: 1000, + // If a message isn't acked within AckWait, it will be redelivered up to MaxDelive + AckWait: 60 * time.Second, + AckPolicy: jetstream.AckExplicitPolicy, FilterSubjects: []string{ consumerWildCard, }, diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index e78583a..5340778 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -274,7 +274,6 @@ func (n *Node) getECDSAPreParams(isOldParty bool) (*keygen.LocalPreParams, error if err := json.Unmarshal(preparamsBytes, &preparams); err != nil { return nil, err } - logger.Info("Preparams loaded", "isOldParty", isOldParty) return &preparams, nil } diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 9bb4f2f..684126d 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "math/big" + "time" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/logger" @@ -23,6 +24,7 @@ type Party interface { onComplete func([]byte), ) + WalletID() string PartyID() *tss.PartyID PartyIDs() []*tss.PartyID GetSaveData() []byte @@ -50,6 +52,10 @@ func NewParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, th return &party{walletID, threshold, partyID, partyIDs, inCh, outCh, errCh} } +func (p *party) WalletID() string { + return p.walletID +} + func (p *party) PartyID() *tss.PartyID { return p.partyID } @@ -86,11 +92,14 @@ func runParty[T any]( ) { // Start the party in a goroutine to handle errors go func() { - logger.Info("Starting party", "partyID", s.PartyID().String()) + start := time.Now() + logger.Info("[Starting] party", "walletID", s.WalletID()) if err := party.Start(); err != nil { s.ErrCh() <- err return } + elapsed := time.Since(start) + logger.Info("[Closing] party", "walletID", s.WalletID(), "elapsed", elapsed.Milliseconds()) }() // Main message handling loop diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index b79f42f..54cb7dd 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -6,12 +6,14 @@ import ( "errors" "fmt" "math/big" + "time" "github.com/bnb-chain/tss-lib/v2/common" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/logger" "github.com/golang/protobuf/ptypes/any" "google.golang.org/protobuf/proto" ) @@ -56,6 +58,7 @@ func (s *ECDSAParty) ClassifyMsg(msgBytes []byte) (uint8, bool, error) { } _, isBroadcast := ecdsaBroadcastMessages[msg.TypeUrl] + // logger.Info("ClassifyMsg", "typeUrl", msg.TypeUrl, "isBroadcast", isBroadcast) round := ecdsaMsgURL2Round[msg.TypeUrl] if round > 4 { @@ -66,9 +69,18 @@ func (s *ECDSAParty) ClassifyMsg(msgBytes []byte) (uint8, bool, error) { func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { end := make(chan *keygen.LocalPartySaveData, 1) + // Time the initialization of TSS parameters and party + initStart := time.Now() + initElapsed := time.Since(initStart) params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := keygen.NewLocalParty(params, s.outCh, end, s.preParams) + logger.Info("[Starting ECDSA] key generation", "walletID", s.walletID, "initElapsed", initElapsed.Milliseconds()) + + // Time the runParty execution + runStart := time.Now() runParty(s, ctx, party, send, end, finish) + runElapsed := time.Since(runStart) + logger.Info("[Finished ECDSA] key generation run", "walletID", s.walletID, "runElapsed", runElapsed.Milliseconds()) } func (s *ECDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index f90741c..e9bf697 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -6,12 +6,14 @@ import ( "errors" "fmt" "math/big" + "time" "github.com/bnb-chain/tss-lib/v2/common" "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" "github.com/bnb-chain/tss-lib/v2/eddsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/logger" "github.com/golang/protobuf/ptypes/any" "google.golang.org/protobuf/proto" ) @@ -67,9 +69,21 @@ func (s *EDDSAParty) ClassifyMsg(msgBytes []byte) (uint8, bool, error) { func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { end := make(chan *keygen.LocalPartySaveData, 1) + + // Measure time to initialize the party + initStart := time.Now() params := tss.NewParameters(tss.Edwards(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := keygen.NewLocalParty(params, s.outCh, end) + initElapsed := time.Since(initStart) + + logger.Info("[Starting EDDSA] key generation", "walletID", s.walletID, "initElapsed", initElapsed) + + // Measure time to run the party + runStart := time.Now() runParty(s, ctx, party, send, end, finish) + runElapsed := time.Since(runStart) + + logger.Info("[Finished EDDSA] key generation run", "walletID", s.walletID, "runElapsed", runElapsed) } func (s *EDDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 59d7b03..e6c0966 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -40,9 +40,17 @@ type TopicComposer struct { type KeyComposerFn func(id string) string type Session interface { - StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) - StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) - StartResharing(ctx context.Context, oldPartyIDs []*tss.PartyID, newPartyIDs []*tss.PartyID, oldThreshold int, newThreshold int, send func(tss.Message), finish func([]byte)) + StartKeygen(ctx context.Context, send func(tss.Message), onComplete func([]byte)) + StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), onComplete func([]byte)) + StartResharing( + ctx context.Context, + oldPartyIDs []*tss.PartyID, + newPartyIDs []*tss.PartyID, + oldThreshold int, + newThreshold int, + send func(tss.Message), + onComplete func([]byte), + ) GetSaveData(version int) ([]byte, error) GetPublicKey(data []byte) ([]byte, error) @@ -159,11 +167,17 @@ func (s *session) Send(msg tss.Message) { // Listen is a wrapper around the party's Listen method // It subscribes to the broadcast and self direct topics func (s *session) Listen() { + var wg sync.WaitGroup + wg.Add(2) + selfDirectTopic := s.topicComposer.ComposeDirectTopic(getRoutingFromPartyID(s.party.PartyID())) + broadcastTopic := s.topicComposer.ComposeBroadcastTopic() + broadcast := func() { - sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { + defer wg.Done() + sub, err := s.pubSub.Subscribe(broadcastTopic, func(natMsg *nats.Msg) { msg := natMsg.Data - s.receive(msg) + go s.receive(msg) }) if err != nil { @@ -175,8 +189,9 @@ func (s *session) Listen() { } direct := func() { + defer wg.Done() sub, err := s.direct.Listen(selfDirectTopic, func(msg []byte) { - s.receive(msg) + go s.receive(msg) }) if err != nil { @@ -189,6 +204,7 @@ func (s *session) Listen() { go broadcast() go direct() + wg.Wait() } // SaveKey saves the key to the keyinfo store and the kvstore From 05c9053f5bb6f76b10fb193456feccc6427bb09b Mon Sep 17 00:00:00 2001 From: anhthii Date: Thu, 26 Jun 2025 22:45:56 +0700 Subject: [PATCH 31/34] Add tsslimiter queue --- Makefile.local | 11 ++ clean_logs.sh | 16 +++ cmd/mpcium/main.go | 1 + go.mod | 12 +- go.sum | 8 ++ pkg/common/concurrency/utils.go | 49 ++++++++ pkg/eventconsumer/event_consumer.go | 178 +++++++++++++++++----------- pkg/monitoring/recorder.go | 77 ++++++++++++ pkg/mpc/node/node.go | 46 ++++++- pkg/mpc/party/ecdsa.go | 35 +++++- pkg/mpc/party/eddsa.go | 35 +++++- pkg/mpc/session/base.go | 44 +++++++ pkg/mpc/session/ecdsa.go | 17 ++- pkg/mpc/session/eddsa.go | 16 ++- pkg/tsslimiter/queue.go | 91 ++++++++++++++ pkg/tsslimiter/queue_test.go | 119 +++++++++++++++++++ pkg/tsslimiter/tsslimiter.go | 120 +++++++++++++++++++ 17 files changed, 781 insertions(+), 94 deletions(-) create mode 100644 Makefile.local create mode 100755 clean_logs.sh create mode 100644 pkg/common/concurrency/utils.go create mode 100644 pkg/monitoring/recorder.go create mode 100644 pkg/tsslimiter/queue.go create mode 100644 pkg/tsslimiter/queue_test.go create mode 100644 pkg/tsslimiter/tsslimiter.go diff --git a/Makefile.local b/Makefile.local new file mode 100644 index 0000000..f350478 --- /dev/null +++ b/Makefile.local @@ -0,0 +1,11 @@ +.PHONY: clean new + +clean: + @# only kill the window if it exists + @if tmux list-windows -F "#{window_name}" \ + | grep -qw "^mpcium$$"; then \ + tmux kill-window -t mpcium; \ + fi + +new: clean + @tmuxifier load-window mpcium diff --git a/clean_logs.sh b/clean_logs.sh new file mode 100755 index 0000000..96364a3 --- /dev/null +++ b/clean_logs.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Directories to clean under +nodes=("node0" "node1" "node2") + +for dir in "${nodes[@]}"; do + identity_dir="$dir" + echo "Cleaning .txt files in $identity_dir..." + if [ -d "$identity_dir" ]; then + find "$identity_dir" -type f -name "*.txt" -print -delete + else + echo "Directory $identity_dir not found" + fi +done + +echo "βœ… Cleanup complete." diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index ca533c4..cb6ff39 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -156,6 +156,7 @@ func runNode(ctx context.Context, c *cli.Command) error { keyinfoStore, identityStore, peerRegistry, + consulClient.KV(), ) // Preload preparams for the first time mpcNode.PreloadPreParams() diff --git a/go.mod b/go.mod index 04268d9..6132720 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require ( github.com/bnb-chain/tss-lib/v2 v2.0.2 github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 github.com/dgraph-io/badger/v4 v4.2.0 + github.com/golang-queue/queue v0.4.0 + github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/hashicorp/consul/api v1.26.1 github.com/mitchellh/mapstructure v1.5.0 @@ -17,8 +19,11 @@ require ( github.com/rs/zerolog v1.31.0 github.com/samber/lo v1.39.0 github.com/spf13/viper v1.18.0 + github.com/stretchr/testify v1.10.0 github.com/urfave/cli/v3 v3.3.2 + go.uber.org/mock v0.5.2 golang.org/x/term v0.31.0 + google.golang.org/protobuf v1.36.6 ) require ( @@ -29,6 +34,7 @@ require ( github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect github.com/btcsuite/btcutil v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dustin/go-humanize v1.0.0 // indirect @@ -37,7 +43,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/glog v1.2.4 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect github.com/google/flatbuffers v1.12.1 // indirect github.com/google/go-cmp v0.7.0 // indirect @@ -53,6 +58,7 @@ require ( github.com/hashicorp/serf v0.10.1 // indirect github.com/ipfs/go-log v1.0.5 // indirect github.com/ipfs/go-log/v2 v2.1.3 // indirect + github.com/jpillora/backoff v1.0.0 // indirect github.com/klauspost/compress v1.17.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect @@ -64,6 +70,7 @@ require ( github.com/otiai10/primes v0.0.0-20210501021515-f1b2be525a11 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect @@ -71,11 +78,9 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.opencensus.io v0.24.0 // indirect go.uber.org/atomic v1.9.0 // indirect - go.uber.org/goleak v1.3.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.21.0 // indirect golang.org/x/crypto v0.37.0 // indirect @@ -83,7 +88,6 @@ require ( golang.org/x/net v0.39.0 // indirect golang.org/x/sys v0.32.0 // indirect golang.org/x/text v0.24.0 // indirect - google.golang.org/protobuf v1.36.6 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index cdd2f60..2ff79c0 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuy github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/appleboy/com v0.3.0 h1:omze/tJPyi2YVH+m23GSrCGt90A+4vQNpEYBW+GuSr4= +github.com/appleboy/com v0.3.0/go.mod h1:kByEI3/vzI5GM1+O5QdBHLsXaOsmFsJcOpCSgASi4sg= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJA= @@ -107,6 +109,8 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-queue/queue v0.4.0 h1:vsOvW4Wqb7Ow5+tKnlZD0PbLf4MLEO1e5C7DV8BDfBg= +github.com/golang-queue/queue v0.4.0/go.mod h1:bZobuNN7gnumxi9LRGihr7y7quDeBZZAvfPcC+H5dzg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc= github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= @@ -203,6 +207,8 @@ github.com/ipfs/go-log/v2 v2.1.3 h1:1iS3IU7aXRlbgUpN8yTTpJ53NXYjAe37vcI5+5nYrzk= github.com/ipfs/go-log/v2 v2.1.3/go.mod h1:/8d0SH3Su5Ooc31QlL1WysJhvyOTDCjcCZ9Axpmri6g= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -373,6 +379,8 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/pkg/common/concurrency/utils.go b/pkg/common/concurrency/utils.go new file mode 100644 index 0000000..d668869 --- /dev/null +++ b/pkg/common/concurrency/utils.go @@ -0,0 +1,49 @@ +package concurrency + +import ( + "runtime" +) + +// GetVirtualCoreCount returns the number of logical CPUs (virtual cores) available on the system. +// This includes physical cores *and* hyperthreads. +func GetVirtualCoreCount() int { + return runtime.NumCPU() +} + +// GetTSSConcurrencyLimit returns the recommended maximum number of concurrent TSS sessions. +// It estimates the number of *physical* cores by dividing the virtual core count by 2, +// because each physical core typically has 2 logical threads due to hyperthreading. +// +// Threshold signing (e.g., ECDSA) is CPU-bound and does not benefit much from hyperthreads, +// so we limit concurrency based on physical core estimates. +func GetTSSConcurrencyLimit() int { + logicalCores := GetVirtualCoreCount() + + // Estimate physical cores by dividing virtual CPUs by 2 + estimatedPhysicalCores := logicalCores / 2 + if estimatedPhysicalCores < 1 { + estimatedPhysicalCores = 1 // always allow at least one session + } + + return calculateAllowedSessions(estimatedPhysicalCores) +} + +// calculateAllowedSessions maps physical core count to safe TSS concurrency limits. +// You can tune these thresholds depending on your latency and throughput requirements. +func calculateAllowedSessions(coreCount int) int { + switch { + case coreCount <= 2: + return 1 + case coreCount <= 4: + return 2 + case coreCount <= 8: + return 3 + case coreCount <= 12: + return 5 + case coreCount <= 16: + return 6 + default: + // For large systems, reserve some headroom for OS, logs, GC, etc. + return coreCount / 2 + } +} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index ef34362..b9d4f18 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -9,12 +9,14 @@ import ( "sync" "time" + "github.com/fystack/mpcium/pkg/common/concurrency" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/monitoring" "github.com/fystack/mpcium/pkg/mpc/node" - "github.com/fystack/mpcium/pkg/mpc/session" + "github.com/fystack/mpcium/pkg/tsslimiter" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -27,8 +29,10 @@ const ( // Default version for keygen DefaultVersion int = 1 - SessionTimeout = 3 * time.Minute + SessionTimeout = 1 * time.Minute MaxConcurrentSessions = 5 + // how long the entire handler will wait for *all* sessions + publishing: + HandlerTimeout = 2 * time.Minute ) type EventConsumer interface { @@ -46,7 +50,7 @@ type eventConsumer struct { pubsub messaging.PubSub mpcThreshold int - genKeySucecssQueue messaging.MessageQueue + genKeySuccessQueue messaging.MessageQueue signingResultQueue messaging.MessageQueue resharingResultQueue messaging.MessageQueue @@ -54,7 +58,6 @@ type eventConsumer struct { signingSub messaging.Subscription resharingSub messaging.Subscription identityStore identity.Store - sessionLimiter chan struct{} // acts as a pool of session tokens // Track active sessions with timestamps for cleanup activeSessions map[string]time.Time // Maps "walletID-txID" to creation time @@ -62,20 +65,24 @@ type eventConsumer struct { cleanupInterval time.Duration // How often to run cleanup sessionTimeout time.Duration // How long before a session is considered stale cleanupStopChan chan struct{} // Signal to stop cleanup goroutine + limiterQueue tsslimiter.Queue } func NewEventConsumer( node *node.Node, pubsub messaging.PubSub, - genKeySucecssQueue messaging.MessageQueue, + genKeySuccessQueue messaging.MessageQueue, signingResultQueue messaging.MessageQueue, resharingResultQueue messaging.MessageQueue, identityStore identity.Store, ) EventConsumer { + limiter := tsslimiter.NewWeightedLimiter(concurrency.GetTSSConcurrencyLimit()) + limiterQueue := tsslimiter.NewWeightedQueue(limiter, 100) + ec := &eventConsumer{ node: node, pubsub: pubsub, - genKeySucecssQueue: genKeySucecssQueue, + genKeySuccessQueue: genKeySuccessQueue, signingResultQueue: signingResultQueue, resharingResultQueue: resharingResultQueue, activeSessions: make(map[string]time.Time), @@ -84,6 +91,7 @@ func NewEventConsumer( cleanupStopChan: make(chan struct{}), mpcThreshold: viper.GetInt("mpc_threshold"), identityStore: identityStore, + limiterQueue: limiterQueue, } // Start background cleanup goroutine @@ -112,24 +120,18 @@ func (ec *eventConsumer) Run() { } func (ec *eventConsumer) consumeKeyGenerationEvent() error { // Create session limiter channel with capacity 5 - ec.sessionLimiter = make(chan struct{}, MaxConcurrentSessions) sub, err := ec.pubsub.Subscribe(MPCGenerateEvent, func(natMsg *nats.Msg) { logger.Info("Received key generation event", "subject", natMsg.Subject) // This blocks if max sessions are already running - ec.sessionLimiter <- struct{}{} - + // go func(data []byte) { go func(data []byte) { - defer func() { - <-ec.sessionLimiter // release slot - }() - - ctx, cancel := context.WithTimeout(context.Background(), SessionTimeout) - defer cancel() + // Ack the message immediately to prevent redelivery from JetStream. This is critical. - if err := ec.handleKeyGenerationEvent(ctx, data); err != nil { + if err := ec.handleKeyGenerationEvent(context.Background(), data); err != nil { logger.Error("Failed to handle key generation event", err) } }(natMsg.Data) + // }(natMsg.Data) }) if err != nil { @@ -140,105 +142,129 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { return nil } -func (ec *eventConsumer) handleKeyGenerationEvent(ctx context.Context, raw []byte) error { +func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, raw []byte) error { + // 1) decode and verify var msg types.GenerateKeyMessage if err := json.Unmarshal(raw, &msg); err != nil { return fmt.Errorf("unmarshal message: %w", err) } - logger.Info("Received key generation event", "msg", msg) - if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { return fmt.Errorf("verify initiator: %w", err) } walletID := msg.WalletID successEvent := &event.KeygenSuccessEvent{WalletID: walletID} + + // 2) give this handler its own timeout + handlerCtx, handlerCancel := context.WithTimeout(parentCtx, HandlerTimeout) + defer handlerCancel() + + // wait for the sessions to return (even if they timed out) var wg sync.WaitGroup + // wait for *both* callbacks to fire before publishing + var cbWg sync.WaitGroup + cbWg.Add(2) - // Start ECDSA and EDDSA sessions + var eventMutex sync.Mutex + + // 3) enqueue ECDSA & EDDSA jobs for _, keyType := range []types.KeyType{types.KeyTypeSecp256k1, types.KeyTypeEd25519} { - s, err := ec.node.CreateKeygenSession(keyType, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + keyType := keyType + + s, err := ec.node.CreateKeygenSession(keyType, walletID, ec.mpcThreshold, ec.genKeySuccessQueue) if err != nil { return fmt.Errorf("create %v session: %w", keyType, err) } - start := time.Now() s.Listen() - Elaps(start, "Listen") + wg.Add(1) - go func(s session.Session, kt types.KeyType) { + run := func() { defer wg.Done() defer s.Close() - start := time.Now() - s.StartKeygen(ctx, s.Send, func(data []byte) { - err := s.SaveKey( + + // give each session its own shorter timeout + sessionCtx, sessionCancel := context.WithTimeout(handlerCtx, SessionTimeout) + defer sessionCancel() + + s.StartKeygen(sessionCtx, s.Send, func(data []byte) { + defer cbWg.Done() // signal that this keyType actually called back + + logger.Info("[callback] StartKeygen fired", "walletID", walletID, "keyType", keyType) + + // save the share + if err := s.SaveKey( ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data, - ) - if err != nil { - logger.Error("Failed to save key", err) + ); err != nil { + logger.Error("Failed to save key", err, "walletID", walletID, "keyType", keyType) } - logger.Info( - "[KEY GEN]", - "type", - kt, - "walletID", - walletID, - "threshold", - ec.mpcThreshold, - "version", - DefaultVersion, - "data", - len(data), - ) + // extract & record the pubkey if pubKey, err := s.GetPublicKey(data); err == nil { - switch kt { + eventMutex.Lock() + switch keyType { case types.KeyTypeSecp256k1: successEvent.ECDSAPubKey = pubKey case types.KeyTypeEd25519: successEvent.EDDSAPubKey = pubKey } + eventMutex.Unlock() + } else { + logger.Error("Failed to get public key", err, "walletID", walletID, "keyType", keyType) } - }) + } - Elaps(start, string(kt)) - }(s, keyType) + var sessionType tsslimiter.SessionType + if keyType == types.KeyTypeSecp256k1 { + sessionType = tsslimiter.SessionKeygenECDSA + } else { + sessionType = tsslimiter.SessionKeygenEDDSA + } + + ec.limiterQueue.Enqueue(tsslimiter.SessionJob{ + Type: sessionType, + Run: run, + }) } - doneCh := make(chan struct{}) + // 4) wait for both session goroutines to return + wg.Wait() + + // 5) now wait for both callbacks (or handler timeout) + doneCb := make(chan struct{}) go func() { - wg.Wait() - close(doneCh) + cbWg.Wait() + close(doneCb) }() select { - case <-ctx.Done(): - logger.Warn("Keygen timed out", "walletID", walletID) - return ctx.Err() - case <-doneCh: // All keygens done + case <-handlerCtx.Done(): + logger.Warn("Keygen callbacks did not all fire before timeout", "walletID", walletID) + return handlerCtx.Err() + case <-doneCb: + // both callbacks have run } + // 6) marshal & publish success successBytes, err := json.Marshal(successEvent) if err != nil { - return fmt.Errorf("marshal keygen error: %w", err) + return fmt.Errorf("marshal success event: %w", err) } - - err = ec.genKeySucecssQueue.Enqueue(fmt.Sprintf( - event.TypeGenerateWalletSuccess, - walletID, - ), successBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), - }) - if err != nil { - return fmt.Errorf("enqueue keygen error: %w", err) + if err := ec.genKeySuccessQueue.Enqueue( + fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), + successBytes, + &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), + }, + ); err != nil { + return fmt.Errorf("enqueue success event: %w", err) } logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) return nil - } func (ec *eventConsumer) consumeTxSigningEvent() error { @@ -586,14 +612,24 @@ func (ec *eventConsumer) Close() error { // Signal cleanup routine to stop close(ec.cleanupStopChan) - err := ec.keyGenerationSub.Unsubscribe() - if err != nil { - return err + if ec.keyGenerationSub != nil { + if err := ec.keyGenerationSub.Unsubscribe(); err != nil { + return err + } } - err = ec.signingSub.Unsubscribe() - if err != nil { - return err + if ec.signingSub != nil { + if err := ec.signingSub.Unsubscribe(); err != nil { + return err + } } + if ec.resharingSub != nil { + if err := ec.resharingSub.Unsubscribe(); err != nil { + return err + } + } + + // Ensure all monitoring logs are written to disk before exiting. + monitoring.Close() return nil } diff --git a/pkg/monitoring/recorder.go b/pkg/monitoring/recorder.go new file mode 100644 index 0000000..4e6786f --- /dev/null +++ b/pkg/monitoring/recorder.go @@ -0,0 +1,77 @@ +package monitoring + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" + + "github.com/fystack/mpcium/pkg/logger" +) + +// KeygenTimestamps holds the structured data for a single key generation event. +type KeygenTimestamps struct { + WalletID string `json:"wallet_id"` + NodeID string `json:"node_id"` + KeyType string `json:"key_type"` + StartTime time.Time `json:"start_time"` + CompletionTime time.Time `json:"completion_time"` + InitDurationMs int64 `json:"init_duration_ms"` + RunDurationMs int64 `json:"run_duration_ms"` +} + +var ( + logFile *os.File + logOnce sync.Once + logMux sync.Mutex +) + +// initLogFile initializes the log file for appending. It ensures this only happens once. +func initLogFile() { + logOnce.Do(func() { + logDir := "monitoring" + if err := os.MkdirAll(logDir, 0755); err != nil { + logger.Error("Failed to create monitoring directory", err) + return + } + + var err error + logFile, err = os.OpenFile(filepath.Join(logDir, "keygen_times.jsonl"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + logger.Error("Failed to open keygen log file", err) + } + }) +} + +// RecordKeygenCompletion marshals the timestamp data to JSON and writes it to the log file. +func RecordKeygenCompletion(data KeygenTimestamps) { + initLogFile() + if logFile == nil { + return // Initialization failed + } + + logMux.Lock() + defer logMux.Unlock() + + line, err := json.Marshal(data) + if err != nil { + logger.Error("Failed to marshal keygen timestamp data", err) + return + } + + if _, err := logFile.Write(append(line, '\n')); err != nil { + logger.Error("Failed to write to keygen log file", err) + } +} + +// Close ensures the log file is synced and closed gracefully. +func Close() { + logMux.Lock() + defer logMux.Unlock() + + if logFile != nil { + logFile.Sync() + logFile.Close() + } +} diff --git a/pkg/mpc/node/node.go b/pkg/mpc/node/node.go index 5340778..6917a7f 100644 --- a/pkg/mpc/node/node.go +++ b/pkg/mpc/node/node.go @@ -10,6 +10,7 @@ import ( "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/infra" "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" @@ -33,9 +34,20 @@ type Node struct { identityStore identity.Store peerRegistry *registry + consulKV infra.ConsulKV } -func NewNode(nodeID string, peerIDs []string, pubSub messaging.PubSub, direct messaging.DirectMessaging, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, identityStore identity.Store, peerRegistry *registry) *Node { +func NewNode( + nodeID string, + peerIDs []string, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + identityStore identity.Store, + peerRegistry *registry, + consulKV infra.ConsulKV, +) *Node { go peerRegistry.WatchPeersReady() return &Node{ @@ -47,6 +59,7 @@ func NewNode(nodeID string, peerIDs []string, pubSub messaging.PubSub, direct me keyinfoStore: keyinfoStore, identityStore: identityStore, peerRegistry: peerRegistry, + consulKV: consulKV, } } @@ -78,6 +91,7 @@ func (n *Node) CreateKeygenSession(keyType types.KeyType, walletID string, thres n.identityStore, n.kvstore, n.keyinfoStore, + n.consulKV, ) return ecdsaSession, nil @@ -92,6 +106,7 @@ func (n *Node) CreateKeygenSession(keyType types.KeyType, walletID string, thres n.identityStore, n.kvstore, n.keyinfoStore, + n.consulKV, ) return eddsaSession, nil default: @@ -119,6 +134,7 @@ func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID n.identityStore, n.kvstore, n.keyinfoStore, + n.consulKV, ) saveData, err := ecdsaSession.GetSaveData(partyVersion) if err != nil { @@ -139,6 +155,7 @@ func (n *Node) CreateSigningSession(keyType types.KeyType, walletID string, txID n.identityStore, n.kvstore, n.keyinfoStore, + n.consulKV, ) saveData, err := eddsaSession.GetSaveData(partyVersion) @@ -173,7 +190,19 @@ func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, wa if err != nil { return nil, fmt.Errorf("failed to get preparams: %w", err) } - ecdsaSession := session.NewECDSASession(walletID, selfPartyID, partyIDs, threshold, *preparams, n.pubSub, n.direct, n.identityStore, n.kvstore, n.keyinfoStore) + ecdsaSession := session.NewECDSASession( + walletID, + selfPartyID, + partyIDs, + threshold, + *preparams, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + n.consulKV, + ) if isOldParty { saveData, err := ecdsaSession.GetSaveData(partyVersion) if err != nil { @@ -193,7 +222,18 @@ func (n *Node) CreateResharingSession(isOldParty bool, keyType types.KeyType, wa } return ecdsaSession, nil case types.KeyTypeEd25519: - eddsaSession := session.NewEDDSASession(walletID, selfPartyID, partyIDs, threshold, n.pubSub, n.direct, n.identityStore, n.kvstore, n.keyinfoStore) + eddsaSession := session.NewEDDSASession( + walletID, + selfPartyID, + partyIDs, + threshold, + n.pubSub, + n.direct, + n.identityStore, + n.kvstore, + n.keyinfoStore, + n.consulKV, + ) saveData, err := eddsaSession.GetSaveData(partyVersion) if err != nil { return nil, fmt.Errorf("failed to get save data: %w", err) diff --git a/pkg/mpc/party/ecdsa.go b/pkg/mpc/party/ecdsa.go index 54cb7dd..6f2881c 100644 --- a/pkg/mpc/party/ecdsa.go +++ b/pkg/mpc/party/ecdsa.go @@ -14,14 +14,17 @@ import ( "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/monitoring" "github.com/golang/protobuf/ptypes/any" "google.golang.org/protobuf/proto" ) type ECDSAParty struct { party - preParams keygen.LocalPreParams - saveData *keygen.LocalPartySaveData + preParams keygen.LocalPreParams + saveData *keygen.LocalPartySaveData + KeygenStart time.Time + KeygenCompletion time.Time } func NewECDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, @@ -70,17 +73,37 @@ func (s *ECDSAParty) ClassifyMsg(msgBytes []byte) (uint8, bool, error) { func (s *ECDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), finish func([]byte)) { end := make(chan *keygen.LocalPartySaveData, 1) // Time the initialization of TSS parameters and party - initStart := time.Now() - initElapsed := time.Since(initStart) + s.KeygenStart = time.Now() params := tss.NewParameters(tss.S256(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := keygen.NewLocalParty(params, s.outCh, end, s.preParams) - logger.Info("[Starting ECDSA] key generation", "walletID", s.walletID, "initElapsed", initElapsed.Milliseconds()) + initElapsed := time.Since(s.KeygenStart) + logger.Info("[Starting ECDSA] key generation", + "walletID", s.walletID, + "initElapsed", initElapsed.Milliseconds(), + "startTime", s.KeygenStart.Format(time.RFC3339), + ) // Time the runParty execution runStart := time.Now() runParty(s, ctx, party, send, end, finish) + s.KeygenCompletion = time.Now() runElapsed := time.Since(runStart) - logger.Info("[Finished ECDSA] key generation run", "walletID", s.walletID, "runElapsed", runElapsed.Milliseconds()) + logger.Info("[Finished ECDSA] key generation run", + "walletID", s.walletID, + "runElapsed", runElapsed.Milliseconds(), + "completionTime", s.KeygenCompletion.Format(time.RFC3339), + ) + + // Record the completion event + monitoring.RecordKeygenCompletion(monitoring.KeygenTimestamps{ + WalletID: s.walletID, + NodeID: s.partyID.Id, + KeyType: "ECDSA", + StartTime: s.KeygenStart, + CompletionTime: s.KeygenCompletion, + InitDurationMs: initElapsed.Milliseconds(), + RunDurationMs: runElapsed.Milliseconds(), + }) } func (s *ECDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { diff --git a/pkg/mpc/party/eddsa.go b/pkg/mpc/party/eddsa.go index e9bf697..5d086b1 100644 --- a/pkg/mpc/party/eddsa.go +++ b/pkg/mpc/party/eddsa.go @@ -14,14 +14,17 @@ import ( "github.com/bnb-chain/tss-lib/v2/eddsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/monitoring" "github.com/golang/protobuf/ptypes/any" "google.golang.org/protobuf/proto" ) type EDDSAParty struct { party - reshareParams *tss.ReSharingParameters - saveData *keygen.LocalPartySaveData + reshareParams *tss.ReSharingParameters + saveData *keygen.LocalPartySaveData + KeygenStart time.Time + KeygenCompletion time.Time } func NewEDDSAParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, @@ -71,19 +74,39 @@ func (s *EDDSAParty) StartKeygen(ctx context.Context, send func(tss.Message), fi end := make(chan *keygen.LocalPartySaveData, 1) // Measure time to initialize the party - initStart := time.Now() + s.KeygenStart = time.Now() params := tss.NewParameters(tss.Edwards(), tss.NewPeerContext(s.partyIDs), s.partyID, len(s.partyIDs), s.threshold) party := keygen.NewLocalParty(params, s.outCh, end) - initElapsed := time.Since(initStart) + initElapsed := time.Since(s.KeygenStart) - logger.Info("[Starting EDDSA] key generation", "walletID", s.walletID, "initElapsed", initElapsed) + logger.Info("[Starting EDDSA] key generation", + "walletID", s.walletID, + "initElapsed", initElapsed.Milliseconds(), + "startTime", s.KeygenStart.Format(time.RFC3339), + ) // Measure time to run the party runStart := time.Now() runParty(s, ctx, party, send, end, finish) + s.KeygenCompletion = time.Now() runElapsed := time.Since(runStart) - logger.Info("[Finished EDDSA] key generation run", "walletID", s.walletID, "runElapsed", runElapsed) + logger.Info("[Finished EDDSA] key generation run", + "walletID", s.walletID, + "runElapsed", runElapsed.Milliseconds(), + "completionTime", s.KeygenCompletion.Format(time.RFC3339), + ) + + // Record the completion event + monitoring.RecordKeygenCompletion(monitoring.KeygenTimestamps{ + WalletID: s.walletID, + NodeID: s.partyID.Id, + KeyType: "EDDSA", + StartTime: s.KeygenStart, + CompletionTime: s.KeygenCompletion, + InitDurationMs: initElapsed.Milliseconds(), + RunDurationMs: runElapsed.Milliseconds(), + }) } func (s *EDDSAParty) StartSigning(ctx context.Context, msg *big.Int, send func(tss.Message), finish func([]byte)) { diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index e6c0966..d6f5c8e 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -6,16 +6,19 @@ import ( "math/big" "slices" "sync" + "time" "github.com/bnb-chain/tss-lib/v2/common" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/infra" "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/party" "github.com/fystack/mpcium/pkg/types" + "github.com/hashicorp/consul/api" "github.com/nats-io/nats.go" ) @@ -60,6 +63,7 @@ type Session interface { Send(msg tss.Message) Listen() SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) + WaitForReady(ctx context.Context, sessionID string) error ErrCh() chan error Close() } @@ -79,6 +83,7 @@ type session struct { topicComposer *TopicComposer composeKey KeyComposerFn + consulKV infra.ConsulKV mu sync.Mutex errCh chan error @@ -92,6 +97,7 @@ func NewSession( identityStore identity.Store, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, + consulKV infra.ConsulKV, ) *session { errCh := make(chan error, 1000) return &session{ @@ -102,6 +108,7 @@ func NewSession( kvstore: kvstore, keyinfoStore: keyinfoStore, errCh: errCh, + consulKV: consulKV, } } @@ -113,6 +120,43 @@ func (s *session) ErrCh() chan error { return s.errCh } +func (s *session) WaitForReady(ctx context.Context, sessionID string) error { + // build our Consul prefix + prefix := fmt.Sprintf("tss-ready/%s/%s/", s.walletID, sessionID) + + // 1) publish our ready flag + myKey := prefix + s.party.PartyID().String() + if _, err := s.consulKV.Put(&api.KVPair{ + Key: myKey, + Value: []byte("true"), + }, nil); err != nil { + return fmt.Errorf("failed to write ready flag: %w", err) + } + + // 2) poll until we see everyone + total := len(s.party.PartyIDs()) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + pairs, _, err := s.consulKV.List(prefix, nil) + if err != nil { + logger.Error("error listing readiness keys", err) + continue + } + if len(pairs) >= total { + logger.Info("[READY] peers ready", "have", len(pairs), "need", total, "walletID", s.walletID) + return nil + } + logger.Info("[READY] Waiting for peers ready", "wallet", s.walletID, "have", len(pairs), "need", total) + } + } +} + // Send is a wrapper around the party's Send method // It signs the message and sends it to the remote party func (s *session) Send(msg tss.Message) { diff --git a/pkg/mpc/session/ecdsa.go b/pkg/mpc/session/ecdsa.go index f883bf7..a03882e 100644 --- a/pkg/mpc/session/ecdsa.go +++ b/pkg/mpc/session/ecdsa.go @@ -13,6 +13,7 @@ import ( "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/encoding" "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/infra" "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/messaging" @@ -23,8 +24,20 @@ type ECDSASession struct { *session } -func NewECDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, preParams keygen.LocalPreParams, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store) *ECDSASession { - s := NewSession(PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) +func NewECDSASession( + walletID string, + partyID *tss.PartyID, + partyIDs []*tss.PartyID, + threshold int, + preParams keygen.LocalPreParams, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + identityStore identity.Store, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + consulKV infra.ConsulKV, +) *ECDSASession { + s := NewSession(PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore, consulKV) s.party = party.NewECDSAParty(walletID, partyID, partyIDs, threshold, preParams, s.errCh) s.topicComposer = &TopicComposer{ ComposeBroadcastTopic: func() string { diff --git a/pkg/mpc/session/eddsa.go b/pkg/mpc/session/eddsa.go index 4cbf65c..65a162d 100644 --- a/pkg/mpc/session/eddsa.go +++ b/pkg/mpc/session/eddsa.go @@ -12,6 +12,7 @@ import ( "github.com/bnb-chain/tss-lib/v2/tss" "github.com/decred/dcrd/dcrec/edwards/v2" "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/infra" "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/messaging" @@ -22,8 +23,19 @@ type EDDSASession struct { *session } -func NewEDDSASession(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, pubSub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store) *EDDSASession { - s := NewSession(PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore) +func NewEDDSASession( + walletID string, + partyID *tss.PartyID, + partyIDs []*tss.PartyID, + threshold int, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + identityStore identity.Store, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + consulKV infra.ConsulKV, +) *EDDSASession { + s := NewSession(PurposeKeygen, walletID, pubSub, direct, identityStore, kvstore, keyinfoStore, consulKV) s.party = party.NewEDDSAParty(walletID, partyID, partyIDs, threshold, nil, nil, s.errCh) s.topicComposer = &TopicComposer{ ComposeBroadcastTopic: func() string { diff --git a/pkg/tsslimiter/queue.go b/pkg/tsslimiter/queue.go new file mode 100644 index 0000000..c170853 --- /dev/null +++ b/pkg/tsslimiter/queue.go @@ -0,0 +1,91 @@ +package tsslimiter + +import ( + "sync" + + "github.com/fystack/mpcium/pkg/logger" +) + +// SessionJob represents a queued job with type and execution logic +type SessionJob struct { + Type SessionType + Run func() +} + +// Queue defines the interface for a job queue that manages TSS session jobs. +type Queue interface { + // Enqueue adds a new session job to the queue for processing. + Enqueue(job SessionJob) + + // Stop gracefully shuts down the queue and waits for background workers to finish. + Stop() +} + +// WeightedQueue buffers and processes session jobs using the WeightedLimiter +type WeightedQueue struct { + queue chan SessionJob + limiter *WeightedLimiter + stopChan chan struct{} + wg sync.WaitGroup +} + +// NewWeightedQueue initializes a buffered job queue +func NewWeightedQueue(limiter *WeightedLimiter, bufferSize int) *WeightedQueue { + q := &WeightedQueue{ + queue: make(chan SessionJob, bufferSize), + limiter: limiter, + stopChan: make(chan struct{}), + } + + // Start the background worker to process queue + q.wg.Add(1) + go q.run() + return q +} + +// Enqueue adds a job to the queue +func (q *WeightedQueue) Enqueue(job SessionJob) { + q.queue <- job +} + +// run continuously processes jobs based on limiter capacity, logging counters +func (q *WeightedQueue) run() { + defer q.wg.Done() + + for { + select { + case job := <-q.queue: + // Log queue length and limiter state before acquire + usedBefore, max := q.limiter.Stats() + logger.Info("Before Acquire", "usedPoints", usedBefore, "maxPoints", max, "pendingJobs", len(q.queue)) + + // Block until we can acquire budget + q.limiter.Acquire(job.Type) + + // Log limiter state after acquire + usedAfter, _ := q.limiter.Stats() + logger.Info("After Acquire", "usedPoints", usedAfter, "jobType", job.Type) + + // Launch job + q.wg.Add(1) + go func(j SessionJob) { + defer q.wg.Done() + defer q.limiter.Release(j.Type) + + usedExec, _ := q.limiter.Stats() + logger.Info("Executing Job", "usedPoints", usedExec, "jobType", j.Type) + j.Run() + logger.Info("Pending Jobs", "num", len(q.queue)) + }(job) + + case <-q.stopChan: + return + } + } +} + +// Stop shuts down the queue processing loop and waits for running jobs +func (q *WeightedQueue) Stop() { + close(q.stopChan) + q.wg.Wait() +} diff --git a/pkg/tsslimiter/queue_test.go b/pkg/tsslimiter/queue_test.go new file mode 100644 index 0000000..77ed427 --- /dev/null +++ b/pkg/tsslimiter/queue_test.go @@ -0,0 +1,119 @@ +package tsslimiter_test + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/fystack/mpcium/pkg/tsslimiter" + "github.com/stretchr/testify/assert" +) + +func TestWeightedQueue_SingleJobExecution(t *testing.T) { + limiter := tsslimiter.NewWeightedLimiter(2) // 2 cores = 200 points + queue := tsslimiter.NewWeightedQueue(limiter, 10) + defer queue.Stop() + + var executed int32 = 0 + + job := tsslimiter.SessionJob{ + Type: tsslimiter.SessionSignECDSA, + Run: func() { + atomic.AddInt32(&executed, 1) + }, + } + + queue.Enqueue(job) + time.Sleep(200 * time.Millisecond) // Give time to process + + assert.Equal(t, int32(1), executed, "Expected job to execute") +} + +func TestWeightedQueue_RespectsConcurrency(t *testing.T) { + limiter := tsslimiter.NewWeightedLimiter(1) // 1 core = 100 points + queue := tsslimiter.NewWeightedQueue(limiter, 10) + defer queue.Stop() + + var executing int32 = 0 + var completed int32 = 0 + + // 3 jobs each costing 100 (keygen) β†’ only 1 should run at a time + for i := 0; i < 3; i++ { + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionKeygenECDSA, + Run: func() { + current := atomic.AddInt32(&executing, 1) + assert.LessOrEqual(t, current, int32(1), "Too many concurrent jobs running") + time.Sleep(100 * time.Millisecond) + atomic.AddInt32(&executing, -1) + atomic.AddInt32(&completed, 1) + }, + }) + } + + time.Sleep(500 * time.Millisecond) + assert.Equal(t, int32(3), completed, "All jobs should complete sequentially") +} + +func TestWeightedQueue_MixedSessions(t *testing.T) { + limiter := tsslimiter.NewWeightedLimiter(2) // 2 cores = 200 points + queue := tsslimiter.NewWeightedQueue(limiter, 10) + defer queue.Stop() + + var completed int32 = 0 + + // Sign (40) + Keygen (100) + Sign (40) = 180 total, fits under 200 + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionSignECDSA, + Run: func() { + time.Sleep(50 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }, + }) + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionKeygenECDSA, + Run: func() { + time.Sleep(50 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }, + }) + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionSignECDSA, + Run: func() { + time.Sleep(50 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }, + }) + + time.Sleep(300 * time.Millisecond) + assert.Equal(t, int32(3), completed, "All mixed jobs should run within capacity") +} + +func TestWeightedQueue_BackpressureBuffering(t *testing.T) { + limiter := tsslimiter.NewWeightedLimiter(1) // 1 core = 100 + queue := tsslimiter.NewWeightedQueue(limiter, 10) + defer queue.Stop() + + var completed int32 = 0 + + // First job blocks the CPU + queue.Enqueue(tsslimiter.SessionJob{ + Type: tsslimiter.SessionKeygenECDSA, + Run: func() { + time.Sleep(150 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }, + }) + + // Second job should wait in the queue + queue.Enqueue(tsslimiter.SessionJob{ + + Type: tsslimiter.SessionSignECDSA, + Run: func() { + atomic.AddInt32(&completed, 1) + }, + }) + + time.Sleep(400 * time.Millisecond) + assert.Equal(t, int32(2), completed, "Both jobs should run in sequence due to backpressure") +} diff --git a/pkg/tsslimiter/tsslimiter.go b/pkg/tsslimiter/tsslimiter.go new file mode 100644 index 0000000..e868f18 --- /dev/null +++ b/pkg/tsslimiter/tsslimiter.go @@ -0,0 +1,120 @@ +package tsslimiter + +import ( + "sync" + "time" + + "github.com/fystack/mpcium/pkg/logger" +) + +type SessionType int + +const ( + SessionKeygenECDSA SessionType = iota + SessionReshareECDSA + SessionSignECDSA + SessionKeygenEDDSA + SessionReshareEDDSA + SessionSignEDDSA + SessionKeygenCombined +) + +// sessionCosts defines the estimated CPU cost (in points) of each session type. +// The values are based on practical benchmarks using tss-lib (ECDSA over secp256k1), +// where 100 points = 100% of a physical CPU core. +// +// These costs allow us to model CPU pressure and prevent overload by setting +// a total max budget equal to the number of physical cores Γ— 100 points. +// +// For example, on a 4-core CPU: +// +// - maxPoints = 400 +// +// - You could run 1 keygen (100) + 10 sign sessions (30 Γ— 10 = 300) +// +// - Or 4 resharing sessions (80 Γ— 4 = 320) + 2 sign sessions (30 Γ— 2 = 60) +// +// Note: These values are conservative to maintain low latency and avoid timeouts. +var sessionCosts = map[SessionType]int{ + SessionKeygenECDSA: 100, // Full core + SessionReshareECDSA: 70, + SessionSignECDSA: 40, + SessionKeygenEDDSA: 25, // ~25% of core + SessionReshareEDDSA: 20, + SessionSignEDDSA: 15, + SessionKeygenCombined: 125, // ECDSA (100) + EDDSA (25) +} + +type Limiter interface { + // TryAcquire attempts to acquire resources for the given session type. + // Returns true if successful, false otherwise. + TryAcquire(t SessionType) bool + + // Acquire blocks until it successfully acquires resources for the session type. + Acquire(t SessionType) + + // Release frees the resources for the given session type. + Release(t SessionType) + Stats() (int, int) +} + +type WeightedLimiter struct { + mu sync.Mutex + usedPoints int + maxPoints int +} + +// NewWeightedLimiter creates a limiter with maxPoints = maxSessionsAllowed * 100 +func NewWeightedLimiter(maxSessions int) *WeightedLimiter { + return &WeightedLimiter{ + maxPoints: maxSessions * 100, + } +} + +func (l *WeightedLimiter) TryAcquire(t SessionType) bool { + l.mu.Lock() + defer l.mu.Unlock() + + logger.Info("TryAcquire....", "sessionType", t, "usedPoints", l.usedPoints, "maxPoints", l.maxPoints) + cost := sessionCosts[t] + if l.usedPoints+cost > l.maxPoints { + return false + } + + logger.Info("DOneACQUIRE") + l.usedPoints += cost + return true +} + +func (l *WeightedLimiter) Acquire(t SessionType) { + cost := sessionCosts[t] + + for { + l.mu.Lock() + if l.usedPoints+cost <= l.maxPoints { + l.usedPoints += cost + l.mu.Unlock() + return + } + l.mu.Unlock() + time.Sleep(50 * time.Millisecond) // backoff + } +} + +func (l *WeightedLimiter) Release(t SessionType) { + l.mu.Lock() + defer l.mu.Unlock() + + cost := sessionCosts[t] + l.usedPoints -= cost + if l.usedPoints < 0 { + l.usedPoints = 0 + } + logger.Info("Release", "sessionType", t, "usedPoints", l.usedPoints, "maxPoints", l.maxPoints) +} + +func (l *WeightedLimiter) Stats() (int, int) { + l.mu.Lock() + defer l.mu.Unlock() + return l.usedPoints, l.maxPoints +} From 8d79d1c7ae8244622bde9f5afb68102544bdbe51 Mon Sep 17 00:00:00 2001 From: Poseidon-G Date: Tue, 24 Jun 2025 20:01:37 +0700 Subject: [PATCH 32/34] Update gen request using jetstream --- cmd/mpcium/main.go | 8 ++ go.mod | 10 +- go.sum | 20 ++-- pkg/client/client.go | 12 ++- pkg/eventconsumer/keygen_consumer.go | 143 +++++++++++++++++++++++++++ pkg/messaging/message_queue.go | 70 +++++++++++++ pkg/messaging/pubsub.go | 2 +- 7 files changed, 242 insertions(+), 23 deletions(-) create mode 100644 pkg/eventconsumer/keygen_consumer.go diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index cb6ff39..14c805b 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -135,6 +135,8 @@ func runNode(ctx context.Context, c *cli.Command) error { "mpc.mpc_resharing_success.*", }, natsConn) + genkeyRequestQueue := mqManager.NewMessagePullSubscriber("mpc_keygen_request") + defer genkeyRequestQueue.Close() genKeySuccessQueue := mqManager.NewMessageQueue("mpc_keygen_success") defer genKeySuccessQueue.Close() singingResultQueue := mqManager.NewMessageQueue("signing_result") @@ -181,6 +183,7 @@ func runNode(ctx context.Context, c *cli.Command) error { timeoutConsumer.Run() defer timeoutConsumer.Close() signingConsumer := eventconsumer.NewSigningConsumer(natsConn, signingStream, pubsub) + keygenConsumer := eventconsumer.NewKeygenConsumer(natsConn, genkeyRequestQueue, pubsub) // Make the node ready before starting the signing consumer peerRegistry.Ready() @@ -195,6 +198,11 @@ func runNode(ctx context.Context, c *cli.Command) error { cancel() }() + fmt.Print("Run keygen consumer") + if err := keygenConsumer.Run(appContext); err != nil { + logger.Error("error running keygen consumer:", err) + } + if err := signingConsumer.Run(appContext); err != nil { logger.Error("error running consumer:", err) } diff --git a/go.mod b/go.mod index 6132720..60cc498 100644 --- a/go.mod +++ b/go.mod @@ -10,18 +10,16 @@ require ( github.com/bnb-chain/tss-lib/v2 v2.0.2 github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 github.com/dgraph-io/badger/v4 v4.2.0 - github.com/golang-queue/queue v0.4.0 github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/hashicorp/consul/api v1.26.1 github.com/mitchellh/mapstructure v1.5.0 - github.com/nats-io/nats.go v1.31.0 + github.com/nats-io/nats.go v1.43.0 github.com/rs/zerolog v1.31.0 github.com/samber/lo v1.39.0 github.com/spf13/viper v1.18.0 github.com/stretchr/testify v1.10.0 github.com/urfave/cli/v3 v3.3.2 - go.uber.org/mock v0.5.2 golang.org/x/term v0.31.0 google.golang.org/protobuf v1.36.6 ) @@ -58,13 +56,12 @@ require ( github.com/hashicorp/serf v0.10.1 // indirect github.com/ipfs/go-log v1.0.5 // indirect github.com/ipfs/go-log/v2 v2.1.3 // indirect - github.com/jpillora/backoff v1.0.0 // indirect - github.com/klauspost/compress v1.17.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect - github.com/nats-io/nkeys v0.4.6 // indirect + github.com/nats-io/nkeys v0.4.11 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/otiai10/primes v0.0.0-20210501021515-f1b2be525a11 // indirect @@ -81,6 +78,7 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect go.opencensus.io v0.24.0 // indirect go.uber.org/atomic v1.9.0 // indirect + go.uber.org/goleak v1.3.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.21.0 // indirect golang.org/x/crypto v0.37.0 // indirect diff --git a/go.sum b/go.sum index 2ff79c0..7930c50 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuy github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/appleboy/com v0.3.0 h1:omze/tJPyi2YVH+m23GSrCGt90A+4vQNpEYBW+GuSr4= -github.com/appleboy/com v0.3.0/go.mod h1:kByEI3/vzI5GM1+O5QdBHLsXaOsmFsJcOpCSgASi4sg= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJA= @@ -109,8 +107,6 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-queue/queue v0.4.0 h1:vsOvW4Wqb7Ow5+tKnlZD0PbLf4MLEO1e5C7DV8BDfBg= -github.com/golang-queue/queue v0.4.0/go.mod h1:bZobuNN7gnumxi9LRGihr7y7quDeBZZAvfPcC+H5dzg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc= github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= @@ -207,8 +203,6 @@ github.com/ipfs/go-log/v2 v2.1.3 h1:1iS3IU7aXRlbgUpN8yTTpJ53NXYjAe37vcI5+5nYrzk= github.com/ipfs/go-log/v2 v2.1.3/go.mod h1:/8d0SH3Su5Ooc31QlL1WysJhvyOTDCjcCZ9Axpmri6g= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= -github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= -github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -216,8 +210,8 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= -github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= -github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -262,10 +256,10 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/nats-io/nats.go v1.31.0 h1:/WFBHEc/dOKBF6qf1TZhrdEfTmOZ5JzdJ+Y3m6Y/p7E= -github.com/nats-io/nats.go v1.31.0/go.mod h1:di3Bm5MLsoB4Bx61CBTsxuarI36WbhAwOm8QrW39+i8= -github.com/nats-io/nkeys v0.4.6 h1:IzVe95ru2CT6ta874rt9saQRkWfe2nFj1NtvYSLqMzY= -github.com/nats-io/nkeys v0.4.6/go.mod h1:4DxZNzenSVd1cYQoAa8948QY3QDjrHfcfVADymtkpts= +github.com/nats-io/nats.go v1.43.0 h1:uRFZ2FEoRvP64+UUhaTokyS18XBCR/xM2vQZKO4i8ug= +github.com/nats-io/nats.go v1.43.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= +github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0= +github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= @@ -379,8 +373,6 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= -go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/pkg/client/client.go b/pkg/client/client.go index 3d5986f..519edf6 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -29,12 +29,14 @@ const ( mpcKeygenSuccessQueue = "mpc_keygen_success" mpcSigningResultQueue = "signing_result" mpcResharingSuccessQueue = "mpc_resharing_success" + mpcKeygenRequestQueue = "mpc_keygen_request" // NATS subjects mpcSigningRequestSubject = "mpc.signing_request.*" mpcKeygenSuccessSubject = "mpc.mpc_keygen_success.*" mpcSigningResultSubject = "mpc.signing_result.*" mpcResharingSuccessSubject = "mpc.mpc_resharing_success.*" + mpcKeygenRequestSubject = "mpc.mpc_keygen_request.*" ) type MPCClient interface { @@ -54,7 +56,9 @@ type mpcClient struct { genKeySuccessQueue messaging.MessageQueue signResultQueue messaging.MessageQueue resharingResultQueue messaging.MessageQueue - privKey ed25519.PrivateKey + genKeyRequestQueue messaging.MessageQueue + + privKey ed25519.PrivateKey } // Options defines configuration options for creating a new MPCClient @@ -91,6 +95,7 @@ func NewMPCClient(opts Options) MPCClient { genKeySuccessQueue: manager.NewMessageQueue(mpcKeygenSuccessQueue), signResultQueue: manager.NewMessageQueue(mpcSigningResultQueue), resharingResultQueue: manager.NewMessageQueue(mpcResharingSuccessQueue), + genKeyRequestQueue: manager.NewMessagePullSubscriber(mpcKeygenRequestQueue), privKey: privKey, } } @@ -100,6 +105,7 @@ func initMessageQueueManager(natsConn *nats.Conn) *messaging.NATsMessageQueueMan mpcKeygenSuccessSubject, mpcSigningResultSubject, mpcResharingSuccessSubject, + mpcKeygenRequestSubject, }, natsConn) } @@ -118,7 +124,9 @@ func (c *mpcClient) CreateWallet(walletID string) error { return fmt.Errorf("CreateWallet: marshal error: %w", err) } - if err := c.pubsub.Publish(eventconsumer.MPCGenerateEvent, bytes); err != nil { + if err := c.genKeyRequestQueue.Enqueue(mpcKeygenRequestSubject, bytes, &messaging.EnqueueOptions{ + IdempotententKey: fmt.Sprintf("%s.%s", eventconsumer.MPCGenerateEvent, walletID), + }); err != nil { return fmt.Errorf("CreateWallet: publish error: %w", err) } return nil diff --git a/pkg/eventconsumer/keygen_consumer.go b/pkg/eventconsumer/keygen_consumer.go new file mode 100644 index 0000000..891b1dc --- /dev/null +++ b/pkg/eventconsumer/keygen_consumer.go @@ -0,0 +1,143 @@ +package eventconsumer + +import ( + "context" + "errors" + "time" + + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +const ( + // Maximum time to wait for a keygen response. + keygenResponseTimeout = 30 * time.Second + // How often to poll for the reply message. + keygenPollingInterval = 500 * time.Millisecond +) + +// KeygenConsumer represents a consumer that processes keygen events. +type KeygenConsumer interface { + // Run starts the consumer and blocks until the provided context is canceled. + Run(ctx context.Context) error + // Close performs a graceful shutdown of the consumer. + Close() error +} + +// keygenConsumer implements KeygenConsumer. +type keygenConsumer struct { + natsConn *nats.Conn + pubsub messaging.PubSub + keygenRequestQueue messaging.MessageQueue + + // jsSub holds the JetStream subscription, so it can be cleaned up during Close(). + jsSub messaging.Subscription +} + +// NewKeygenConsumer returns a new instance of KeygenConsumer. +func NewKeygenConsumer(natsConn *nats.Conn, keygenRequestQueue messaging.MessageQueue, pubsub messaging.PubSub) KeygenConsumer { + return &keygenConsumer{ + natsConn: natsConn, + pubsub: pubsub, + keygenRequestQueue: keygenRequestQueue, + } +} + +// Run subscribes to keygen events and processes them until the context is canceled. +func (sc *keygenConsumer) Run(ctx context.Context) error { + logger.Info("Starting key generation event consumer") + + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + logger.Info("Stopping key generation event processing") + return + + case <-ticker.C: + logger.Info("Calling to fetch key generation events...") + + // No need for a separate fetch context since the fetch operation + // is synchronous and completes before we'd cancel it + err := sc.keygenRequestQueue.Fetch(2, func(msg jetstream.Msg) error { + sc.handleKeygenEvent(msg) + return nil + }) + + if err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + logger.Error("Error fetching key generation events", err) + } + } + } + } + }() + + return nil +} + +func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { + // Create a reply inbox to receive the keygen event response. + replyInbox := nats.NewInbox() + + // Use a synchronous subscription for the reply inbox. + replySub, err := sc.natsConn.SubscribeSync(replyInbox) + if err != nil { + logger.Error("KeygenConsumer: Failed to subscribe to reply inbox", err) + _ = msg.Nak() + return + } + defer func() { + if err := replySub.Unsubscribe(); err != nil { + logger.Warn("KeygenConsumer: Failed to unsubscribe from reply inbox", err) + } + }() + + // Publish the keygen event with the reply inbox. + if err := sc.pubsub.PublishWithReply(MPCGenerateEvent, replyInbox, msg.Data()); err != nil { + logger.Error("KeygenConsumer: Failed to publish keygen event with reply", err) + _ = msg.Nak() + return + } + + // Poll for the reply message until timeout. + deadline := time.Now().Add(keygenResponseTimeout) + for time.Now().Before(deadline) { + replyMsg, err := replySub.NextMsg(keygenPollingInterval) + if err != nil { + // If timeout occurs, continue trying. + if err == nats.ErrTimeout { + continue + } + logger.Error("KeygenConsumer: Error receiving reply message", err) + break + } + if replyMsg != nil { + logger.Info("KeygenConsumer: Completed keygen event reply received") + if ackErr := msg.Ack(); ackErr != nil { + logger.Error("KeygenConsumer: ACK failed", ackErr) + } + return + } + } + + logger.Warn("KeygenConsumer: Timeout waiting for keygen event response") + _ = msg.Nak() +} + +// Close unsubscribes from the JetStream subject and cleans up resources. +func (sc *keygenConsumer) Close() error { + if sc.jsSub != nil { + if err := sc.jsSub.Unsubscribe(); err != nil { + logger.Error("KeygenConsumer: Failed to unsubscribe from JetStream", err) + return err + } + logger.Info("KeygenConsumer: Unsubscribed from JetStream") + } + return nil +} diff --git a/pkg/messaging/message_queue.go b/pkg/messaging/message_queue.go index 60242df..d706b2b 100644 --- a/pkg/messaging/message_queue.go +++ b/pkg/messaging/message_queue.go @@ -18,6 +18,7 @@ var ( type MessageQueue interface { Enqueue(topic string, message []byte, options *EnqueueOptions) error Dequeue(topic string, handler func(message []byte) error) error + Fetch(batch int, handler func(msg jetstream.Msg) error) error Close() } @@ -32,6 +33,13 @@ type msgQueue struct { consumerContext jetstream.ConsumeContext } +type msgPull struct { + consumerName string + js jetstream.JetStream + consumer jetstream.Consumer + fetchMaxWait time.Duration +} + type NATsMessageQueueManager struct { queueName string js jetstream.JetStream @@ -158,3 +166,65 @@ func (mq *msgQueue) Close() { func (n *msgQueue) handleReconnect(nc *nats.Conn) { logger.Info("NATS: Reconnected to NATS") } + +func (m *NATsMessageQueueManager) NewMessagePullSubscriber(consumerName string) MessageQueue { + mq := &msgQueue{ + consumerName: consumerName, + js: m.js, + } + consumerWildCard := fmt.Sprintf("%s.%s.*", m.queueName, consumerName) + cfg := jetstream.ConsumerConfig{ + Name: consumerName, + Durable: consumerName, + MaxAckPending: 1000, + // If a message isn't acked within AckWait, it will be redelivered up to MaxDelive + AckWait: 180 * time.Second, + MaxWaiting: 1000, + AckPolicy: jetstream.AckExplicitPolicy, + FilterSubjects: []string{ + consumerWildCard, + }, + MaxDeliver: 3, + MaxRequestBatch: 10, + } + + logger.Info("Creating pull consumer for subject", "config", cfg) + consumer, err := m.js.CreateOrUpdateConsumer(context.Background(), m.queueName, cfg) + if err != nil { + logger.Fatal("Error creating JetStream consumer: ", err) + } + + mq.consumer = consumer + return mq +} + +func (mq *msgQueue) Fetch(batch int, handler func(msg jetstream.Msg) error) error { + msgs, err := mq.consumer.Fetch(batch, jetstream.FetchMaxWait(2*time.Minute)) + if err != nil { + return fmt.Errorf("error fetching messages: %w", err) + } + + for msg := range msgs.Messages() { + meta, _ := msg.Metadata() + logger.Debug("Received message", "meta", meta) + err := handler(msg) + if err != nil { + if errors.Is(err, ErrPermament) { + logger.Info("Permanent error on message", "subject", msg.Subject) + msg.Term() + continue + } + + logger.Error("Error handling message: ", err) + msg.Nak() + continue + } + + logger.Debug("Message Acknowledged", "subject", msg.Subject) + err = msg.Ack() + if err != nil { + logger.Error("Error acknowledging message: ", err) + } + } + return nil +} diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 9860e02..4e64a13 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -17,7 +17,7 @@ type Subscription interface { type PubSub interface { Publish(topic string, message []byte) error - PublishWithReply(ttopic, reply string, data []byte) error + PublishWithReply(topic, reply string, data []byte) error Subscribe(topic string, handler func(msg *nats.Msg)) (Subscription, error) } From c0f9a1135bb190f8b2010e26181870885f2bb3bc Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 18 Jun 2025 16:55:00 +0700 Subject: [PATCH 33/34] Refactor key retrieval in GetSaveData method and adjust key generation event handling --- pkg/mpc/session/base.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index d6f5c8e..845d8d5 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -279,8 +279,14 @@ func (s *session) SetSaveData(saveBytes []byte) { // GetSaveData gets the key from the kvstore func (s *session) GetSaveData(version int) ([]byte, error) { + var key string composeKey := s.composeKey(s.walletID) - data, err := s.kvstore.Get(fmt.Sprintf("%s-%d", composeKey, version)) + if version == 0 { + key = composeKey + } else { + key = fmt.Sprintf("%s-%d", composeKey, version) + } + data, err := s.kvstore.Get(key) if err != nil { return nil, fmt.Errorf("failed to get key: %w", err) } From aeb1249ae526e4fd6f97d6e8fc11cb95b8fb1520 Mon Sep 17 00:00:00 2001 From: anhthii Date: Thu, 26 Jun 2025 20:59:22 +0700 Subject: [PATCH 34/34] Fix concurrency issue --- pkg/eventconsumer/event_consumer.go | 202 ++++++++++++++------------- pkg/eventconsumer/keygen_consumer.go | 49 ++++--- pkg/messaging/message_queue.go | 21 +-- pkg/mpc/party/base.go | 115 +++++++++------ pkg/mpc/session/base.go | 51 +++++-- pkg/tsslimiter/queue.go | 29 +++- pkg/tsslimiter/tsslimiter.go | 28 ++-- 7 files changed, 295 insertions(+), 200 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index b9d4f18..3671789 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -16,6 +16,7 @@ import ( "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/monitoring" "github.com/fystack/mpcium/pkg/mpc/node" + "github.com/fystack/mpcium/pkg/mpc/session" "github.com/fystack/mpcium/pkg/tsslimiter" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" @@ -29,10 +30,10 @@ const ( // Default version for keygen DefaultVersion int = 1 - SessionTimeout = 1 * time.Minute + SessionTimeout = 15 * time.Second MaxConcurrentSessions = 5 // how long the entire handler will wait for *all* sessions + publishing: - HandlerTimeout = 2 * time.Minute + HandlerTimeout = 20 * time.Second ) type EventConsumer interface { @@ -77,7 +78,8 @@ func NewEventConsumer( identityStore identity.Store, ) EventConsumer { limiter := tsslimiter.NewWeightedLimiter(concurrency.GetTSSConcurrencyLimit()) - limiterQueue := tsslimiter.NewWeightedQueue(limiter, 100) + bufferSize := 100 + limiterQueue := tsslimiter.NewWeightedQueue(limiter, bufferSize) ec := &eventConsumer{ node: node, @@ -122,16 +124,17 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { // Create session limiter channel with capacity 5 sub, err := ec.pubsub.Subscribe(MPCGenerateEvent, func(natMsg *nats.Msg) { logger.Info("Received key generation event", "subject", natMsg.Subject) - // This blocks if max sessions are already running - // go func(data []byte) { - go func(data []byte) { - // Ack the message immediately to prevent redelivery from JetStream. This is critical. - - if err := ec.handleKeyGenerationEvent(context.Background(), data); err != nil { + job := tsslimiter.SessionJob{ + Type: tsslimiter.SessionKeygenCombined, + Run: func() error { + return ec.handleKeyGenerationEvent(context.Background(), natMsg) + }, + OnError: func(err error) { logger.Error("Failed to handle key generation event", err) - } - }(natMsg.Data) - // }(natMsg.Data) + }, + Name: fmt.Sprintf("keygen-%s", string(natMsg.Data)), + } + ec.limiterQueue.Enqueue(job) }) if err != nil { @@ -142,7 +145,11 @@ func (ec *eventConsumer) consumeKeyGenerationEvent() error { return nil } -func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, raw []byte) error { +func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, natMsg *nats.Msg) error { + raw := natMsg.Data + ctx, handlerCancel := context.WithTimeout(parentCtx, HandlerTimeout) + defer handlerCancel() + // 1) decode and verify var msg types.GenerateKeyMessage if err := json.Unmarshal(raw, &msg); err != nil { @@ -155,104 +162,97 @@ func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, raw walletID := msg.WalletID successEvent := &event.KeygenSuccessEvent{WalletID: walletID} - // 2) give this handler its own timeout - handlerCtx, handlerCancel := context.WithTimeout(parentCtx, HandlerTimeout) - defer handlerCancel() + // 2) prepare both sessions + s0, err := ec.node.CreateKeygenSession(types.KeyTypeSecp256k1, walletID, ec.mpcThreshold, ec.genKeySuccessQueue) + if err != nil { + return fmt.Errorf("create ECDSA session: %w", err) + } - // wait for the sessions to return (even if they timed out) - var wg sync.WaitGroup - // wait for *both* callbacks to fire before publishing - var cbWg sync.WaitGroup - cbWg.Add(2) + s1, err := ec.node.CreateKeygenSession(types.KeyTypeEd25519, walletID, ec.mpcThreshold, ec.genKeySuccessQueue) + if err != nil { + s0.Close() + return fmt.Errorf("create EDDSA session: %w", err) + } - var eventMutex sync.Mutex + s0.Listen(ctx) + s1.Listen(ctx) - // 3) enqueue ECDSA & EDDSA jobs - for _, keyType := range []types.KeyType{types.KeyTypeSecp256k1, types.KeyTypeEd25519} { - keyType := keyType + defer s0.Close() + defer s1.Close() - s, err := ec.node.CreateKeygenSession(keyType, walletID, ec.mpcThreshold, ec.genKeySuccessQueue) - if err != nil { - return fmt.Errorf("create %v session: %w", keyType, err) + runKeygen := func(s session.Session, keyType types.KeyType) error { + sessionCtx, sessionCancel := context.WithTimeout(ctx, SessionTimeout) + defer sessionCancel() + + // // 1. Wait for all parties to be ready to start + if err := s.WaitForReady(sessionCtx, fmt.Sprintf("KEYGEN-start:%s", keyType)); err != nil { + return fmt.Errorf("failed to wait for ready: %w", err) } - s.Listen() - - wg.Add(1) - run := func() { - defer wg.Done() - defer s.Close() - - // give each session its own shorter timeout - sessionCtx, sessionCancel := context.WithTimeout(handlerCtx, SessionTimeout) - defer sessionCancel() - - s.StartKeygen(sessionCtx, s.Send, func(data []byte) { - defer cbWg.Done() // signal that this keyType actually called back - - logger.Info("[callback] StartKeygen fired", "walletID", walletID, "keyType", keyType) - - // save the share - if err := s.SaveKey( - ec.node.GetReadyPeersIncludeSelf(), - ec.mpcThreshold, - DefaultVersion, - data, - ); err != nil { - logger.Error("Failed to save key", err, "walletID", walletID, "keyType", keyType) - } - // extract & record the pubkey - if pubKey, err := s.GetPublicKey(data); err == nil { - eventMutex.Lock() - switch keyType { - case types.KeyTypeSecp256k1: - successEvent.ECDSAPubKey = pubKey - case types.KeyTypeEd25519: - successEvent.EDDSAPubKey = pubKey - } - eventMutex.Unlock() - } else { - logger.Error("Failed to get public key", err, "walletID", walletID, "keyType", keyType) + doneCh := make(chan error, 1) + + // 2. Start the key generation protocol + s.StartKeygen(sessionCtx, s.Send, func(data []byte) { + logger.Info("[callback] StartKeygen fired", "walletID", walletID, "keyType", keyType) + if err := s.SaveKey(ec.node.GetReadyPeersIncludeSelf(), ec.mpcThreshold, DefaultVersion, data); err != nil { + logger.Error("Failed to save key", err, "walletID", walletID, "keyType", keyType) + doneCh <- err + return + } + + if pubKey, err := s.GetPublicKey(data); err == nil { + switch keyType { + case types.KeyTypeSecp256k1: + successEvent.ECDSAPubKey = pubKey + case types.KeyTypeEd25519: + successEvent.EDDSAPubKey = pubKey } - }) - } + } else { + logger.Error("Failed to get public key", err, "walletID", walletID, "keyType", keyType) + doneCh <- err + return + } - var sessionType tsslimiter.SessionType - if keyType == types.KeyTypeSecp256k1 { - sessionType = tsslimiter.SessionKeygenECDSA - } else { - sessionType = tsslimiter.SessionKeygenEDDSA - } + // // 3. Wait for all parties to confirm completion + if err := s.WaitForReady(sessionCtx, fmt.Sprintf("KEYGEN-complete:%s", keyType)); err != nil { + doneCh <- fmt.Errorf("failed to wait for completion: %w", err) + return + } - ec.limiterQueue.Enqueue(tsslimiter.SessionJob{ - Type: sessionType, - Run: run, + doneCh <- nil }) - } - - // 4) wait for both session goroutines to return - wg.Wait() - // 5) now wait for both callbacks (or handler timeout) - doneCb := make(chan struct{}) - go func() { - cbWg.Wait() - close(doneCb) - }() + select { + case err := <-doneCh: + if err != nil { + return fmt.Errorf("keygen onComplete failed: %w", err) + } + return nil + case err := <-s.ErrCh(): + return fmt.Errorf("session error during keygen: %w", err) + case <-sessionCtx.Done(): + return fmt.Errorf("keygen timed out: %w", sessionCtx.Err()) + } + } - select { - case <-handlerCtx.Done(): - logger.Warn("Keygen callbacks did not all fire before timeout", "walletID", walletID) - return handlerCtx.Err() - case <-doneCb: - // both callbacks have run + logger.Info("Starting ECDSA key generation...", "walletID", walletID) + if err := runKeygen(s0, types.KeyTypeSecp256k1); err != nil { + return fmt.Errorf("ECDSA keygen failed: %w", err) } + logger.Info("ECDSA key generation completed.", "walletID", walletID) - // 6) marshal & publish success + logger.Info("Starting EDDSA key generation...", "walletID", walletID) + if err := runKeygen(s1, types.KeyTypeEd25519); err != nil { + return fmt.Errorf("EDDSA keygen failed: %w", err) + } + logger.Info("EDDSA key generation completed.", "walletID", walletID) + // 3) Send reply to keygen consumer after both keygens complete + // 4) marshal & publish success successBytes, err := json.Marshal(successEvent) if err != nil { return fmt.Errorf("marshal success event: %w", err) } + if err := ec.genKeySuccessQueue.Enqueue( fmt.Sprintf(event.TypeGenerateWalletSuccess, walletID), successBytes, @@ -263,6 +263,15 @@ func (ec *eventConsumer) handleKeyGenerationEvent(parentCtx context.Context, raw return fmt.Errorf("enqueue success event: %w", err) } + if natMsg.Reply != "" { + err = ec.pubsub.Publish(natMsg.Reply, successBytes) + if err != nil { + logger.Error("Failed to publish reply", err) + } else { + logger.Info("Reply sent to keygen consumer", "reply", natMsg.Reply, "walletID", walletID) + } + } + logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) return nil } @@ -323,7 +332,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return } - go signingSession.Listen() + go signingSession.Listen(context.Background()) txBigInt := new(big.Int).SetBytes(msg.Tx) go func() { @@ -366,6 +375,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { logger.Info("Signing completed", "walletID", msg.WalletID, "txID", msg.TxID, "data", len(data)) ec.removeSession(msg.WalletID, msg.TxID) + }) }() @@ -436,8 +446,8 @@ func (ec *eventConsumer) handleReshareEvent(ctx context.Context, raw []byte) err return fmt.Errorf("create new session: %w", err) } - go oldSession.Listen() - go newSession.Listen() + go oldSession.Listen(context.Background()) + go newSession.Listen(context.Background()) successEvent := &event.ResharingSuccessEvent{WalletID: msg.WalletID} diff --git a/pkg/eventconsumer/keygen_consumer.go b/pkg/eventconsumer/keygen_consumer.go index 891b1dc..c810d46 100644 --- a/pkg/eventconsumer/keygen_consumer.go +++ b/pkg/eventconsumer/keygen_consumer.go @@ -13,9 +13,9 @@ import ( const ( // Maximum time to wait for a keygen response. - keygenResponseTimeout = 30 * time.Second + keygenResponseTimeout = 90 * time.Second // How often to poll for the reply message. - keygenPollingInterval = 500 * time.Millisecond + keygenPollingInterval = 1 * time.Second ) // KeygenConsumer represents a consumer that processes keygen events. @@ -50,7 +50,18 @@ func (sc *keygenConsumer) Run(ctx context.Context) error { logger.Info("Starting key generation event consumer") go func() { - ticker := time.NewTicker(30 * time.Second) + // Initial fetch + logger.Info("Calling to fetch key generation events...") + err := sc.keygenRequestQueue.Fetch(5, func(msg jetstream.Msg) error { + sc.handleKeygenEvent(msg) + return nil + }) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + logger.Error("Error fetching key generation events", err) + } + + // Then start the ticker + ticker := time.NewTicker(15 * time.Second) defer ticker.Stop() for { @@ -62,17 +73,12 @@ func (sc *keygenConsumer) Run(ctx context.Context) error { case <-ticker.C: logger.Info("Calling to fetch key generation events...") - // No need for a separate fetch context since the fetch operation - // is synchronous and completes before we'd cancel it - err := sc.keygenRequestQueue.Fetch(2, func(msg jetstream.Msg) error { + err := sc.keygenRequestQueue.Fetch(5, func(msg jetstream.Msg) error { sc.handleKeygenEvent(msg) return nil }) - - if err != nil { - if !errors.Is(err, context.DeadlineExceeded) { - logger.Error("Error fetching key generation events", err) - } + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + logger.Error("Error fetching key generation events", err) } } } @@ -89,28 +95,22 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { replySub, err := sc.natsConn.SubscribeSync(replyInbox) if err != nil { logger.Error("KeygenConsumer: Failed to subscribe to reply inbox", err) - _ = msg.Nak() + _ = msg.Term() return } - defer func() { - if err := replySub.Unsubscribe(); err != nil { - logger.Warn("KeygenConsumer: Failed to unsubscribe from reply inbox", err) - } - }() + defer replySub.Unsubscribe() // Publish the keygen event with the reply inbox. if err := sc.pubsub.PublishWithReply(MPCGenerateEvent, replyInbox, msg.Data()); err != nil { logger.Error("KeygenConsumer: Failed to publish keygen event with reply", err) - _ = msg.Nak() + _ = msg.Term() return } - // Poll for the reply message until timeout. deadline := time.Now().Add(keygenResponseTimeout) for time.Now().Before(deadline) { replyMsg, err := replySub.NextMsg(keygenPollingInterval) if err != nil { - // If timeout occurs, continue trying. if err == nats.ErrTimeout { continue } @@ -119,15 +119,18 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { } if replyMsg != nil { logger.Info("KeygenConsumer: Completed keygen event reply received") - if ackErr := msg.Ack(); ackErr != nil { - logger.Error("KeygenConsumer: ACK failed", ackErr) + if err := msg.Ack(); err != nil && !messaging.IsAlreadyAcknowledged(err) { + logger.Error("KeygenConsumer: ACK failed", err) } return } } + // Timeout logger.Warn("KeygenConsumer: Timeout waiting for keygen event response") - _ = msg.Nak() + if err := msg.Term(); err != nil { + logger.Error("KeygenConsumer: Failed to terminate message", err) + } } // Close unsubscribes from the JetStream subject and cleans up resources. diff --git a/pkg/messaging/message_queue.go b/pkg/messaging/message_queue.go index d706b2b..d1a5f95 100644 --- a/pkg/messaging/message_queue.go +++ b/pkg/messaging/message_queue.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/fystack/mpcium/pkg/logger" @@ -163,10 +164,6 @@ func (mq *msgQueue) Close() { } } -func (n *msgQueue) handleReconnect(nc *nats.Conn) { - logger.Info("NATS: Reconnected to NATS") -} - func (m *NATsMessageQueueManager) NewMessagePullSubscriber(consumerName string) MessageQueue { mq := &msgQueue{ consumerName: consumerName, @@ -184,7 +181,7 @@ func (m *NATsMessageQueueManager) NewMessagePullSubscriber(consumerName string) FilterSubjects: []string{ consumerWildCard, }, - MaxDeliver: 3, + MaxDeliver: 1, MaxRequestBatch: 10, } @@ -199,14 +196,15 @@ func (m *NATsMessageQueueManager) NewMessagePullSubscriber(consumerName string) } func (mq *msgQueue) Fetch(batch int, handler func(msg jetstream.Msg) error) error { - msgs, err := mq.consumer.Fetch(batch, jetstream.FetchMaxWait(2*time.Minute)) + // Reduced fetch timeout from 2 minutes to 30 seconds for faster processing + msgs, err := mq.consumer.Fetch(batch, jetstream.FetchMaxWait(30*time.Second)) if err != nil { return fmt.Errorf("error fetching messages: %w", err) } for msg := range msgs.Messages() { meta, _ := msg.Metadata() - logger.Debug("Received message", "meta", meta) + logger.Debug("Received message", "meta", meta) // Changed to Debug to reduce log noise err := handler(msg) if err != nil { if errors.Is(err, ErrPermament) { @@ -220,11 +218,16 @@ func (mq *msgQueue) Fetch(batch int, handler func(msg jetstream.Msg) error) erro continue } - logger.Debug("Message Acknowledged", "subject", msg.Subject) err = msg.Ack() if err != nil { - logger.Error("Error acknowledging message: ", err) + if !IsAlreadyAcknowledged(err) { + logger.Error("Error acknowledging message:", err) + } } } return nil } + +func IsAlreadyAcknowledged(err error) bool { + return err != nil && strings.Contains(err.Error(), nats.ErrMsgAlreadyAckd.Error()) +} diff --git a/pkg/mpc/party/base.go b/pkg/mpc/party/base.go index 684126d..ccb5094 100644 --- a/pkg/mpc/party/base.go +++ b/pkg/mpc/party/base.go @@ -3,11 +3,12 @@ package party import ( "context" "encoding/json" + "errors" + "fmt" "math/big" - "time" + "sync" "github.com/bnb-chain/tss-lib/v2/tss" - "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" ) @@ -33,7 +34,6 @@ type Party interface { InCh() chan types.TssMessage OutCh() chan tss.Message ErrCh() chan error - Close() } type party struct { @@ -44,12 +44,28 @@ type party struct { inCh chan types.TssMessage outCh chan tss.Message errCh chan error + + ctx context.Context + cancel context.CancelFunc + closeOnce sync.Once } -func NewParty(walletID string, partyID *tss.PartyID, partyIDs []*tss.PartyID, threshold int, errCh chan error) *party { - inCh := make(chan types.TssMessage, 1000) - outCh := make(chan tss.Message, 1000) - return &party{walletID, threshold, partyID, partyIDs, inCh, outCh, errCh} +func NewParty( + walletID string, + partyID *tss.PartyID, + partyIDs []*tss.PartyID, + threshold int, + errCh chan error, +) *party { + return &party{ + walletID: walletID, + threshold: threshold, + partyID: partyID, + partyIDs: partyIDs, + inCh: make(chan types.TssMessage, 1000), + outCh: make(chan tss.Message, 1000), + errCh: errCh, + } } func (p *party) WalletID() string { @@ -64,64 +80,79 @@ func (p *party) PartyIDs() []*tss.PartyID { return p.partyIDs } -func (p *party) InCh() chan types.TssMessage { - return p.inCh -} - -func (p *party) OutCh() chan tss.Message { - return p.outCh -} - -func (p *party) ErrCh() chan error { - return p.errCh -} - -func (p *party) Close() { - close(p.inCh) - close(p.outCh) -} +func (p *party) InCh() chan types.TssMessage { return p.inCh } +func (p *party) OutCh() chan tss.Message { return p.outCh } +func (p *party) ErrCh() chan error { return p.errCh } // runParty handles the common party execution loop +// startPartyLoop runs a TSS party, handling messages, errors, and completion. func runParty[T any]( s Party, ctx context.Context, party tss.Party, send func(tss.Message), - endCh chan T, + endCh <-chan T, onComplete func([]byte), ) { - // Start the party in a goroutine to handle errors + // safe error reporter + safeErr := func(err error) { + select { + case s.ErrCh() <- err: + case <-ctx.Done(): + } + } + + // start the tss party logic go func() { - start := time.Now() - logger.Info("[Starting] party", "walletID", s.WalletID()) + defer func() { + if r := recover(); r != nil { + safeErr(fmt.Errorf("panic in party.Start: %v", r)) + } + }() if err := party.Start(); err != nil { - s.ErrCh() <- err - return + safeErr(err) } - elapsed := time.Since(start) - logger.Info("[Closing] party", "walletID", s.WalletID(), "elapsed", elapsed.Milliseconds()) }() - // Main message handling loop + // main handling loop for { select { case <-ctx.Done(): + if ctx.Err() != context.Canceled { + safeErr(fmt.Errorf("party timed out: %w", ctx.Err())) + } return - case in := <-s.InCh(): - ok, err := party.UpdateFromBytes(in.MsgBytes, in.From, in.IsBroadcast) - if !ok || err != nil { - s.ErrCh() <- err + + case inMsg, ok := <-s.InCh(): + if !ok { + return + } + ok2, err := party.UpdateFromBytes(inMsg.MsgBytes, inMsg.From, inMsg.IsBroadcast) + if err != nil || !ok2 { + safeErr(errors.New("UpdateFromBytes failed")) + return + } + + case outMsg, ok := <-s.OutCh(): + if !ok { + return + } + // respect cancellation before invoking callback + if ctx.Err() != nil { + return + } + send(outMsg) + + case result, ok := <-endCh: + if !ok { return } - case out := <-s.OutCh(): - send(out) - case result := <-endCh: - bytes, err := json.Marshal(result) + bts, err := json.Marshal(result) if err != nil { - s.ErrCh() <- err + safeErr(err) return } - onComplete(bytes) + onComplete(bts) return } } diff --git a/pkg/mpc/session/base.go b/pkg/mpc/session/base.go index 845d8d5..4fad2ca 100644 --- a/pkg/mpc/session/base.go +++ b/pkg/mpc/session/base.go @@ -61,7 +61,7 @@ type Session interface { PartyIDs() []*tss.PartyID Send(msg tss.Message) - Listen() + Listen(ctx context.Context) SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) WaitForReady(ctx context.Context, sessionID string) error ErrCh() chan error @@ -81,6 +81,9 @@ type session struct { kvstore kvstore.KVStore keyinfoStore keyinfo.Store + msgBuffer chan []byte + workerCount int + topicComposer *TopicComposer composeKey KeyComposerFn consulKV infra.ConsulKV @@ -109,6 +112,7 @@ func NewSession( keyinfoStore: keyinfoStore, errCh: errCh, consulKV: consulKV, + msgBuffer: make(chan []byte, 100), // Buffer for 100 messages } } @@ -135,9 +139,8 @@ func (s *session) WaitForReady(ctx context.Context, sessionID string) error { // 2) poll until we see everyone total := len(s.party.PartyIDs()) - ticker := time.NewTicker(100 * time.Millisecond) + ticker := time.NewTicker(50 * time.Millisecond) defer ticker.Stop() - for { select { case <-ctx.Done(): @@ -149,10 +152,10 @@ func (s *session) WaitForReady(ctx context.Context, sessionID string) error { continue } if len(pairs) >= total { - logger.Info("[READY] peers ready", "have", len(pairs), "need", total, "walletID", s.walletID) + logger.Debug("[READY] peers ready", "have", len(pairs), "need", total, "walletID", s.walletID) return nil } - logger.Info("[READY] Waiting for peers ready", "wallet", s.walletID, "have", len(pairs), "need", total) + logger.Debug("[READY] Waiting for peers ready", "wallet", s.walletID, "have", len(pairs), "need", total) } } } @@ -210,7 +213,9 @@ func (s *session) Send(msg tss.Message) { // Listen is a wrapper around the party's Listen method // It subscribes to the broadcast and self direct topics -func (s *session) Listen() { +func (s *session) Listen(ctx context.Context) { + go s.startIncomingMessageWorker(ctx) + var wg sync.WaitGroup wg.Add(2) @@ -221,7 +226,13 @@ func (s *session) Listen() { defer wg.Done() sub, err := s.pubSub.Subscribe(broadcastTopic, func(natMsg *nats.Msg) { msg := natMsg.Data - go s.receive(msg) + select { + + case <-ctx.Done(): + return + default: + s.msgBuffer <- msg + } }) if err != nil { @@ -235,7 +246,12 @@ func (s *session) Listen() { direct := func() { defer wg.Done() sub, err := s.direct.Listen(selfDirectTopic, func(msg []byte) { - go s.receive(msg) + select { + case <-ctx.Done(): + return + default: + s.msgBuffer <- msg + } }) if err != nil { @@ -251,6 +267,20 @@ func (s *session) Listen() { wg.Wait() } +func (s *session) startIncomingMessageWorker(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-s.msgBuffer: + if !ok { + return + } + s.receive(msg) + } + } +} + // SaveKey saves the key to the keyinfo store and the kvstore func (s *session) SaveKey(participantPeerIDs []string, threshold int, version int, data []byte) (err error) { keyInfo := keyinfo.KeyInfo{ @@ -302,9 +332,8 @@ func (s *session) Close() { s.directSub.Unsubscribe() } - // Close party - if s.party != nil { - s.party.Close() + if s.msgBuffer != nil { + close(s.msgBuffer) } // Close error channel last diff --git a/pkg/tsslimiter/queue.go b/pkg/tsslimiter/queue.go index c170853..7e12fa2 100644 --- a/pkg/tsslimiter/queue.go +++ b/pkg/tsslimiter/queue.go @@ -6,10 +6,13 @@ import ( "github.com/fystack/mpcium/pkg/logger" ) -// SessionJob represents a queued job with type and execution logic +// SessionJob represents a queued job with type, execution logic, and optional error callback +// Run should return an error if execution fails. type SessionJob struct { - Type SessionType - Run func() + Type SessionType + Run func() error + OnError func(error) + Name string } // Queue defines the interface for a job queue that manages TSS session jobs. @@ -61,10 +64,18 @@ func (q *WeightedQueue) run() { // Block until we can acquire budget q.limiter.Acquire(job.Type) + // if !ok { + // logger.Info("Failed to Acquire", "jobType", job.Type, "name", job.Name) + // // Notify via OnError callback if provided + // if job.OnError != nil { + // job.OnError(fmt.Errorf("tsslimiter: failed to acquire budget for job type %v, job %s", job.Type, job.Name)) + // } + // continue + // } // Log limiter state after acquire usedAfter, _ := q.limiter.Stats() - logger.Info("After Acquire", "usedPoints", usedAfter, "jobType", job.Type) + logger.Info("After Acquire", "usedPoints", usedAfter, "jobType", job.Type, "name", job.Name) // Launch job q.wg.Add(1) @@ -73,8 +84,14 @@ func (q *WeightedQueue) run() { defer q.limiter.Release(j.Type) usedExec, _ := q.limiter.Stats() - logger.Info("Executing Job", "usedPoints", usedExec, "jobType", j.Type) - j.Run() + logger.Info("Executing Job", "usedPoints", usedExec, "jobType", j.Type, "name", job.Name) + + err := j.Run() + if err != nil && j.OnError != nil { + // Call the error handler for this job + j.OnError(err) + } + logger.Info("Pending Jobs", "num", len(q.queue)) }(job) diff --git a/pkg/tsslimiter/tsslimiter.go b/pkg/tsslimiter/tsslimiter.go index e868f18..2f34b25 100644 --- a/pkg/tsslimiter/tsslimiter.go +++ b/pkg/tsslimiter/tsslimiter.go @@ -2,7 +2,6 @@ package tsslimiter import ( "sync" - "time" "github.com/fystack/mpcium/pkg/logger" ) @@ -36,13 +35,13 @@ const ( // // Note: These values are conservative to maintain low latency and avoid timeouts. var sessionCosts = map[SessionType]int{ - SessionKeygenECDSA: 100, // Full core + SessionKeygenECDSA: 75, // Full core SessionReshareECDSA: 70, SessionSignECDSA: 40, SessionKeygenEDDSA: 25, // ~25% of core SessionReshareEDDSA: 20, SessionSignEDDSA: 15, - SessionKeygenCombined: 125, // ECDSA (100) + EDDSA (25) + SessionKeygenCombined: 100, // ECDSA (100) + EDDSA (25) } type Limiter interface { @@ -62,13 +61,16 @@ type WeightedLimiter struct { mu sync.Mutex usedPoints int maxPoints int + cond *sync.Cond } // NewWeightedLimiter creates a limiter with maxPoints = maxSessionsAllowed * 100 func NewWeightedLimiter(maxSessions int) *WeightedLimiter { - return &WeightedLimiter{ + l := &WeightedLimiter{ maxPoints: maxSessions * 100, } + l.cond = sync.NewCond(&l.mu) + return l } func (l *WeightedLimiter) TryAcquire(t SessionType) bool { @@ -89,16 +91,14 @@ func (l *WeightedLimiter) TryAcquire(t SessionType) bool { func (l *WeightedLimiter) Acquire(t SessionType) { cost := sessionCosts[t] - for { - l.mu.Lock() - if l.usedPoints+cost <= l.maxPoints { - l.usedPoints += cost - l.mu.Unlock() - return - } - l.mu.Unlock() - time.Sleep(50 * time.Millisecond) // backoff + l.mu.Lock() + defer l.mu.Unlock() + + for l.usedPoints+cost > l.maxPoints { + l.cond.Wait() } + + l.usedPoints += cost } func (l *WeightedLimiter) Release(t SessionType) { @@ -110,7 +110,9 @@ func (l *WeightedLimiter) Release(t SessionType) { if l.usedPoints < 0 { l.usedPoints = 0 } + logger.Info("Release", "sessionType", t, "usedPoints", l.usedPoints, "maxPoints", l.maxPoints) + l.cond.Broadcast() // Wake up waiting goroutines } func (l *WeightedLimiter) Stats() (int, int) {