[client] Check the client status in the earlier phase (#4509)

This PR improves the NetBird client's status checking mechanism by implementing earlier detection of client state changes and better handling of connection lifecycle management. The key improvements focus on:

  • Enhanced status detection - Added waitForReady option to StatusRequest for improved client status handling
  • Better connection management - Improved context handling for signal and management gRPC connections• Reduced connection timeouts - Increased gRPC dial timeout from 3 to 10 seconds for better reliability
  • Cleaner error handling - Enhanced error propagation and context cancellation in retry loops

  Key Changes

  Core Status Improvements:
  - Added waitForReady optional field to StatusRequest proto (daemon.proto:190)
  - Enhanced status checking logic to detect client state changes earlier in the connection process
  - Improved handling of client permanent exit scenarios from retry loops

  Connection & Context Management:
  - Fixed context cancellation in management and signal client retry mechanisms
  - Added proper context propagation for Login operations
  - Enhanced gRPC connection handling with better timeout management

  Error Handling & Cleanup:
  - Moved feedback channels to upper layers for better separation of concerns
  - Improved error handling patterns throughout the client server implementation
  - Fixed synchronization issues and removed debug logging
This commit is contained in:
Zoltan Papp
2025-09-20 22:14:01 +02:00
committed by GitHub
parent e254b4cde5
commit 998fb30e1e
10 changed files with 128 additions and 54 deletions

View File

@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
return grpc.DialContext(

View File

@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
status, err := client.Status(ctx, &proto.StatusRequest{
WaitForReady: func() *bool { b := true; return &b }(),
})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}

View File

@@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {

View File

@@ -58,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
certPool, err := x509.SystemCertPool()
@@ -72,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v5.29.3
// protoc v6.32.1
// source: daemon.proto
package proto
@@ -794,8 +794,10 @@ type StatusRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StatusRequest) Reset() {
@@ -842,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool {
return false
}
func (x *StatusRequest) GetWaitForReady() bool {
if x != nil && x.WaitForReady != nil {
return *x.WaitForReady
}
return false
}
type StatusResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// status of the server.
@@ -4673,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" +
"\f_profileNameB\v\n" +
"\t_username\"\f\n" +
"\n" +
"UpResponse\"g\n" +
"UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" +
"\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" +
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" +
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" +
"\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" +
"\r_waitForReady\"\x82\x01\n" +
"\x0eStatusResponse\x12\x16\n" +
"\x06status\x18\x01 \x01(\tR\x06status\x122\n" +
"\n" +
@@ -5231,6 +5242,7 @@ func file_daemon_proto_init() {
}
file_daemon_proto_msgTypes[1].OneofWrappers = []any{}
file_daemon_proto_msgTypes[5].OneofWrappers = []any{}
file_daemon_proto_msgTypes[7].OneofWrappers = []any{}
file_daemon_proto_msgTypes[26].OneofWrappers = []any{
(*PortInfo_Port)(nil),
(*PortInfo_Range_)(nil),

View File

@@ -186,6 +186,8 @@ message UpResponse {}
message StatusRequest{
bool getFullPeerStatus = 1;
bool shouldRunProbes = 2;
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
optional bool waitForReady = 3;
}
message StatusResponse{

View File

@@ -67,6 +67,7 @@ type Server struct {
proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
clientRunningChan chan struct{}
clientGiveUpChan chan struct{}
connectClient *internal.ConnectClient
@@ -106,6 +107,10 @@ func (s *Server) Start() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.clientRunning {
return nil
}
state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil {
@@ -175,12 +180,10 @@ func (s *Server) Start() error {
return nil
}
if s.clientRunning {
return nil
}
s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan)
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
@@ -211,7 +214,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) {
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
@@ -261,6 +264,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
if err := backoff.Retry(runOperation, backOff); err != nil {
log.Errorf("operation failed: %v", err)
}
if giveUpChan != nil {
close(giveUpChan)
}
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
@@ -379,7 +386,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
if s.actCancel != nil {
s.actCancel()
}
ctx, cancel := context.WithCancel(s.rootCtx)
ctx, cancel := context.WithCancel(callerCtx)
md, ok := metadata.FromIncomingContext(callerCtx)
if ok {
@@ -389,11 +396,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.actCancel = cancel
s.mutex.Unlock()
if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil {
if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err)
}
state := internal.CtxGetState(ctx)
state := internal.CtxGetState(s.rootCtx)
defer func() {
status, err := state.Status()
if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) {
@@ -606,6 +613,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
// Up starts engine work in the daemon.
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
s.mutex.Lock()
if s.clientRunning {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return nil, err
}
if status == internal.StatusNeedsLogin {
s.actCancel()
}
s.mutex.Unlock()
return s.waitForUp(callerCtx)
}
defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
@@ -621,16 +642,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
if err != nil {
return nil, err
}
if status != internal.StatusIdle {
return nil, fmt.Errorf("up already in progress: current status %s", status)
}
// it should be nil here, but .
// it should be nil here, but in case it isn't we cancel it.
if s.actCancel != nil {
s.actCancel()
}
ctx, cancel := context.WithCancel(s.rootCtx)
md, ok := metadata.FromIncomingContext(callerCtx)
if ok {
ctx = metadata.NewOutgoingContext(ctx, md)
@@ -673,26 +694,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
// todo: handle potential race conditions
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel()
if !s.clientRunning {
s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan)
}
for {
select {
case <-s.clientRunningChan:
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil
case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready")
return nil, callerCtx.Err()
case <-timeoutCtx.Done():
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
}
select {
case <-s.clientGiveUpChan:
return nil, fmt.Errorf("client gave up to connect")
case <-s.clientRunningChan:
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil
case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready")
return nil, callerCtx.Err()
case <-timeoutCtx.Done():
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
}
}
@@ -966,12 +992,46 @@ func (s *Server) Status(
ctx context.Context,
msg *proto.StatusRequest,
) (*proto.StatusResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
defer s.mutex.Unlock()
clientRunning := s.clientRunning
s.mutex.Unlock()
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
return nil, err
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-s.clientGiveUpChan:
ticker.Stop()
break loop
case <-s.clientRunningChan:
ticker.Stop()
break loop
case <-ticker.C:
status, err := state.Status()
if err != nil {
continue
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
continue
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
status, err := internal.CtxGetState(s.rootCtx).Status()
if err != nil {

View File

@@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
@@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) {
profName := "default"
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
require.NoError(t, err)
ic := profilemanager.ConfigInput{
ConfigPath: filepath.Join(tempDir, profName+".json"),
ConfigPath: filepath.Join(tempDir, profName+".json"),
ManagementURL: u.String(),
}
_, err = profilemanager.UpdateOrCreateConfig(ic)
@@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) {
}
s := New(ctx, "console", "", false, false)
err = s.Start()
require.NoError(t, err)
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
require.NoError(t, err)
s.config = &profilemanager.Config{
ManagementURL: u,
}
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
@@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) {
Username: &currUser.Username,
}
_, err = s.Up(upCtx, upReq)
log.Errorf("error from Up: %v", err)
assert.Contains(t, err.Error(), "context deadline exceeded")
}

View File

@@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err

View File

@@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err