mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-08 02:14:16 -04:00
Compare commits
17 Commits
add-defaul
...
feature/de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2d7121695 | ||
|
|
3288c4414f | ||
|
|
f3b0439211 | ||
|
|
ae801d77fb | ||
|
|
bf83549db2 | ||
|
|
804a3871fe | ||
|
|
64d1edce27 | ||
|
|
bf0698e5aa | ||
|
|
fc15625963 | ||
|
|
a75dde33b9 | ||
|
|
bb46e438aa | ||
|
|
11ba253ffb | ||
|
|
14fe7c29cb | ||
|
|
158f3aceff | ||
|
|
bfa776c155 | ||
|
|
885b5c68ad | ||
|
|
b1ebac795d |
@@ -4,12 +4,15 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@@ -17,6 +20,9 @@ import (
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
||||
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
||||
|
||||
// Backoff returns a backoff configuration for gRPC calls
|
||||
func Backoff(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
||||
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
||||
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||
conn.Connect()
|
||||
|
||||
state := conn.GetState()
|
||||
for state != connectivity.Ready && state != connectivity.Shutdown {
|
||||
if !conn.WaitForStateChange(ctx, state) {
|
||||
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
||||
}
|
||||
state = conn.GetState()
|
||||
}
|
||||
|
||||
if state == connectivity.Shutdown {
|
||||
return ErrConnectionShutdown
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
}))
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
conn, err := grpc.NewClient(
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(tlsEnabled, component),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("DialContext error: %v", err)
|
||||
return nil, fmt.Errorf("new client: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := waitForConnectionReady(ctx, conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -34,7 +35,6 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
if c.engine != nil && c.engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
engine := c.engine
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if engine != nil && engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
c.engine = nil
|
||||
}
|
||||
c.engineMutex.Unlock()
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Stop() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
engine := c.Engine()
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
}
|
||||
c.engineMutex.Lock()
|
||||
defer c.engineMutex.Unlock()
|
||||
|
||||
if c.engine == nil {
|
||||
return nil
|
||||
}
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
shutdownWg sync.WaitGroup
|
||||
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
|
||||
// This is different from ServiceEnable=false from management which completely disables the DNS service.
|
||||
disableSys bool
|
||||
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.ctxCancel()
|
||||
s.shutdownWg.Wait()
|
||||
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.applyHostConfig()
|
||||
|
||||
s.shutdownWg.Add(1)
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
defer s.shutdownWg.Done()
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
@@ -200,8 +200,10 @@ type Engine struct {
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
|
||||
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
// dns forwarder port
|
||||
dnsFwdPort uint16
|
||||
@@ -326,10 +328,6 @@ func (e *Engine) Stop() error {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.Close() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
e.close()
|
||||
|
||||
// stop flow manager after wg interface is gone
|
||||
@@ -337,8 +335,6 @@ func (e *Engine) Stop() error {
|
||||
e.flowManager.Close()
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -349,12 +345,52 @@ func (e *Engine) Stop() error {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
// Stop WireGuard interface monitor and wait for it to exit
|
||||
e.wgIfaceMonitorWg.Wait()
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
func (e *Engine) calculateShutdownTimeout() time.Duration {
|
||||
peerCount := len(e.peerStore.PeersPubKey())
|
||||
|
||||
baseTimeout := 10 * time.Second
|
||||
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
|
||||
timeout := baseTimeout + perPeerTimeout
|
||||
|
||||
maxTimeout := 30 * time.Second
|
||||
if timeout > maxTimeout {
|
||||
timeout = maxTimeout
|
||||
}
|
||||
|
||||
return timeout
|
||||
}
|
||||
|
||||
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
|
||||
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
@@ -484,14 +520,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.wgIfaceMonitorWg.Add(1)
|
||||
e.shutdownWg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer e.wgIfaceMonitorWg.Done()
|
||||
defer e.shutdownWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.restartEngine()
|
||||
e.triggerClientRestart()
|
||||
} else if err != nil {
|
||||
log.Warnf("WireGuard interface monitor: %s", err)
|
||||
}
|
||||
@@ -892,7 +928,9 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
if err != nil {
|
||||
@@ -950,7 +988,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
@@ -1368,7 +1408,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1724,8 +1766,10 @@ func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
func (e *Engine) restartEngine() {
|
||||
// triggerClientRestart triggers a full client restart by cancelling the client context.
|
||||
// Note: This does NOT just restart the engine - it cancels the entire client context,
|
||||
// which causes the connect client's retry loop to create a completely new engine.
|
||||
func (e *Engine) triggerClientRestart() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -1747,7 +1791,9 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Infof("network monitor stopped")
|
||||
@@ -1757,8 +1803,8 @@ func (e *Engine) startNetworkMonitor() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
log.Infof("Network monitor: detected network change, triggering client restart")
|
||||
e.triggerClientRestart()
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
// Manager handles netflow tracking and logging
|
||||
type Manager struct {
|
||||
mux sync.Mutex
|
||||
shutdownWg sync.WaitGroup
|
||||
logger nftypes.FlowLogger
|
||||
flowConfig *nftypes.FlowConfig
|
||||
conntrack nftypes.ConnTracker
|
||||
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
go m.receiveACKs(ctx, flowClient)
|
||||
go m.startSender(ctx)
|
||||
m.shutdownWg.Add(2)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
m.receiveACKs(ctx, flowClient)
|
||||
}()
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
m.startSender(ctx)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
// Close cleans up all resources
|
||||
func (m *Manager) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if err := m.disableFlow(); err != nil {
|
||||
log.Warnf("failed to disable flow manager: %v", err)
|
||||
}
|
||||
m.mux.Unlock()
|
||||
|
||||
m.shutdownWg.Wait()
|
||||
}
|
||||
|
||||
// GetLogger returns the flow logger
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package networkmonitor
|
||||
|
||||
|
||||
344
client/internal/networkmonitor/check_change_darwin.go
Normal file
344
client/internal/networkmonitor/check_change_darwin.go
Normal file
@@ -0,0 +1,344 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// todo: refactor to not use static functions
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
routeChanged := make(chan struct{})
|
||||
go func() {
|
||||
routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||
close(routeChanged)
|
||||
}()
|
||||
|
||||
wakeUp := make(chan struct{})
|
||||
go func() {
|
||||
wakeUpListen(ctx)
|
||||
close(wakeUp)
|
||||
}()
|
||||
|
||||
gatewayChanged := make(chan string)
|
||||
go func() {
|
||||
gatewayPoll(ctx, nexthopv4, nexthopv6, gatewayChanged)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-routeChanged:
|
||||
log.Infof("route change detected via routing socket")
|
||||
return nil
|
||||
case <-wakeUp:
|
||||
log.Infof("wakeup detected via sleep hash change")
|
||||
return nil
|
||||
case reason := <-gatewayChanged:
|
||||
log.Infof("gateway change detected via polling: %s", reason)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func routeCheck(ctx context.Context, fd int, nexthopv4 systemops.Nexthop, nexthopv6 systemops.Nexthop) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
|
||||
func wakeUpListen(ctx context.Context) {
|
||||
log.Infof("start to watch for system wakeups")
|
||||
var (
|
||||
initialHash uint32
|
||||
err error
|
||||
)
|
||||
|
||||
// Keep retrying until initial sysctl succeeds or context is canceled
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||
return
|
||||
default:
|
||||
initialHash, err = readSleepTimeHash()
|
||||
if err != nil {
|
||||
log.Errorf("failed to detect initial sleep time: %v", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||
return
|
||||
case <-time.After(3 * time.Second):
|
||||
continue
|
||||
}
|
||||
}
|
||||
log.Infof("initial wakeup hash: %d", initialHash)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
lastCheck := time.Now()
|
||||
const maxTickerDrift = 1 * time.Minute
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("context canceled, stopping wakeUpListen")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(lastCheck)
|
||||
|
||||
// If more time passed than expected, system likely slept (informational only)
|
||||
if elapsed > maxTickerDrift {
|
||||
upOut, err := exec.Command("uptime").Output()
|
||||
if err != nil {
|
||||
log.Errorf("failed to run uptime command: %v", err)
|
||||
upOut = []byte("unknown")
|
||||
}
|
||||
log.Infof("Time drift detected (potential wakeup): expected ~5s, actual %s, uptime: %s", elapsed, upOut)
|
||||
|
||||
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
if errV4 == nil {
|
||||
log.Infof("Current IPv4 default gateway: %s via %s", currentV4.IP, currentV4.Intf.Name)
|
||||
} else {
|
||||
log.Debugf("No IPv4 default gateway: %v", errV4)
|
||||
}
|
||||
if errV6 == nil {
|
||||
log.Infof("Current IPv6 default gateway: %s via %s", currentV6.IP, currentV6.Intf.Name)
|
||||
} else {
|
||||
log.Debugf("No IPv6 default gateway: %v", errV6)
|
||||
}
|
||||
}
|
||||
|
||||
newHash, err := readSleepTimeHash()
|
||||
if err != nil {
|
||||
log.Errorf("failed to read sleep time hash: %v", err)
|
||||
lastCheck = now
|
||||
continue
|
||||
}
|
||||
|
||||
if newHash == initialHash {
|
||||
log.Debugf("no wakeup detected (hash unchanged: %d, time drift: %s)", initialHash, elapsed)
|
||||
lastCheck = now
|
||||
continue
|
||||
}
|
||||
|
||||
upOut, err := exec.Command("uptime").Output()
|
||||
if err != nil {
|
||||
log.Errorf("failed to run uptime command: %v", err)
|
||||
upOut = []byte("unknown")
|
||||
}
|
||||
log.Infof("Wakeup detected via hash change: %d -> %d, uptime: %s", initialHash, newHash, upOut)
|
||||
|
||||
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
if errV4 == nil {
|
||||
log.Infof("Current IPv4 default gateway after wakeup: %s via %s", currentV4.IP, currentV4.Intf.Name)
|
||||
} else {
|
||||
log.Debugf("No IPv4 default gateway after wakeup: %v", errV4)
|
||||
}
|
||||
if errV6 == nil {
|
||||
log.Infof("Current IPv6 default gateway after wakeup: %s via %s", currentV6.IP, currentV6.Intf.Name)
|
||||
} else {
|
||||
log.Debugf("No IPv6 default gateway after wakeup: %v", errV6)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readSleepTimeHash() (uint32, error) {
|
||||
cmd := exec.Command("sysctl", "kern.sleeptime")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to run sysctl: %w", err)
|
||||
}
|
||||
|
||||
h, err := hash(out)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to compute hash: %w", err)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func hash(data []byte) (uint32, error) {
|
||||
hasher := fnv.New32a()
|
||||
if _, err := hasher.Write(data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return hasher.Sum32(), nil
|
||||
}
|
||||
|
||||
// gatewayPoll polls the default gateway every 5 seconds to detect changes that might be missed by routing socket or wake-up detection.
|
||||
func gatewayPoll(ctx context.Context, initialV4, initialV6 systemops.Nexthop, changed chan<- string) {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
log.Infof("Gateway polling started - initial v4: %s via %v, v6: %s via %v",
|
||||
initialV4.IP, initialV4.Intf, initialV6.IP, initialV6.Intf)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug("context canceled, stopping gateway polling")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
var reason string
|
||||
|
||||
if errV4 == nil && initialV4.IP.IsValid() {
|
||||
if currentV4.IP.Compare(initialV4.IP) != 0 {
|
||||
reason = fmt.Sprintf("IPv4 gateway changed from %s to %s", initialV4.IP, currentV4.IP)
|
||||
log.Infof("Gateway poll detected change: %s", reason)
|
||||
changed <- reason
|
||||
return
|
||||
}
|
||||
if initialV4.Intf != nil && currentV4.Intf != nil && currentV4.Intf.Name != initialV4.Intf.Name {
|
||||
reason = fmt.Sprintf("IPv4 interface changed from %s to %s", initialV4.Intf.Name, currentV4.Intf.Name)
|
||||
log.Infof("Gateway poll detected change: %s", reason)
|
||||
changed <- reason
|
||||
return
|
||||
}
|
||||
} else if errV4 == nil && !initialV4.IP.IsValid() {
|
||||
reason = "IPv4 gateway appeared"
|
||||
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV4.IP)
|
||||
changed <- reason
|
||||
return
|
||||
} else if errV4 != nil && initialV4.IP.IsValid() {
|
||||
reason = "IPv4 gateway disappeared"
|
||||
log.Infof("Gateway poll detected change: %s", reason)
|
||||
changed <- reason
|
||||
return
|
||||
}
|
||||
|
||||
if errV6 == nil && initialV6.IP.IsValid() {
|
||||
if currentV6.IP.Compare(initialV6.IP) != 0 {
|
||||
reason = fmt.Sprintf("IPv6 gateway changed from %s to %s", initialV6.IP, currentV6.IP)
|
||||
log.Infof("Gateway poll detected change: %s", reason)
|
||||
changed <- reason
|
||||
return
|
||||
}
|
||||
if initialV6.Intf != nil && currentV6.Intf != nil && currentV6.Intf.Name != initialV6.Intf.Name {
|
||||
reason = fmt.Sprintf("IPv6 interface changed from %s to %s", initialV6.Intf.Name, currentV6.Intf.Name)
|
||||
log.Infof("Gateway poll detected change: %s", reason)
|
||||
changed <- reason
|
||||
return
|
||||
}
|
||||
} else if errV6 == nil && !initialV6.IP.IsValid() {
|
||||
reason = "IPv6 gateway appeared"
|
||||
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV6.IP)
|
||||
changed <- reason
|
||||
return
|
||||
} else if errV6 != nil && initialV6.IP.IsValid() {
|
||||
reason = "IPv6 gateway disappeared"
|
||||
log.Infof("Gateway poll detected change: %s", reason)
|
||||
changed <- reason
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Gateway poll: no change detected")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||
event := make(chan struct{}, 1)
|
||||
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||
|
||||
log.Infof("start watching for network changes")
|
||||
// debounce changes
|
||||
timer := time.NewTimer(0)
|
||||
timer.Stop()
|
||||
|
||||
@@ -19,11 +19,11 @@ type SRWatcher struct {
|
||||
signalClient chNotifier
|
||||
relayManager chNotifier
|
||||
|
||||
listeners map[chan struct{}]struct{}
|
||||
mu sync.Mutex
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig ice.Config
|
||||
|
||||
listeners map[chan struct{}]struct{}
|
||||
mu sync.Mutex
|
||||
shutdownWg sync.WaitGroup
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig ice.Config
|
||||
cancelIceMonitor context.CancelFunc
|
||||
}
|
||||
|
||||
@@ -52,7 +52,11 @@ func (w *SRWatcher) Start() {
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer w.shutdownWg.Done()
|
||||
iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}()
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
@@ -60,14 +64,16 @@ func (w *SRWatcher) Start() {
|
||||
|
||||
func (w *SRWatcher) Close() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.cancelIceMonitor == nil {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.cancelIceMonitor()
|
||||
w.signalClient.SetOnReconnectedListener(nil)
|
||||
w.relayManager.SetOnReconnectedListener(nil)
|
||||
w.mu.Unlock()
|
||||
|
||||
w.shutdownWg.Wait()
|
||||
}
|
||||
|
||||
func (w *SRWatcher) NewListener() chan struct{} {
|
||||
|
||||
@@ -78,6 +78,7 @@ type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
shutdownWg sync.WaitGroup
|
||||
clientNetworks map[route.HAUniqueID]*client.Watcher
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter *server.Router
|
||||
@@ -273,6 +274,7 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
m.stop()
|
||||
m.shutdownWg.Wait()
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.CleanUp()
|
||||
}
|
||||
@@ -474,7 +476,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
}
|
||||
clientNetworkWatcher := client.NewWatcher(config)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.Start()
|
||||
m.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
clientNetworkWatcher.Start()
|
||||
}()
|
||||
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
||||
}
|
||||
|
||||
@@ -516,7 +522,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
}
|
||||
clientNetworkWatcher = client.NewWatcher(config)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.Start()
|
||||
m.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
clientNetworkWatcher.Start()
|
||||
}()
|
||||
}
|
||||
update := client.RoutesUpdate{
|
||||
UpdateSerial: updateSerial,
|
||||
|
||||
@@ -105,11 +105,31 @@ func (r *SysOps) FlushMarkedRoutes() error {
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
if prefix.IsSingleIP() {
|
||||
log.Debugf("Adding single IP route: %s via %s", prefix, formatNexthop(nexthop))
|
||||
}
|
||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
|
||||
if prefix.IsSingleIP() {
|
||||
log.Debugf("Removing single IP route: %s via %s", prefix, formatNexthop(nexthop))
|
||||
}
|
||||
|
||||
if err := r.routeSocket(unix.RTM_DELETE, prefix, nexthop); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if prefix.IsSingleIP() {
|
||||
log.Debugf("Route removal completed for %s, verifying...", prefix)
|
||||
if exists := r.verifyRouteRemoved(prefix); exists {
|
||||
log.Warnf("Route %s still exists in routing table after removal", prefix)
|
||||
} else {
|
||||
log.Debugf("Verified route %s successfully removed", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
|
||||
@@ -276,3 +296,51 @@ func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
|
||||
|
||||
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
|
||||
}
|
||||
|
||||
// formatNexthop returns a string representation of the nexthop for logging.
|
||||
func formatNexthop(nexthop Nexthop) string {
|
||||
if nexthop.IP.IsValid() {
|
||||
return nexthop.IP.String()
|
||||
}
|
||||
if nexthop.Intf != nil {
|
||||
return nexthop.Intf.Name
|
||||
}
|
||||
return "direct"
|
||||
}
|
||||
|
||||
// verifyRouteRemoved checks if a route still exists in the routing table.
|
||||
func (r *SysOps) verifyRouteRemoved(prefix netip.Prefix) bool {
|
||||
rib, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
log.Debugf("Failed to fetch RIB for route verification: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
log.Debugf("Failed to parse RIB for route verification: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
rtMsg, ok := msg.(*route.RouteMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
routeInfo, err := MsgToRoute(rtMsg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if routeInfo.Dst == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -17,8 +17,7 @@ type Conn struct {
|
||||
ID hooks.ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
|
||||
func (c *Conn) Close() error {
|
||||
return closeConn(c.ID, c.Conn)
|
||||
}
|
||||
@@ -29,7 +28,7 @@ type TCPConn struct {
|
||||
ID hooks.ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
||||
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
|
||||
func (c *TCPConn) Close() error {
|
||||
return closeConn(c.ID, c.TCPConn)
|
||||
}
|
||||
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
|
||||
// closeConn is a helper function to close connections and execute close hooks.
|
||||
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||
err := conn.Close()
|
||||
cleanupConnID(id)
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanupConnID executes close hooks for a connection ID.
|
||||
func cleanupConnID(id hooks.ConnectionID) {
|
||||
closeHooks := hooks.GetCloseHooks()
|
||||
for _, hook := range closeHooks {
|
||||
if err := hook(id); err != nil {
|
||||
log.Errorf("Error executing close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
|
||||
}
|
||||
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
cleanupConnID(connID)
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
|
||||
|
||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
||||
return fmt.Errorf("resolve address %s: %w", address, err)
|
||||
}
|
||||
|
||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||
|
||||
@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
return c.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
||||
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
|
||||
func (c *PacketConn) Close() error {
|
||||
defer c.seenAddrs.Clear()
|
||||
return closeConn(c.ID, c.PacketConn)
|
||||
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
return c.UDPConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
||||
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
|
||||
func (c *UDPConn) Close() error {
|
||||
defer c.seenAddrs.Clear()
|
||||
return closeConn(c.ID, c.UDPConn)
|
||||
|
||||
@@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("create connection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("create connection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user