From afb11667e63725dff870d9ff0b53e3c2bc337489 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 25 Dec 2025 12:36:32 +0000 Subject: [PATCH] feat: add encryption sigil with pre-obfuscation layer Implements ChaChaPolySigil that applies pre-obfuscation before sending data to CPU encryption routines. This ensures raw plaintext is never passed directly to encryption functions. Key improvements: - XORObfuscator and ShuffleMaskObfuscator for pre-encryption transforms - Nonce is now properly embedded in ciphertext, not stored separately in headers (production-ready, not demo-style) - Trix crypto integration with EncryptPayload/DecryptPayload methods - Comprehensive test coverage following Good/Bad/Ugly pattern --- pkg/enchantrix/crypto_sigil.go | 338 ++++++++++++++++++ pkg/enchantrix/crypto_sigil_test.go | 524 ++++++++++++++++++++++++++++ pkg/trix/crypto.go | 189 ++++++++++ pkg/trix/crypto_test.go | 438 +++++++++++++++++++++++ 4 files changed, 1489 insertions(+) create mode 100644 pkg/enchantrix/crypto_sigil.go create mode 100644 pkg/enchantrix/crypto_sigil_test.go create mode 100644 pkg/trix/crypto.go create mode 100644 pkg/trix/crypto_test.go diff --git a/pkg/enchantrix/crypto_sigil.go b/pkg/enchantrix/crypto_sigil.go new file mode 100644 index 0000000..f1017c3 --- /dev/null +++ b/pkg/enchantrix/crypto_sigil.go @@ -0,0 +1,338 @@ +package enchantrix + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "io" + + "golang.org/x/crypto/chacha20poly1305" +) + +var ( + // ErrInvalidKey is returned when the encryption key is invalid. + ErrInvalidKey = errors.New("enchantrix: invalid key size, must be 32 bytes") + // ErrCiphertextTooShort is returned when the ciphertext is too short to decrypt. + ErrCiphertextTooShort = errors.New("enchantrix: ciphertext too short") + // ErrDecryptionFailed is returned when decryption or authentication fails. + ErrDecryptionFailed = errors.New("enchantrix: decryption failed") + // ErrNoKeyConfigured is returned when no encryption key has been set. + ErrNoKeyConfigured = errors.New("enchantrix: no encryption key configured") +) + +// PreObfuscator applies a reversible transformation to data before encryption. +// This ensures that raw plaintext is never sent directly to CPU encryption routines. +type PreObfuscator interface { + // Obfuscate transforms plaintext before encryption. + Obfuscate(data []byte, entropy []byte) []byte + // Deobfuscate reverses the transformation after decryption. + Deobfuscate(data []byte, entropy []byte) []byte +} + +// XORObfuscator performs XOR-based obfuscation using entropy-derived key stream. +// This is a reversible transformation that ensures no cleartext patterns remain. +type XORObfuscator struct{} + +// Obfuscate XORs the data with a key stream derived from the entropy. +func (x *XORObfuscator) Obfuscate(data []byte, entropy []byte) []byte { + if len(data) == 0 { + return data + } + return x.transform(data, entropy) +} + +// Deobfuscate reverses the XOR transformation (XOR is symmetric). +func (x *XORObfuscator) Deobfuscate(data []byte, entropy []byte) []byte { + if len(data) == 0 { + return data + } + return x.transform(data, entropy) +} + +// transform applies XOR with an entropy-derived key stream. +func (x *XORObfuscator) transform(data []byte, entropy []byte) []byte { + result := make([]byte, len(data)) + keyStream := x.deriveKeyStream(entropy, len(data)) + for i := range data { + result[i] = data[i] ^ keyStream[i] + } + return result +} + +// deriveKeyStream creates a deterministic key stream from entropy. +func (x *XORObfuscator) deriveKeyStream(entropy []byte, length int) []byte { + stream := make([]byte, length) + h := sha256.New() + + // Generate key stream in 32-byte blocks + blockNum := uint64(0) + offset := 0 + for offset < length { + h.Reset() + h.Write(entropy) + var blockBytes [8]byte + binary.BigEndian.PutUint64(blockBytes[:], blockNum) + h.Write(blockBytes[:]) + block := h.Sum(nil) + + copyLen := len(block) + if offset+copyLen > length { + copyLen = length - offset + } + copy(stream[offset:], block[:copyLen]) + offset += copyLen + blockNum++ + } + return stream +} + +// ShuffleMaskObfuscator applies byte-level shuffling based on entropy. +// This provides additional diffusion before encryption. +type ShuffleMaskObfuscator struct{} + +// Obfuscate shuffles bytes and applies a mask derived from entropy. +func (s *ShuffleMaskObfuscator) Obfuscate(data []byte, entropy []byte) []byte { + if len(data) == 0 { + return data + } + + result := make([]byte, len(data)) + copy(result, data) + + // Generate permutation and mask from entropy + perm := s.generatePermutation(entropy, len(data)) + mask := s.deriveMask(entropy, len(data)) + + // Apply mask first, then shuffle + for i := range result { + result[i] ^= mask[i] + } + + // Shuffle using Fisher-Yates with deterministic seed + shuffled := make([]byte, len(data)) + for i, p := range perm { + shuffled[i] = result[p] + } + + return shuffled +} + +// Deobfuscate reverses the shuffle and mask operations. +func (s *ShuffleMaskObfuscator) Deobfuscate(data []byte, entropy []byte) []byte { + if len(data) == 0 { + return data + } + + result := make([]byte, len(data)) + + // Generate permutation and mask from entropy + perm := s.generatePermutation(entropy, len(data)) + mask := s.deriveMask(entropy, len(data)) + + // Unshuffle first + for i, p := range perm { + result[p] = data[i] + } + + // Remove mask + for i := range result { + result[i] ^= mask[i] + } + + return result +} + +// generatePermutation creates a deterministic permutation from entropy. +func (s *ShuffleMaskObfuscator) generatePermutation(entropy []byte, length int) []int { + perm := make([]int, length) + for i := range perm { + perm[i] = i + } + + // Use entropy to seed a deterministic shuffle + h := sha256.New() + h.Write(entropy) + h.Write([]byte("permutation")) + seed := h.Sum(nil) + + // Fisher-Yates shuffle with deterministic randomness + for i := length - 1; i > 0; i-- { + h.Reset() + h.Write(seed) + var iBytes [8]byte + binary.BigEndian.PutUint64(iBytes[:], uint64(i)) + h.Write(iBytes[:]) + jBytes := h.Sum(nil) + j := int(binary.BigEndian.Uint64(jBytes[:8]) % uint64(i+1)) + perm[i], perm[j] = perm[j], perm[i] + } + + return perm +} + +// deriveMask creates a mask byte array from entropy. +func (s *ShuffleMaskObfuscator) deriveMask(entropy []byte, length int) []byte { + mask := make([]byte, length) + h := sha256.New() + + blockNum := uint64(0) + offset := 0 + for offset < length { + h.Reset() + h.Write(entropy) + h.Write([]byte("mask")) + var blockBytes [8]byte + binary.BigEndian.PutUint64(blockBytes[:], blockNum) + h.Write(blockBytes[:]) + block := h.Sum(nil) + + copyLen := len(block) + if offset+copyLen > length { + copyLen = length - offset + } + copy(mask[offset:], block[:copyLen]) + offset += copyLen + blockNum++ + } + return mask +} + +// ChaChaPolySigil is a Sigil that encrypts/decrypts data using ChaCha20-Poly1305. +// It applies pre-obfuscation before encryption to ensure raw plaintext never +// goes directly to CPU encryption routines. +// +// The output format is: +// [24-byte nonce][encrypted(obfuscated(plaintext))] +// +// Unlike demo implementations, the nonce is ONLY embedded in the ciphertext, +// not exposed separately in headers. +type ChaChaPolySigil struct { + Key []byte + Obfuscator PreObfuscator + randReader io.Reader // for testing injection +} + +// NewChaChaPolySigil creates a new encryption sigil with the given key. +// The key must be exactly 32 bytes. +func NewChaChaPolySigil(key []byte) (*ChaChaPolySigil, error) { + if len(key) != 32 { + return nil, ErrInvalidKey + } + + keyCopy := make([]byte, 32) + copy(keyCopy, key) + + return &ChaChaPolySigil{ + Key: keyCopy, + Obfuscator: &XORObfuscator{}, + randReader: rand.Reader, + }, nil +} + +// NewChaChaPolySigilWithObfuscator creates a new encryption sigil with custom obfuscator. +func NewChaChaPolySigilWithObfuscator(key []byte, obfuscator PreObfuscator) (*ChaChaPolySigil, error) { + sigil, err := NewChaChaPolySigil(key) + if err != nil { + return nil, err + } + if obfuscator != nil { + sigil.Obfuscator = obfuscator + } + return sigil, nil +} + +// In encrypts the data with pre-obfuscation. +// The flow is: plaintext -> obfuscate -> encrypt +func (s *ChaChaPolySigil) In(data []byte) ([]byte, error) { + if s.Key == nil { + return nil, ErrNoKeyConfigured + } + if data == nil { + return nil, nil + } + + aead, err := chacha20poly1305.NewX(s.Key) + if err != nil { + return nil, err + } + + // Generate nonce + nonce := make([]byte, aead.NonceSize()) + reader := s.randReader + if reader == nil { + reader = rand.Reader + } + if _, err := io.ReadFull(reader, nonce); err != nil { + return nil, err + } + + // Pre-obfuscate the plaintext using nonce as entropy + // This ensures CPU encryption routines never see raw plaintext + obfuscated := data + if s.Obfuscator != nil { + obfuscated = s.Obfuscator.Obfuscate(data, nonce) + } + + // Encrypt the obfuscated data + // Output: [nonce | ciphertext | auth tag] + ciphertext := aead.Seal(nonce, nonce, obfuscated, nil) + + return ciphertext, nil +} + +// Out decrypts the data and reverses obfuscation. +// The flow is: decrypt -> deobfuscate -> plaintext +func (s *ChaChaPolySigil) Out(data []byte) ([]byte, error) { + if s.Key == nil { + return nil, ErrNoKeyConfigured + } + if data == nil { + return nil, nil + } + + aead, err := chacha20poly1305.NewX(s.Key) + if err != nil { + return nil, err + } + + minLen := aead.NonceSize() + aead.Overhead() + if len(data) < minLen { + return nil, ErrCiphertextTooShort + } + + // Extract nonce from ciphertext + nonce := data[:aead.NonceSize()] + ciphertext := data[aead.NonceSize():] + + // Decrypt + obfuscated, err := aead.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, ErrDecryptionFailed + } + + // Deobfuscate using the same nonce as entropy + plaintext := obfuscated + if s.Obfuscator != nil { + plaintext = s.Obfuscator.Deobfuscate(obfuscated, nonce) + } + + if len(plaintext) == 0 { + return []byte{}, nil + } + + return plaintext, nil +} + +// GetNonceFromCiphertext extracts the nonce from encrypted output. +// This is provided for debugging/logging purposes only. +// The nonce should NOT be stored separately in headers. +func GetNonceFromCiphertext(ciphertext []byte) ([]byte, error) { + nonceSize := chacha20poly1305.NonceSizeX + if len(ciphertext) < nonceSize { + return nil, ErrCiphertextTooShort + } + nonceCopy := make([]byte, nonceSize) + copy(nonceCopy, ciphertext[:nonceSize]) + return nonceCopy, nil +} diff --git a/pkg/enchantrix/crypto_sigil_test.go b/pkg/enchantrix/crypto_sigil_test.go new file mode 100644 index 0000000..e401a1b --- /dev/null +++ b/pkg/enchantrix/crypto_sigil_test.go @@ -0,0 +1,524 @@ +package enchantrix + +import ( + "bytes" + "crypto/rand" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockRandReader is a reader that returns an error. +type mockRandReader struct{} + +func (r *mockRandReader) Read(p []byte) (n int, err error) { + return 0, errors.New("random read error") +} + +// deterministicReader returns a predictable sequence for testing. +type deterministicReader struct { + seed byte +} + +func (r *deterministicReader) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = r.seed + r.seed++ + } + return len(p), nil +} + +// --- ChaChaPolySigil Tests --- + +func TestChaChaPolySigil_Good(t *testing.T) { + t.Run("EncryptDecrypt", func(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + plaintext := []byte("Hello, this is a secret message!") + ciphertext, err := sigil.In(plaintext) + require.NoError(t, err) + assert.NotEqual(t, plaintext, ciphertext) + assert.Greater(t, len(ciphertext), len(plaintext)) // nonce + overhead + + decrypted, err := sigil.Out(ciphertext) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) + }) + + t.Run("EmptyPlaintext", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + ciphertext, err := sigil.In([]byte{}) + require.NoError(t, err) + + decrypted, err := sigil.Out(ciphertext) + require.NoError(t, err) + assert.Equal(t, []byte{}, decrypted) + }) + + t.Run("LargeData", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + // Test with 1MB of data + plaintext := make([]byte, 1024*1024) + _, err = rand.Read(plaintext) + require.NoError(t, err) + + ciphertext, err := sigil.In(plaintext) + require.NoError(t, err) + + decrypted, err := sigil.Out(ciphertext) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) + }) + + t.Run("DifferentNoncesEachEncryption", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + plaintext := []byte("same message") + + ciphertext1, err := sigil.In(plaintext) + require.NoError(t, err) + + ciphertext2, err := sigil.In(plaintext) + require.NoError(t, err) + + // Ciphertexts should differ due to different nonces + assert.NotEqual(t, ciphertext1, ciphertext2) + + // But both should decrypt to the same plaintext + decrypted1, err := sigil.Out(ciphertext1) + require.NoError(t, err) + decrypted2, err := sigil.Out(ciphertext2) + require.NoError(t, err) + + assert.Equal(t, plaintext, decrypted1) + assert.Equal(t, plaintext, decrypted2) + }) + + t.Run("PreObfuscationApplied", func(t *testing.T) { + key := make([]byte, 32) + + // Use deterministic reader so we can verify obfuscation + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + sigil.randReader = &deterministicReader{seed: 0} + + plaintext := []byte("test data") + ciphertext, err := sigil.In(plaintext) + require.NoError(t, err) + + // The nonce is the first 24 bytes + nonce := ciphertext[:24] + + // Verify that pre-obfuscation was applied by checking that + // the plaintext pattern doesn't appear in raw form + // (The obfuscated data is XORed with a stream derived from the nonce) + obfuscator := &XORObfuscator{} + obfuscated := obfuscator.Obfuscate(plaintext, nonce) + assert.NotEqual(t, plaintext, obfuscated) + }) +} + +func TestChaChaPolySigil_Bad(t *testing.T) { + t.Run("InvalidKeySize", func(t *testing.T) { + _, err := NewChaChaPolySigil([]byte("too short")) + assert.ErrorIs(t, err, ErrInvalidKey) + + _, err = NewChaChaPolySigil(make([]byte, 16)) + assert.ErrorIs(t, err, ErrInvalidKey) + + _, err = NewChaChaPolySigil(make([]byte, 64)) + assert.ErrorIs(t, err, ErrInvalidKey) + }) + + t.Run("WrongKey", func(t *testing.T) { + key1 := make([]byte, 32) + key2 := make([]byte, 32) + key2[0] = 1 // Different key + + sigil1, err := NewChaChaPolySigil(key1) + require.NoError(t, err) + sigil2, err := NewChaChaPolySigil(key2) + require.NoError(t, err) + + ciphertext, err := sigil1.In([]byte("secret")) + require.NoError(t, err) + + _, err = sigil2.Out(ciphertext) + assert.ErrorIs(t, err, ErrDecryptionFailed) + }) + + t.Run("TamperedCiphertext", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + ciphertext, err := sigil.In([]byte("secret")) + require.NoError(t, err) + + // Tamper with the ciphertext (after the nonce) + ciphertext[30] ^= 0xff + + _, err = sigil.Out(ciphertext) + assert.ErrorIs(t, err, ErrDecryptionFailed) + }) + + t.Run("TruncatedCiphertext", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + _, err = sigil.Out([]byte("too short")) + assert.ErrorIs(t, err, ErrCiphertextTooShort) + }) + + t.Run("NoKeyConfigured", func(t *testing.T) { + sigil := &ChaChaPolySigil{} + + _, err := sigil.In([]byte("test")) + assert.ErrorIs(t, err, ErrNoKeyConfigured) + + _, err = sigil.Out([]byte("test")) + assert.ErrorIs(t, err, ErrNoKeyConfigured) + }) + + t.Run("RandomReaderError", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + sigil.randReader = &mockRandReader{} + + _, err = sigil.In([]byte("test")) + assert.Error(t, err) + }) +} + +func TestChaChaPolySigil_Ugly(t *testing.T) { + t.Run("NilPlaintext", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + ciphertext, err := sigil.In(nil) + assert.NoError(t, err) + assert.Nil(t, ciphertext) + }) + + t.Run("NilCiphertext", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + plaintext, err := sigil.Out(nil) + assert.NoError(t, err) + assert.Nil(t, plaintext) + }) + + t.Run("NilObfuscator", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + sigil.Obfuscator = nil // Explicitly set to nil + + plaintext := []byte("test without obfuscation") + ciphertext, err := sigil.In(plaintext) + require.NoError(t, err) + + decrypted, err := sigil.Out(ciphertext) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) + }) +} + +// --- XORObfuscator Tests --- + +func TestXORObfuscator_Good(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + obfuscator := &XORObfuscator{} + data := []byte("Hello, World!") + entropy := []byte("random-entropy-value") + + obfuscated := obfuscator.Obfuscate(data, entropy) + assert.NotEqual(t, data, obfuscated) + + deobfuscated := obfuscator.Deobfuscate(obfuscated, entropy) + assert.Equal(t, data, deobfuscated) + }) + + t.Run("DifferentEntropyDifferentOutput", func(t *testing.T) { + obfuscator := &XORObfuscator{} + data := []byte("same data") + entropy1 := []byte("entropy1") + entropy2 := []byte("entropy2") + + obfuscated1 := obfuscator.Obfuscate(data, entropy1) + obfuscated2 := obfuscator.Obfuscate(data, entropy2) + + assert.NotEqual(t, obfuscated1, obfuscated2) + }) + + t.Run("LargeData", func(t *testing.T) { + obfuscator := &XORObfuscator{} + data := make([]byte, 10000) + for i := range data { + data[i] = byte(i % 256) + } + entropy := []byte("test-entropy") + + obfuscated := obfuscator.Obfuscate(data, entropy) + deobfuscated := obfuscator.Deobfuscate(obfuscated, entropy) + assert.Equal(t, data, deobfuscated) + }) +} + +func TestXORObfuscator_Ugly(t *testing.T) { + t.Run("EmptyData", func(t *testing.T) { + obfuscator := &XORObfuscator{} + data := []byte{} + entropy := []byte("entropy") + + obfuscated := obfuscator.Obfuscate(data, entropy) + assert.Equal(t, data, obfuscated) + }) + + t.Run("EmptyEntropy", func(t *testing.T) { + obfuscator := &XORObfuscator{} + data := []byte("test") + entropy := []byte{} + + obfuscated := obfuscator.Obfuscate(data, entropy) + deobfuscated := obfuscator.Deobfuscate(obfuscated, entropy) + assert.Equal(t, data, deobfuscated) + }) +} + +// --- ShuffleMaskObfuscator Tests --- + +func TestShuffleMaskObfuscator_Good(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + obfuscator := &ShuffleMaskObfuscator{} + data := []byte("Hello, World!") + entropy := []byte("random-entropy-value") + + obfuscated := obfuscator.Obfuscate(data, entropy) + assert.NotEqual(t, data, obfuscated) + + deobfuscated := obfuscator.Deobfuscate(obfuscated, entropy) + assert.Equal(t, data, deobfuscated) + }) + + t.Run("DifferentEntropyDifferentOutput", func(t *testing.T) { + obfuscator := &ShuffleMaskObfuscator{} + data := []byte("same data") + entropy1 := []byte("entropy1") + entropy2 := []byte("entropy2") + + obfuscated1 := obfuscator.Obfuscate(data, entropy1) + obfuscated2 := obfuscator.Obfuscate(data, entropy2) + + assert.NotEqual(t, obfuscated1, obfuscated2) + }) + + t.Run("Deterministic", func(t *testing.T) { + obfuscator := &ShuffleMaskObfuscator{} + data := []byte("test data") + entropy := []byte("same entropy") + + obfuscated1 := obfuscator.Obfuscate(data, entropy) + obfuscated2 := obfuscator.Obfuscate(data, entropy) + + assert.Equal(t, obfuscated1, obfuscated2) + }) + + t.Run("LargeData", func(t *testing.T) { + obfuscator := &ShuffleMaskObfuscator{} + data := make([]byte, 10000) + for i := range data { + data[i] = byte(i % 256) + } + entropy := []byte("test-entropy") + + obfuscated := obfuscator.Obfuscate(data, entropy) + deobfuscated := obfuscator.Deobfuscate(obfuscated, entropy) + assert.Equal(t, data, deobfuscated) + }) +} + +func TestShuffleMaskObfuscator_Ugly(t *testing.T) { + t.Run("EmptyData", func(t *testing.T) { + obfuscator := &ShuffleMaskObfuscator{} + data := []byte{} + entropy := []byte("entropy") + + obfuscated := obfuscator.Obfuscate(data, entropy) + assert.Equal(t, data, obfuscated) + }) + + t.Run("SingleByte", func(t *testing.T) { + obfuscator := &ShuffleMaskObfuscator{} + data := []byte{0x42} + entropy := []byte("entropy") + + obfuscated := obfuscator.Obfuscate(data, entropy) + deobfuscated := obfuscator.Deobfuscate(obfuscated, entropy) + assert.Equal(t, data, deobfuscated) + }) +} + +// --- GetNonceFromCiphertext Tests --- + +func TestGetNonceFromCiphertext_Good(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + ciphertext, err := sigil.In([]byte("test")) + require.NoError(t, err) + + nonce, err := GetNonceFromCiphertext(ciphertext) + require.NoError(t, err) + assert.Len(t, nonce, 24) + + // Verify the nonce matches the first 24 bytes + assert.Equal(t, ciphertext[:24], nonce) +} + +func TestGetNonceFromCiphertext_Bad(t *testing.T) { + _, err := GetNonceFromCiphertext([]byte("too short")) + assert.ErrorIs(t, err, ErrCiphertextTooShort) +} + +// --- Custom Obfuscator Tests --- + +func TestCustomObfuscator(t *testing.T) { + key := make([]byte, 32) + + t.Run("WithShuffleMaskObfuscator", func(t *testing.T) { + sigil, err := NewChaChaPolySigilWithObfuscator(key, &ShuffleMaskObfuscator{}) + require.NoError(t, err) + + plaintext := []byte("test with shuffle mask obfuscator") + ciphertext, err := sigil.In(plaintext) + require.NoError(t, err) + + decrypted, err := sigil.Out(ciphertext) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) + }) + + t.Run("WithNilObfuscator", func(t *testing.T) { + sigil, err := NewChaChaPolySigilWithObfuscator(key, nil) + require.NoError(t, err) + // Default XORObfuscator should be used + assert.IsType(t, &XORObfuscator{}, sigil.Obfuscator) + }) +} + +// --- Integration Tests --- + +func TestChaChaPolySigil_Integration(t *testing.T) { + t.Run("PlaintextNeverInOutput", func(t *testing.T) { + key := make([]byte, 32) + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + // Use a distinctive pattern that would be easy to find + plaintext := []byte("DISTINCTIVE_SECRET_PATTERN_12345") + ciphertext, err := sigil.In(plaintext) + require.NoError(t, err) + + // The plaintext pattern should not appear anywhere in the ciphertext + assert.False(t, bytes.Contains(ciphertext, plaintext)) + + // Even substrings should not appear + assert.False(t, bytes.Contains(ciphertext, []byte("DISTINCTIVE"))) + assert.False(t, bytes.Contains(ciphertext, []byte("SECRET"))) + assert.False(t, bytes.Contains(ciphertext, []byte("PATTERN"))) + }) + + t.Run("ConsistentRoundTrip", func(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i * 7) + } + sigil, err := NewChaChaPolySigil(key) + require.NoError(t, err) + + // Test multiple round trips + for i := 0; i < 100; i++ { + plaintext := make([]byte, i+1) + for j := range plaintext { + plaintext[j] = byte(j * i) + } + + ciphertext, err := sigil.In(plaintext) + require.NoError(t, err) + + decrypted, err := sigil.Out(ciphertext) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted, "Round trip failed for size %d", i+1) + } + }) +} + +// --- Benchmark Tests --- + +func BenchmarkChaChaPolySigil_Encrypt(b *testing.B) { + key := make([]byte, 32) + sigil, _ := NewChaChaPolySigil(key) + plaintext := make([]byte, 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = sigil.In(plaintext) + } +} + +func BenchmarkChaChaPolySigil_Decrypt(b *testing.B) { + key := make([]byte, 32) + sigil, _ := NewChaChaPolySigil(key) + plaintext := make([]byte, 1024) + ciphertext, _ := sigil.In(plaintext) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = sigil.Out(ciphertext) + } +} + +func BenchmarkXORObfuscator(b *testing.B) { + obfuscator := &XORObfuscator{} + data := make([]byte, 1024) + entropy := make([]byte, 24) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = obfuscator.Obfuscate(data, entropy) + } +} + +func BenchmarkShuffleMaskObfuscator(b *testing.B) { + obfuscator := &ShuffleMaskObfuscator{} + data := make([]byte, 1024) + entropy := make([]byte, 24) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = obfuscator.Obfuscate(data, entropy) + } +} diff --git a/pkg/trix/crypto.go b/pkg/trix/crypto.go new file mode 100644 index 0000000..eb5479f --- /dev/null +++ b/pkg/trix/crypto.go @@ -0,0 +1,189 @@ +package trix + +import ( + "errors" + "time" + + "github.com/Snider/Enchantrix/pkg/enchantrix" +) + +var ( + // ErrNoEncryptionKey is returned when encryption is requested without a key. + ErrNoEncryptionKey = errors.New("trix: encryption key not configured") + // ErrAlreadyEncrypted is returned when trying to encrypt already encrypted data. + ErrAlreadyEncrypted = errors.New("trix: payload is already encrypted") + // ErrNotEncrypted is returned when trying to decrypt non-encrypted data. + ErrNotEncrypted = errors.New("trix: payload is not encrypted") +) + +const ( + // HeaderKeyEncrypted indicates whether the payload is encrypted. + HeaderKeyEncrypted = "encrypted" + // HeaderKeyAlgorithm stores the encryption algorithm used. + HeaderKeyAlgorithm = "encryption_algorithm" + // HeaderKeyEncryptedAt stores when the payload was encrypted. + HeaderKeyEncryptedAt = "encrypted_at" + // HeaderKeyObfuscator stores the obfuscator type used. + HeaderKeyObfuscator = "obfuscator" + + // AlgorithmChaCha20Poly1305 is the identifier for ChaCha20-Poly1305. + AlgorithmChaCha20Poly1305 = "xchacha20-poly1305" + // ObfuscatorXOR identifies the XOR obfuscator. + ObfuscatorXOR = "xor" + // ObfuscatorShuffleMask identifies the shuffle-mask obfuscator. + ObfuscatorShuffleMask = "shuffle-mask" +) + +// CryptoConfig holds encryption configuration for a Trix container. +type CryptoConfig struct { + // Key is the 32-byte encryption key. + Key []byte + // Obfuscator type: "xor" (default) or "shuffle-mask" + Obfuscator string +} + +// EncryptPayload encrypts the Trix payload using ChaCha20-Poly1305 with pre-obfuscation. +// +// The nonce is embedded in the ciphertext itself and is NOT stored separately +// in the header. This is the production-ready approach (not demo-style). +// +// Header metadata is updated to indicate encryption status without exposing +// cryptographic parameters that are already embedded in the ciphertext. +func (t *Trix) EncryptPayload(config *CryptoConfig) error { + if config == nil || len(config.Key) != 32 { + return ErrNoEncryptionKey + } + + // Check if already encrypted + if encrypted, ok := t.Header[HeaderKeyEncrypted].(bool); ok && encrypted { + return ErrAlreadyEncrypted + } + + // Create the obfuscator + var obfuscator enchantrix.PreObfuscator + obfuscatorName := ObfuscatorXOR + switch config.Obfuscator { + case ObfuscatorShuffleMask: + obfuscator = &enchantrix.ShuffleMaskObfuscator{} + obfuscatorName = ObfuscatorShuffleMask + default: + obfuscator = &enchantrix.XORObfuscator{} + } + + // Create the encryption sigil + sigil, err := enchantrix.NewChaChaPolySigilWithObfuscator(config.Key, obfuscator) + if err != nil { + return err + } + + // Encrypt the payload + ciphertext, err := sigil.In(t.Payload) + if err != nil { + return err + } + + // Update payload with ciphertext + t.Payload = ciphertext + + // Update header with encryption metadata + // NOTE: We do NOT store the nonce in the header - it's embedded in the ciphertext + if t.Header == nil { + t.Header = make(map[string]interface{}) + } + t.Header[HeaderKeyEncrypted] = true + t.Header[HeaderKeyAlgorithm] = AlgorithmChaCha20Poly1305 + t.Header[HeaderKeyObfuscator] = obfuscatorName + t.Header[HeaderKeyEncryptedAt] = time.Now().UTC().Format(time.RFC3339) + + return nil +} + +// DecryptPayload decrypts the Trix payload using the provided key. +// +// The nonce is extracted from the ciphertext itself - no need to read it +// from the header separately. +func (t *Trix) DecryptPayload(config *CryptoConfig) error { + if config == nil || len(config.Key) != 32 { + return ErrNoEncryptionKey + } + + // Check if encrypted + encrypted, ok := t.Header[HeaderKeyEncrypted].(bool) + if !ok || !encrypted { + return ErrNotEncrypted + } + + // Determine obfuscator from header + var obfuscator enchantrix.PreObfuscator + if obfType, ok := t.Header[HeaderKeyObfuscator].(string); ok { + switch obfType { + case ObfuscatorShuffleMask: + obfuscator = &enchantrix.ShuffleMaskObfuscator{} + default: + obfuscator = &enchantrix.XORObfuscator{} + } + } else { + obfuscator = &enchantrix.XORObfuscator{} + } + + // Create the decryption sigil + sigil, err := enchantrix.NewChaChaPolySigilWithObfuscator(config.Key, obfuscator) + if err != nil { + return err + } + + // Decrypt the payload + plaintext, err := sigil.Out(t.Payload) + if err != nil { + return err + } + + // Update payload with plaintext + t.Payload = plaintext + + // Update header to indicate decrypted state + t.Header[HeaderKeyEncrypted] = false + + return nil +} + +// IsEncrypted returns true if the payload is currently encrypted. +func (t *Trix) IsEncrypted() bool { + if t.Header == nil { + return false + } + encrypted, ok := t.Header[HeaderKeyEncrypted].(bool) + return ok && encrypted +} + +// GetEncryptionAlgorithm returns the encryption algorithm used, if any. +func (t *Trix) GetEncryptionAlgorithm() string { + if t.Header == nil { + return "" + } + algo, ok := t.Header[HeaderKeyAlgorithm].(string) + if !ok { + return "" + } + return algo +} + +// NewEncryptedTrix creates a new Trix container with an encrypted payload. +// This is a convenience function for creating encrypted containers in one step. +func NewEncryptedTrix(payload []byte, key []byte, header map[string]interface{}) (*Trix, error) { + if header == nil { + header = make(map[string]interface{}) + } + + t := &Trix{ + Header: header, + Payload: payload, + } + + config := &CryptoConfig{Key: key} + if err := t.EncryptPayload(config); err != nil { + return nil, err + } + + return t, nil +} diff --git a/pkg/trix/crypto_test.go b/pkg/trix/crypto_test.go new file mode 100644 index 0000000..7b9448e --- /dev/null +++ b/pkg/trix/crypto_test.go @@ -0,0 +1,438 @@ +package trix_test + +import ( + "bytes" + "testing" + + "github.com/Snider/Enchantrix/pkg/trix" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncryptPayload_Good(t *testing.T) { + t.Run("BasicEncryption", func(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + + originalPayload := []byte("This is a secret message that should be encrypted.") + trixContainer := &trix.Trix{ + Header: map[string]interface{}{"content_type": "text/plain"}, + Payload: originalPayload, + } + + config := &trix.CryptoConfig{Key: key} + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + + // Verify encryption occurred + assert.True(t, trixContainer.IsEncrypted()) + assert.Equal(t, trix.AlgorithmChaCha20Poly1305, trixContainer.GetEncryptionAlgorithm()) + assert.NotEqual(t, originalPayload, trixContainer.Payload) + + // Verify header metadata + assert.Equal(t, true, trixContainer.Header[trix.HeaderKeyEncrypted]) + assert.Equal(t, trix.AlgorithmChaCha20Poly1305, trixContainer.Header[trix.HeaderKeyAlgorithm]) + assert.Equal(t, trix.ObfuscatorXOR, trixContainer.Header[trix.HeaderKeyObfuscator]) + assert.NotEmpty(t, trixContainer.Header[trix.HeaderKeyEncryptedAt]) + + // Verify NO nonce in header (this is the key improvement over demo-style) + _, hasNonce := trixContainer.Header["nonce"] + assert.False(t, hasNonce, "nonce should NOT be stored in header") + }) + + t.Run("WithShuffleMaskObfuscator", func(t *testing.T) { + key := make([]byte, 32) + payload := []byte("test data") + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: payload, + } + + config := &trix.CryptoConfig{ + Key: key, + Obfuscator: trix.ObfuscatorShuffleMask, + } + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + + assert.Equal(t, trix.ObfuscatorShuffleMask, trixContainer.Header[trix.HeaderKeyObfuscator]) + }) + + t.Run("WithNilHeader", func(t *testing.T) { + key := make([]byte, 32) + trixContainer := &trix.Trix{ + Payload: []byte("test"), + } + + config := &trix.CryptoConfig{Key: key} + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + assert.NotNil(t, trixContainer.Header) + assert.True(t, trixContainer.IsEncrypted()) + }) +} + +func TestEncryptPayload_Bad(t *testing.T) { + t.Run("NilConfig", func(t *testing.T) { + trixContainer := &trix.Trix{Payload: []byte("test")} + err := trixContainer.EncryptPayload(nil) + assert.ErrorIs(t, err, trix.ErrNoEncryptionKey) + }) + + t.Run("InvalidKeySize", func(t *testing.T) { + trixContainer := &trix.Trix{Payload: []byte("test")} + + config := &trix.CryptoConfig{Key: []byte("too short")} + err := trixContainer.EncryptPayload(config) + assert.ErrorIs(t, err, trix.ErrNoEncryptionKey) + }) + + t.Run("AlreadyEncrypted", func(t *testing.T) { + key := make([]byte, 32) + trixContainer := &trix.Trix{ + Header: map[string]interface{}{trix.HeaderKeyEncrypted: true}, + Payload: []byte("test"), + } + + config := &trix.CryptoConfig{Key: key} + err := trixContainer.EncryptPayload(config) + assert.ErrorIs(t, err, trix.ErrAlreadyEncrypted) + }) +} + +func TestDecryptPayload_Good(t *testing.T) { + t.Run("BasicDecryption", func(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + + originalPayload := []byte("This is a secret message that should be encrypted.") + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: originalPayload, + } + + config := &trix.CryptoConfig{Key: key} + + // Encrypt + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + assert.True(t, trixContainer.IsEncrypted()) + + // Decrypt + err = trixContainer.DecryptPayload(config) + require.NoError(t, err) + assert.False(t, trixContainer.IsEncrypted()) + assert.Equal(t, originalPayload, trixContainer.Payload) + }) + + t.Run("WithShuffleMaskObfuscator", func(t *testing.T) { + key := make([]byte, 32) + originalPayload := []byte("test with shuffle mask") + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: originalPayload, + } + + config := &trix.CryptoConfig{ + Key: key, + Obfuscator: trix.ObfuscatorShuffleMask, + } + + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + + err = trixContainer.DecryptPayload(config) + require.NoError(t, err) + assert.Equal(t, originalPayload, trixContainer.Payload) + }) + + t.Run("EmptyPayload", func(t *testing.T) { + key := make([]byte, 32) + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: []byte{}, + } + + config := &trix.CryptoConfig{Key: key} + + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + + err = trixContainer.DecryptPayload(config) + require.NoError(t, err) + assert.Equal(t, []byte{}, trixContainer.Payload) + }) +} + +func TestDecryptPayload_Bad(t *testing.T) { + t.Run("NilConfig", func(t *testing.T) { + trixContainer := &trix.Trix{ + Header: map[string]interface{}{trix.HeaderKeyEncrypted: true}, + Payload: []byte("encrypted data"), + } + err := trixContainer.DecryptPayload(nil) + assert.ErrorIs(t, err, trix.ErrNoEncryptionKey) + }) + + t.Run("InvalidKeySize", func(t *testing.T) { + trixContainer := &trix.Trix{ + Header: map[string]interface{}{trix.HeaderKeyEncrypted: true}, + Payload: []byte("encrypted data"), + } + + config := &trix.CryptoConfig{Key: []byte("too short")} + err := trixContainer.DecryptPayload(config) + assert.ErrorIs(t, err, trix.ErrNoEncryptionKey) + }) + + t.Run("NotEncrypted", func(t *testing.T) { + key := make([]byte, 32) + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: []byte("not encrypted"), + } + + config := &trix.CryptoConfig{Key: key} + err := trixContainer.DecryptPayload(config) + assert.ErrorIs(t, err, trix.ErrNotEncrypted) + }) + + t.Run("WrongKey", func(t *testing.T) { + key1 := make([]byte, 32) + key2 := make([]byte, 32) + key2[0] = 1 + + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: []byte("secret"), + } + + config1 := &trix.CryptoConfig{Key: key1} + err := trixContainer.EncryptPayload(config1) + require.NoError(t, err) + + config2 := &trix.CryptoConfig{Key: key2} + err = trixContainer.DecryptPayload(config2) + assert.Error(t, err) + }) +} + +func TestDecryptPayload_Ugly(t *testing.T) { + t.Run("MissingObfuscatorHeader", func(t *testing.T) { + key := make([]byte, 32) + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: []byte("test"), + } + + config := &trix.CryptoConfig{Key: key} + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + + // Remove the obfuscator header + delete(trixContainer.Header, trix.HeaderKeyObfuscator) + + // Should still work with default XOR obfuscator + err = trixContainer.DecryptPayload(config) + require.NoError(t, err) + }) +} + +func TestNewEncryptedTrix_Good(t *testing.T) { + t.Run("Basic", func(t *testing.T) { + key := make([]byte, 32) + payload := []byte("secret message") + header := map[string]interface{}{"custom": "value"} + + trixContainer, err := trix.NewEncryptedTrix(payload, key, header) + require.NoError(t, err) + + assert.True(t, trixContainer.IsEncrypted()) + assert.Equal(t, "value", trixContainer.Header["custom"]) + assert.NotEqual(t, payload, trixContainer.Payload) + }) + + t.Run("WithNilHeader", func(t *testing.T) { + key := make([]byte, 32) + payload := []byte("secret message") + + trixContainer, err := trix.NewEncryptedTrix(payload, key, nil) + require.NoError(t, err) + + assert.True(t, trixContainer.IsEncrypted()) + assert.NotNil(t, trixContainer.Header) + }) +} + +func TestNewEncryptedTrix_Bad(t *testing.T) { + t.Run("InvalidKey", func(t *testing.T) { + _, err := trix.NewEncryptedTrix([]byte("test"), []byte("short"), nil) + assert.Error(t, err) + }) +} + +func TestIsEncrypted(t *testing.T) { + t.Run("NilHeader", func(t *testing.T) { + trixContainer := &trix.Trix{} + assert.False(t, trixContainer.IsEncrypted()) + }) + + t.Run("MissingKey", func(t *testing.T) { + trixContainer := &trix.Trix{Header: map[string]interface{}{}} + assert.False(t, trixContainer.IsEncrypted()) + }) + + t.Run("FalseValue", func(t *testing.T) { + trixContainer := &trix.Trix{ + Header: map[string]interface{}{trix.HeaderKeyEncrypted: false}, + } + assert.False(t, trixContainer.IsEncrypted()) + }) + + t.Run("TrueValue", func(t *testing.T) { + trixContainer := &trix.Trix{ + Header: map[string]interface{}{trix.HeaderKeyEncrypted: true}, + } + assert.True(t, trixContainer.IsEncrypted()) + }) + + t.Run("WrongType", func(t *testing.T) { + trixContainer := &trix.Trix{ + Header: map[string]interface{}{trix.HeaderKeyEncrypted: "true"}, + } + assert.False(t, trixContainer.IsEncrypted()) + }) +} + +func TestGetEncryptionAlgorithm(t *testing.T) { + t.Run("NilHeader", func(t *testing.T) { + trixContainer := &trix.Trix{} + assert.Empty(t, trixContainer.GetEncryptionAlgorithm()) + }) + + t.Run("MissingKey", func(t *testing.T) { + trixContainer := &trix.Trix{Header: map[string]interface{}{}} + assert.Empty(t, trixContainer.GetEncryptionAlgorithm()) + }) + + t.Run("ValidAlgorithm", func(t *testing.T) { + trixContainer := &trix.Trix{ + Header: map[string]interface{}{trix.HeaderKeyAlgorithm: "test-algo"}, + } + assert.Equal(t, "test-algo", trixContainer.GetEncryptionAlgorithm()) + }) + + t.Run("WrongType", func(t *testing.T) { + trixContainer := &trix.Trix{ + Header: map[string]interface{}{trix.HeaderKeyAlgorithm: 123}, + } + assert.Empty(t, trixContainer.GetEncryptionAlgorithm()) + }) +} + +func TestEncryptedTrixRoundTrip(t *testing.T) { + t.Run("FullRoundTrip", func(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i * 3) + } + + originalPayload := []byte("This is the original secret message that will be encrypted, stored, and decrypted.") + header := map[string]interface{}{ + "content_type": "text/plain", + "custom_field": "custom_value", + } + + // Create encrypted Trix + config := &trix.CryptoConfig{Key: key} + trixContainer := &trix.Trix{ + Header: header, + Payload: originalPayload, + } + + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + + // Encode to binary format + encoded, err := trix.Encode(trixContainer, "ENCR", nil) + require.NoError(t, err) + + // Decode from binary format + decoded, err := trix.Decode(encoded, "ENCR", nil) + require.NoError(t, err) + + // Verify still encrypted after decode + assert.True(t, decoded.IsEncrypted()) + + // Decrypt + err = decoded.DecryptPayload(config) + require.NoError(t, err) + + // Verify payload matches original + assert.Equal(t, originalPayload, decoded.Payload) + assert.Equal(t, "custom_value", decoded.Header["custom_field"]) + }) +} + +func TestNonceNotInHeader(t *testing.T) { + t.Run("NonceEmbeddedNotExposed", func(t *testing.T) { + key := make([]byte, 32) + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: []byte("secret data"), + } + + config := &trix.CryptoConfig{Key: key} + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + + // Verify nonce is NOT in header + _, hasNonce := trixContainer.Header["nonce"] + assert.False(t, hasNonce) + + // But the ciphertext contains the nonce (first 24 bytes) + assert.GreaterOrEqual(t, len(trixContainer.Payload), 24) + + // Encode and decode + encoded, err := trix.Encode(trixContainer, "TEST", nil) + require.NoError(t, err) + + decoded, err := trix.Decode(encoded, "TEST", nil) + require.NoError(t, err) + + // Still no nonce in header after decode + _, hasNonce = decoded.Header["nonce"] + assert.False(t, hasNonce) + + // But decryption still works (nonce is embedded in payload) + err = decoded.DecryptPayload(config) + require.NoError(t, err) + assert.Equal(t, []byte("secret data"), decoded.Payload) + }) +} + +func TestPlaintextNotExposed(t *testing.T) { + t.Run("CleartextNeverInCiphertext", func(t *testing.T) { + key := make([]byte, 32) + distinctivePayload := []byte("DISTINCTIVE_SECRET_PATTERN_THAT_SHOULD_NOT_APPEAR") + + trixContainer := &trix.Trix{ + Header: map[string]interface{}{}, + Payload: distinctivePayload, + } + + config := &trix.CryptoConfig{Key: key} + err := trixContainer.EncryptPayload(config) + require.NoError(t, err) + + // The plaintext should not appear in the encrypted payload + assert.False(t, bytes.Contains(trixContainer.Payload, distinctivePayload)) + assert.False(t, bytes.Contains(trixContainer.Payload, []byte("DISTINCTIVE"))) + assert.False(t, bytes.Contains(trixContainer.Payload, []byte("SECRET"))) + assert.False(t, bytes.Contains(trixContainer.Payload, []byte("PATTERN"))) + }) +}