From a1de2b8a986269961226e3997562002f85bc8293 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 22 Jan 2026 13:01:13 +0100 Subject: [PATCH] [management] Move activity store encryption to shared crypt package (#5111) --- go.mod | 2 +- go.sum | 4 +- management/server/activity/store/crypt.go | 136 -------- .../server/activity/store/crypt_test.go | 310 ------------------ management/server/activity/store/migration.go | 9 +- .../server/activity/store/migration_test.go | 5 +- management/server/activity/store/sql_store.go | 9 +- .../server/activity/store/sql_store_test.go | 3 +- util/crypt/crypt_test.go | 139 ++++++++ util/crypt/legacy.go | 71 ++++ util/crypt/legacy_test.go | 164 +++++++++ 11 files changed, 392 insertions(+), 460 deletions(-) delete mode 100644 management/server/activity/store/crypt.go delete mode 100644 management/server/activity/store/crypt_test.go create mode 100644 util/crypt/crypt_test.go create mode 100644 util/crypt/legacy.go create mode 100644 util/crypt/legacy_test.go diff --git a/go.mod b/go.mod index cb16fff52..8ac5613ee 100644 --- a/go.mod +++ b/go.mod @@ -68,7 +68,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 + github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 diff --git a/go.sum b/go.sum index c59acbb23..6adc7f7e8 100644 --- a/go.sum +++ b/go.sum @@ -406,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/activity/store/crypt.go b/management/server/activity/store/crypt.go deleted file mode 100644 index ce97347d4..000000000 --- a/management/server/activity/store/crypt.go +++ /dev/null @@ -1,136 +0,0 @@ -package store - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/base64" - "errors" -) - -var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} - -type FieldEncrypt struct { - block cipher.Block - gcm cipher.AEAD -} - -func GenerateKey() (string, error) { - key := make([]byte, 32) - _, err := rand.Read(key) - if err != nil { - return "", err - } - readableKey := base64.StdEncoding.EncodeToString(key) - return readableKey, nil -} - -func NewFieldEncrypt(key string) (*FieldEncrypt, error) { - binKey, err := base64.StdEncoding.DecodeString(key) - if err != nil { - return nil, err - } - - block, err := aes.NewCipher(binKey) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - ec := &FieldEncrypt{ - block: block, - gcm: gcm, - } - - return ec, nil -} - -func (ec *FieldEncrypt) LegacyEncrypt(payload string) string { - plainText := pkcs5Padding([]byte(payload)) - cipherText := make([]byte, len(plainText)) - cbc := cipher.NewCBCEncrypter(ec.block, iv) - cbc.CryptBlocks(cipherText, plainText) - return base64.StdEncoding.EncodeToString(cipherText) -} - -// Encrypt encrypts plaintext using AES-GCM -func (ec *FieldEncrypt) Encrypt(payload string) (string, error) { - plaintext := []byte(payload) - nonceSize := ec.gcm.NonceSize() - - nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead()) - if _, err := rand.Read(nonce); err != nil { - return "", err - } - - ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil) - - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) { - cipherText, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return "", err - } - cbc := cipher.NewCBCDecrypter(ec.block, iv) - cbc.CryptBlocks(cipherText, cipherText) - payload, err := pkcs5UnPadding(cipherText) - if err != nil { - return "", err - } - - return string(payload), nil -} - -// Decrypt decrypts ciphertext using AES-GCM -func (ec *FieldEncrypt) Decrypt(data string) (string, error) { - cipherText, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return "", err - } - - nonceSize := ec.gcm.NonceSize() - if len(cipherText) < nonceSize { - return "", errors.New("cipher text too short") - } - - nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:] - plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil) - if err != nil { - return "", err - } - - return string(plainText), nil -} - -func pkcs5Padding(ciphertext []byte) []byte { - padding := aes.BlockSize - len(ciphertext)%aes.BlockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - return append(ciphertext, padText...) -} -func pkcs5UnPadding(src []byte) ([]byte, error) { - srcLen := len(src) - if srcLen == 0 { - return nil, errors.New("input data is empty") - } - - paddingLen := int(src[srcLen-1]) - if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen { - return nil, errors.New("invalid padding size") - } - - // Verify that all padding bytes are the same - for i := 0; i < paddingLen; i++ { - if src[srcLen-1-i] != byte(paddingLen) { - return nil, errors.New("invalid padding") - } - } - - return src[:srcLen-paddingLen], nil -} diff --git a/management/server/activity/store/crypt_test.go b/management/server/activity/store/crypt_test.go deleted file mode 100644 index 700bbcd6b..000000000 --- a/management/server/activity/store/crypt_test.go +++ /dev/null @@ -1,310 +0,0 @@ -package store - -import ( - "bytes" - "testing" -) - -func TestGenerateKey(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted, err := ee.Encrypt(testData) - if err != nil { - t.Fatalf("failed to encrypt data: %s", err) - } - - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - decrypted, err := ee.Decrypt(encrypted) - if err != nil { - t.Fatalf("failed to decrypt data: %s", err) - } - - if decrypted != testData { - t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) - } -} - -func TestGenerateKeyLegacy(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted := ee.LegacyEncrypt(testData) - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - decrypted, err := ee.LegacyDecrypt(encrypted) - if err != nil { - t.Fatalf("failed to decrypt data: %s", err) - } - - if decrypted != testData { - t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) - } -} - -func TestCorruptKey(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted, err := ee.Encrypt(testData) - if err != nil { - t.Fatalf("failed to encrypt data: %s", err) - } - - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - newKey, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - - ee, err = NewFieldEncrypt(newKey) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - res, _ := ee.Decrypt(encrypted) - if res == testData { - t.Fatalf("incorrect decryption, the result is: %s", res) - } -} - -func TestEncryptDecrypt(t *testing.T) { - // Generate a key for encryption/decryption - key, err := GenerateKey() - if err != nil { - t.Fatalf("Failed to generate key: %v", err) - } - - // Initialize the FieldEncrypt with the generated key - ec, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("Failed to create FieldEncrypt: %v", err) - } - - // Test cases - testCases := []struct { - name string - input string - }{ - { - name: "Empty String", - input: "", - }, - { - name: "Short String", - input: "Hello", - }, - { - name: "String with Spaces", - input: "Hello, World!", - }, - { - name: "Long String", - input: "The quick brown fox jumps over the lazy dog.", - }, - { - name: "Unicode Characters", - input: "こんにちは世界", - }, - { - name: "Special Characters", - input: "!@#$%^&*()_+-=[]{}|;':\",./<>?", - }, - { - name: "Numeric String", - input: "1234567890", - }, - { - name: "Repeated Characters", - input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", - }, - { - name: "Multi-block String", - input: "This is a longer string that will span multiple blocks in the encryption algorithm.", - }, - { - name: "Non-ASCII and ASCII Mix", - input: "Hello 世界 123", - }, - } - - for _, tc := range testCases { - t.Run(tc.name+" - Legacy", func(t *testing.T) { - // Legacy Encryption - encryptedLegacy := ec.LegacyEncrypt(tc.input) - if encryptedLegacy == "" { - t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input) - } - - // Legacy Decryption - decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy) - if err != nil { - t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err) - } - - // Verify that the decrypted value matches the original input - if decryptedLegacy != tc.input { - t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input) - } - }) - - t.Run(tc.name+" - New", func(t *testing.T) { - // New Encryption - encryptedNew, err := ec.Encrypt(tc.input) - if err != nil { - t.Errorf("Encrypt failed for input '%s': %v", tc.input, err) - } - if encryptedNew == "" { - t.Errorf("Encrypt returned empty string for input '%s'", tc.input) - } - - // New Decryption - decryptedNew, err := ec.Decrypt(encryptedNew) - if err != nil { - t.Errorf("Decrypt failed for input '%s': %v", tc.input, err) - } - - // Verify that the decrypted value matches the original input - if decryptedNew != tc.input { - t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input) - } - }) - } -} - -func TestPKCS5UnPadding(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - expectError bool - }{ - { - name: "Valid Padding", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...), - expected: []byte("Hello, World!"), - }, - { - name: "Empty Input", - input: []byte{}, - expectError: true, - }, - { - name: "Padding Length Zero", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...), - expectError: true, - }, - { - name: "Padding Length Exceeds Block Size", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...), - expectError: true, - }, - { - name: "Padding Length Exceeds Input Length", - input: []byte{5, 5, 5}, - expectError: true, - }, - { - name: "Invalid Padding Bytes", - input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...), - expectError: true, - }, - { - name: "Valid Single Byte Padding", - input: append([]byte("Hello, World!"), byte(1)), - expected: []byte("Hello, World!"), - }, - { - name: "Invalid Mixed Padding Bytes", - input: append([]byte("Hello, World!"), []byte{3, 3, 2}...), - expectError: true, - }, - { - name: "Valid Full Block Padding", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("Hello, World!"), - }, - { - name: "Non-Padding Byte at End", - input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...), - expectError: true, - }, - { - name: "Valid Padding with Different Text Length", - input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...), - expected: []byte("Test"), - }, - { - name: "Padding Length Equal to Input Length", - input: bytes.Repeat([]byte{8}, 8), - expected: []byte{}, - }, - { - name: "Invalid Padding Length Zero (Again)", - input: append([]byte("Test"), byte(0)), - expectError: true, - }, - { - name: "Padding Length Greater Than Input", - input: []byte{10}, - expectError: true, - }, - { - name: "Input Length Not Multiple of Block Size", - input: append([]byte("Invalid Length"), byte(1)), - expected: []byte("Invalid Length"), - }, - { - name: "Valid Padding with Non-ASCII Characters", - input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...), - expected: []byte("こんにちは"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs5UnPadding(tt.input) - if tt.expectError { - if err == nil { - t.Errorf("Expected error but got nil") - } - } else { - if err != nil { - t.Errorf("Did not expect error but got: %v", err) - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("Expected output %v, got %v", tt.expected, result) - } - } - }) - } -} diff --git a/management/server/activity/store/migration.go b/management/server/activity/store/migration.go index af19a34eb..d0f165d5f 100644 --- a/management/server/activity/store/migration.go +++ b/management/server/activity/store/migration.go @@ -10,9 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/migration" + "github.com/netbirdio/netbird/util/crypt" ) -func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error { +func migrate(ctx context.Context, crypt *crypt.FieldEncrypt, db *gorm.DB) error { migrations := getMigrations(ctx, crypt) for _, m := range migrations { @@ -26,7 +27,7 @@ func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error { type migrationFunc func(*gorm.DB) error -func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc { +func getMigrations(ctx context.Context, crypt *crypt.FieldEncrypt) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { return migration.MigrateNewField[activity.DeletedUser](ctx, db, "name", "") @@ -45,7 +46,7 @@ func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc { // migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using // legacy CBC encryption with a static IV to the new GCM encryption method. -func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *FieldEncrypt) error { +func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *crypt.FieldEncrypt) error { model := &activity.DeletedUser{} if !db.Migrator().HasTable(model) { @@ -80,7 +81,7 @@ func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *F return nil } -func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *FieldEncrypt) error { +func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *crypt.FieldEncrypt) error { var err error var decryptedEmail, decryptedName string diff --git a/management/server/activity/store/migration_test.go b/management/server/activity/store/migration_test.go index e3261d9fa..5c6f5ade8 100644 --- a/management/server/activity/store/migration_test.go +++ b/management/server/activity/store/migration_test.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/migration" "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/util/crypt" ) const ( @@ -40,10 +41,10 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateLegacyEncryptedUsersToGCM(t *testing.T) { db := setupDatabase(t) - key, err := GenerateKey() + key, err := crypt.GenerateKey() require.NoError(t, err, "Failed to generate key") - crypt, err := NewFieldEncrypt(key) + crypt, err := crypt.NewFieldEncrypt(key) require.NoError(t, err, "Failed to initialize FieldEncrypt") t.Run("empty table, no migration required", func(t *testing.T) { diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index ffecb6b8f..db614d0cd 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/util/crypt" ) const ( @@ -45,12 +46,12 @@ type eventWithNames struct { // Store is the implementation of the activity.Store interface backed by SQLite type Store struct { db *gorm.DB - fieldEncrypt *FieldEncrypt + fieldEncrypt *crypt.FieldEncrypt } // NewSqlStore creates a new Store with an event table if not exists. func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) { - crypt, err := NewFieldEncrypt(encryptionKey) + fieldEncrypt, err := crypt.NewFieldEncrypt(encryptionKey) if err != nil { return nil, err @@ -61,7 +62,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St return nil, fmt.Errorf("initialize database: %w", err) } - if err = migrate(ctx, crypt, db); err != nil { + if err = migrate(ctx, fieldEncrypt, db); err != nil { return nil, fmt.Errorf("events database migration: %w", err) } @@ -72,7 +73,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St return &Store{ db: db, - fieldEncrypt: crypt, + fieldEncrypt: fieldEncrypt, }, nil } diff --git a/management/server/activity/store/sql_store_test.go b/management/server/activity/store/sql_store_test.go index 8c0d159df..d723f1623 100644 --- a/management/server/activity/store/sql_store_test.go +++ b/management/server/activity/store/sql_store_test.go @@ -9,11 +9,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/util/crypt" ) func TestNewSqlStore(t *testing.T) { dataDir := t.TempDir() - key, _ := GenerateKey() + key, _ := crypt.GenerateKey() store, err := NewSqlStore(context.Background(), dataDir, key) if err != nil { t.Fatal(err) diff --git a/util/crypt/crypt_test.go b/util/crypt/crypt_test.go new file mode 100644 index 000000000..143a4bbc2 --- /dev/null +++ b/util/crypt/crypt_test.go @@ -0,0 +1,139 @@ +package crypt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateKey(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + assert.NotEmpty(t, key) + + _, err = NewFieldEncrypt(key) + assert.NoError(t, err) +} + +func TestNewFieldEncrypt_InvalidKey(t *testing.T) { + tests := []struct { + name string + key string + }{ + {name: "invalid base64", key: "not-valid-base64!!!"}, + {name: "too short", key: "c2hvcnQ="}, + {name: "empty", key: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewFieldEncrypt(tt.key) + assert.Error(t, err) + }) + } +} + +func TestEncryptDecrypt(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + testCases := []struct { + name string + input string + }{ + {name: "Empty String", input: ""}, + {name: "Short String", input: "Hello"}, + {name: "String with Spaces", input: "Hello, World!"}, + {name: "Long String", input: "The quick brown fox jumps over the lazy dog."}, + {name: "Unicode Characters", input: "こんにちは世界"}, + {name: "Special Characters", input: "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, + {name: "Numeric String", input: "1234567890"}, + {name: "Email Address", input: "user@example.com"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encrypted, err := ec.Encrypt(tc.input) + require.NoError(t, err) + + decrypted, err := ec.Decrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, tc.input, decrypted) + }) + } +} + +func TestEncrypt_DifferentCiphertexts(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + plaintext := "same plaintext" + + // Encrypt the same plaintext multiple times + encrypted1, err := ec.Encrypt(plaintext) + require.NoError(t, err) + + encrypted2, err := ec.Encrypt(plaintext) + require.NoError(t, err) + + assert.NotEqual(t, encrypted1, encrypted2, "expected different ciphertexts for same plaintext (random nonce)") + + // Both should decrypt to the same plaintext + decrypted1, err := ec.Decrypt(encrypted1) + require.NoError(t, err) + + decrypted2, err := ec.Decrypt(encrypted2) + require.NoError(t, err) + + assert.Equal(t, plaintext, decrypted1) + assert.Equal(t, plaintext, decrypted2) +} + +func TestDecrypt_InvalidCiphertext(t *testing.T) { + key, err := GenerateKey() + assert.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + assert.NoError(t, err) + + tests := []struct { + name string + ciphertext string + }{ + {name: "invalid base64", ciphertext: "not-valid!!!"}, + {name: "too short", ciphertext: "c2hvcnQ="}, + {name: "corrupted", ciphertext: "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo="}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := ec.Decrypt(tt.ciphertext) + assert.Error(t, err) + assert.Empty(t, payload) + }) + } +} + +func TestDecrypt_WrongKey(t *testing.T) { + key1, _ := GenerateKey() + key2, _ := GenerateKey() + + ec1, _ := NewFieldEncrypt(key1) + ec2, _ := NewFieldEncrypt(key2) + + plaintext := "secret data" + encrypted, _ := ec1.Encrypt(plaintext) + + // Try to decrypt with wrong key + payload, err := ec2.Decrypt(encrypted) + assert.Error(t, err) + assert.Empty(t, payload) +} diff --git a/util/crypt/legacy.go b/util/crypt/legacy.go new file mode 100644 index 000000000..f84e6964f --- /dev/null +++ b/util/crypt/legacy.go @@ -0,0 +1,71 @@ +package crypt + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "fmt" +) + +// legacyIV is the static IV used by the legacy CBC encryption. +// Deprecated: This is kept only for backward compatibility with existing encrypted data. +var legacyIV = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} + +// LegacyEncrypt encrypts plaintext using AES-CBC with a static IV. +// Deprecated: Use Encrypt instead. This method is kept only for backward compatibility. +func (f *FieldEncrypt) LegacyEncrypt(plaintext string) string { + padded := pkcs5Padding([]byte(plaintext)) + ciphertext := make([]byte, len(padded)) + cbc := cipher.NewCBCEncrypter(f.block, legacyIV) + cbc.CryptBlocks(ciphertext, padded) + return base64.StdEncoding.EncodeToString(ciphertext) +} + +// LegacyDecrypt decrypts ciphertext that was encrypted using AES-CBC with a static IV. +// Deprecated: This method is kept only for backward compatibility with existing encrypted data. +func (f *FieldEncrypt) LegacyDecrypt(ciphertext string) (string, error) { + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("decode ciphertext: %w", err) + } + + cbc := cipher.NewCBCDecrypter(f.block, legacyIV) + cbc.CryptBlocks(data, data) + + plaintext, err := pkcs5UnPadding(data) + if err != nil { + return "", fmt.Errorf("unpad plaintext: %w", err) + } + + return string(plaintext), nil +} + +// pkcs5Padding adds PKCS#5 padding to the input. +func pkcs5Padding(data []byte) []byte { + padding := aes.BlockSize - len(data)%aes.BlockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + +// pkcs5UnPadding removes PKCS#5 padding from the input. +func pkcs5UnPadding(data []byte) ([]byte, error) { + length := len(data) + if length == 0 { + return nil, fmt.Errorf("input data is empty") + } + + paddingLen := int(data[length-1]) + if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > length { + return nil, fmt.Errorf("invalid padding size") + } + + // Verify that all padding bytes are the same + for i := 0; i < paddingLen; i++ { + if data[length-1-i] != byte(paddingLen) { + return nil, fmt.Errorf("invalid padding") + } + } + + return data[:length-paddingLen], nil +} diff --git a/util/crypt/legacy_test.go b/util/crypt/legacy_test.go new file mode 100644 index 000000000..09b75a71f --- /dev/null +++ b/util/crypt/legacy_test.go @@ -0,0 +1,164 @@ +package crypt + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLegacyEncryptDecrypt(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + encrypted := ec.LegacyEncrypt(testData) + assert.NotEmpty(t, encrypted) + + decrypted, err := ec.LegacyDecrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, testData, decrypted) +} + +func TestLegacyEncryptDecryptVariousInputs(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + testCases := []struct { + name string + input string + }{ + {name: "Empty String", input: ""}, + {name: "Short String", input: "Hello"}, + {name: "String with Spaces", input: "Hello, World!"}, + {name: "Long String", input: "The quick brown fox jumps over the lazy dog."}, + {name: "Unicode Characters", input: "こんにちは世界"}, + {name: "Special Characters", input: "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, + {name: "Numeric String", input: "1234567890"}, + {name: "Repeated Characters", input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + {name: "Multi-block String", input: "This is a longer string that will span multiple blocks in the encryption algorithm."}, + {name: "Non-ASCII and ASCII Mix", input: "Hello 世界 123"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encrypted := ec.LegacyEncrypt(tc.input) + assert.NotEmpty(t, encrypted) + + decrypted, err := ec.LegacyDecrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, tc.input, decrypted) + }) + } +} + +func TestPKCS5UnPadding(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + expectError bool + }{ + { + name: "Valid Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...), + expected: []byte("Hello, World!"), + }, + { + name: "Empty Input", + input: []byte{}, + expectError: true, + }, + { + name: "Padding Length Zero", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Block Size", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Input Length", + input: []byte{5, 5, 5}, + expectError: true, + }, + { + name: "Invalid Padding Bytes", + input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...), + expectError: true, + }, + { + name: "Valid Single Byte Padding", + input: append([]byte("Hello, World!"), byte(1)), + expected: []byte("Hello, World!"), + }, + { + name: "Invalid Mixed Padding Bytes", + input: append([]byte("Hello, World!"), []byte{3, 3, 2}...), + expectError: true, + }, + { + name: "Valid Full Block Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("Hello, World!"), + }, + { + name: "Non-Padding Byte at End", + input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...), + expectError: true, + }, + { + name: "Valid Padding with Different Text Length", + input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...), + expected: []byte("Test"), + }, + { + name: "Padding Length Equal to Input Length", + input: bytes.Repeat([]byte{8}, 8), + expected: []byte{}, + }, + { + name: "Invalid Padding Length Zero (Again)", + input: append([]byte("Test"), byte(0)), + expectError: true, + }, + { + name: "Padding Length Greater Than Input", + input: []byte{10}, + expectError: true, + }, + { + name: "Input Length Not Multiple of Block Size", + input: append([]byte("Invalid Length"), byte(1)), + expected: []byte("Invalid Length"), + }, + { + name: "Valid Padding with Non-ASCII Characters", + input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...), + expected: []byte("こんにちは"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs5UnPadding(tt.input) + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +}