mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[client] Migrate deprecated grpc client code (#4687)
This commit is contained in:
@@ -4,12 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -17,6 +20,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
||||||
|
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
||||||
|
|
||||||
// Backoff returns a backoff configuration for gRPC calls
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
||||||
|
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
||||||
|
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||||
|
conn.Connect()
|
||||||
|
|
||||||
|
state := conn.GetState()
|
||||||
|
for state != connectivity.Ready && state != connectivity.Shutdown {
|
||||||
|
if !conn.WaitForStateChange(ctx, state) {
|
||||||
|
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
||||||
|
}
|
||||||
|
state = conn.GetState()
|
||||||
|
}
|
||||||
|
|
||||||
|
if state == connectivity.Shutdown {
|
||||||
|
return ErrConnectionShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
conn, err := grpc.NewClient(
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
|
||||||
connCtx,
|
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("DialContext error: %v", err)
|
return nil, fmt.Errorf("new client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := waitForConnectionReady(ctx, conn); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" {
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
|||||||
|
|
||||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to dial: %s", err)
|
|
||||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ type Conn struct {
|
|||||||
ID hooks.ConnectionID
|
ID hooks.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return closeConn(c.ID, c.Conn)
|
return closeConn(c.ID, c.Conn)
|
||||||
}
|
}
|
||||||
@@ -29,7 +28,7 @@ type TCPConn struct {
|
|||||||
ID hooks.ConnectionID
|
ID hooks.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *TCPConn) Close() error {
|
func (c *TCPConn) Close() error {
|
||||||
return closeConn(c.ID, c.TCPConn)
|
return closeConn(c.ID, c.TCPConn)
|
||||||
}
|
}
|
||||||
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
|
|||||||
// closeConn is a helper function to close connections and execute close hooks.
|
// closeConn is a helper function to close connections and execute close hooks.
|
||||||
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
|
cleanupConnID(id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupConnID executes close hooks for a connection ID.
|
||||||
|
func cleanupConnID(id hooks.ConnectionID) {
|
||||||
closeHooks := hooks.GetCloseHooks()
|
closeHooks := hooks.GetCloseHooks()
|
||||||
for _, hook := range closeHooks {
|
for _, hook := range closeHooks {
|
||||||
if err := hook(id); err != nil {
|
if err := hook(id); err != nil {
|
||||||
log.Errorf("Error executing close hook: %v", err)
|
log.Errorf("Error executing close hook: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
|
|||||||
}
|
}
|
||||||
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
log.Errorf("failed to close connection: %v", err)
|
log.Errorf("failed to close connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
|||||||
|
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cleanupConnID(connID)
|
||||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
|
|||||||
|
|
||||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
return fmt.Errorf("resolve address %s: %w", address, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|||||||
return c.PacketConn.WriteTo(b, addr)
|
return c.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *PacketConn) Close() error {
|
func (c *PacketConn) Close() error {
|
||||||
defer c.seenAddrs.Clear()
|
defer c.seenAddrs.Clear()
|
||||||
return closeConn(c.ID, c.PacketConn)
|
return closeConn(c.ID, c.PacketConn)
|
||||||
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|||||||
return c.UDPConn.WriteTo(b, addr)
|
return c.UDPConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *UDPConn) Close() error {
|
func (c *UDPConn) Close() error {
|
||||||
defer c.seenAddrs.Clear()
|
defer c.seenAddrs.Clear()
|
||||||
return closeConn(c.ID, c.UDPConn)
|
return closeConn(c.ID, c.UDPConn)
|
||||||
|
|||||||
@@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
|||||||
var err error
|
var err error
|
||||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
|
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("createConnection error: %v", err)
|
return fmt.Errorf("create connection: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
|||||||
var err error
|
var err error
|
||||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
|
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("createConnection error: %v", err)
|
return fmt.Errorf("create connection: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user