Compare commits

...

17 Commits

Author SHA1 Message Date
Viktor Liu
b2d7121695 Add logging around routes 2025-10-28 14:20:15 +01:00
Viktor Liu
3288c4414f Merge branch 'shutdown-block' into feature/detect-mac-wakeup 2025-10-27 23:20:25 +01:00
Viktor Liu
f3b0439211 Merge branch 'main' into feature/detect-mac-wakeup 2025-10-27 23:20:19 +01:00
Viktor Liu
ae801d77fb Block on all subsystems on shutdown 2025-10-27 23:15:47 +01:00
Viktor Liu
bf83549db2 Merge branch 'main' into feature/detect-mac-wakeup 2025-10-23 17:09:06 +02:00
Viktor Liu
804a3871fe Merge branch 'fix-deprecated-grpc' into feature/detect-mac-wakeup 2025-10-23 17:08:12 +02:00
Viktor Liu
64d1edce27 Merge branch 'bsd-route-cleanup' into feature/detect-mac-wakeup 2025-10-23 17:07:22 +02:00
Viktor Liu
bf0698e5aa Clean up bsd routes independently of the state file 2025-10-23 16:42:23 +02:00
Viktor Liu
fc15625963 Clean up failed conn hooks 2025-10-23 15:35:45 +02:00
Viktor Liu
a75dde33b9 Fix happy eyeballs for grpc 2025-10-23 13:14:37 +02:00
Viktor Liu
bb46e438aa Add gw change polling and time drift detection (informational only) 2025-10-21 12:19:35 +02:00
Zoltán Papp
11ba253ffb Merge branch 'main' into feature/detect-mac-wakeup 2025-10-16 17:16:19 +02:00
Zoltan Papp
14fe7c29cb Change log levels 2025-10-14 18:13:52 +02:00
Zoltan Papp
158f3aceff Handle better sleep period 2025-10-14 12:24:39 +02:00
Zoltan Papp
bfa776c155 Update client/internal/networkmonitor/check_change_darwin.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-14 12:18:00 +02:00
Zoltan Papp
885b5c68ad Update client/internal/networkmonitor/monitor.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-14 12:17:51 +02:00
Zoltan Papp
b1ebac795d Extend Darwin network monitoring with wakeup detection 2025-10-14 12:14:08 +02:00
18 changed files with 586 additions and 76 deletions

View File

@@ -4,12 +4,15 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"runtime"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
@@ -17,6 +20,9 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots"
)
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}))
}
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
conn, err := grpc.NewClient(
addr,
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
}

View File

@@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
<-engineCtx.Done()
c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown()
backOff.Reset()
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
}
func (c *ConnectClient) Stop() error {
if c == nil {
return nil
engine := c.Engine()
if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
}
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
if c.engine == nil {
return nil
}
if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil
}

View File

@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelFunc
shutdownWg sync.WaitGroup
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
// This is different from ServiceEnable=false from management which completely disables the DNS service.
disableSys bool
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
// Stop stops the server
func (s *DefaultServer) Stop() {
s.ctxCancel()
s.shutdownWg.Wait()
s.mux.Lock()
defer s.mux.Unlock()
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.applyHostConfig()
s.shutdownWg.Add(1)
go func() {
// persist dns state right away
defer s.shutdownWg.Done()
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}

View File

@@ -200,8 +200,10 @@ type Engine struct {
flowManager nftypes.FlowManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
wgIfaceMonitor *WGIfaceMonitor
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
shutdownWg sync.WaitGroup
// dns forwarder port
dnsFwdPort uint16
@@ -326,10 +328,6 @@ func (e *Engine) Stop() error {
e.cancel()
}
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
// stop flow manager after wg interface is gone
@@ -337,8 +335,6 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
@@ -349,12 +345,52 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
func (e *Engine) calculateShutdownTimeout() time.Duration {
peerCount := len(e.peerStore.PeersPubKey())
baseTimeout := 10 * time.Second
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
timeout := baseTimeout + perPeerTimeout
maxTimeout := 30 * time.Second
if timeout > maxTimeout {
timeout = maxTimeout
}
return timeout
}
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
@@ -484,14 +520,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
e.shutdownWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
e.triggerClientRestart()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
@@ -892,7 +928,9 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// blocking
err = e.sshServer.Start()
if err != nil {
@@ -950,7 +988,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
@@ -1368,7 +1408,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
e.syncMsgMux.Lock()
@@ -1724,8 +1766,10 @@ func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
)
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
// triggerClientRestart triggers a full client restart by cancelling the client context.
// Note: This does NOT just restart the engine - it cancels the entire client context,
// which causes the connect client's retry loop to create a completely new engine.
func (e *Engine) triggerClientRestart() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1747,7 +1791,9 @@ func (e *Engine) startNetworkMonitor() {
}
e.networkMonitor = networkmonitor.New()
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
@@ -1757,8 +1803,8 @@ func (e *Engine) startNetworkMonitor() {
return
}
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
log.Infof("Network monitor: detected network change, triggering client restart")
e.triggerClientRestart()
}()
}

