From 2172d6f1b96dde5e036841b4f1ef25bb1a514d7d Mon Sep 17 00:00:00 2001 From: Mikhail Bragin Date: Thu, 22 Jul 2021 15:23:24 +0200 Subject: [PATCH] Extract common server encryption logic (#65) * refactor: extract common message encryption logic * refactor: move letsencrypt logic to common * refactor: rename common package to encryption * test: add encryption tests --- cmd/management.go | 35 +------ cmd/root.go | 33 ------ cmd/signal.go | 6 +- {signal => encryption}/encryption.go | 6 +- encryption/encryption_suite_test.go | 13 +++ encryption/encryption_test.go | 60 +++++++++++ encryption/letsencrypt.go | 40 ++++++++ encryption/message.go | 40 ++++++++ encryption/testprotos/generate.sh | 2 + encryption/testprotos/testproto.pb.go | 142 ++++++++++++++++++++++++++ encryption/testprotos/testproto.proto | 9 ++ management/management_test.go | 16 +-- management/message.go | 44 -------- management/server.go | 17 ++- signal/client.go | 16 +-- signal/encryption_test.go | 5 +- 16 files changed, 343 insertions(+), 141 deletions(-) rename {signal => encryption}/encryption.go (83%) create mode 100644 encryption/encryption_suite_test.go create mode 100644 encryption/encryption_test.go create mode 100644 encryption/letsencrypt.go create mode 100644 encryption/message.go create mode 100755 encryption/testprotos/generate.sh create mode 100644 encryption/testprotos/testproto.pb.go create mode 100644 encryption/testprotos/testproto.proto delete mode 100644 management/message.go diff --git a/cmd/management.go b/cmd/management.go index 5d3ca30cf..c4ee1490a 100644 --- a/cmd/management.go +++ b/cmd/management.go @@ -1,21 +1,18 @@ package cmd import ( - "crypto/tls" "flag" "fmt" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/wiretrustee/wiretrustee/encryption" mgmt "github.com/wiretrustee/wiretrustee/management" mgmtProto "github.com/wiretrustee/wiretrustee/management/proto" - "golang.org/x/crypto/acme/autocert" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "net" - "net/http" "os" - "path/filepath" "time" ) @@ -52,34 +49,8 @@ var ( var opts []grpc.ServerOption if mgmtLetsencryptDomain != "" { - - certDir := filepath.Join(mgmtDataDir, "letsencrypt") - - if _, err := os.Stat(certDir); os.IsNotExist(err) { - err = os.MkdirAll(certDir, os.ModeDir) - if err != nil { - log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err) - } - } - - log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", mgmtLetsencryptDomain, certDir) - - certManager := autocert.Manager{ - Prompt: autocert.AcceptTOS, - Cache: autocert.DirCache(certDir), - HostPolicy: autocert.HostWhitelist(mgmtLetsencryptDomain), - } - tls := &tls.Config{GetCertificate: certManager.GetCertificate} - - credentials := credentials.NewTLS(tls) - opts = append(opts, grpc.Creds(credentials)) - - // listener to handle Let's encrypt certificate challenge - go func() { - if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil { - log.Fatalf("failed to serve letsencrypt handler: %v", err) - } - }() + transportCredentials := credentials.NewTLS(encryption.EnableLetsEncrypt(mgmtDataDir, mgmtLetsencryptDomain)) + opts = append(opts, grpc.Creds(transportCredentials)) } opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) diff --git a/cmd/root.go b/cmd/root.go index ef1b6f0ee..698f1a576 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,16 +1,11 @@ package cmd import ( - "crypto/tls" "fmt" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "golang.org/x/crypto/acme/autocert" - "google.golang.org/grpc/credentials" - "net/http" "os" "os/signal" - "path/filepath" "runtime" ) @@ -79,31 +74,3 @@ func InitLog(logLevel string) { } log.SetLevel(level) } - -func enableLetsEncrypt(datadir string, letsencryptDomain string) credentials.TransportCredentials { - certDir := filepath.Join(datadir, "letsencrypt") - - if _, err := os.Stat(certDir); os.IsNotExist(err) { - err = os.MkdirAll(certDir, os.ModeDir) - if err != nil { - log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err) - } - } - - log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir) - - certManager := autocert.Manager{ - Prompt: autocert.AcceptTOS, - Cache: autocert.DirCache(certDir), - HostPolicy: autocert.HostWhitelist(letsencryptDomain), - } - - // listener to handle Let's encrypt certificate challenge - go func() { - if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil { - log.Fatalf("failed to serve letsencrypt handler: %v", err) - } - }() - - return credentials.NewTLS(&tls.Config{GetCertificate: certManager.GetCertificate}) -} diff --git a/cmd/signal.go b/cmd/signal.go index 487987032..a4c465dc5 100644 --- a/cmd/signal.go +++ b/cmd/signal.go @@ -5,9 +5,11 @@ import ( "fmt" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/wiretrustee/wiretrustee/encryption" sig "github.com/wiretrustee/wiretrustee/signal" sigProto "github.com/wiretrustee/wiretrustee/signal/proto" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "net" "os" @@ -45,8 +47,8 @@ var ( } var opts []grpc.ServerOption - if mgmtLetsencryptDomain != "" { - transportCredentials := enableLetsEncrypt(signalDataDir, signalLetsencryptDomain) + if signalLetsencryptDomain != "" { + transportCredentials := credentials.NewTLS(encryption.EnableLetsEncrypt(signalDataDir, signalLetsencryptDomain)) opts = append(opts, grpc.Creds(transportCredentials)) } diff --git a/signal/encryption.go b/encryption/encryption.go similarity index 83% rename from signal/encryption.go rename to encryption/encryption.go index 0018e04a2..196c42106 100644 --- a/signal/encryption.go +++ b/encryption/encryption.go @@ -1,4 +1,4 @@ -package signal +package encryption import ( "crypto/rand" @@ -7,9 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// As set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service. -// We want to make sure that the Connection Candidates and other irrelevant (to the Signal Exchange) -// information can't be read anywhere else but the Peer the message is being sent to. +// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service // These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate) // Wireguard keys are used for encryption diff --git a/encryption/encryption_suite_test.go b/encryption/encryption_suite_test.go new file mode 100644 index 000000000..1f05f0765 --- /dev/null +++ b/encryption/encryption_suite_test.go @@ -0,0 +1,13 @@ +package encryption_test + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "testing" +) + +func TestManagement(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Management Service Suite") +} diff --git a/encryption/encryption_test.go b/encryption/encryption_test.go new file mode 100644 index 000000000..fbaf4ca7c --- /dev/null +++ b/encryption/encryption_test.go @@ -0,0 +1,60 @@ +package encryption_test + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/wiretrustee/wiretrustee/encryption" + "github.com/wiretrustee/wiretrustee/encryption/testprotos" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const () + +var _ = Describe("Encryption", func() { + + var ( + encryptionKey wgtypes.Key + decryptionKey wgtypes.Key + ) + + BeforeEach(func() { + var err error + encryptionKey, err = wgtypes.GenerateKey() + Expect(err).NotTo(HaveOccurred()) + decryptionKey, err = wgtypes.GenerateKey() + Expect(err).NotTo(HaveOccurred()) + }) + + Context("decrypting a plain message", func() { + Context("when it was encrypted with Wireguard keys", func() { + Specify("should be successful", func() { + msg := "message" + encryptedMsg, err := encryption.Encrypt([]byte(msg), decryptionKey.PublicKey(), encryptionKey) + Expect(err).NotTo(HaveOccurred()) + + decryptedMsg, err := encryption.Decrypt(encryptedMsg, encryptionKey.PublicKey(), decryptionKey) + Expect(err).NotTo(HaveOccurred()) + + Expect(string(decryptedMsg)).To(BeEquivalentTo(msg)) + }) + }) + }) + + Context("decrypting a protobuf message", func() { + Context("when it was encrypted with Wireguard keys", func() { + Specify("should be successful", func() { + + protoMsg := &testprotos.TestMessage{Body: "message"} + encryptedMsg, err := encryption.EncryptMessage(decryptionKey.PublicKey(), encryptionKey, protoMsg) + Expect(err).NotTo(HaveOccurred()) + + decryptedMsg := &testprotos.TestMessage{} + err = encryption.DecryptMessage(encryptionKey.PublicKey(), decryptionKey, encryptedMsg, decryptedMsg) + Expect(err).NotTo(HaveOccurred()) + + Expect(decryptedMsg.GetBody()).To(BeEquivalentTo(protoMsg.GetBody())) + }) + }) + }) + +}) diff --git a/encryption/letsencrypt.go b/encryption/letsencrypt.go new file mode 100644 index 000000000..5664b6e7a --- /dev/null +++ b/encryption/letsencrypt.go @@ -0,0 +1,40 @@ +package encryption + +import ( + "crypto/tls" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/acme/autocert" + "net/http" + "os" + "path/filepath" +) + +// EnableLetsEncrypt wraps common logic of generating Let's encrypt certificate. +// Includes a HTTP handler and listener to solve the Let's encrypt challenge +func EnableLetsEncrypt(datadir string, letsencryptDomain string) *tls.Config { + certDir := filepath.Join(datadir, "letsencrypt") + + if _, err := os.Stat(certDir); os.IsNotExist(err) { + err = os.MkdirAll(certDir, os.ModeDir) + if err != nil { + log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err) + } + } + + log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir) + + certManager := autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(certDir), + HostPolicy: autocert.HostWhitelist(letsencryptDomain), + } + + // listener to handle Let's encrypt certificate challenge + go func() { + if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil { + log.Fatalf("failed to serve letsencrypt handler: %v", err) + } + }() + + return &tls.Config{GetCertificate: certManager.GetCertificate} +} diff --git a/encryption/message.go b/encryption/message.go new file mode 100644 index 000000000..a646fa679 --- /dev/null +++ b/encryption/message.go @@ -0,0 +1,40 @@ +package encryption + +import ( + pb "github.com/golang/protobuf/proto" //nolint + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// EncryptMessage encrypts a body of the given protobuf Message +func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) { + byteResp, err := pb.Marshal(message) + if err != nil { + log.Errorf("failed marshalling message %v", err) + return nil, err + } + + encryptedBytes, err := Encrypt(byteResp, remotePubKey, ourPrivateKey) + if err != nil { + log.Errorf("failed encrypting SyncResponse %v", err) + return nil, err + } + + return encryptedBytes, nil +} + +// DecryptMessage decrypts an encrypted message into given protobuf Message +func DecryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, encryptedMessage []byte, message pb.Message) error { + decrypted, err := Decrypt(encryptedMessage, remotePubKey, ourPrivateKey) + if err != nil { + log.Warnf("error while decrypting Sync request message from peer %s", remotePubKey.String()) + return err + } + + err = pb.Unmarshal(decrypted, message) + if err != nil { + log.Warnf("error while umarshalling Sync request message from peer %s", remotePubKey.String()) + return err + } + return nil +} diff --git a/encryption/testprotos/generate.sh b/encryption/testprotos/generate.sh new file mode 100755 index 000000000..0ce6ebdea --- /dev/null +++ b/encryption/testprotos/generate.sh @@ -0,0 +1,2 @@ +#!/bin/bash +protoc -I testprotos/ testprotos/testproto.proto --go_out=. \ No newline at end of file diff --git a/encryption/testprotos/testproto.pb.go b/encryption/testprotos/testproto.pb.go new file mode 100644 index 000000000..f3520683c --- /dev/null +++ b/encryption/testprotos/testproto.pb.go @@ -0,0 +1,142 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.12.4 +// source: testproto.proto + +package testprotos + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type TestMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Body string `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"` +} + +func (x *TestMessage) Reset() { + *x = TestMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_testproto_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TestMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TestMessage) ProtoMessage() {} + +func (x *TestMessage) ProtoReflect() protoreflect.Message { + mi := &file_testproto_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TestMessage.ProtoReflect.Descriptor instead. +func (*TestMessage) Descriptor() ([]byte, []int) { + return file_testproto_proto_rawDescGZIP(), []int{0} +} + +func (x *TestMessage) GetBody() string { + if x != nil { + return x.Body + } + return "" +} + +var File_testproto_proto protoreflect.FileDescriptor + +var file_testproto_proto_rawDesc = []byte{ + 0x0a, 0x0f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x12, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x22, 0x21, 0x0a, + 0x0b, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, + 0x42, 0x0d, 0x5a, 0x0b, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_testproto_proto_rawDescOnce sync.Once + file_testproto_proto_rawDescData = file_testproto_proto_rawDesc +) + +func file_testproto_proto_rawDescGZIP() []byte { + file_testproto_proto_rawDescOnce.Do(func() { + file_testproto_proto_rawDescData = protoimpl.X.CompressGZIP(file_testproto_proto_rawDescData) + }) + return file_testproto_proto_rawDescData +} + +var file_testproto_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_testproto_proto_goTypes = []interface{}{ + (*TestMessage)(nil), // 0: testprotos.TestMessage +} +var file_testproto_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_testproto_proto_init() } +func file_testproto_proto_init() { + if File_testproto_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_testproto_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TestMessage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_testproto_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_testproto_proto_goTypes, + DependencyIndexes: file_testproto_proto_depIdxs, + MessageInfos: file_testproto_proto_msgTypes, + }.Build() + File_testproto_proto = out.File + file_testproto_proto_rawDesc = nil + file_testproto_proto_goTypes = nil + file_testproto_proto_depIdxs = nil +} diff --git a/encryption/testprotos/testproto.proto b/encryption/testprotos/testproto.proto new file mode 100644 index 000000000..77cf5d633 --- /dev/null +++ b/encryption/testprotos/testproto.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +option go_package = "/testprotos"; + +package testprotos; + +message TestMessage { + string body = 1; +} \ No newline at end of file diff --git a/management/management_test.go b/management/management_test.go index 3baa724a5..e5cfac3ad 100644 --- a/management/management_test.go +++ b/management/management_test.go @@ -4,7 +4,7 @@ import ( "context" pb "github.com/golang/protobuf/proto" //nolint log "github.com/sirupsen/logrus" - "github.com/wiretrustee/wiretrustee/signal" + "github.com/wiretrustee/wiretrustee/encryption" "io" "io/ioutil" "math/rand" @@ -94,7 +94,7 @@ var _ = Describe("Management service", func() { messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key) + encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key) Expect(err).NotTo(HaveOccurred()) sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ @@ -106,7 +106,7 @@ var _ = Describe("Management service", func() { encryptedResponse := &mgmtProto.EncryptedMessage{} err = sync.RecvMsg(encryptedResponse) Expect(err).NotTo(HaveOccurred()) - decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key) + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) Expect(err).NotTo(HaveOccurred()) resp := &mgmtProto.SyncResponse{} @@ -127,7 +127,7 @@ var _ = Describe("Management service", func() { messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key) + encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key) Expect(err).NotTo(HaveOccurred()) sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ @@ -140,7 +140,7 @@ var _ = Describe("Management service", func() { encryptedResponse := &mgmtProto.EncryptedMessage{} err = sync.RecvMsg(encryptedResponse) Expect(err).NotTo(HaveOccurred()) - decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key) + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) Expect(err).NotTo(HaveOccurred()) resp := &mgmtProto.SyncResponse{} err = pb.Unmarshal(decryptedBytes, resp) @@ -153,7 +153,7 @@ var _ = Describe("Management service", func() { go func() { err = sync.RecvMsg(encryptedResponse) - decryptedBytes, err = signal.Decrypt(encryptedResponse.Body, serverPubKey, key) + decryptedBytes, err = encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) Expect(err).NotTo(HaveOccurred()) resp = &mgmtProto.SyncResponse{} err = pb.Unmarshal(decryptedBytes, resp) @@ -240,7 +240,7 @@ var _ = Describe("Management service", func() { for _, peer := range peers { messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, peer) + encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer) Expect(err).NotTo(HaveOccurred()) // receive stream @@ -261,7 +261,7 @@ var _ = Describe("Management service", func() { } else if err != nil { Expect(err).NotTo(HaveOccurred()) } - decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, peer) + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, peer) Expect(err).NotTo(HaveOccurred()) resp := &mgmtProto.SyncResponse{} diff --git a/management/message.go b/management/message.go deleted file mode 100644 index 46806a68e..000000000 --- a/management/message.go +++ /dev/null @@ -1,44 +0,0 @@ -package management - -import ( - pb "github.com/golang/protobuf/proto" //nolint - log "github.com/sirupsen/logrus" - "github.com/wiretrustee/wiretrustee/management/proto" - "github.com/wiretrustee/wiretrustee/signal" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -// EncryptMessage encrypts a body of the given pn.Message and wraps into proto.EncryptedMessage -func EncryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, message pb.Message) (*proto.EncryptedMessage, error) { - byteResp, err := pb.Marshal(message) - if err != nil { - log.Errorf("failed marshalling message %v", err) - return nil, err - } - - encryptedBytes, err := signal.Encrypt(byteResp, peerKey, serverPrivateKey) - if err != nil { - log.Errorf("failed encrypting SyncResponse %v", err) - return nil, err - } - - return &proto.EncryptedMessage{ - WgPubKey: serverPrivateKey.PublicKey().String(), - Body: encryptedBytes}, nil -} - -//DecryptMessage decrypts an encrypted message (proto.EncryptedMessage) -func DecryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, encryptedMessage *proto.EncryptedMessage, message pb.Message) error { - decrypted, err := signal.Decrypt(encryptedMessage.Body, peerKey, serverPrivateKey) - if err != nil { - log.Warnf("error while decrypting Sync request message from peer %s", peerKey.String()) - return err - } - - err = pb.Unmarshal(decrypted, message) - if err != nil { - log.Warnf("error while umarshalling Sync request message from peer %s", peerKey.String()) - return err - } - return nil -} diff --git a/management/server.go b/management/server.go index 48fd6d148..0a4dca9e9 100644 --- a/management/server.go +++ b/management/server.go @@ -4,6 +4,7 @@ import ( "context" "github.com/golang/protobuf/ptypes/timestamp" log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/encryption" "github.com/wiretrustee/wiretrustee/management/proto" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" @@ -76,7 +77,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } syncReq := &proto.SyncRequest{} - err = DecryptMessage(peerKey, s.wgKey, req, syncReq) + err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq) if err != nil { return status.Errorf(codes.InvalidArgument, "invalid request message") } @@ -99,12 +100,15 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } log.Debugf("recevied an update for peer %s", peerKey.String()) - encryptedResp, err := EncryptMessage(peerKey, s.wgKey, update.Update) + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { return status.Errorf(codes.Internal, "failed processing update message") } - err = srv.SendMsg(encryptedResp) + err = srv.SendMsg(&proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }) if err != nil { return status.Errorf(codes.Internal, "failed sending update message") } @@ -200,12 +204,15 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, srv proto.ManagementServic Peers: peers, } - encryptedResp, err := EncryptMessage(peerKey, s.wgKey, plainResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { return status.Errorf(codes.Internal, "error handling request") } - err = srv.Send(encryptedResp) + err = srv.Send(&proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }) if err != nil { log.Errorf("failed sending SyncResponse %v", err) diff --git a/signal/client.go b/signal/client.go index 0542adf8a..bedcaaaf3 100644 --- a/signal/client.go +++ b/signal/client.go @@ -4,8 +4,8 @@ import ( "context" "fmt" "github.com/cenkalti/backoff/v4" - pb "github.com/golang/protobuf/proto" //nolint log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/encryption" "github.com/wiretrustee/wiretrustee/signal/proto" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -162,12 +162,9 @@ func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, er if err != nil { return nil, err } - decryptedBody, err := Decrypt(msg.GetBody(), remoteKey, c.key) - if err != nil { - return nil, err - } + body := &proto.Body{} - err = pb.Unmarshal(decryptedBody, body) + err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body) if err != nil { return nil, err } @@ -181,16 +178,13 @@ func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, er // encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) { - body, err := pb.Marshal(msg.GetBody()) - if err != nil { - return nil, err - } + remoteKey, err := wgtypes.ParseKey(msg.RemoteKey) if err != nil { return nil, err } - encryptedBody, err := Encrypt(body, remoteKey, c.key) + encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body) if err != nil { return nil, err } diff --git a/signal/encryption_test.go b/signal/encryption_test.go index 8e617953e..5b02ecc60 100644 --- a/signal/encryption_test.go +++ b/signal/encryption_test.go @@ -1,6 +1,7 @@ package signal import ( + "github.com/wiretrustee/wiretrustee/encryption" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "testing" ) @@ -21,13 +22,13 @@ func TestEncryptDecrypt(t *testing.T) { return } - encryptedMessage, err := Encrypt(bytesMsg, peerBKey.PublicKey(), peerAKey) + encryptedMessage, err := encryption.Encrypt(bytesMsg, peerBKey.PublicKey(), peerAKey) if err != nil { t.Error(err) return } - decryptedMessage, err := Decrypt(encryptedMessage, peerAKey.PublicKey(), peerBKey) + decryptedMessage, err := encryption.Decrypt(encryptedMessage, peerAKey.PublicKey(), peerBKey) if err != nil { t.Error(err) return