mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[client] Check the client status in the earlier phase (#4509)
This PR improves the NetBird client's status checking mechanism by implementing earlier detection of client state changes and better handling of connection lifecycle management. The key improvements focus on: • Enhanced status detection - Added waitForReady option to StatusRequest for improved client status handling • Better connection management - Improved context handling for signal and management gRPC connections• Reduced connection timeouts - Increased gRPC dial timeout from 3 to 10 seconds for better reliability • Cleaner error handling - Enhanced error propagation and context cancellation in retry loops Key Changes Core Status Improvements: - Added waitForReady optional field to StatusRequest proto (daemon.proto:190) - Enhanced status checking logic to detect client state changes earlier in the connection process - Improved handling of client permanent exit scenarios from retry loops Connection & Context Management: - Fixed context cancellation in management and signal client retry mechanisms - Added proper context propagation for Login operations - Enhanced gRPC connection handling with better timeout management Error Handling & Cleanup: - Moved feedback channels to upper layers for better separation of concerns - Improved error handling patterns throughout the client server implementation - Fixed synchronization issues and removed debug logging
This commit is contained in:
@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
return grpc.DialContext(
|
||||
|
||||
@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{
|
||||
WaitForReady: func() *bool { b := true; return &b }(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{}, 1)
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
|
||||
@@ -58,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
@@ -72,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
}))
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.6
|
||||
// protoc v5.29.3
|
||||
// protoc v6.32.1
|
||||
// source: daemon.proto
|
||||
|
||||
package proto
|
||||
@@ -794,8 +794,10 @@ type StatusRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
|
||||
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||
WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StatusRequest) Reset() {
|
||||
@@ -842,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *StatusRequest) GetWaitForReady() bool {
|
||||
if x != nil && x.WaitForReady != nil {
|
||||
return *x.WaitForReady
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type StatusResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// status of the server.
|
||||
@@ -4673,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\f_profileNameB\v\n" +
|
||||
"\t_username\"\f\n" +
|
||||
"\n" +
|
||||
"UpResponse\"g\n" +
|
||||
"UpResponse\"\xa1\x01\n" +
|
||||
"\rStatusRequest\x12,\n" +
|
||||
"\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" +
|
||||
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" +
|
||||
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" +
|
||||
"\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" +
|
||||
"\r_waitForReady\"\x82\x01\n" +
|
||||
"\x0eStatusResponse\x12\x16\n" +
|
||||
"\x06status\x18\x01 \x01(\tR\x06status\x122\n" +
|
||||
"\n" +
|
||||
@@ -5231,6 +5242,7 @@ func file_daemon_proto_init() {
|
||||
}
|
||||
file_daemon_proto_msgTypes[1].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[5].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[7].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[26].OneofWrappers = []any{
|
||||
(*PortInfo_Port)(nil),
|
||||
(*PortInfo_Range_)(nil),
|
||||
|
||||
@@ -186,6 +186,8 @@ message UpResponse {}
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||
optional bool waitForReady = 3;
|
||||
}
|
||||
|
||||
message StatusResponse{
|
||||
|
||||
@@ -67,6 +67,7 @@ type Server struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
clientRunning bool // protected by mutex
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{}
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
|
||||
@@ -106,6 +107,10 @@ func (s *Server) Start() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.clientRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
|
||||
if err := handlePanicLog(); err != nil {
|
||||
@@ -175,12 +180,10 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.clientRunning {
|
||||
return nil
|
||||
}
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{}, 1)
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan)
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -211,7 +214,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) {
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
defer func() {
|
||||
s.mutex.Lock()
|
||||
s.clientRunning = false
|
||||
@@ -261,6 +264,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
@@ -379,7 +386,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
ctx, cancel := context.WithCancel(callerCtx)
|
||||
|
||||
md, ok := metadata.FromIncomingContext(callerCtx)
|
||||
if ok {
|
||||
@@ -389,11 +396,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
s.actCancel = cancel
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil {
|
||||
if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil {
|
||||
log.Warnf(errRestoreResidualState, err)
|
||||
}
|
||||
|
||||
state := internal.CtxGetState(ctx)
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
defer func() {
|
||||
status, err := state.Status()
|
||||
if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) {
|
||||
@@ -606,6 +613,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
if s.clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
if status == internal.StatusNeedsLogin {
|
||||
s.actCancel()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
|
||||
@@ -621,16 +642,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle {
|
||||
return nil, fmt.Errorf("up already in progress: current status %s", status)
|
||||
}
|
||||
|
||||
// it should be nil here, but .
|
||||
// it should be nil here, but in case it isn't we cancel it.
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
|
||||
md, ok := metadata.FromIncomingContext(callerCtx)
|
||||
if ok {
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
@@ -673,26 +694,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
|
||||
// todo: handle potential race conditions
|
||||
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.clientRunning {
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{}, 1)
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-s.clientRunningChan:
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
case <-callerCtx.Done():
|
||||
log.Debug("context done, stopping the wait for engine to become ready")
|
||||
return nil, callerCtx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
log.Debug("up is timed out, stopping the wait for engine to become ready")
|
||||
return nil, timeoutCtx.Err()
|
||||
}
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return nil, fmt.Errorf("client gave up to connect")
|
||||
case <-s.clientRunningChan:
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
case <-callerCtx.Done():
|
||||
log.Debug("context done, stopping the wait for engine to become ready")
|
||||
return nil, callerCtx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
log.Debug("up is timed out, stopping the wait for engine to become ready")
|
||||
return nil, timeoutCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -966,12 +992,46 @@ func (s *Server) Status(
|
||||
ctx context.Context,
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
clientRunning := s.clientRunning
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-s.clientRunningChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status, err := internal.CtxGetState(s.rootCtx).Status()
|
||||
if err != nil {
|
||||
|
||||
@@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
@@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) {
|
||||
|
||||
profName := "default"
|
||||
|
||||
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
|
||||
require.NoError(t, err)
|
||||
|
||||
ic := profilemanager.ConfigInput{
|
||||
ConfigPath: filepath.Join(tempDir, profName+".json"),
|
||||
ConfigPath: filepath.Join(tempDir, profName+".json"),
|
||||
ManagementURL: u.String(),
|
||||
}
|
||||
|
||||
_, err = profilemanager.UpdateOrCreateConfig(ic)
|
||||
@@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) {
|
||||
}
|
||||
|
||||
s := New(ctx, "console", "", false, false)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
|
||||
require.NoError(t, err)
|
||||
s.config = &profilemanager.Config{
|
||||
ManagementURL: u,
|
||||
}
|
||||
|
||||
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) {
|
||||
Username: &currUser.Username,
|
||||
}
|
||||
_, err = s.Up(upCtx, upReq)
|
||||
log.Errorf("error from Up: %v", err)
|
||||
|
||||
assert.Contains(t, err.Error(), "context deadline exceeded")
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
|
||||
@@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user