View File

@@ -24,6 +24,7 @@ import (
// Manager handles netflow tracking and logging
type Manager struct {
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
go m.receiveACKs(ctx, flowClient)
go m.startSender(ctx)
m.shutdownWg.Add(2)
go func() {
defer m.shutdownWg.Done()
m.receiveACKs(ctx, flowClient)
}()
go func() {
defer m.shutdownWg.Done()
m.startSender(ctx)
}()
return nil
}
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
// Close cleans up all resources
func (m *Manager) Close() {
m.mux.Lock()
defer m.mux.Unlock()
if err := m.disableFlow(); err != nil {
log.Warnf("failed to disable flow manager: %v", err)
}
m.mux.Unlock()
m.shutdownWg.Wait()
}
// GetLogger returns the flow logger

View File

@@ -1,4 +1,4 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
//go:build dragonfly || freebsd || netbsd || openbsd
package networkmonitor

View File

@@ -0,0 +1,344 @@
//go:build darwin && !ios
package networkmonitor
import (
"context"
"errors"
"fmt"
"hash/fnv"
"net/netip"
"os/exec"
"syscall"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// todo: refactor to not use static functions
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %v", err)
}
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}()
routeChanged := make(chan struct{})
go func() {
routeCheck(ctx, fd, nexthopv4, nexthopv6)
close(routeChanged)
}()
wakeUp := make(chan struct{})
go func() {
wakeUpListen(ctx)
close(wakeUp)
}()
gatewayChanged := make(chan string)
go func() {
gatewayPoll(ctx, nexthopv4, nexthopv6, gatewayChanged)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-routeChanged:
log.Infof("route change detected via routing socket")
return nil
case <-wakeUp:
log.Infof("wakeup detected via sleep hash change")
return nil
case reason := <-gatewayChanged:
log.Infof("gateway change detected via polling: %s", reason)
return nil
}
}
func routeCheck(ctx context.Context, fd int, nexthopv4 systemops.Nexthop, nexthopv6 systemops.Nexthop) {
for {
if ctx.Err() != nil {
return
}
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return
}
}
}
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
}
func wakeUpListen(ctx context.Context) {
log.Infof("start to watch for system wakeups")
var (
initialHash uint32
err error
)
// Keep retrying until initial sysctl succeeds or context is canceled
for {
select {
case <-ctx.Done():
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
return
default:
initialHash, err = readSleepTimeHash()
if err != nil {
log.Errorf("failed to detect initial sleep time: %v", err)
select {
case <-ctx.Done():
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
return
case <-time.After(3 * time.Second):
continue
}
}
log.Infof("initial wakeup hash: %d", initialHash)
break
}
break
}
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
lastCheck := time.Now()
const maxTickerDrift = 1 * time.Minute
for {
select {
case <-ctx.Done():
log.Info("context canceled, stopping wakeUpListen")
return
case <-ticker.C:
now := time.Now()
elapsed := now.Sub(lastCheck)
// If more time passed than expected, system likely slept (informational only)
if elapsed > maxTickerDrift {
upOut, err := exec.Command("uptime").Output()
if err != nil {
log.Errorf("failed to run uptime command: %v", err)
upOut = []byte("unknown")
}
log.Infof("Time drift detected (potential wakeup): expected ~5s, actual %s, uptime: %s", elapsed, upOut)
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
if errV4 == nil {
log.Infof("Current IPv4 default gateway: %s via %s", currentV4.IP, currentV4.Intf.Name)
} else {
log.Debugf("No IPv4 default gateway: %v", errV4)
}
if errV6 == nil {
log.Infof("Current IPv6 default gateway: %s via %s", currentV6.IP, currentV6.Intf.Name)
} else {
log.Debugf("No IPv6 default gateway: %v", errV6)
}
}
newHash, err := readSleepTimeHash()
if err != nil {
log.Errorf("failed to read sleep time hash: %v", err)
lastCheck = now
continue
}
if newHash == initialHash {
log.Debugf("no wakeup detected (hash unchanged: %d, time drift: %s)", initialHash, elapsed)
lastCheck = now
continue
}
upOut, err := exec.Command("uptime").Output()
if err != nil {
log.Errorf("failed to run uptime command: %v", err)
upOut = []byte("unknown")
}
log.Infof("Wakeup detected via hash change: %d -> %d, uptime: %s", initialHash, newHash, upOut)
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
if errV4 == nil {
log.Infof("Current IPv4 default gateway after wakeup: %s via %s", currentV4.IP, currentV4.Intf.Name)
} else {
log.Debugf("No IPv4 default gateway after wakeup: %v", errV4)
}
if errV6 == nil {
log.Infof("Current IPv6 default gateway after wakeup: %s via %s", currentV6.IP, currentV6.Intf.Name)
} else {
log.Debugf("No IPv6 default gateway after wakeup: %v", errV6)
}
return
}
}
}
func readSleepTimeHash() (uint32, error) {
cmd := exec.Command("sysctl", "kern.sleeptime")
out, err := cmd.Output()
if err != nil {
return 0, fmt.Errorf("failed to run sysctl: %w", err)
}
h, err := hash(out)
if err != nil {
return 0, fmt.Errorf("failed to compute hash: %w", err)
}
return h, nil
}
func hash(data []byte) (uint32, error) {
hasher := fnv.New32a()
if _, err := hasher.Write(data); err != nil {
return 0, err
}
return hasher.Sum32(), nil
}
// gatewayPoll polls the default gateway every 5 seconds to detect changes that might be missed by routing socket or wake-up detection.
func gatewayPoll(ctx context.Context, initialV4, initialV6 systemops.Nexthop, changed chan<- string) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
log.Infof("Gateway polling started - initial v4: %s via %v, v6: %s via %v",
initialV4.IP, initialV4.Intf, initialV6.IP, initialV6.Intf)
for {
select {
case <-ctx.Done():
log.Debug("context canceled, stopping gateway polling")
return
case <-ticker.C:
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
var reason string
if errV4 == nil && initialV4.IP.IsValid() {
if currentV4.IP.Compare(initialV4.IP) != 0 {
reason = fmt.Sprintf("IPv4 gateway changed from %s to %s", initialV4.IP, currentV4.IP)
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
if initialV4.Intf != nil && currentV4.Intf != nil && currentV4.Intf.Name != initialV4.Intf.Name {
reason = fmt.Sprintf("IPv4 interface changed from %s to %s", initialV4.Intf.Name, currentV4.Intf.Name)
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
} else if errV4 == nil && !initialV4.IP.IsValid() {
reason = "IPv4 gateway appeared"
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV4.IP)
changed <- reason
return
} else if errV4 != nil && initialV4.IP.IsValid() {
reason = "IPv4 gateway disappeared"
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
if errV6 == nil && initialV6.IP.IsValid() {
if currentV6.IP.Compare(initialV6.IP) != 0 {
reason = fmt.Sprintf("IPv6 gateway changed from %s to %s", initialV6.IP, currentV6.IP)
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
if initialV6.Intf != nil && currentV6.Intf != nil && currentV6.Intf.Name != initialV6.Intf.Name {
reason = fmt.Sprintf("IPv6 interface changed from %s to %s", initialV6.Intf.Name, currentV6.Intf.Name)
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
} else if errV6 == nil && !initialV6.IP.IsValid() {
reason = "IPv6 gateway appeared"
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV6.IP)
changed <- reason
return
} else if errV6 != nil && initialV6.IP.IsValid() {
reason = "IPv6 gateway disappeared"
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
log.Debugf("Gateway poll: no change detected")
}
}
}

View File

@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
log.Infof("start watching for network changes")
// debounce changes
timer := time.NewTimer(0)
timer.Stop()

View File

@@ -19,11 +19,11 @@ type SRWatcher struct {
signalClient chNotifier
relayManager chNotifier
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
listeners map[chan struct{}]struct{}
mu sync.Mutex
shutdownWg sync.WaitGroup
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
cancelIceMonitor context.CancelFunc
}
@@ -52,7 +52,11 @@ func (w *SRWatcher) Start() {
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
w.shutdownWg.Add(1)
go func() {
defer w.shutdownWg.Done()
iceMonitor.Start(ctx, w.onICEChanged)
}()
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)
@@ -60,14 +64,16 @@ func (w *SRWatcher) Start() {
func (w *SRWatcher) Close() {
w.mu.Lock()
defer w.mu.Unlock()
if w.cancelIceMonitor == nil {
w.mu.Unlock()
return
}
w.cancelIceMonitor()
w.signalClient.SetOnReconnectedListener(nil)
w.relayManager.SetOnReconnectedListener(nil)
w.mu.Unlock()
w.shutdownWg.Wait()
}
func (w *SRWatcher) NewListener() chan struct{} {

View File

@@ -78,6 +78,7 @@ type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
shutdownWg sync.WaitGroup
clientNetworks map[route.HAUniqueID]*client.Watcher
routeSelector *routeselector.RouteSelector
serverRouter *server.Router
@@ -273,6 +274,7 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop()
m.shutdownWg.Wait()
if m.serverRouter != nil {
m.serverRouter.CleanUp()
}
@@ -474,7 +476,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
}
clientNetworkWatcher := client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
}
@@ -516,7 +522,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
}
clientNetworkWatcher = client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
}
update := client.RoutesUpdate{
UpdateSerial: updateSerial,

View File

@@ -105,11 +105,31 @@ func (r *SysOps) FlushMarkedRoutes() error {
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
if prefix.IsSingleIP() {
log.Debugf("Adding single IP route: %s via %s", prefix, formatNexthop(nexthop))
}
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
if prefix.IsSingleIP() {
log.Debugf("Removing single IP route: %s via %s", prefix, formatNexthop(nexthop))
}
if err := r.routeSocket(unix.RTM_DELETE, prefix, nexthop); err != nil {
return err
}
if prefix.IsSingleIP() {
log.Debugf("Route removal completed for %s, verifying...", prefix)
if exists := r.verifyRouteRemoved(prefix); exists {
log.Warnf("Route %s still exists in routing table after removal", prefix)
} else {
log.Debugf("Verified route %s successfully removed", prefix)
}
}
return nil
}
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
@@ -276,3 +296,51 @@ func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
}
// formatNexthop returns a string representation of the nexthop for logging.
func formatNexthop(nexthop Nexthop) string {
if nexthop.IP.IsValid() {
return nexthop.IP.String()
}
if nexthop.Intf != nil {
return nexthop.Intf.Name
}
return "direct"
}
// verifyRouteRemoved checks if a route still exists in the routing table.
func (r *SysOps) verifyRouteRemoved(prefix netip.Prefix) bool {
rib, err := retryFetchRIB()
if err != nil {
log.Debugf("Failed to fetch RIB for route verification: %v", err)
return false
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
log.Debugf("Failed to parse RIB for route verification: %v", err)
return false
}
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
continue
}
if routeInfo.Dst == prefix {
return true
}
}
return false
}

View File

@@ -17,8 +17,7 @@ type Conn struct {
ID hooks.ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn)
}
@@ -29,7 +28,7 @@ type TCPConn struct {
ID hooks.ConnectionID
}
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn)
}
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
// closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close()
cleanupConnID(id)
return err
}
// cleanupConnID executes close hooks for a connection ID.
func cleanupConnID(id hooks.ConnectionID) {
closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks {
if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err)
}
}
return err
}

View File

@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
}
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}

View File

@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
cleanupConnID(connID)
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
return fmt.Errorf("resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)

View File

@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.PacketConn.WriteTo(b, addr)
}
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn)
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.UDPConn.WriteTo(b, addr)
}
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn)

View File

@@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
return fmt.Errorf("create connection: %w", err)
}
return nil
}

View File

@@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
return fmt.Errorf("create connection: %w", err)
}
return nil
}