Compare commits

...

8 Commits

Author SHA1 Message Date
Zoltan Papp
d68a4a7d21 Replace the grpc header key to NetBird specific 2023-06-28 12:58:05 +02:00
Zoltan Papp
ca1722ed10 Handle the stream sending in thread safe way 2023-06-28 02:08:09 +02:00
Zoltan Papp
649dbf2bed Change log line 2023-06-26 12:47:06 +02:00
Zoltan Papp
70076b98d2 Fix metadata preparation in signal 2023-06-23 13:07:14 +02:00
Zoltan Papp
551455f314 Handle keep alive in signal server 2023-06-23 13:07:12 +02:00
pzoli
a6431e053b Update protobuf 2023-06-23 13:03:34 +02:00
Zoltan Papp
520c7b5d37 Move keepalive out of mgm pkg 2023-06-23 13:02:50 +02:00
Zoltan Papp
e376541745 Add grpc keep alive for management service 2023-06-23 13:01:31 +02:00
12 changed files with 1542 additions and 2116 deletions

View File

@@ -888,7 +888,6 @@ func (e *Engine) receiveSignalEvents() {
err := e.signal.Receive(func(msg *sProto.Message) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
conn := e.peerConns[msg.Key]
if conn == nil {
return fmt.Errorf("wrongly addressed message %s", msg.Key)

5
keepalive/client.go Normal file
View File

@@ -0,0 +1,5 @@
package keepalive
func IsKeepAliveMsg(body []byte) bool {
return len(body) == 0
}

155
keepalive/keep_alive.go Normal file
View File

@@ -0,0 +1,155 @@
package keepalive
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
const (
GrpcVersionHeaderKey = "x-netbird-version"
reversProxyHeaderKey = "x-netbird-peer"
keepAliveInterval = 30 * time.Second
)
type KeepAlive struct {
sync.RWMutex
ticker *time.Ticker
done chan struct{}
streams map[string]*ioMonitor
keepAliveMsg interface{}
}
func NewKeepAlive(keepAliveMsg interface{}) *KeepAlive {
ka := &KeepAlive{
ticker: time.NewTicker(1 * time.Second),
done: make(chan struct{}),
streams: make(map[string]*ioMonitor),
keepAliveMsg: keepAliveMsg,
}
go ka.start()
return ka
}
func (k *KeepAlive) StreamInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
address, supported := k.keepAliveIsSupported(stream.Context())
if !supported {
return handler(srv, stream)
}
m := &ioMonitor{
sync.Mutex{},
sync.Mutex{},
stream,
time.Now(),
}
k.addIoMonitor(address, m)
return handler(srv, m)
}
}
func (k *KeepAlive) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
address, supported := k.keepAliveIsSupported(ctx)
if supported {
k.updateLastSeen(address)
}
return handler(ctx, req)
}
}
func (k *KeepAlive) Stop() {
select {
case k.done <- struct{}{}:
k.ticker.Stop()
return
default:
}
}
func (k *KeepAlive) start() {
for {
select {
case <-k.done:
return
case t := <-k.ticker.C:
k.checkKeepAlive(t)
}
}
}
func (k *KeepAlive) checkKeepAlive(now time.Time) {
k.Lock()
defer k.Unlock()
for addr, m := range k.streams {
if k.isKeepAliveOutDated(now, m) {
continue
}
err := k.sendKeepAlive(m)
if err != nil {
log.Debugf("stop keepalive for: %s", addr)
delete(k.streams, addr)
}
}
}
func (k *KeepAlive) keepAliveIsSupported(ctx context.Context) (string, bool) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Warnf("metadata not found")
return "", false
}
peerAddress := k.addressFromHeader(md)
if peerAddress == "" {
log.Debugf("peer is not using reverse proxy")
return "", false
}
if len(md.Get(GrpcVersionHeaderKey)) == 0 {
log.Debugf("version info not found")
return "", false
}
return peerAddress, true
}
func (k *KeepAlive) addIoMonitor(address string, m *ioMonitor) {
k.Lock()
defer k.Unlock()
log.Debugf("add stream address for keepalive list: %s", address)
k.streams[address] = m
}
func (k *KeepAlive) sendKeepAlive(m *ioMonitor) error {
return m.sendMsg(k.keepAliveMsg)
}
func (k *KeepAlive) updateLastSeen(address string) {
k.RLock()
m, ok := k.streams[address]
k.RUnlock()
if !ok {
return
}
m.updateLastSeen()
}
func (k *KeepAlive) addressFromHeader(md metadata.MD) string {
peer := md.Get(reversProxyHeaderKey)
if len(peer) == 0 {
return ""
}
return peer[0]
}
func (k *KeepAlive) isKeepAliveOutDated(now time.Time, m *ioMonitor) bool {
return now.Sub(m.getLastSeen()) < keepAliveInterval
}

