diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 6aff53b92..7763f2417 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" + "fmt" "runtime" "time" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -17,6 +20,9 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) +// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready +var ErrConnectionShutdown = errors.New("connection shutdown before ready") + // Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() @@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// waitForConnectionReady blocks until the connection becomes ready or fails. +// Returns an error if the connection times out, is cancelled, or enters shutdown state. +func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error { + conn.Connect() + + state := conn.GetState() + for state != connectivity.Ready && state != connectivity.Shutdown { + if !conn.WaitForStateChange(ctx, state) { + return fmt.Errorf("wait state change from %s: %w", state, ctx.Err()) + } + state = conn.GetState() + } + + if state == connectivity.Shutdown { + return ErrConnectionShutdown + } + + return nil +} + // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { @@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone })) } - connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - conn, err := grpc.DialContext( - connCtx, + conn, err := grpc.NewClient( addr, transportOption, WithCustomDialer(tlsEnabled, component), - grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), ) if err != nil { - log.Printf("DialContext error: %v", err) + return nil, fmt.Errorf("new client: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := waitForConnectionReady(ctx, conn); err != nil { + _ = conn.Close() return nil, err } diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index 96f347c64..479575996 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { +func WithCustomDialer(_ bool, _ string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() @@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { - log.Errorf("Failed to dial: %s", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil diff --git a/client/net/conn.go b/client/net/conn.go index 918e7f628..bf54c792d 100644 --- a/client/net/conn.go +++ b/client/net/conn.go @@ -17,8 +17,7 @@ type Conn struct { ID hooks.ConnectionID } -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection. func (c *Conn) Close() error { return closeConn(c.ID, c.Conn) } @@ -29,7 +28,7 @@ type TCPConn struct { ID hooks.ConnectionID } -// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection. func (c *TCPConn) Close() error { return closeConn(c.ID, c.TCPConn) } @@ -37,13 +36,16 @@ func (c *TCPConn) Close() error { // closeConn is a helper function to close connections and execute close hooks. func closeConn(id hooks.ConnectionID, conn io.Closer) error { err := conn.Close() + cleanupConnID(id) + return err +} +// cleanupConnID executes close hooks for a connection ID. +func cleanupConnID(id hooks.ConnectionID) { closeHooks := hooks.GetCloseHooks() for _, hook := range closeHooks { if err := hook(id); err != nil { log.Errorf("Error executing close hook: %v", err) } } - - return err } diff --git a/client/net/dial.go b/client/net/dial.go index 041a00e5d..17c9ff98a 100644 --- a/client/net/dial.go +++ b/client/net/dial.go @@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro } return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil } - if err := conn.Close(); err != nil { log.Errorf("failed to close connection: %v", err) } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go index 2e1eb53d8..1e275013f 100644 --- a/client/net/dialer_dial.go +++ b/client/net/dialer_dial.go @@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { + cleanupConnID(connID) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } @@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str ips, err := resolver.LookupIPAddr(ctx, host) if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) + return fmt.Errorf("resolve address %s: %w", address, err) } log.Debugf("Dialer resolved IPs for %s: %v", address, ips) diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go index 0bb5ad67d..a150172b4 100644 --- a/client/net/listener_listen.go +++ b/client/net/listener_listen.go @@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection. func (c *PacketConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.PacketConn) @@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.UDPConn.WriteTo(b, addr) } -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection. func (c *UDPConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.UDPConn) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 076f2532b..520a83e36 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 31f3372c0..5368b57a2 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil }