mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-26 19:12:37 -04:00
Compare commits
7 Commits
vnc-server
...
reduce-emb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e81ce81494 | ||
|
|
eb3aa96257 | ||
|
|
064ec1c832 | ||
|
|
75e408f51c | ||
|
|
5a89e6621b | ||
|
|
06dfa9d4a5 | ||
|
|
45d9ee52c0 |
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -469,6 +470,55 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// WGTuning bundles runtime-adjustable WireGuard knobs exposed by the embed
|
||||
// client. Nil fields are left unchanged; set a non-nil pointer to apply.
|
||||
type WGTuning struct {
|
||||
// PreallocatedBuffersPerPool caps each per-Device WaitPool.
|
||||
// Zero means "unbounded" (no cap). Live-tunable only if the underlying
|
||||
// Device was originally created with a nonzero cap.
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
}
|
||||
|
||||
// SetWGTuning applies the given tuning to this client's live Device.
|
||||
// Startup-only knobs (batch size) must be set via the package-level
|
||||
// setters before Start.
|
||||
func (c *Client) SetWGTuning(t WGTuning) error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetWGTuning(internal.WGTuning{
|
||||
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||
})
|
||||
}
|
||||
|
||||
// SetWGDefaultPreallocatedBuffersPerPool sets the default WaitPool cap
|
||||
// applied to Devices created after this call. Zero disables the cap.
|
||||
// Existing Devices are unaffected; use Client.SetWGTuning for that.
|
||||
func SetWGDefaultPreallocatedBuffersPerPool(n uint32) {
|
||||
wgdevice.SetPreallocatedBuffersPerPool(n)
|
||||
}
|
||||
|
||||
// WGDefaultPreallocatedBuffersPerPool returns the current default WaitPool
|
||||
// cap applied to newly-created Devices.
|
||||
func WGDefaultPreallocatedBuffersPerPool() uint32 {
|
||||
return wgdevice.PreallocatedBuffersPerPool
|
||||
}
|
||||
|
||||
// SetWGDefaultMaxBatchSize sets the default per-Device batch size applied
|
||||
// to Devices created after this call. Zero means "use the bind+tun default"
|
||||
// (NOT unlimited). Must be called before Start to take effect for a new
|
||||
// Client.
|
||||
func SetWGDefaultMaxBatchSize(n uint32) {
|
||||
wgdevice.SetMaxBatchSizeOverride(n)
|
||||
}
|
||||
|
||||
// WGDefaultMaxBatchSize returns the current default batch-size override.
|
||||
// Zero means "no override".
|
||||
func WGDefaultMaxBatchSize() uint32 {
|
||||
return wgdevice.MaxBatchSizeOverride
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
|
||||
11
client/firewall/firewalld/firewalld.go
Normal file
11
client/firewall/firewalld/firewalld.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||
// versions, which returns EPERM to any other process that tries to insert
|
||||
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||
// itself add the accept rules to its own chains by trusting the interface.
|
||||
package firewalld
|
||||
|
||||
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||
// should bypass firewalld filtering.
|
||||
const TrustedZone = "trusted"
|
||||
260
client/firewall/firewalld/firewalld_linux.go
Normal file
260
client/firewall/firewalld/firewalld_linux.go
Normal file
@@ -0,0 +1,260 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
dbusDest = "org.fedoraproject.FirewallD1"
|
||||
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||
|
||||
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||
errNotEnabled = "NOT_ENABLED"
|
||||
|
||||
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||
// A fresh context is created for each call so a slow DBus probe can't
|
||||
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||
callTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||
|
||||
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||
// Info level only for the first successful add per process; repeat adds
|
||||
// from other init paths are quieter.
|
||||
trustLogOnce sync.Once
|
||||
|
||||
parentCtxMu sync.RWMutex
|
||||
parentCtx context.Context = context.Background()
|
||||
)
|
||||
|
||||
// SetParentContext installs a parent context whose cancellation aborts any
|
||||
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||
// during engine shutdown when the engine context is already cancelled.
|
||||
func SetParentContext(ctx context.Context) {
|
||||
parentCtxMu.Lock()
|
||||
parentCtx = ctx
|
||||
parentCtxMu.Unlock()
|
||||
}
|
||||
|
||||
func getParentContext() context.Context {
|
||||
parentCtxMu.RLock()
|
||||
defer parentCtxMu.RUnlock()
|
||||
return parentCtx
|
||||
}
|
||||
|
||||
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||
// running. It is idempotent and best-effort: errors are returned so callers
|
||||
// can log, but a non-running firewalld is not an error. Only the first
|
||||
// successful call per process logs at Info. Respects the parent context set
|
||||
// via SetParentContext so startup-time cancellation unblocks it.
|
||||
func TrustInterface(iface string) error {
|
||||
parent := getParentContext()
|
||||
if !isRunning(parent) {
|
||||
return nil
|
||||
}
|
||||
if err := addTrusted(parent, iface); err != nil {
|
||||
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
trustLogOnce.Do(func() {
|
||||
log.Infof("added %s to firewalld trusted zone", iface)
|
||||
})
|
||||
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||
// during shutdown after the engine context has been cancelled.
|
||||
func UntrustInterface(iface string) error {
|
||||
if !isRunning(context.Background()) {
|
||||
return nil
|
||||
}
|
||||
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(parent, callTimeout)
|
||||
}
|
||||
|
||||
func isRunning(parent context.Context) bool {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
ok, err := isRunningDBus(ctx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return ok
|
||||
}
|
||||
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return isRunningCLI(ctx)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := addDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return addCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func removeTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := removeDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return removeCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
var zone string
|
||||
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isRunningCLI(ctx context.Context) bool {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return false
|
||||
}
|
||||
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||
}
|
||||
|
||||
func addDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||
if move.Err != nil {
|
||||
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func removeDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func addCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
// --change-interface (no --permanent) binds the interface for the
|
||||
// current runtime only; we do not want membership to persist across
|
||||
// reboots because netbird re-asserts it on every startup.
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(string(out))
|
||||
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbusErrContains(err error, code string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var de dbus.Error
|
||||
if errors.As(err, &de) {
|
||||
for _, b := range de.Body {
|
||||
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Contains(err.Error(), code)
|
||||
}
|
||||
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
func TestDBusErrContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
code string
|
||||
want bool
|
||||
}{
|
||||
{"nil error", nil, errZoneAlreadySet, false},
|
||||
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||
{
|
||||
"dbus.Error body match",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||
errZoneAlreadySet,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"dbus.Error body miss",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||
errAlreadyEnabled,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"dbus.Error non-string body falls back to Error()",
|
||||
dbus.Error{Name: "x", Body: []any{123}},
|
||||
"x",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := dbusErrContains(tc.err, tc.code)
|
||||
if got != tc.want {
|
||||
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
25
client/firewall/firewalld/firewalld_other.go
Normal file
25
client/firewall/firewalld/firewalld_other.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import "context"
|
||||
|
||||
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func SetParentContext(context.Context) {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
}
|
||||
|
||||
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func TrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func UntrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||
// interface in firewalld's trusted zone without a corresponding Close.
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// attempt to delete state only if all other operations succeeded
|
||||
if merr == nil {
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
@@ -40,6 +41,8 @@ const (
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
}
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AllowNetbird()
|
||||
}
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() wgaddr.Address
|
||||
GetWGDevice() *wgdevice.Device
|
||||
|
||||
@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
NameFunc func() string
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() wgaddr.Address
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Name() string {
|
||||
if i.NameFunc == nil {
|
||||
return "wgtest"
|
||||
}
|
||||
return i.NameFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||
)
|
||||
|
||||
type resolvConf struct {
|
||||
|
||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
osManager, reason, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
||||
}
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||
resolved := isSystemdResolvedRunning()
|
||||
nss := isLibnssResolveUsed()
|
||||
stub := checkStub()
|
||||
|
||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||
// that go through nss-resolve, and in foreign mode they can loop back
|
||||
// through resolved as an upstream.
|
||||
if resolved && (nss || stub) {
|
||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||
}
|
||||
|
||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
if reason != "" {
|
||||
return mgr, reason, nil
|
||||
}
|
||||
|
||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||
if len(rejected) > 0 {
|
||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||
}
|
||||
return fileManager, fallback, nil
|
||||
}
|
||||
|
||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||
// matching manager. If reason is empty the caller should pick file mode and
|
||||
// use rejected for diagnostics.
|
||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
if cerr := file.Close(); cerr != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
var rejected []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
continue
|
||||
}
|
||||
if text[0] != '#' {
|
||||
return fileManager, nil
|
||||
break
|
||||
}
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||
return mgr, reason, nil, nil
|
||||
} else if rej != "" {
|
||||
rejected = append(rejected, rej)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
return fileManager, nil
|
||||
return 0, "", rejected, nil
|
||||
}
|
||||
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||
}
|
||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||
}
|
||||
return resolvConfManager, "resolvconf header detected", ""
|
||||
}
|
||||
return 0, "", ""
|
||||
}
|
||||
|
||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||
// into file mode while resolved is active.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -139,3 +178,36 @@ func checkStub() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||
func isLibnssResolveUsed() bool {
|
||||
bs, err := os.ReadFile(nsswitchConfPath)
|
||||
if err != nil {
|
||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||
return false
|
||||
}
|
||||
return parseNsswitchResolveAhead(bs)
|
||||
}
|
||||
|
||||
func parseNsswitchResolveAhead(data []byte) bool {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||
continue
|
||||
}
|
||||
for _, module := range fields[1:] {
|
||||
switch module {
|
||||
case "dns":
|
||||
return false
|
||||
case "resolve":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
76
client/internal/dns/host_unix_test.go
Normal file
76
client/internal/dns/host_unix_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "resolve before dns with action token",
|
||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "dns before resolve",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "debian default with only dns",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "neither resolve nor dns",
|
||||
in: "hosts: files myhostname\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no hosts line",
|
||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "comments and blank lines ignored",
|
||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "trailing inline comment",
|
||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "hosts token must be the first field",
|
||||
in: " hosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other db line mentioning resolve is ignored",
|
||||
in: "networks: resolve\nhosts: dns\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only resolve, no dns",
|
||||
in: "hosts: files resolve\n",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -570,7 +571,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start()
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
@@ -604,6 +605,8 @@ func (e *Engine) createFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
firewalld.SetParentContext(e.ctx)
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil {
|
||||
@@ -1871,6 +1874,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
// WGTuning bundles runtime-adjustable WireGuard pool knobs.
|
||||
// See Engine.SetWGTuning. Nil fields are ignored.
|
||||
type WGTuning struct {
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
}
|
||||
|
||||
// SetWGTuning applies the given tuning to this engine's live Device.
|
||||
func (e *Engine) SetWGTuning(t WGTuning) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
if e.wgInterface == nil {
|
||||
return fmt.Errorf("wg interface not initialized")
|
||||
}
|
||||
dev := e.wgInterface.GetWGDevice()
|
||||
if dev == nil {
|
||||
return fmt.Errorf("wg device not initialized")
|
||||
}
|
||||
if t.PreallocatedBuffersPerPool != nil {
|
||||
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
forceRelay := IsForceRelayed()
|
||||
if !forceRelay {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
if !forceRelay {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.Close()
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
conn.dumpState.RemoteCandidate()
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
||||
//
|
||||
// The result is a tri-state:
|
||||
// - ConnStatusConnected: all available transports are up
|
||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
||||
// - ConnStatusDisconnected: no working transport
|
||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
defer func() {
|
||||
if !connected {
|
||||
if status == guard.ConnStatusDisconnected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
iceWorkerCreated := conn.workerICE != nil
|
||||
|
||||
var iceInProgress bool
|
||||
if iceWorkerCreated {
|
||||
iceInProgress = conn.workerICE.InProgress()
|
||||
}
|
||||
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return evalConnStatus(connStatusInputs{
|
||||
forceRelay: IsForceRelayed(),
|
||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
||||
iceWorkerCreated: iceWorkerCreated,
|
||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
||||
iceInProgress: iceInProgress,
|
||||
})
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
||||
|
||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
||||
if in.forceRelay {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
||||
// relay is the only possible transport.
|
||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
||||
|
||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
||||
|
||||
switch {
|
||||
case iceUp && relayOK:
|
||||
return guard.ConnStatusConnected
|
||||
case relayUsedAndUp:
|
||||
// Relay is up but ICE is down — partially connected.
|
||||
return guard.ConnStatusPartiallyConnected
|
||||
default:
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
}
|
||||
|
||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
||||
if connected {
|
||||
return guard.ConnStatusConnected
|
||||
}
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
|
||||
@@ -13,6 +13,20 @@ const (
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
||||
// without constructing full Worker/Handshaker objects.
|
||||
type connStatusInputs struct {
|
||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
201
client/internal/peer/conn_status_eval_test.go
Normal file
201
client/internal/peer/conn_status_eval_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
)
|
||||
|
||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "force relay, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: false,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE worker not yet created, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: false,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer does not use relay",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: false,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
||||
base := connStatusInputs{
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutator func(*connStatusInputs)
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "ICE connected, relay connected, peer uses relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE connected, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE InProgress only, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up, peer uses relay -> partial",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusPartiallyConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
||||
// falls into default: Disconnected.
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
in := base
|
||||
tc.mutator(&in)
|
||||
if got := evalConnStatus(in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
func IsForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -8,7 +8,19 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type isConnectedFunc func() bool
|
||||
// ConnStatus represents the connection state as seen by the guard.
|
||||
type ConnStatus int
|
||||
|
||||
const (
|
||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
||||
ConnStatusDisconnected ConnStatus = iota
|
||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
||||
ConnStatusPartiallyConnected
|
||||
// ConnStatusConnected means all required connections are established.
|
||||
ConnStatusConnected
|
||||
)
|
||||
|
||||
type connStatusFunc func() ConnStatus
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
log *log.Entry
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
isConnectedOnAllWay connStatusFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan struct{}
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
log: log,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
||||
//
|
||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
||||
// - Connected: no action, the peer is fully reachable.
|
||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
||||
//
|
||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
iceState := &iceRetryState{log: g.log}
|
||||
defer iceState.reset()
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
case <-tickerChannel:
|
||||
switch g.isConnectedOnAllWay() {
|
||||
case ConnStatusConnected:
|
||||
// all good, nothing to do
|
||||
case ConnStatusDisconnected:
|
||||
callback()
|
||||
case ConnStatusPartiallyConnected:
|
||||
if iceState.shouldRetry() {
|
||||
callback()
|
||||
} else {
|
||||
iceState.enterHourlyMode()
|
||||
ticker.Stop()
|
||||
tickerChannel = iceState.hourlyC()
|
||||
}
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
callback()
|
||||
}
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
|
||||
61
client/internal/peer/guard/ice_retry_state.go
Normal file
61
client/internal/peer/guard/ice_retry_state.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
||||
maxICERetries = 3
|
||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
||||
iceRetryInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
||||
type iceRetryState struct {
|
||||
log *log.Entry
|
||||
retries int
|
||||
hourly *time.Ticker
|
||||
}
|
||||
|
||||
func (s *iceRetryState) reset() {
|
||||
s.retries = 0
|
||||
if s.hourly != nil {
|
||||
s.hourly.Stop()
|
||||
s.hourly = nil
|
||||
}
|
||||
}
|
||||
|
||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
||||
func (s *iceRetryState) shouldRetry() bool {
|
||||
if s.hourly != nil {
|
||||
s.log.Debugf("hourly ICE retry attempt")
|
||||
return true
|
||||
}
|
||||
|
||||
s.retries++
|
||||
if s.retries <= maxICERetries {
|
||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
||||
func (s *iceRetryState) enterHourlyMode() {
|
||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
||||
s.hourly = time.NewTicker(iceRetryInterval)
|
||||
}
|
||||
|
||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
||||
if s.hourly == nil {
|
||||
return nil
|
||||
}
|
||||
return s.hourly.C
|
||||
}
|
||||
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTestRetryState() *iceRetryState {
|
||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
||||
}
|
||||
|
||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 0; i < maxICERetries; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
if s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if s.hourlyC() == nil {
|
||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
|
||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
s.enterHourlyMode()
|
||||
|
||||
s.reset()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
||||
}
|
||||
if s.retries != 0 {
|
||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
||||
}
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.reset()
|
||||
s.reset() // second call must not panic or re-stop a nil ticker
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC non-nil after double reset")
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start() {
|
||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
if !disableICEMonitor {
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -43,6 +44,10 @@ type OfferAnswer struct {
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
|
||||
func (o *OfferAnswer) hasICECredentials() bool {
|
||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
@@ -59,6 +64,10 @@ type Handshaker struct {
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
||||
remoteICESupported atomic.Bool
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
@@ -66,7 +75,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
return &Handshaker{
|
||||
h := &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
// assume remote supports ICE until we learn otherwise from received offers
|
||||
h.remoteICESupported.Store(ice != nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handshaker) RemoteICESupported() bool {
|
||||
return h.remoteICESupported.Load()
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
@@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error {
|
||||
}
|
||||
|
||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
SessionID: &sid,
|
||||
}
|
||||
|
||||
if h.ice != nil && h.RemoteICESupported() {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
||||
answer.SessionID = &sid
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
@@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
||||
hasICE := offer.hasICECredentials()
|
||||
prev := h.remoteICESupported.Swap(hasICE)
|
||||
if prev != hasICE {
|
||||
if hasICE {
|
||||
h.log.Infof("remote peer started sending ICE credentials")
|
||||
} else {
|
||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
||||
if h.ice != nil {
|
||||
h.ice.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
var sessionIDBytes []byte
|
||||
if offerAnswer.SessionID != nil {
|
||||
var err error
|
||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
|
||||
@@ -119,6 +119,8 @@ server:
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
# trustedHTTPProxies: []
|
||||
# trustedHTTPProxiesCount: 0
|
||||
# trustedPeers: []
|
||||
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
|
||||
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
|
||||
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
|
||||
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
|
||||
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.
|
||||
|
||||
4
go.mod
4
go.mod
@@ -314,7 +314,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
@@ -323,3 +323,5 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
|
||||
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
||||
|
||||
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
|
||||
|
||||
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0
|
||||
|
||||
8
go.sum
8
go.sum
@@ -400,8 +400,6 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
|
||||
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
@@ -449,6 +447,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
|
||||
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
|
||||
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
|
||||
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
@@ -459,8 +459,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58 h1:6REpBYpJBLTTgqCcLGpTqvRDoEoLbA5r2nAXqMd2La0=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
_ "embed"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
var isUpdate = policy.ID != ""
|
||||
var updateAccountPeers bool
|
||||
var action = activity.PolicyAdded
|
||||
var unchanged bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
||||
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
saveFunc := transaction.CreatePolicy
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
if policy.Equal(existingPolicy) {
|
||||
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||
unchanged = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = saveFunc(ctx, policy); err != nil {
|
||||
return err
|
||||
action = activity.PolicyUpdated
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
@@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if unchanged {
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
@@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
@@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
// validatePolicy validates the policy and its rules. For updates it returns
|
||||
// the existing policy loaded from the store so callers can avoid a second read.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
|
||||
var existingPolicy *types.Policy
|
||||
if policy.ID != "" {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
var err error
|
||||
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: Refactor to support multiple rules per policy
|
||||
@@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, rule := range policy.Rules {
|
||||
@@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||
}
|
||||
|
||||
return nil
|
||||
return existingPolicy, nil
|
||||
}
|
||||
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
|
||||
@@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy {
|
||||
return c
|
||||
}
|
||||
|
||||
func (p *Policy) Equal(other *Policy) bool {
|
||||
if p == nil || other == nil {
|
||||
return p == other
|
||||
}
|
||||
|
||||
if p.ID != other.ID ||
|
||||
p.AccountID != other.AccountID ||
|
||||
p.Name != other.Name ||
|
||||
p.Description != other.Description ||
|
||||
p.Enabled != other.Enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(p.Rules) != len(other.Rules) {
|
||||
return false
|
||||
}
|
||||
|
||||
otherRules := make(map[string]*PolicyRule, len(other.Rules))
|
||||
for _, r := range other.Rules {
|
||||
otherRules[r.ID] = r
|
||||
}
|
||||
for _, r := range p.Rules {
|
||||
otherRule, ok := otherRules[r.ID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !r.Equal(otherRule) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to this policy
|
||||
func (p *Policy) EventMeta() map[string]any {
|
||||
return map[string]any{"name": p.Name}
|
||||
|
||||
193
management/server/types/policy_test.go
Normal file
193
management/server/types/policy_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentRules(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentRuleCount(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
{ID: "r2", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc2", "pc3"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentPostureChecks(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc3"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentScalarFields(t *testing.T) {
|
||||
base := Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Description: "desc",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
other := base
|
||||
other.Name = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Enabled = false
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Description = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_NilCases(t *testing.T) {
|
||||
var a *Policy
|
||||
var b *Policy
|
||||
assert.True(t, a.Equal(b))
|
||||
|
||||
a = &Policy{ID: "pol1"}
|
||||
assert.False(t, a.Equal(nil))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_RulesMismatchByID(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r2", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_FullScenario(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "Web Access",
|
||||
Description: "Allow web access",
|
||||
Enabled: true,
|
||||
SourcePostureChecks: []string{"pc2", "pc1"},
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "r1",
|
||||
PolicyID: "pol1",
|
||||
Name: "HTTP",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"g2", "g1"},
|
||||
Destinations: []string{"g4", "g3"},
|
||||
Ports: []string{"443", "80", "8080"},
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 8000, End: 9000},
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "Web Access",
|
||||
Description: "Allow web access",
|
||||
Enabled: true,
|
||||
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "r1",
|
||||
PolicyID: "pol1",
|
||||
Name: "HTTP",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"g1", "g2"},
|
||||
Destinations: []string{"g3", "g4"},
|
||||
Ports: []string{"80", "8080", "443"},
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 8000, End: 9000},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func (pm *PolicyRule) Equal(other *PolicyRule) bool {
|
||||
if pm == nil || other == nil {
|
||||
return pm == other
|
||||
}
|
||||
|
||||
if pm.ID != other.ID ||
|
||||
pm.PolicyID != other.PolicyID ||
|
||||
pm.Name != other.Name ||
|
||||
pm.Description != other.Description ||
|
||||
pm.Enabled != other.Enabled ||
|
||||
pm.Action != other.Action ||
|
||||
pm.Bidirectional != other.Bidirectional ||
|
||||
pm.Protocol != other.Protocol ||
|
||||
pm.SourceResource != other.SourceResource ||
|
||||
pm.DestinationResource != other.DestinationResource ||
|
||||
pm.AuthorizedUser != other.AuthorizedUser {
|
||||
return false
|
||||
}
|
||||
|
||||
if !stringSlicesEqualUnordered(pm.Sources, other.Sources) {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(pm.Ports, other.Ports) {
|
||||
return false
|
||||
}
|
||||
if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) {
|
||||
return false
|
||||
}
|
||||
if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSlicesEqualUnordered(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
sorted1 := make([]string, len(a))
|
||||
sorted2 := make([]string, len(b))
|
||||
copy(sorted1, a)
|
||||
copy(sorted2, b)
|
||||
slices.Sort(sorted1)
|
||||
slices.Sort(sorted2)
|
||||
return slices.Equal(sorted1, sorted2)
|
||||
}
|
||||
|
||||
func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
cmp := func(x, y RulePortRange) int {
|
||||
if x.Start != y.Start {
|
||||
if x.Start < y.Start {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if x.End != y.End {
|
||||
if x.End < y.End {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
sorted1 := make([]RulePortRange, len(a))
|
||||
sorted2 := make([]RulePortRange, len(b))
|
||||
copy(sorted1, a)
|
||||
copy(sorted2, b)
|
||||
slices.SortFunc(sorted1, cmp)
|
||||
slices.SortFunc(sorted2, cmp)
|
||||
return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool {
|
||||
return x.Start == y.Start && x.End == y.End
|
||||
})
|
||||
}
|
||||
|
||||
func authorizedGroupsEqual(a, b map[string][]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, va := range a {
|
||||
vb, ok := b[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(va, vb) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
194
management/server/types/policyrule_test.go
Normal file
194
management/server/types/policyrule_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "80", "22"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"22", "443", "80"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentPorts(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "80"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "22"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g2", "g3"},
|
||||
Destinations: []string{"g4", "g5"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g3", "g1", "g2"},
|
||||
Destinations: []string{"g5", "g4"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentSources(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g2"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g3"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 8000, End: 9000},
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 8000, End: 9000},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 443},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u1", "u2", "u3"},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u3", "u1", "u2"},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u1"},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g2": {"u1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) {
|
||||
base := PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Name: "test",
|
||||
Description: "desc",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
}
|
||||
|
||||
other := base
|
||||
other.Name = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Enabled = false
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Action = PolicyTrafficActionDrop
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Protocol = PolicyRuleProtocolUDP
|
||||
assert.False(t, base.Equal(&other))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_NilCases(t *testing.T) {
|
||||
var a *PolicyRule
|
||||
var b *PolicyRule
|
||||
assert.True(t, a.Equal(b))
|
||||
|
||||
a = &PolicyRule{ID: "rule1"}
|
||||
assert.False(t, a.Equal(nil))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_EmptySlices(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{},
|
||||
Sources: nil,
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: nil,
|
||||
Sources: []string{},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
@@ -99,6 +99,35 @@ var debugStopCmd = &cobra.Command{
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugWGTuneCmd = &cobra.Command{
|
||||
Use: "wgtune",
|
||||
Short: "Inspect and live-tune WireGuard pool settings",
|
||||
}
|
||||
|
||||
var debugWGTuneGetCmd = &cobra.Command{
|
||||
Use: "get",
|
||||
Short: "Show pool cap and batch size defaults",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: runDebugWGTuneGet,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugWGTuneSetCmd = &cobra.Command{
|
||||
Use: "set <pool-cap>",
|
||||
Short: "Set the pool cap (new and live clients)",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runDebugWGTuneSet,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugRuntimeCmd = &cobra.Command{
|
||||
Use: "runtime",
|
||||
Short: "Show runtime stats (heap, goroutines, RSS)",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: runDebugRuntime,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address")
|
||||
debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format")
|
||||
@@ -119,6 +148,10 @@ func init() {
|
||||
debugCmd.AddCommand(debugLogCmd)
|
||||
debugCmd.AddCommand(debugStartCmd)
|
||||
debugCmd.AddCommand(debugStopCmd)
|
||||
debugWGTuneCmd.AddCommand(debugWGTuneGetCmd)
|
||||
debugWGTuneCmd.AddCommand(debugWGTuneSetCmd)
|
||||
debugCmd.AddCommand(debugWGTuneCmd)
|
||||
debugCmd.AddCommand(debugRuntimeCmd)
|
||||
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
}
|
||||
@@ -171,3 +204,19 @@ func runDebugStart(cmd *cobra.Command, args []string) error {
|
||||
func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugWGTuneGet(cmd *cobra.Command, _ []string) error {
|
||||
return getDebugClient(cmd).WGTuneGet(cmd.Context())
|
||||
}
|
||||
|
||||
func runDebugWGTuneSet(cmd *cobra.Command, args []string) error {
|
||||
n, err := strconv.ParseUint(args[0], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value %q: %w", args[0], err)
|
||||
}
|
||||
return getDebugClient(cmd).WGTuneSet(cmd.Context(), uint32(n))
|
||||
}
|
||||
|
||||
func runDebugRuntime(cmd *cobra.Command, _ []string) error {
|
||||
return getDebugClient(cmd).Runtime(cmd.Context())
|
||||
}
|
||||
|
||||
@@ -15,11 +15,22 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// envWGPreallocatedBuffers caps the per-Device WireGuard buffer pool
|
||||
// size. Zero (unset) keeps the uncapped upstream default.
|
||||
envWGPreallocatedBuffers = "NB_WG_PREALLOCATED_BUFFERS"
|
||||
// envWGMaxBatchSize overrides the per-Device WireGuard batch size,
|
||||
// which controls how many buffers each receive/TUN worker eagerly
|
||||
// allocates. Zero (unset) keeps the bind+tun default.
|
||||
envWGMaxBatchSize = "NB_WG_MAX_BATCH_SIZE"
|
||||
)
|
||||
|
||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||
|
||||
// envProxyToken is the environment variable name for the proxy access token.
|
||||
@@ -145,6 +156,42 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
logger.Infof("configured log level: %s", level)
|
||||
|
||||
var wgPool, wgBatch uint64
|
||||
if raw := os.Getenv(envWGPreallocatedBuffers); raw != "" {
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid %s %q: %w", envWGPreallocatedBuffers, raw, err)
|
||||
}
|
||||
wgPool = n
|
||||
embed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n))
|
||||
logger.Infof("wireguard preallocated buffers per pool: %d", n)
|
||||
}
|
||||
if raw := os.Getenv(envWGMaxBatchSize); raw != "" {
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid %s %q: %w", envWGMaxBatchSize, raw, err)
|
||||
}
|
||||
wgBatch = n
|
||||
embed.SetWGDefaultMaxBatchSize(uint32(n))
|
||||
logger.Infof("wireguard max batch size override: %d", n)
|
||||
}
|
||||
if wgPool > 0 {
|
||||
// Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus
|
||||
// RoutineReadFromTUN eagerly reserves `batch` message buffers for
|
||||
// the lifetime of the Device. A pool cap below that floor blocks
|
||||
// the receive pipeline at startup.
|
||||
batch := wgBatch
|
||||
if batch == 0 {
|
||||
batch = 128
|
||||
}
|
||||
const recvGoroutines = 4
|
||||
floor := batch * recvGoroutines
|
||||
if wgPool < floor {
|
||||
logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock",
|
||||
envWGPreallocatedBuffers, wgPool, floor, batch)
|
||||
}
|
||||
}
|
||||
|
||||
switch forwardedProto {
|
||||
case "auto", "http", "https":
|
||||
default:
|
||||
|
||||
@@ -272,6 +272,74 @@ func (c *Client) printLogLevelResult(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
// WGTuneGet fetches the current WireGuard pool cap.
|
||||
func (c *Client) WGTuneGet(ctx context.Context) error {
|
||||
return c.fetchAndPrint(ctx, "/debug/wgtune", c.printWGTuneGet)
|
||||
}
|
||||
|
||||
// WGTuneSet updates the WireGuard pool cap on the global default and all live clients.
|
||||
func (c *Client) WGTuneSet(ctx context.Context, value uint32) error {
|
||||
path := fmt.Sprintf("/debug/wgtune?value=%d", value)
|
||||
return c.fetchAndPrint(ctx, path, c.printWGTuneSet)
|
||||
}
|
||||
|
||||
func (c *Client) printWGTuneGet(data map[string]any) {
|
||||
def, _ := data["default"].(float64)
|
||||
batch, _ := data["batch_size"].(float64)
|
||||
_, _ = fmt.Fprintf(c.out, "Default: %d\n", uint32(def))
|
||||
_, _ = fmt.Fprintf(c.out, "Batch size: %d (0 = unset)\n", uint32(batch))
|
||||
}
|
||||
|
||||
func (c *Client) printWGTuneSet(data map[string]any) {
|
||||
if errMsg, ok := data["error"].(string); ok && errMsg != "" {
|
||||
c.printError(data)
|
||||
return
|
||||
}
|
||||
def, _ := data["default"].(float64)
|
||||
applied, _ := data["applied"].(float64)
|
||||
_, _ = fmt.Fprintf(c.out, "Default set to: %d\n", uint32(def))
|
||||
_, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied))
|
||||
if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed:")
|
||||
for k, v := range failed {
|
||||
_, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Runtime fetches runtime stats (heap, goroutines, RSS).
|
||||
func (c *Client) Runtime(ctx context.Context) error {
|
||||
return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime)
|
||||
}
|
||||
|
||||
func (c *Client) printRuntime(data map[string]any) {
|
||||
i := func(k string) uint64 {
|
||||
v, _ := data[k].(float64)
|
||||
return uint64(v)
|
||||
}
|
||||
mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) }
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||
_, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs")))
|
||||
_, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines"))
|
||||
_, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects"))
|
||||
_, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns")))
|
||||
_, _ = fmt.Fprintln(c.out, "Heap:")
|
||||
_, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc")))
|
||||
_, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse")))
|
||||
_, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle")))
|
||||
_, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released")))
|
||||
_, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys")))
|
||||
_, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys")))
|
||||
if _, ok := data["vm_rss"]; ok {
|
||||
_, _ = fmt.Fprintln(c.out, "Process:")
|
||||
_, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss")))
|
||||
_, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size")))
|
||||
_, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data")))
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started"))
|
||||
}
|
||||
|
||||
// StartClient starts a specific client.
|
||||
func (c *Client) StartClient(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"html/template"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -58,6 +60,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A
|
||||
type clientProvider interface {
|
||||
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
|
||||
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||
ListClientsForStartup() map[types.AccountID]*nbembed.Client
|
||||
}
|
||||
|
||||
// healthChecker provides health probe state.
|
||||
@@ -139,6 +142,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleListClients(w, r, wantJSON)
|
||||
case "/debug/health":
|
||||
h.handleHealth(w, r, wantJSON)
|
||||
case "/debug/wgtune":
|
||||
h.handleWGTune(w, r)
|
||||
case "/debug/runtime":
|
||||
h.handleRuntime(w, r)
|
||||
default:
|
||||
if h.handleClientRoutes(w, r, path, wantJSON) {
|
||||
return
|
||||
@@ -230,10 +237,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
}
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
clientsJSON := make([]map[string]any, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
clientsJSON = append(clientsJSON, map[string]any{
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
@@ -242,7 +249,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
resp := map[string]any{
|
||||
"version": version.NetbirdVersion(),
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
@@ -320,10 +327,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
sortedIDs := sortedAccountIDs(clients)
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
clientsJSON := make([]map[string]any, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
clientsJSON = append(clientsJSON, map[string]any{
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
@@ -332,7 +339,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"clients": clientsJSON,
|
||||
@@ -418,7 +425,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
|
||||
})
|
||||
|
||||
if wantJSON {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"account_id": accountID,
|
||||
"status": overview.FullDetailSummary(),
|
||||
})
|
||||
@@ -501,20 +508,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco
|
||||
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
host := r.URL.Query().Get("host")
|
||||
portStr := r.URL.Query().Get("port")
|
||||
if host == "" || portStr == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
|
||||
h.writeJSON(w, map[string]any{"error": "host and port parameters required"})
|
||||
return
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
|
||||
h.writeJSON(w, map[string]any{"error": "invalid port"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -533,7 +540,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"host": host,
|
||||
"port": port,
|
||||
@@ -546,7 +553,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"host": host,
|
||||
"port": port,
|
||||
@@ -558,25 +565,25 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
level := r.URL.Query().Get("level")
|
||||
if level == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||
h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.SetLogLevel(level); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"level": level,
|
||||
})
|
||||
@@ -587,7 +594,7 @@ const clientActionTimeout = 30 * time.Second
|
||||
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -595,14 +602,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "client started",
|
||||
})
|
||||
@@ -611,7 +618,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
|
||||
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -619,19 +626,136 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
|
||||
defer cancel()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "client stopped",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleWGTune(w http.ResponseWriter, r *http.Request) {
|
||||
values, ok := r.URL.Query()["value"]
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]any{
|
||||
"default": nbembed.WGDefaultPreallocatedBuffersPerPool(),
|
||||
"batch_size": nbembed.WGDefaultMaxBatchSize(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(values) == 0 || values[0] == "" {
|
||||
http.Error(w, "value parameter must not be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
raw := values[0]
|
||||
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
nbembed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n))
|
||||
|
||||
applied := 0
|
||||
failed := map[string]string{}
|
||||
for accountID, client := range h.provider.ListClientsForStartup() {
|
||||
capN := uint32(n)
|
||||
if err := client.SetWGTuning(nbembed.WGTuning{PreallocatedBuffersPerPool: &capN}); err != nil {
|
||||
failed[string(accountID)] = err.Error()
|
||||
continue
|
||||
}
|
||||
applied++
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"success": true,
|
||||
"default": uint32(n),
|
||||
"batch_size": nbembed.WGDefaultMaxBatchSize(),
|
||||
"applied": applied,
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
resp["failed"] = failed
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
// handleRuntime returns cheap runtime and process stats. Safe to hit on a
|
||||
// running proxy; does not read pprof profiles.
|
||||
func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
clients := h.provider.ListClientsForDebug()
|
||||
started := 0
|
||||
for _, c := range clients {
|
||||
if c.HasClient {
|
||||
started++
|
||||
}
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"num_cpu": runtime.NumCPU(),
|
||||
"gomaxprocs": runtime.GOMAXPROCS(0),
|
||||
"go_version": runtime.Version(),
|
||||
"heap_alloc": m.HeapAlloc,
|
||||
"heap_inuse": m.HeapInuse,
|
||||
"heap_idle": m.HeapIdle,
|
||||
"heap_released": m.HeapReleased,
|
||||
"heap_sys": m.HeapSys,
|
||||
"sys": m.Sys,
|
||||
"live_objects": m.Mallocs - m.Frees,
|
||||
"num_gc": m.NumGC,
|
||||
"pause_total_ns": m.PauseTotalNs,
|
||||
"clients": len(clients),
|
||||
"started": started,
|
||||
}
|
||||
|
||||
if proc := readProcStatus(); proc != nil {
|
||||
resp["vm_rss"] = proc["VmRSS"]
|
||||
resp["vm_size"] = proc["VmSize"]
|
||||
resp["vm_data"] = proc["VmData"]
|
||||
}
|
||||
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
// readProcStatus parses /proc/self/status on Linux and returns size fields
|
||||
// in bytes. Returns nil on non-Linux or read failure.
|
||||
func readProcStatus() map[string]uint64 {
|
||||
raw, err := os.ReadFile("/proc/self/status")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
out := map[string]uint64{}
|
||||
for _, line := range strings.Split(string(raw), "\n") {
|
||||
k, v, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if k != "VmRSS" && k != "VmSize" && k != "VmData" {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(v)
|
||||
if len(fields) < 1 {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.ParseUint(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Values are reported in kB.
|
||||
out[k] = n * 1024
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) {
|
||||
if !wantJSON {
|
||||
http.Redirect(w, r, "/debug", http.StatusSeeOther)
|
||||
@@ -685,7 +809,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
tmpl := h.getTemplates()
|
||||
if tmpl == nil {
|
||||
@@ -698,7 +822,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
Reference in New Issue
Block a user