35
keepalive/monitor.go Normal file
View File

@@ -0,0 +1,35 @@
package keepalive
import (
"sync"
"time"
"google.golang.org/grpc"
)
type ioMonitor struct {
mu sync.Mutex
streamLock sync.Mutex
grpc.ServerStream
lastSeen time.Time
}
func (l *ioMonitor) sendMsg(m interface{}) error {
l.updateLastSeen()
l.streamLock.Lock()
defer l.streamLock.Unlock()
return l.ServerStream.SendMsg(m)
}
func (l *ioMonitor) updateLastSeen() {
l.mu.Lock()
defer l.mu.Unlock()
l.lastSeen = time.Now()
}
func (l *ioMonitor) getLastSeen() time.Time {
l.mu.Lock()
t := l.lastSeen
l.mu.Unlock()
return t
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
@@ -23,7 +24,9 @@ import (
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
appKeepAlive "github.com/netbirdio/netbird/keepalive"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/version"
)
// ConnStateNotifier is a wrapper interface of the status recorders
@@ -67,6 +70,9 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
realClient := proto.NewManagementServiceClient(conn)
md := metadata.Pairs(appKeepAlive.GrpcVersionHeaderKey, version.NetbirdVersion())
ctx = metadata.NewOutgoingContext(ctx, md)
return &GrpcClient{
key: ourPrivateKey,
realClient: realClient,
@@ -131,6 +137,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
@@ -246,6 +253,10 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
return err
}
if appKeepAlive.IsKeepAliveMsg(update.Body) {
continue
}
log.Debugf("got an update message from Management Service")
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)

View File

@@ -40,6 +40,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
grpcKeepAlive "github.com/netbirdio/netbird/keepalive"
mgmtProto "github.com/netbirdio/netbird/management/proto"
)
@@ -202,11 +203,17 @@ var (
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
ka := grpcKeepAlive.NewKeepAlive(&mgmtProto.KeepAlive{})
defer ka.Stop()
sInterc := grpc.StreamInterceptor(ka.StreamInterceptor())
uInterc := grpc.UnaryInterceptor(ka.UnaryInterceptor())
gRPCOpts = append(gRPCOpts, sInterc, uInterc)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(store)

File diff suppressed because it is too large Load Diff

View File

@@ -329,3 +329,5 @@ message FirewallRule {
ICMP = 4;
}
}
message KeepAlive {}

View File

@@ -21,7 +21,9 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
appKeepAlive "github.com/netbirdio/netbird/keepalive"
"github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/version"
)
const defaultSendTimeout = 5 * time.Second
@@ -73,6 +75,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
sigCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
addr,
@@ -87,9 +90,14 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
log.Errorf("failed to connect to the signalling server %v", err)
return nil, err
}
log.Debugf("connected to Signal Service: %v", conn.Target())
md := metadata.New(map[string]string{
proto.HeaderId: key.PublicKey().String(), // add key fingerprint to the request header to be identified on the server side
appKeepAlive.GrpcVersionHeaderKey: version.NetbirdVersion(), // add version info to ensure keep alive is supported
})
ctx = metadata.NewOutgoingContext(ctx, md)
return &GrpcClient{
realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx,
@@ -146,7 +154,7 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connect(ctx, c.key.PublicKey().String())
stream, err := c.connect(ctx)
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
return err
@@ -208,13 +216,9 @@ func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
return c.connectedCh
}
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
func (c *GrpcClient) connect(ctx context.Context) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil
// add key fingerprint to the request header to be identified on the server side
md := metadata.New(map[string]string{proto.HeaderId: key})
metaCtx := metadata.NewOutgoingContext(ctx, md)
stream, err := c.realClient.ConnectStream(metaCtx, grpc.WaitForReady(true))
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
c.stream = stream
if err != nil {
return nil, err
@@ -366,6 +370,12 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
} else if err != nil {
return err
}
if appKeepAlive.IsKeepAliveMsg(msg.Body) {
log.Tracef("received keepalive")
continue
}
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg)

View File

@@ -4,7 +4,6 @@ import (
"errors"
"flag"
"fmt"
"golang.org/x/crypto/acme/autocert"
"io"
"io/fs"
"net"
@@ -14,15 +13,18 @@ import (
"strings"
"time"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme/autocert"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
appKeepAlive "github.com/netbirdio/netbird/keepalive"
"github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
var (
@@ -93,6 +95,13 @@ var (
}
opts = append(opts, signalKaep, signalKasp)
ka := appKeepAlive.NewKeepAlive(&proto.KeepAlive{})
defer ka.Stop()
sInterc := grpc.StreamInterceptor(ka.StreamInterceptor())
uInterc := grpc.UnaryInterceptor(ka.UnaryInterceptor())
opts = append(opts, sInterc, uInterc)
grpcServer := grpc.NewServer(opts...)
proto.RegisterSignalExchangeServer(grpcServer, server.NewServer())

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.21.9
// protoc v3.21.12
// source: signalexchange.proto
package proto
@@ -347,6 +347,44 @@ func (x *Mode) GetDirect() bool {
return false
}
type KeepAlive struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *KeepAlive) Reset() {
*x = KeepAlive{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *KeepAlive) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*KeepAlive) ProtoMessage() {}
func (x *KeepAlive) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use KeepAlive.ProtoReflect.Descriptor instead.
func (*KeepAlive) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{4}
}
var File_signalexchange_proto protoreflect.FileDescriptor
var file_signalexchange_proto_rawDesc = []byte{
@@ -388,20 +426,20 @@ var file_signalexchange_proto_rawDesc = []byte{
0x45, 0x10, 0x04, 0x22, 0x2e, 0x0a, 0x04, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64,
0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64,
0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72,
0x65, 0x63, 0x74, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20,
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67,
0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61,
0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53,
0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c,
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42,
0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
0x65, 0x63, 0x74, 0x22, 0x0b, 0x0a, 0x09, 0x4b, 0x65, 0x65, 0x70, 0x41, 0x6c, 0x69, 0x76, 0x65,
0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61,
0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69,
0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63,
0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e,
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65,
0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61,
0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73,
0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06,
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -417,13 +455,14 @@ func file_signalexchange_proto_rawDescGZIP() []byte {
}
var file_signalexchange_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_signalexchange_proto_goTypes = []interface{}{
(Body_Type)(0), // 0: signalexchange.Body.Type
(*EncryptedMessage)(nil), // 1: signalexchange.EncryptedMessage
(*Message)(nil), // 2: signalexchange.Message
(*Body)(nil), // 3: signalexchange.Body
(*Mode)(nil), // 4: signalexchange.Mode
(*KeepAlive)(nil), // 5: signalexchange.KeepAlive
}
var file_signalexchange_proto_depIdxs = []int32{
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
@@ -494,6 +533,18 @@ func file_signalexchange_proto_init() {
return nil
}
}
file_signalexchange_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*KeepAlive); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
type x struct{}
@@ -502,7 +553,7 @@ func file_signalexchange_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_signalexchange_proto_rawDesc,
NumEnums: 1,
NumMessages: 4,
NumMessages: 5,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -62,4 +62,6 @@ message Body {
// Mode indicates a connection mode
message Mode {
optional bool direct = 1;
}
}
message KeepAlive {}